├── .gitignore ├── README.md ├── azure-pipeline-pytest.yml ├── azure-pipeline-unittest.yml └── src ├── etl ├── __init__.py └── etl.py └── tests ├── __init__.py ├── test-requirements.txt ├── test_etl_pytest.py └── test_etl_unittest.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Xavier-Az-Learn-PySpark-UnitTests -------------------------------------------------------------------------------- /azure-pipeline-pytest.yml: -------------------------------------------------------------------------------- 1 | name: Py Spark Unit Tests 2 | 3 | pool: 4 | vmImage: ubuntu-latest 5 | 6 | stages: 7 | - stage: Tests 8 | displayName: Unit Tests using Pytest 9 | 10 | jobs: 11 | - job: 12 | displayName: PySpark Unit Tests 13 | steps: 14 | - script: | 15 | sudo apt-get update 16 | sudo apt-get install default-jdk -y 17 | pip install -r $(System.DefaultWorkingDirectory)/src/tests/test-requirements.txt 18 | pip install --upgrade pytest pytest-azurepipelines 19 | cd src && pytest -v -rf --test-run-title='Unit Tests Report' 20 | displayName: Run Unit Tests 21 | -------------------------------------------------------------------------------- /azure-pipeline-unittest.yml: -------------------------------------------------------------------------------- 1 | name: Py Spark Unit Tests 2 | 3 | 4 | pool: 5 | vmImage: ubuntu-latest 6 | 7 | 8 | stages: 9 | - stage: Tests 10 | displayName: Unit Tests 11 | 12 | jobs: 13 | - job: 14 | displayName: PySpark Unit Tests 15 | steps: 16 | - script: | 17 | sudo apt-get update 18 | sudo apt-get install default-jdk -y 19 | pip install -r $(System.DefaultWorkingDirectory)/src/tests/test-requirements.txt 20 | cd src && python -m unittest -v 21 | displayName: Run Unit Tests -------------------------------------------------------------------------------- /src/etl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xavier211192/Xavier-Az-Learn-PySpark-UnitTests/78ad2f46ea442cee77bc7869bd4e528365225df7/src/etl/__init__.py -------------------------------------------------------------------------------- /src/etl/etl.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql.types import * 2 | from pyspark.sql.functions import * 3 | 4 | 5 | def transform_data(input_df): 6 | transformed_df = (input_df.groupBy('Location',).agg(sum('ItemCount').alias('TotalItemCount'))) 7 | return transformed_df 8 | -------------------------------------------------------------------------------- /src/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xavier211192/Xavier-Az-Learn-PySpark-UnitTests/78ad2f46ea442cee77bc7869bd4e528365225df7/src/tests/__init__.py -------------------------------------------------------------------------------- /src/tests/test-requirements.txt: -------------------------------------------------------------------------------- 1 | pyspark==3.1.2 2 | -------------------------------------------------------------------------------- /src/tests/test_etl_pytest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from etl.etl import transform_data 3 | from pyspark.sql.functions import * 4 | from pyspark.sql.types import * 5 | from datetime import datetime 6 | from pyspark.sql import SparkSession 7 | 8 | 9 | @pytest.fixture(scope="session") 10 | def spark(): 11 | print("----Setup Spark Session---") 12 | spark = ( 13 | SparkSession.builder.master("local[1]") 14 | .appName("Unit-Tests") 15 | .config("spark.executor.cores", "1") 16 | .config("spark.executor.instances", "1") 17 | .config("spark.port.maxRetries", "30") 18 | .config("spark.sql.shuffle.partitions", "1") 19 | .getOrCreate() 20 | ) 21 | yield spark 22 | print("--- Tear down Spark Session---") 23 | spark.stop() 24 | 25 | 26 | @pytest.fixture(scope="session") 27 | def input_data(spark): 28 | input_schema = StructType( 29 | [ 30 | StructField("StoreID", IntegerType(), True), 31 | StructField("Location", StringType(), True), 32 | StructField("Date", StringType(), True), 33 | StructField("ItemCount", IntegerType(), True), 34 | ] 35 | ) 36 | input_data = [ 37 | (1, "Bangalore", "2021-12-01", 5), 38 | (2, "Bangalore", "2021-12-01", 3), 39 | (5, "Amsterdam", "2021-12-02", 10), 40 | (6, "Amsterdam", "2021-12-01", 1), 41 | (8, "Warsaw", "2021-12-02", 15), 42 | (7, "Warsaw", "2021-12-01", 99), 43 | ] 44 | input_df = spark.createDataFrame(data=input_data, schema=input_schema) 45 | return input_df 46 | 47 | 48 | @pytest.fixture(scope="session") 49 | def expected_data(spark): 50 | # Define an expected data frame 51 | expected_schema = StructType( 52 | [ 53 | StructField("Location", StringType(), True), 54 | StructField("TotalItemCount", IntegerType(), True), 55 | ] 56 | ) 57 | expected_data = [("Bangalore", 8), ("Warsaw", 114), ("Amsterdam", 11)] 58 | expected_df = spark.createDataFrame(data=expected_data, schema=expected_schema) 59 | return expected_df 60 | 61 | 62 | def test_etl(spark, input_data, expected_data): 63 | # Apply transforamtion on the input data frame 64 | transformed_df = transform_data(input_data) 65 | 66 | # Compare schema of transformed_df and expected_df 67 | field_list = lambda fields: (fields.name, fields.dataType, fields.nullable) 68 | fields1 = [*map(field_list, transformed_df.schema.fields)] 69 | fields2 = [*map(field_list, expected_data.schema.fields)] 70 | res = set(fields1) == set(fields2) 71 | 72 | # assert 73 | # Compare data in transformed_df and expected_df 74 | assert sorted(expected_data.collect()) == sorted(transformed_df.collect()) 75 | -------------------------------------------------------------------------------- /src/tests/test_etl_unittest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from etl.etl import transform_data 3 | from pyspark.sql.functions import * 4 | from pyspark.sql.types import * 5 | from datetime import datetime 6 | from pyspark.sql import SparkSession 7 | 8 | class SparkETLTestCase(unittest.TestCase): 9 | 10 | @classmethod 11 | def setUpClass(cls): 12 | cls.spark = (SparkSession 13 | .builder 14 | .master("local[*]") 15 | .appName("PySpark-unit-test") 16 | .config('spark.port.maxRetries', 30) 17 | .getOrCreate()) 18 | 19 | @classmethod 20 | def tearDownClass(cls): 21 | cls.spark.stop() 22 | 23 | 24 | def test_etl(self): 25 | input_schema = StructType([ 26 | StructField('StoreID', IntegerType(), True), 27 | StructField('Location', StringType(), True), 28 | StructField('Date', StringType(), True), 29 | StructField('ItemCount', IntegerType(), True) 30 | ]) 31 | input_data = [(1, "Bangalore", "2021-12-01", 5), 32 | (2,"Bangalore" ,"2021-12-01",3), 33 | (5,"Amsterdam", "2021-12-02", 10), 34 | (6,"Amsterdam", "2021-12-01", 1), 35 | (8,"Warsaw","2021-12-02", 15), 36 | (7,"Warsaw","2021-12-01",99)] 37 | input_df = self.spark.createDataFrame(data=input_data, schema=input_schema) 38 | 39 | expected_schema = StructType([ 40 | StructField('Location', StringType(), True), 41 | StructField('TotalItemCount', IntegerType(), True) 42 | ]) 43 | 44 | expected_data = [("Bangalore", 8), 45 | ("Warsaw", 114), 46 | ("Amsterdam", 11)] 47 | expected_df = self.spark.createDataFrame(data=expected_data, schema=expected_schema) 48 | 49 | #Apply transforamtion on the input data frame 50 | transformed_df = transform_data(input_df) 51 | 52 | # Compare schema of transformed_df and expected_df 53 | field_list = lambda fields: (fields.name, fields.dataType, fields.nullable) 54 | fields1 = [*map(field_list, transformed_df.schema.fields)] 55 | fields2 = [*map(field_list, expected_df.schema.fields)] 56 | res = set(fields1) == set(fields2) 57 | 58 | # assert 59 | self.assertTrue(res) 60 | # Compare data in transformed_df and expected_df 61 | self.assertEqual(sorted(expected_df.collect()), sorted(transformed_df.collect())) 62 | 63 | if __name__ == '__main__': 64 | unittest.main() --------------------------------------------------------------------------------