├── .gitignore ├── .travis.yml ├── LICENSE.txt ├── README.md ├── install-spark.sh ├── pyspark_testing ├── __init__.py ├── data │ └── National_Broadband_Data_March2012_Eng.csv.gz ├── driver.py ├── models.py └── version.py ├── requirements-test.txt ├── run_driver.sh ├── run_tests.sh ├── setup.py └── tests ├── __init__.py └── pyspark_testing ├── __init__.py ├── integration ├── __init__.py └── test_driver.py └── unit ├── __init__.py └── test_models.py /.gitignore: -------------------------------------------------------------------------------- 1 | .eggs/ 2 | pyspark_testing.egg-info/ 3 | .python-version 4 | *.pyc 5 | dist/ 6 | build/ 7 | .DS_Store 8 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "2.7" 4 | 5 | env: 6 | global: 7 | - SPARK_HOME=/tmp/spark-1.5.1-bin-hadoop2.4 8 | 9 | install: 10 | - ./install-spark.sh 11 | - "pip install -r requirements-test.txt" 12 | 13 | script: ./run_tests.sh 14 | 15 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Mike Sukmanowsky 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Testing with PySpark is a pain, so let's make it a little easier by example. 2 | 3 | This project serves as an example of some good practices to follow when 4 | developing and testing PySpark applications/driver scripts. 5 | 6 | [![Build Status](https://travis-ci.org/msukmanowsky/pyspark-testing.svg?branch=master)](https://travis-ci.org/msukmanowsky/pyspark-testing) 7 | 8 | ## Tip 1: Use Python packages 9 | 10 | Spark requires that any code your driver needs to be on the `PYTHONPATH` of 11 | the executors which launch `python` processes. This means that either every 12 | node in the cluster needs to be properly provisioned with all the required 13 | dependencies, or that code your driver needs is sent to executors via 14 | `spark-submit --py-files /path/to/myegg.egg` or `sc.addPyFile()`. 15 | 16 | For requirements that do not change often, doing a global `pip install ...` on 17 | all nodes as part of provisioning/bootstrapping is fine, but for proprietary 18 | code that changes frequently, a better solution is needed. 19 | 20 | To do this, you have one of two chocies: 21 | 22 | 1. Manually create a regular zip file, and ship it via `--py-files` or 23 | `addPyFile()` 24 | 2. Build an egg (`python setup.py bdist_egg`) or source distribution 25 | `python setup.py sdist` 26 | 27 | Building a regular zip file is fine, albeit a little more tedious than being 28 | able to run: 29 | 30 | ```bash 31 | python setup.py clean bdist_egg 32 | ``` 33 | 34 | Of course the other benefit from creating a package is that you can benefit 35 | from sharing your code if you've created something that's pip installable. 36 | 37 | ## Tip 2: Try to avoid lambdas 38 | 39 | By design, Spark requires a functional programming approach to driver scripts. 40 | Python functions are pickled, sent across the network and executed on remote 41 | servers when tranformation methods like `map`, `filter` or `reduce` are called. 42 | 43 | It's tempting to write the bulk of functions as: 44 | 45 | ```python 46 | data = (sc.textFile('/path/to/data') 47 | .map(lambda d: d.split(',')) 48 | .map(lambda d: d[0].upper()) 49 | .filter(lambda d: d == 'THING')) 50 | ``` 51 | 52 | Anonymous lambda's like these are quick and easy, but suffer from two big 53 | problems: 54 | 55 | 1. They aren't unit testable 56 | 2. They aren't self-documenting 57 | 58 | Instead, we could rewrite the code above like so: 59 | 60 | ```python 61 | def parse_line(line): 62 | parts = line.split(',') 63 | return parts[0].upper() 64 | 65 | def is_valid_thing(thing): 66 | return thing == 'THING' 67 | 68 | data = (sc.textFile('/path/to/data') 69 | .map(parse_line) 70 | .filter(is_valid_thing)) 71 | ``` 72 | 73 | More verbose, sure, but `parse_line` and `is_valid_thing` are now easily unit 74 | testable and arguably, self-documenting. 75 | 76 | ## Tip 3: Abstract your data with models 77 | 78 | The code above is good, but it's still pretty annoying that we have to deal with 79 | strings that are split and then remember the array index of fields we want to 80 | work with. 81 | 82 | To improve on this, we could create a model that encapsulates the data 83 | structures we're playing with. 84 | 85 | ```python 86 | class Person(object): 87 | 88 | __slots__ = ('first_name', 'last_name', 'birth_date') 89 | 90 | def __init__(first_name, last_name, birth_date): 91 | self.first_name = first_name 92 | self.last_name = last_name 93 | self.birth_date = birth_date 94 | 95 | @classmethod 96 | def from_csv_line(cls, line): 97 | parts = line.split(',') 98 | if len(parts) != 3: 99 | raise Exception('Bad line') 100 | 101 | return cls(*parts) 102 | ``` 103 | 104 | Now we can play with a class who's attributes are known: 105 | 106 | ```python 107 | def is_valid_person(person): 108 | return person.first_name is not None and person.last_name is not None 109 | 110 | 111 | data = (sc.textFile('/path/to/data') 112 | .map(Person.from_csv_line) 113 | .filter(is_valid_person)) 114 | ``` 115 | 116 | Astute Pythonistas will question why I didn't use a `namedtuple` and instead 117 | resorted to an object using `__slots__`. The answer is performance. In some 118 | testing we've done internally, allocating lots of slot-based objects is both 119 | faster and more memory efficient than using anything like `namedtuple`s. 120 | 121 | Given that you'll often allocate millions if not billions of these objects, 122 | speed and memory are important to keep in mind. 123 | 124 | 125 | ## Tip 4: Use test-ready closures for database connections 126 | 127 | When working with external databases, give yourself the ability to send a mock 128 | connection object to facilitate tests later on: 129 | 130 | ```python 131 | def enrich_data(db_conn=None): 132 | def _enrich_data(partition): 133 | db_conn = db_conn or create_db_conn() 134 | for datum in partition: 135 | # do something with db_conn like join additional data 136 | enriched_datum = do_stuff(datum, db_conn) 137 | yield enriched_datum 138 | return _enrich_data 139 | 140 | my_rdd.mapPartitions(enrich_data()) 141 | ``` 142 | 143 | By creating a closure like this, we can still independently test `enrich_data` 144 | by passing in a `MagicMock` for our `db_conn` instance. 145 | 146 | ## Tip 4: Use some `unittest` magic for integration tests 147 | 148 | How does one create an integration test that relies on Spark running? This repo 149 | serves as a perfect example! Check out: 150 | 151 | - [the integration test harness](https://github.com/msukmanowsky/pyspark-testing/blob/master/tests/pyspark_testing/integration/__init__.py) 152 | - [the sample integration tests](https://github.com/msukmanowsky/pyspark-testing/blob/master/tests/pyspark_testing/integration/test_driver.py) 153 | - [the Travis CI config](https://github.com/msukmanowsky/pyspark-testing/blob/master/.travis.yml) 154 | 155 | ## Notes on the data set used in this project 156 | 157 | The data set used in this project is the 158 | [National Broadband Data Set](http://open.canada.ca/data/en/dataset/00a331db-121b-445d-b119-35dbbe3eedd9) 159 | which is provided thanks to the Government of Canada's Open Government 160 | initiative. 161 | -------------------------------------------------------------------------------- /install-spark.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | SPARK_VERSION='1.5.1' 4 | HADOOP_VERSION='2.4' 5 | 6 | curl http://apache.mirror.gtcomm.net/spark/spark-$SPARK_VERSION/spark-$SPARK_VERSION-bin-hadoop$HADOOP_VERSION.tgz --output /tmp/spark-$SPARK_VERSION-bin-hadoop$HADOOP_VERSION.tgz 7 | cd /tmp && tar -xvzf /tmp/spark-$SPARK_VERSION-bin-hadoop$HADOOP_VERSION.tgz 8 | -------------------------------------------------------------------------------- /pyspark_testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msukmanowsky/pyspark-testing/9d81e6b024b8b9d1dc39527bd2ccba26e36f8b44/pyspark_testing/__init__.py -------------------------------------------------------------------------------- /pyspark_testing/data/National_Broadband_Data_March2012_Eng.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msukmanowsky/pyspark-testing/9d81e6b024b8b9d1dc39527bd2ccba26e36f8b44/pyspark_testing/data/National_Broadband_Data_March2012_Eng.csv.gz -------------------------------------------------------------------------------- /pyspark_testing/driver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from operator import add 3 | import warnings 4 | import pkg_resources 5 | import os 6 | import pprint 7 | 8 | try: 9 | from pyspark import SparkContext 10 | except ImportError: 11 | warnings.warn('Cannot import pyspark.SparkContext, certain driver functions ' 12 | 'will not work') 13 | 14 | 15 | from pyspark_testing.models import BroadbandCoverageInfo 16 | 17 | 18 | def data_path(): 19 | ''' 20 | Return absolute path to the data file contained in the pyspark_testing 21 | package. 22 | ''' 23 | resource_path = os.path.join('data', 'National_Broadband_Data_March2012_Eng.csv.gz') 24 | return pkg_resources.resource_filename('pyspark_testing', resource_path) 25 | 26 | 27 | def top_unserved(data, n=10): 28 | ''' 29 | What are the top n largest areas that don't have broadband connections? 30 | ''' 31 | return (data.filter(lambda d: d.unserved != 0) 32 | .sortBy(lambda d: d.population, ascending=False) 33 | .take(n)) 34 | 35 | 36 | def summary_stats(data): 37 | ''' 38 | Returns a dict of availability stats by connection type 39 | ''' 40 | def stats_gen(datum): 41 | for k in ('dsl', 'wireless', 'broadband'): 42 | if getattr(datum, '{}_available'.format(k)): 43 | yield ('{}_available'.format(k), 1) 44 | else: 45 | yield ('{}_unavailable'.format(k), 1) 46 | 47 | return data.flatMap(stats_gen).foldByKey(0, add).collectAsMap() 48 | 49 | 50 | def main(): 51 | ''' 52 | Driver entry point for spark-submit. 53 | ''' 54 | with SparkContext() as sc: 55 | data = (sc.textFile(data_path(), use_unicode=False) 56 | .map(lambda l: l.decode('latin_1')) 57 | .map(BroadbandCoverageInfo.from_csv_line)) 58 | 59 | pprint.pprint(data.first()) 60 | 61 | # What are the top 10 largest that don't have broadband connections? 62 | pprint.pprint(top_unserved(data)) 63 | 64 | # What are the overall stats for availability by connection type? 65 | pprint.pprint(summary_stats(data)) 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /pyspark_testing/models.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | 4 | def safe_convert(string): 5 | if string == '': 6 | return None 7 | 8 | if string == 'T': 9 | return True 10 | elif string == 'F': 11 | return False 12 | 13 | try: 14 | return int(string) 15 | except ValueError: 16 | pass 17 | 18 | try: 19 | return float(string) 20 | except ValueError: 21 | pass 22 | 23 | return string 24 | 25 | 26 | def unicode_csv_reader(unicode_csv_data, encoding='latin_1', **kwargs): 27 | def encoder(): 28 | for line in unicode_csv_data: 29 | try: 30 | yield line.encode(encoding) 31 | except UnicodeEncodeError: 32 | raise Exception('Could not encode {!r} using {!r}' 33 | .format(line, encoding)) 34 | 35 | reader = csv.reader(encoder(), **kwargs) 36 | for row in reader: 37 | yield [cell.decode(encoding) for cell in row] 38 | 39 | 40 | class BroadbandCoverageInfo(object): 41 | __slots__ = ( 42 | 'hexagon_number', # Hexagon identifier (49,999 in total) 43 | 'gsa_number', # Geographic Service Area 44 | 'first_nation', # Name of first nations reserve 45 | 'location_name', # Name of location of hexagon 46 | 'municipality', # Municipality of hexagon 47 | 'latitude', # Latitudinal coordinate in decimal degrees 48 | 'longitude', # Longitudinal coordinate in decimal degrees 49 | 'population', # Total population in hexagon, from 2006 Census data 50 | 'unserved', # Estimated range of unserved / underserved population in hexagon (without 1.5 Mbps broadband service availability) 51 | 'is_deferral_account', # Location included in CRTC Decisions on Deferral Accounts 52 | 'dsl_available', # Indicates availability of DSL (Digital Subscriber Loop) 53 | 'broadband_available', # Indicates availability of Cable broadband service 54 | 'wireless_available', # Indicates availability of wireless broadband service 55 | ) 56 | 57 | def __repr__(self): 58 | params = (p for p in self.__slots__ if not p.startswith('_')) 59 | params = ('{}={!r}'.format(p, getattr(self, p)) for p in params) 60 | params = ', '.join(params) 61 | return '{}({})'.format(self.__class__.__name__, params) 62 | 63 | __str__ = __repr__ 64 | __unicode__ = unicode(__str__) 65 | 66 | def __init__(self, **kwargs): 67 | for k, v in kwargs.iteritems(): 68 | setattr(self, k, v) 69 | 70 | def __eq__(self, other): 71 | if not isinstance(other, self.__class__): 72 | return False 73 | 74 | for k in self.__slots__: 75 | if getattr(self, k) != getattr(other, k): 76 | return False 77 | 78 | return True 79 | 80 | def __ne__(self, other): 81 | return not self.__equals__(other) 82 | 83 | @classmethod 84 | def from_csv_line(cls, line): 85 | # This is definitely not the most efficient way of reading CSV, but the 86 | # file contains some tricky quote chars 87 | reader = unicode_csv_reader((line,)) 88 | parts = reader.next() 89 | parts = [safe_convert(p) for p in parts] 90 | kwargs = dict(zip(cls.__slots__, parts)) 91 | 92 | return cls(**kwargs) 93 | -------------------------------------------------------------------------------- /pyspark_testing/version.py: -------------------------------------------------------------------------------- 1 | def _safe_int(string): 2 | try: 3 | return int(string) 4 | except ValueError: 5 | return string 6 | 7 | 8 | __version__ = '0.1.0' 9 | VERSION = tuple(_safe_int(x) for x in __version__.split('.')) 10 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | mock 2 | nose 3 | unittest2 4 | pylint 5 | -------------------------------------------------------------------------------- /run_driver.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Assumes you're working inside an active virtualenv 4 | python=`which python` 5 | PYSPARK_PYTHON=$python PYSPARK_DRIVER_PYTHON=$python $SPARK_HOME/bin/spark-submit pyspark_testing/driver.py "$@" 6 | -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python setup.py clean bdist_egg 3 | nosetests "$@" 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | from setuptools import setup, find_packages 5 | 6 | # Get version without importing, which avoids dependency issues 7 | def get_version(): 8 | with open('pyspark_testing/version.py') as version_file: 9 | return re.search(r"""__version__\s+=\s+(['"])(?P.+?)\1""", 10 | version_file.read()).group('version') 11 | 12 | def readme(): 13 | with open('README.md') as f: 14 | return f.read() 15 | 16 | install_requires = [] 17 | lint_requires = [ 18 | 'pep8', 19 | 'pyflakes' 20 | ] 21 | 22 | if sys.version_info.major < 3: 23 | tests_require = ['mock', 'nose', 'unittest2'] 24 | else: 25 | tests_require = ['mock', 'nose'] 26 | 27 | dependency_links = [] 28 | setup_requires = [] 29 | if 'nosetests' in sys.argv[1:]: 30 | setup_requires.append('nose') 31 | 32 | setup( 33 | name='pyspark_testing', 34 | version=get_version(), 35 | author='Mike Sukmanowsky', 36 | author_email='mike.sukmanowsky@gmail.com', 37 | url='https://github.com/msukmanowsky/pyspark-testing', 38 | description=('Examples of unit and integration testing with PySpark'), 39 | long_description=readme(), 40 | license='MIT', 41 | packages=find_packages(), 42 | install_requires=install_requires, 43 | tests_require=tests_require, 44 | setup_requires=setup_requires, 45 | extras_require={ 46 | 'test': tests_require, 47 | 'all': install_requires + tests_require, 48 | 'docs': ['sphinx'] + tests_require, 49 | 'lint': lint_requires 50 | }, 51 | dependency_links=dependency_links, 52 | test_suite='nose.collector', 53 | include_package_data=True, 54 | ) 55 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def relative_file(file_obj, *paths): 5 | path_to_file = os.path.dirname(os.path.abspath(file_obj)) 6 | return os.path.abspath(os.path.join(path_to_file, *paths)) 7 | -------------------------------------------------------------------------------- /tests/pyspark_testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msukmanowsky/pyspark-testing/9d81e6b024b8b9d1dc39527bd2ccba26e36f8b44/tests/pyspark_testing/__init__.py -------------------------------------------------------------------------------- /tests/pyspark_testing/integration/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from functools import partial 3 | import atexit 4 | import glob 5 | import logging 6 | import os 7 | import sys 8 | import subprocess 9 | import time 10 | import unittest 11 | 12 | from pyspark_testing.version import __version__ as version 13 | 14 | from ... import relative_file 15 | 16 | 17 | log = logging.getLogger(__name__) 18 | here = partial(relative_file, __file__) 19 | 20 | 21 | def initialize_pyspark(spark_home, app_name, add_files=None): 22 | py4j = glob.glob(os.path.join(spark_home, 'python', 'lib', 'py4j*.zip'))[0] 23 | pyspark_path = os.path.join(spark_home, 'python') 24 | 25 | add_files = add_files or [] 26 | sys.path.insert(0, py4j) 27 | sys.path.insert(0, pyspark_path) 28 | for file in add_files: 29 | sys.path.insert(0, file) 30 | 31 | from pyspark.context import SparkContext 32 | logging.getLogger('py4j.java_gateway').setLevel(logging.WARN) 33 | sc = SparkContext(appName=app_name, pyFiles=add_files) 34 | log.debug('SparkContext initialized') 35 | return sc 36 | 37 | 38 | class PySparkIntegrationTest(unittest.TestCase): 39 | 40 | @classmethod 41 | def setUpClass(cls): 42 | if not hasattr(cls, 'sc'): 43 | spark_home = os.environ['SPARK_HOME'] 44 | build_zip = here('../../../dist/pyspark_testing-{}-py2.7.egg'.format(version)) 45 | app_name = '{} Tests'.format(cls.__name__) 46 | cls.sc = initialize_pyspark(spark_home, app_name, [build_zip]) 47 | log.debug('SparkContext initialized on %s', cls.__name__) 48 | 49 | @classmethod 50 | def tearDownClass(cls): 51 | if hasattr(cls, 'sc'): 52 | cls.sc.stop() 53 | -------------------------------------------------------------------------------- /tests/pyspark_testing/integration/test_driver.py: -------------------------------------------------------------------------------- 1 | from ... import relative_file 2 | from . import PySparkIntegrationTest 3 | 4 | from pyspark_testing import driver 5 | from pyspark_testing.models import BroadbandCoverageInfo 6 | 7 | 8 | class TestDriver(PySparkIntegrationTest): 9 | 10 | def setUp(self): 11 | self.data = (self.sc.textFile(driver.data_path(), use_unicode=False) 12 | .map(lambda l: l.decode('latin_1')) 13 | .map(BroadbandCoverageInfo.from_csv_line)) 14 | 15 | 16 | # def test_top_unserved(self): 17 | # driver.top_unserved() 18 | 19 | def test_summary_stats(self): 20 | expected_stats = { 21 | 'broadband_available': 8714, 22 | 'broadband_unavailable': 41285, 23 | 'dsl_available': 14858, 24 | 'dsl_unavailable': 35141, 25 | 'wireless_available': 30971, 26 | 'wireless_unavailable': 19028 27 | } 28 | self.assertDictEqual(expected_stats, driver.summary_stats(self.data)) 29 | -------------------------------------------------------------------------------- /tests/pyspark_testing/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msukmanowsky/pyspark-testing/9d81e6b024b8b9d1dc39527bd2ccba26e36f8b44/tests/pyspark_testing/unit/__init__.py -------------------------------------------------------------------------------- /tests/pyspark_testing/unit/test_models.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from pyspark_testing.models import BroadbandCoverageInfo 4 | 5 | 6 | class TestModels(unittest.TestCase): 7 | 8 | def setUp(self): 9 | self.line = '40930,,,"Aalders Landing, NS @ 44.82\xb0N x 64.94\xb0W","Annapolis, Subd. D :SC",44.82,-64.94,154,"0","F","F","F","T"\r\n' 10 | self.line = self.line.decode('latin_1').strip() 11 | 12 | def test_broadband_coverage_info(self): 13 | info = BroadbandCoverageInfo.from_csv_line(self.line) 14 | expected_info = BroadbandCoverageInfo(hexagon_number=40930, gsa_number=None, first_nation=None, location_name=u'Aalders Landing, NS @ 44.82\xb0N x 64.94\xb0W', municipality=u'Annapolis, Subd. D :SC', latitude=44.82, longitude=-64.94, population=154, unserved=0, is_deferral_account=False, dsl_available=False, broadband_available=False, wireless_available=True) 15 | self.assertEqual(expected_info, info) 16 | --------------------------------------------------------------------------------