├── tests └── __init__.py ├── elephas ├── __init__.py ├── ml │ ├── __init__.py │ ├── adapter.py │ └── params.py ├── mllib │ ├── __init__.py │ └── adapter.py ├── utils │ ├── __init__.py │ ├── functional_utils.py │ ├── rdd_utils.py │ └── rwlock.py ├── hyperparam.py ├── ml_model.py ├── optimizers.py └── spark_model.py ├── setup.cfg ├── elephas.gif ├── setup.py ├── .gitignore ├── LICENSE ├── .travis.yml ├── examples ├── mnist_mlp_spark.py ├── mllib_mlp.py ├── ml_mlp.py ├── hyperparam_optimization.py ├── ml_pipeline_otto.py └── Spark_ML_Pipeline.ipynb └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elephas/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elephas/ml/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elephas/mllib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elephas/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /elephas.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leriomaggio/elephas/master/elephas.gif -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from setuptools import find_packages 3 | 4 | setup(name='elephas', 5 | version='0.3', 6 | description='Deep learning on Spark with Keras', 7 | url='http://github.com/maxpumperla/elephas', 8 | download_url='https://github.com/maxpumperla/elephas/tarball/0.3', 9 | author='Max Pumperla', 10 | author_email='max.pumperla@googlemail.com', 11 | install_requires=['keras', 'hyperas', 'flask'], 12 | license='MIT', 13 | packages=find_packages(), 14 | zip_safe=False) 15 | -------------------------------------------------------------------------------- /elephas/utils/functional_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | 5 | 6 | def add_params(p1, p2): 7 | ''' 8 | Add two lists of parameters 9 | ''' 10 | res = [] 11 | for x, y in zip(p1, p2): 12 | res.append(x+y) 13 | return res 14 | 15 | 16 | def subtract_params(p1, p2): 17 | ''' 18 | Subtract two lists of parameters 19 | ''' 20 | res = [] 21 | for x, y in zip(p1, p2): 22 | res.append(x-y) 23 | return res 24 | 25 | 26 | def get_neutral(array): 27 | ''' 28 | Get list of zero-valued numpy arrays for 29 | specified list of numpy arrays 30 | ''' 31 | res = [] 32 | for x in array: 33 | res.append(np.zeros_like(x)) 34 | return res 35 | 36 | 37 | def divide_by(array_list, num_workers): 38 | ''' 39 | Divide a list of parameters by an integer num_workers. 40 | ''' 41 | for i, x in enumerate(array_list): 42 | array_list[i] /= num_workers 43 | return array_list 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *,cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | 59 | examples/.ipynb_checkpoints 60 | examples/metastore_db 61 | 62 | examples/*.csv 63 | -------------------------------------------------------------------------------- /elephas/mllib/adapter.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from pyspark.mllib.linalg import Matrices, Vectors 4 | 5 | 6 | def from_matrix(matrix): 7 | ''' Convert MLlib Matrix to numpy array ''' 8 | return matrix.toArray() 9 | 10 | 11 | def to_matrix(np_array): 12 | ''' Convert numpy array to MLlib Matrix ''' 13 | if len(np_array.shape) == 2: 14 | return Matrices.dense(np_array.shape[0], 15 | np_array.shape[1], 16 | np_array.ravel()) 17 | else: 18 | raise Exception("""An MLLib Matrix can only be created from a two-dimensional numpy array""") 19 | 20 | 21 | def from_vector(vector): 22 | ''' Convert MLlib Vector to numpy array ''' 23 | return vector.array 24 | 25 | 26 | def to_vector(np_array): 27 | ''' Convert numpy array to MLlib Vector ''' 28 | if len(np_array.shape) == 1: 29 | return Vectors.dense(np_array) 30 | else: 31 | raise Exception("""An MLLib Vector can only be created from a one-dimensional numpy array""") 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Max 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /elephas/ml/adapter.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from pyspark.sql import SQLContext 4 | from pyspark.mllib.regression import LabeledPoint 5 | from ..utils.rdd_utils import from_labeled_point, to_labeled_point, lp_to_simple_rdd 6 | 7 | 8 | def to_data_frame(sc, features, labels, categorical=False): 9 | ''' 10 | Convert numpy arrays of features and labels into Spark DataFrame 11 | ''' 12 | lp_rdd = to_labeled_point(sc, features, labels, categorical) 13 | sql_context = SQLContext(sc) 14 | df = sql_context.createDataFrame(lp_rdd) 15 | return df 16 | 17 | 18 | def from_data_frame(df, categorical=False, nb_classes=None): 19 | ''' 20 | Convert DataFrame back to pair of numpy arrays 21 | ''' 22 | lp_rdd = df.rdd.map(lambda row: LabeledPoint(row.label, row.features)) 23 | features, labels = from_labeled_point(lp_rdd, categorical, nb_classes) 24 | return features, labels 25 | 26 | 27 | def df_to_simple_rdd(df, categorical=False, nb_classes=None, featuresCol='features', labelCol='label'): 28 | ''' 29 | Convert DataFrame into RDD of pairs 30 | ''' 31 | sqlContext = df.sql_ctx 32 | sqlContext.registerDataFrameAsTable(df, "temp_table") 33 | selected_df = sqlContext.sql("SELECT {0} AS features, {1} as label from temp_table".format(featuresCol, labelCol)) 34 | lp_rdd = selected_df.rdd.map(lambda row: LabeledPoint(row.label, row.features)) 35 | rdd = lp_to_simple_rdd(lp_rdd, categorical, nb_classes) 36 | return rdd 37 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | dist: trusty 3 | language: python 4 | python: 5 | - "2.7" 6 | # - "3.4" # Note that hyperopt currently seems to have issues with 3.4 7 | install: 8 | # code below is taken from http://conda.pydata.org/docs/travis.html 9 | # We do this conditionally because it saves us some downloading if the 10 | # version is the same. 11 | - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then 12 | wget https://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh; 13 | else 14 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 15 | fi 16 | - bash miniconda.sh -b -p $HOME/miniconda 17 | - export PATH="$HOME/miniconda/bin:$PATH" 18 | - hash -r 19 | - conda config --set always_yes yes --set changeps1 no 20 | - conda update -q conda 21 | - conda info -a 22 | - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy scipy matplotlib pandas pytest h5py flask 23 | - source activate test-environment 24 | - pip install pytest-cov python-coveralls 25 | - pip install git+git://github.com/Theano/Theano.git 26 | - pip install keras 27 | - python setup.py install 28 | 29 | # Install Spark 30 | - wget http://apache.mirrors.tds.net/spark/spark-1.5.2/spark-1.5.2-bin-hadoop2.6.tgz -P $HOME 31 | - tar zxvf $HOME/spark-* -C $HOME 32 | - export SPARK_HOME=$HOME/spark-1.5.2-bin-hadoop2.6 33 | - export PATH=$PATH:$SPARK_HOME/bin 34 | 35 | # Just run an example for now 36 | script: 37 | - python -c "import keras.backend" 38 | - spark-submit --driver-memory 2G $PWD/examples/mnist_mlp_spark.py 39 | after_success: 40 | - coveralls 41 | -------------------------------------------------------------------------------- /elephas/utils/rdd_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from pyspark.mllib.regression import LabeledPoint 4 | import numpy as np 5 | 6 | from ..mllib.adapter import to_vector, from_vector 7 | 8 | 9 | def to_simple_rdd(sc, features, labels): 10 | ''' 11 | Convert numpy arrays of features and labels into 12 | an RDD of pairs. 13 | ''' 14 | pairs = [(x, y) for x, y in zip(features, labels)] 15 | return sc.parallelize(pairs) 16 | 17 | 18 | def to_labeled_point(sc, features, labels, categorical=False): 19 | ''' 20 | Convert numpy arrays of features and labels into 21 | a LabeledPoint RDD 22 | ''' 23 | labeled_points = [] 24 | for x, y in zip(features, labels): 25 | if categorical: 26 | lp = LabeledPoint(np.argmax(y), to_vector(x)) 27 | else: 28 | lp = LabeledPoint(y, to_vector(x)) 29 | labeled_points.append(lp) 30 | return sc.parallelize(labeled_points) 31 | 32 | 33 | def from_labeled_point(rdd, categorical=False, nb_classes=None): 34 | ''' 35 | Convert a LabeledPoint RDD back to a pair of numpy arrays 36 | ''' 37 | features = np.asarray(rdd.map(lambda lp: from_vector(lp.features)).collect()) 38 | labels = np.asarray(rdd.map(lambda lp: lp.label).collect(), dtype='int32') 39 | if categorical: 40 | if not nb_classes: 41 | nb_classes = np.max(labels)+1 42 | temp = np.zeros((len(labels), nb_classes)) 43 | for i, label in enumerate(labels): 44 | temp[i, label] = 1. 45 | labels = temp 46 | return features, labels 47 | 48 | 49 | def encode_label(label, nb_classes): 50 | ''' one-hot encoding of a label ''' 51 | encoded = np.zeros(nb_classes) 52 | encoded[label] = 1. 53 | return encoded 54 | 55 | 56 | def lp_to_simple_rdd(lp_rdd, categorical=False, nb_classes=None): 57 | ''' 58 | Convert a LabeledPoint RDD into an RDD of feature-label pairs 59 | ''' 60 | if categorical: 61 | if not nb_classes: 62 | labels = np.asarray(lp_rdd.map(lambda lp: lp.label).collect(), dtype='int32') 63 | nb_classes = np.max(labels)+1 64 | rdd = lp_rdd.map(lambda lp: (from_vector(lp.features), encode_label(lp.label, nb_classes))) 65 | else: 66 | rdd = lp_rdd.map(lambda lp: (from_vector(lp.features), lp.label)) 67 | return rdd 68 | -------------------------------------------------------------------------------- /elephas/utils/rwlock.py: -------------------------------------------------------------------------------- 1 | """Simple reader-writer locks in Python 2 | Many readers can hold the lock XOR one and only one writer 3 | http://majid.info/blog/a-reader-writer-lock-for-python/ 4 | """ 5 | import threading 6 | 7 | version = """$Id: 04-1.html,v 1.3 2006/12/05 17:45:12 majid Exp $""" 8 | 9 | 10 | class RWLock: 11 | """ 12 | A simple reader-writer lock Several readers can hold the lock 13 | simultaneously, XOR one writer. Write locks have priority over reads to 14 | prevent write starvation. 15 | """ 16 | def __init__(self): 17 | self.rwlock = 0 18 | self.writers_waiting = 0 19 | self.monitor = threading.Lock() 20 | self.readers_ok = threading.Condition(self.monitor) 21 | self.writers_ok = threading.Condition(self.monitor) 22 | 23 | def acquire_read(self): 24 | """ 25 | Acquire a read lock. Several threads can hold this typeof lock. 26 | It is exclusive with write locks. 27 | """ 28 | self.monitor.acquire() 29 | while self.rwlock < 0 or self.writers_waiting: 30 | self.readers_ok.wait() 31 | self.rwlock += 1 32 | self.monitor.release() 33 | 34 | def acquire_write(self): 35 | """ 36 | Acquire a write lock. Only one thread can hold this lock, and 37 | only when no read locks are also held. 38 | """ 39 | self.monitor.acquire() 40 | while self.rwlock != 0: 41 | self.writers_waiting += 1 42 | self.writers_ok.wait() 43 | self.writers_waiting -= 1 44 | self.rwlock = -1 45 | self.monitor.release() 46 | 47 | def release(self): 48 | """ 49 | Release a lock, whether read or write. 50 | """ 51 | self.monitor.acquire() 52 | if self.rwlock < 0: 53 | self.rwlock = 0 54 | else: 55 | self.rwlock -= 1 56 | wake_writers = self.writers_waiting and self.rwlock == 0 57 | wake_readers = self.writers_waiting == 0 58 | self.monitor.release() 59 | if wake_writers: 60 | self.writers_ok.acquire() 61 | self.writers_ok.notify() 62 | self.writers_ok.release() 63 | elif wake_readers: 64 | self.readers_ok.acquire() 65 | self.readers_ok.notifyAll() 66 | self.readers_ok.release() 67 | -------------------------------------------------------------------------------- /examples/mnist_mlp_spark.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from keras.datasets import mnist 5 | from keras.models import Sequential 6 | from keras.layers.core import Dense, Dropout, Activation 7 | from keras.optimizers import SGD 8 | from keras.utils import np_utils 9 | 10 | from elephas.spark_model import SparkModel 11 | from elephas.utils.rdd_utils import to_simple_rdd 12 | from elephas import optimizers as elephas_optimizers 13 | 14 | from pyspark import SparkContext, SparkConf 15 | 16 | # Define basic parameters 17 | batch_size = 64 18 | nb_classes = 10 19 | nb_epoch = 10 20 | 21 | # Create Spark context 22 | conf = SparkConf().setAppName('Mnist_Spark_MLP').setMaster('local[8]') 23 | sc = SparkContext(conf=conf) 24 | 25 | # Load data 26 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 27 | 28 | x_train = x_train.reshape(60000, 784) 29 | x_test = x_test.reshape(10000, 784) 30 | x_train = x_train.astype("float32") 31 | x_test = x_test.astype("float32") 32 | x_train /= 255 33 | x_test /= 255 34 | print(x_train.shape[0], 'train samples') 35 | print(x_test.shape[0], 'test samples') 36 | 37 | # Convert class vectors to binary class matrices 38 | y_train = np_utils.to_categorical(y_train, nb_classes) 39 | y_test = np_utils.to_categorical(y_test, nb_classes) 40 | 41 | model = Sequential() 42 | model.add(Dense(128, input_dim=784)) 43 | model.add(Activation('relu')) 44 | model.add(Dropout(0.2)) 45 | model.add(Dense(128)) 46 | model.add(Activation('relu')) 47 | model.add(Dropout(0.2)) 48 | model.add(Dense(10)) 49 | model.add(Activation('softmax')) 50 | 51 | sgd = SGD(lr=0.1) 52 | 53 | # Build RDD from numpy features and labels 54 | rdd = to_simple_rdd(sc, x_train, y_train) 55 | 56 | # Initialize SparkModel from Keras model and Spark context 57 | adagrad = elephas_optimizers.Adagrad() 58 | spark_model = SparkModel(sc, 59 | model, 60 | optimizer=adagrad, 61 | frequency='epoch', 62 | mode='asynchronous', 63 | num_workers=2,master_optimizer=sgd) 64 | 65 | # Train Spark model 66 | spark_model.train(rdd, nb_epoch=nb_epoch, batch_size=batch_size, verbose=2, validation_split=0.1) 67 | 68 | # Evaluate Spark model by evaluating the underlying model 69 | score = spark_model.master_network.evaluate(x_test, y_test, verbose=2) 70 | print('Test accuracy:', score[1]) 71 | -------------------------------------------------------------------------------- /examples/mllib_mlp.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from keras.datasets import mnist 5 | from keras.models import Sequential 6 | from keras.layers.core import Dense, Dropout, Activation 7 | from keras.optimizers import RMSprop 8 | from keras.utils import np_utils 9 | 10 | from elephas.spark_model import SparkMLlibModel 11 | from elephas.utils.rdd_utils import to_labeled_point, lp_to_simple_rdd 12 | from elephas import optimizers as elephas_optimizers 13 | 14 | from pyspark import SparkContext, SparkConf 15 | 16 | # Define basic parameters 17 | batch_size = 64 18 | nb_classes = 10 19 | nb_epoch = 3 20 | 21 | # Load data 22 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 23 | 24 | x_train = x_train.reshape(60000, 784) 25 | x_test = x_test.reshape(10000, 784) 26 | x_train = x_train.astype("float32") 27 | x_test = x_test.astype("float32") 28 | x_train /= 255 29 | x_test /= 255 30 | print(x_train.shape[0], 'train samples') 31 | print(x_test.shape[0], 'test samples') 32 | 33 | # Convert class vectors to binary class matrices 34 | y_train = np_utils.to_categorical(y_train, nb_classes) 35 | y_test = np_utils.to_categorical(y_test, nb_classes) 36 | 37 | model = Sequential() 38 | model.add(Dense(128, input_dim=784)) 39 | model.add(Activation('relu')) 40 | model.add(Dropout(0.2)) 41 | model.add(Dense(128)) 42 | model.add(Activation('relu')) 43 | model.add(Dropout(0.2)) 44 | model.add(Dense(10)) 45 | model.add(Activation('softmax')) 46 | 47 | # Compile model 48 | rms = RMSprop() 49 | 50 | # Create Spark context 51 | conf = SparkConf().setAppName('Mnist_Spark_MLP').setMaster('local[8]') 52 | sc = SparkContext(conf=conf) 53 | 54 | # Build RDD from numpy features and labels 55 | lp_rdd = to_labeled_point(sc, x_train, y_train, categorical=True) 56 | rdd = lp_to_simple_rdd(lp_rdd, True, nb_classes) 57 | 58 | # Initialize SparkModel from Keras model and Spark context 59 | adadelta = elephas_optimizers.Adadelta() 60 | spark_model = SparkMLlibModel(sc, model, optimizer=adadelta, frequency='batch', mode='asynchronous', num_workers=2, master_optimizer=rms) 61 | 62 | # Train Spark model 63 | spark_model.train(lp_rdd, nb_epoch=20, batch_size=32, verbose=0, 64 | validation_split=0.1, categorical=True, nb_classes=nb_classes) 65 | 66 | # Evaluate Spark model by evaluating the underlying model 67 | score = spark_model.master_network.evaluate(x_test, y_test, verbose=2) 68 | print('Test accuracy:', score[1]) 69 | -------------------------------------------------------------------------------- /examples/ml_mlp.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from keras.datasets import mnist 5 | from keras.models import Sequential 6 | from keras.layers.core import Dense, Dropout, Activation 7 | from keras.optimizers import Adam 8 | from keras.utils import np_utils 9 | 10 | from elephas.ml_model import ElephasEstimator 11 | from elephas.ml.adapter import to_data_frame 12 | from elephas import optimizers as elephas_optimizers 13 | 14 | from pyspark import SparkContext, SparkConf 15 | from pyspark.mllib.evaluation import MulticlassMetrics 16 | from pyspark.ml import Pipeline 17 | 18 | 19 | # Define basic parameters 20 | batch_size = 64 21 | nb_classes = 10 22 | nb_epoch = 1 23 | 24 | # Load data 25 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 26 | 27 | x_train = x_train.reshape(60000, 784) 28 | x_test = x_test.reshape(10000, 784) 29 | x_train = x_train.astype("float32") 30 | x_test = x_test.astype("float32") 31 | x_train /= 255 32 | x_test /= 255 33 | print(x_train.shape[0], 'train samples') 34 | print(x_test.shape[0], 'test samples') 35 | 36 | # Convert class vectors to binary class matrices 37 | y_train = np_utils.to_categorical(y_train, nb_classes) 38 | y_test = np_utils.to_categorical(y_test, nb_classes) 39 | 40 | model = Sequential() 41 | model.add(Dense(128, input_dim=784)) 42 | model.add(Activation('relu')) 43 | model.add(Dropout(0.2)) 44 | model.add(Dense(128)) 45 | model.add(Activation('relu')) 46 | model.add(Dropout(0.2)) 47 | model.add(Dense(10)) 48 | model.add(Activation('softmax')) 49 | 50 | 51 | # Compile model 52 | adam = Adam() 53 | model.compile(loss='categorical_crossentropy', optimizer=adam) 54 | 55 | # Create Spark context 56 | conf = SparkConf().setAppName('Mnist_Spark_MLP').setMaster('local[8]') 57 | sc = SparkContext(conf=conf) 58 | 59 | # Build RDD from numpy features and labels 60 | df = to_data_frame(sc, x_train, y_train, categorical=True) 61 | test_df = to_data_frame(sc, x_test, y_test, categorical=True) 62 | 63 | # Define elephas optimizer 64 | adadelta = elephas_optimizers.Adadelta() 65 | 66 | # Initialize Spark ML Estimator 67 | estimator = ElephasEstimator() 68 | estimator.set_keras_model_config(model.to_yaml()) 69 | estimator.set_optimizer_config(adadelta.get_config()) 70 | estimator.set_nb_epoch(nb_epoch) 71 | estimator.set_batch_size(batch_size) 72 | estimator.set_num_workers(1) 73 | estimator.set_verbosity(0) 74 | estimator.set_validation_split(0.1) 75 | estimator.set_categorical_labels(True) 76 | estimator.set_nb_classes(nb_classes) 77 | 78 | # Fitting a model returns a Transformer 79 | pipeline = Pipeline(stages=[estimator]) 80 | fitted_pipeline = pipeline.fit(df) 81 | 82 | # Evaluate Spark model by evaluating the underlying model 83 | prediction = fitted_pipeline.transform(test_df) 84 | pnl = prediction.select("label", "prediction") 85 | pnl.show(100) 86 | 87 | prediction_and_label = pnl.map(lambda row: (row.label, row.prediction)) 88 | metrics = MulticlassMetrics(prediction_and_label) 89 | print(metrics.precision()) 90 | print(metrics.recall()) 91 | -------------------------------------------------------------------------------- /examples/hyperparam_optimization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from hyperopt import Trials, STATUS_OK, tpe 3 | 4 | from hyperas import optim 5 | from hyperas.distributions import choice, uniform 6 | 7 | from elephas.hyperparam import HyperParamModel 8 | 9 | from pyspark import SparkContext, SparkConf 10 | 11 | def data(): 12 | ''' 13 | Data providing function: 14 | 15 | Make sure to have every relevant import statement included here and return data as 16 | used in model function below. This function is separated from model() so that hyperopt 17 | won't reload data for each evaluation run. 18 | ''' 19 | from keras.datasets import mnist 20 | from keras.utils import np_utils 21 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 22 | X_train = X_train.reshape(60000, 784) 23 | X_test = X_test.reshape(10000, 784) 24 | X_train = X_train.astype('float32') 25 | X_test = X_test.astype('float32') 26 | X_train /= 255 27 | X_test /= 255 28 | nb_classes = 10 29 | Y_train = np_utils.to_categorical(y_train, nb_classes) 30 | Y_test = np_utils.to_categorical(y_test, nb_classes) 31 | return X_train, Y_train, X_test, Y_test 32 | 33 | 34 | def model(X_train, Y_train, X_test, Y_test): 35 | ''' 36 | Model providing function: 37 | 38 | Create Keras model with double curly brackets dropped-in as needed. 39 | Return value has to be a valid python dictionary with two customary keys: 40 | - loss: Specify a numeric evaluation metric to be minimized 41 | - status: Just use STATUS_OK and see hyperopt documentation if not feasible 42 | The last one is optional, though recommended, namely: 43 | - model: specify the model just created so that we can later use it again. 44 | ''' 45 | from keras.models import Sequential 46 | from keras.layers.core import Dense, Dropout, Activation 47 | from keras.optimizers import RMSprop 48 | 49 | model = Sequential() 50 | model.add(Dense(512, input_shape=(784,))) 51 | model.add(Activation('relu')) 52 | model.add(Dropout({{uniform(0, 1)}})) 53 | model.add(Dense({{choice([256, 512, 1024])}})) 54 | model.add(Activation('relu')) 55 | model.add(Dropout({{uniform(0, 1)}})) 56 | model.add(Dense(10)) 57 | model.add(Activation('softmax')) 58 | 59 | rms = RMSprop() 60 | model.compile(loss='categorical_crossentropy', optimizer=rms) 61 | 62 | model.fit(X_train, Y_train, 63 | batch_size={{choice([64, 128])}}, 64 | nb_epoch=1, 65 | show_accuracy=True, 66 | verbose=2, 67 | validation_data=(X_test, Y_test)) 68 | score, acc = model.evaluate(X_test, Y_test, show_accuracy=True, verbose=0) 69 | print('Test accuracy:', acc) 70 | return {'loss': -acc, 'status': STATUS_OK, 'model': model.to_yaml(), 'weights': pickle.dumps(model.get_weights())} 71 | 72 | # Create Spark context 73 | conf = SparkConf().setAppName('Elephas_Hyperparameter_Optimization').setMaster('local[8]') 74 | sc = SparkContext(conf=conf) 75 | 76 | # Define hyper-parameter model and run optimization. 77 | hyperparam_model = HyperParamModel(sc) 78 | hyperparam_model.minimize(model=model, data=data, max_evals=5) 79 | -------------------------------------------------------------------------------- /elephas/hyperparam.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from hyperopt import Trials, rand 3 | from hyperas.ensemble import VotingModel 4 | from hyperas.optim import get_hyperopt_model_string, base_minimizer 5 | import numpy as np 6 | from keras.models import model_from_yaml 7 | import six.moves.cPickle as pickle 8 | 9 | # depend on hyperas, boto etc. is optional 10 | 11 | class HyperParamModel(object): 12 | ''' 13 | HyperParamModel 14 | ''' 15 | def __init__(self, sc, num_workers=4): 16 | self.spark_context = sc 17 | self.num_workers = num_workers 18 | 19 | def compute_trials(self, model, data, max_evals): 20 | model_string = get_hyperopt_model_string(model, data) 21 | bc_model = self.spark_context.broadcast(model_string) 22 | bc_max_evals = self.spark_context.broadcast(max_evals) 23 | 24 | hyperas_worker = HyperasWorker(bc_model, bc_max_evals) 25 | dummy_rdd = self.spark_context.parallelize([i for i in range(1, 1000)]) 26 | dummy_rdd = dummy_rdd.repartition(self.num_workers) 27 | trials_list = dummy_rdd.mapPartitions(hyperas_worker.minimize).collect() 28 | 29 | return trials_list 30 | 31 | def minimize(self, model, data, max_evals): 32 | trials_list = self.compute_trials(model, data, max_evals) 33 | 34 | best_val = 1e7 35 | for trials in trials_list: 36 | for trial in trials: 37 | val = trial.get('result').get('loss') 38 | if val < best_val: 39 | best_val = val 40 | best_model_yaml = trial.get('result').get('model') 41 | best_model_weights = trial.get('result').get('weights') 42 | 43 | best_model = model_from_yaml(best_model_yaml) 44 | best_model.set_weights(pickle.loads(best_model_weights)) 45 | 46 | return best_model 47 | 48 | def best_ensemble(self, nb_ensemble_models, model, data, max_evals, voting='hard', weights=None): 49 | model_list = self.best_models(nb_models=nb_ensemble_models, model=model, 50 | data=data, max_evals=max_evals) 51 | return VotingModel(model_list, voting, weights) 52 | 53 | def best_models(self, nb_models, model, data, max_evals): 54 | trials_list = self.compute_trials(model, data, max_evals) 55 | num_trials = sum(len(trials) for trials in trials_list) 56 | if num_trials < nb_models: 57 | nb_models = len(trials) 58 | scores = [] 59 | for trials in trials_list: 60 | scores = scores + [trial.get('result').get('loss') for trial in trials] 61 | cut_off = sorted(scores, reverse=True)[nb_models-1] 62 | model_list = [] 63 | for trials in trials_list: 64 | for trial in trials: 65 | if trial.get('result').get('loss') >= cut_off: 66 | model = model_from_yaml(trial.get('result').get('model')) 67 | model.set_weights(pickle.loads(trial.get('result').get('weights'))) 68 | model_list.append(model) 69 | return model_list 70 | 71 | class HyperasWorker(object): 72 | def __init__(self, bc_model, bc_max_evals): 73 | self.model_string = bc_model.value 74 | self.max_evals = bc_max_evals.value 75 | 76 | def minimize(self, dummy_iterator): 77 | trials = Trials() 78 | algo = rand.suggest 79 | 80 | elem = dummy_iterator.next() 81 | import random 82 | random.seed(elem) 83 | rand_seed = np.random.randint(elem) 84 | 85 | best_run = base_minimizer(model=None, data=None, algo=algo, max_evals=self.max_evals, 86 | trials=trials, full_model_string=self.model_string, rseed=rand_seed) 87 | yield trials 88 | -------------------------------------------------------------------------------- /examples/ml_pipeline_otto.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from pyspark.mllib.linalg import Vectors 4 | import numpy as np 5 | import random 6 | 7 | from pyspark import SparkContext, SparkConf 8 | from pyspark.sql import SQLContext 9 | from pyspark.ml.feature import StringIndexer, StandardScaler 10 | from pyspark.mllib.evaluation import MulticlassMetrics 11 | from pyspark.ml import Pipeline 12 | 13 | from keras.models import Sequential 14 | from keras.layers.core import Dense, Dropout, Activation 15 | from keras.layers.normalization import BatchNormalization 16 | from keras.layers.advanced_activations import PReLU 17 | from keras.utils import np_utils, generic_utils 18 | 19 | from elephas.ml_model import ElephasEstimator 20 | from elephas import optimizers as elephas_optimizers 21 | 22 | data_path = "./" 23 | 24 | # Spark contexts 25 | conf = SparkConf().setAppName('Otto_Spark_ML_Pipeline').setMaster('local[8]') 26 | sc = SparkContext(conf=conf) 27 | sql_context = SQLContext(sc) 28 | 29 | # Data loader 30 | def shuffle_csv(csv_file): 31 | lines = open(csv_file).readlines() 32 | random.shuffle(lines) 33 | open(csv_file, 'w').writelines(lines) 34 | 35 | def load_data_rdd(csv_file, shuffle=True, train=True): 36 | if shuffle: 37 | shuffle_csv(csv_file) 38 | data = sc.textFile(data_path + csv_file) 39 | data = data.filter(lambda x:x.split(',')[0] != 'id').map(lambda line: line.split(',')) 40 | if train: 41 | data = data.map( 42 | lambda line: (Vectors.dense(np.asarray(line[1:-1]).astype(np.float32)), 43 | str(line[-1]).replace('Class_', '')) ) 44 | else: 45 | data = data.map(lambda line: (Vectors.dense(np.asarray(line[1:]).astype(np.float32)), "1") ) 46 | return data 47 | 48 | # Define Data frames 49 | train_df = sql_context.createDataFrame(load_data_rdd("train.csv"), ['features', 'category']) 50 | test_df = sql_context.createDataFrame(load_data_rdd("test.csv", shuffle=False, train=False), ['features', 'category']) 51 | 52 | # Preprocessing steps 53 | string_indexer = StringIndexer(inputCol="category", outputCol="index_category") 54 | scaler = StandardScaler(inputCol="features", outputCol="scaled_features", withStd=True, withMean=True) 55 | 56 | # Keras model 57 | nb_classes = train_df.select("category").distinct().count() 58 | input_dim = len(train_df.select("features").first()[0]) 59 | 60 | model = Sequential() 61 | model.add(Dense(512, input_shape=(input_dim,))) 62 | model.add(Activation('relu')) 63 | model.add(Dropout(0.5)) 64 | model.add(Dense(512)) 65 | model.add(Activation('relu')) 66 | model.add(Dropout(0.5)) 67 | model.add(Dense(512)) 68 | model.add(Activation('relu')) 69 | model.add(Dropout(0.5)) 70 | model.add(Dense(nb_classes)) 71 | model.add(Activation('softmax')) 72 | 73 | model.compile(loss='categorical_crossentropy', optimizer='adam') 74 | 75 | 76 | # Initialize Elephas Spark ML Estimator 77 | adadelta = elephas_optimizers.Adadelta() 78 | 79 | estimator = ElephasEstimator() 80 | estimator.setFeaturesCol("scaled_features") 81 | estimator.setLabelCol("index_category") 82 | estimator.set_keras_model_config(model.to_yaml()) 83 | estimator.set_optimizer_config(adadelta.get_config()) 84 | estimator.set_nb_epoch(10) 85 | estimator.set_batch_size(128) 86 | estimator.set_num_workers(1) 87 | estimator.set_verbosity(0) 88 | estimator.set_validation_split(0.15) 89 | estimator.set_categorical_labels(True) 90 | estimator.set_nb_classes(nb_classes) 91 | 92 | # Fitting a model returns a Transformer 93 | pipeline = Pipeline(stages=[string_indexer, scaler, estimator]) 94 | fitted_pipeline = pipeline.fit(train_df) 95 | 96 | from pyspark.mllib.evaluation import MulticlassMetrics 97 | # Evaluate Spark model 98 | 99 | prediction = fitted_pipeline.transform(train_df) 100 | pnl = prediction.select("index_category", "prediction") 101 | pnl.show(100) 102 | -------------------------------------------------------------------------------- /elephas/ml/params.py: -------------------------------------------------------------------------------- 1 | from pyspark.ml.param.shared import Param, Params 2 | 3 | 4 | class HasKerasModelConfig(Params): 5 | ''' 6 | Mandatory field: 7 | 8 | Parameter mixin for Keras model yaml 9 | ''' 10 | def __init__(self): 11 | super(HasKerasModelConfig, self).__init__() 12 | self.keras_model_config = Param(self, "keras_model_config", "Serialized Keras model as yaml string") 13 | 14 | def set_keras_model_config(self, keras_model_config): 15 | self._paramMap[self.keras_model_config] = keras_model_config 16 | return self 17 | 18 | def get_keras_model_config(self): 19 | return self.getOrDefault(self.keras_model_config) 20 | 21 | 22 | class HasOptimizerConfig(Params): 23 | ''' 24 | Parameter mixin for Elephas optimizer config 25 | ''' 26 | def __init__(self): 27 | super(HasOptimizerConfig, self).__init__() 28 | self.optimizer_config = Param(self, "optimizer_config", "Serialized Elephas optimizer properties") 29 | 30 | def set_optimizer_config(self, optimizer_config): 31 | self._paramMap[self.optimizer_config] = optimizer_config 32 | return self 33 | 34 | def get_optimizer_config(self): 35 | return self.getOrDefault(self.optimizer_config) 36 | 37 | 38 | class HasMode(Params): 39 | ''' 40 | Parameter mixin for Elephas mode 41 | ''' 42 | def __init__(self): 43 | super(HasMode, self).__init__() 44 | self.mode = Param(self, "mode", "Elephas mode") 45 | self._setDefault(mode='asynchronous') 46 | 47 | def set_mode(self, mode): 48 | self._paramMap[self.mode] = mode 49 | return self 50 | 51 | def get_mode(self): 52 | return self.getOrDefault(self.mode) 53 | 54 | 55 | class HasFrequency(Params): 56 | ''' 57 | Parameter mixin for Elephas frequency 58 | ''' 59 | def __init__(self): 60 | super(HasFrequency, self).__init__() 61 | self.frequency = Param(self, "frequency", "Elephas frequency") 62 | self._setDefault(frequency='epoch') 63 | 64 | def set_frequency(self, frequency): 65 | self._paramMap[self.frequency] = frequency 66 | return self 67 | 68 | def get_frequency(self): 69 | return self.getOrDefault(self.frequency) 70 | 71 | 72 | class HasNumberOfClasses(Params): 73 | ''' 74 | Mandatory: 75 | 76 | Parameter mixin for number of classes 77 | ''' 78 | def __init__(self): 79 | super(HasNumberOfClasses, self).__init__() 80 | self.nb_classes = Param(self, "nb_classes", "number of classes") 81 | self._setDefault(nb_classes=10) 82 | 83 | def set_nb_classes(self, nb_classes): 84 | self._paramMap[self.nb_classes] = nb_classes 85 | return self 86 | 87 | def get_nb_classes(self): 88 | return self.getOrDefault(self.nb_classes) 89 | 90 | 91 | class HasCategoricalLabels(Params): 92 | ''' 93 | Mandatory: 94 | 95 | Parameter mixin for setting categorical features 96 | ''' 97 | def __init__(self): 98 | super(HasCategoricalLabels, self).__init__() 99 | self.categorical = Param(self, "categorical", "Boolean to indicate if labels are categorical") 100 | self._setDefault(categorical=True) 101 | 102 | def set_categorical_labels(self, categorical): 103 | self._paramMap[self.categorical] = categorical 104 | return self 105 | 106 | def get_categorical_labels(self): 107 | return self.getOrDefault(self.categorical) 108 | 109 | 110 | class HasEpochs(Params): 111 | ''' 112 | Parameter mixin for number of epochs 113 | ''' 114 | def __init__(self): 115 | super(HasEpochs, self).__init__() 116 | self.nb_epoch = Param(self, "nb_epoch", "Number of epochs to train") 117 | self._setDefault(nb_epoch=10) 118 | 119 | def set_nb_epoch(self, nb_epoch): 120 | self._paramMap[self.nb_epoch] = nb_epoch 121 | return self 122 | 123 | def get_nb_epoch(self): 124 | return self.getOrDefault(self.nb_epoch) 125 | 126 | 127 | class HasBatchSize(Params): 128 | ''' 129 | Parameter mixin for batch size 130 | ''' 131 | def __init__(self): 132 | super(HasBatchSize, self).__init__() 133 | self.batch_size = Param(self, "batch_size", "Batch size") 134 | self._setDefault(batch_size=32) 135 | 136 | def set_batch_size(self, batch_size): 137 | self._paramMap[self.batch_size] = batch_size 138 | return self 139 | 140 | def get_batch_size(self): 141 | return self.getOrDefault(self.batch_size) 142 | 143 | 144 | class HasVerbosity(Params): 145 | ''' 146 | Parameter mixin for output verbosity 147 | ''' 148 | def __init__(self): 149 | super(HasVerbosity, self).__init__() 150 | self.verbose = Param(self, "verbose", "Stdout verbosity") 151 | self._setDefault(verbose=0) 152 | 153 | def set_verbosity(self, verbose): 154 | self._paramMap[self.verbose] = verbose 155 | return self 156 | 157 | def get_verbosity(self): 158 | return self.getOrDefault(self.verbose) 159 | 160 | 161 | class HasValidationSplit(Params): 162 | ''' 163 | Parameter mixin for validation split percentage 164 | ''' 165 | def __init__(self): 166 | super(HasValidationSplit, self).__init__() 167 | self.validation_split = Param(self, "validation_split", "validation split percentage") 168 | self._setDefault(validation_split=0.1) 169 | 170 | def set_validation_split(self, validation_split): 171 | self._paramMap[self.validation_split] = validation_split 172 | return self 173 | 174 | def get_validation_split(self): 175 | return self.getOrDefault(self.validation_split) 176 | 177 | 178 | class HasNumberOfWorkers(Params): 179 | ''' 180 | Parameter mixin for number of workers 181 | ''' 182 | def __init__(self): 183 | super(HasNumberOfWorkers, self).__init__() 184 | self.num_workers = Param(self, "num_workers", "number of workers") 185 | self._setDefault(num_workers=8) 186 | 187 | def set_num_workers(self, num_workers): 188 | self._paramMap[self.num_workers] = num_workers 189 | return self 190 | 191 | def get_num_workers(self): 192 | return self.getOrDefault(self.num_workers) 193 | -------------------------------------------------------------------------------- /elephas/ml_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | 3 | import numpy as np 4 | 5 | from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasFeaturesCol, HasLabelCol 6 | from pyspark.ml.util import keyword_only 7 | from pyspark.sql import Row 8 | from pyspark.ml import Estimator, Model 9 | from pyspark.sql.types import StringType, DoubleType, StructField 10 | 11 | from keras.models import model_from_yaml 12 | 13 | from .spark_model import SparkModel 14 | from .utils.rdd_utils import from_vector, to_vector 15 | from .ml.adapter import df_to_simple_rdd 16 | from .ml.params import * 17 | from .optimizers import get 18 | 19 | 20 | class ElephasEstimator(Estimator, HasCategoricalLabels, HasValidationSplit, HasKerasModelConfig, HasFeaturesCol, HasLabelCol, HasMode, HasEpochs, HasBatchSize, 21 | HasFrequency, HasVerbosity, HasNumberOfClasses, HasNumberOfWorkers, HasOptimizerConfig, HasOutputCol): 22 | ''' 23 | SparkML Estimator implementation of an elephas model. This estimator takes all relevant arguments for model 24 | compilation and training. 25 | 26 | Returns a trained model in form of a SparkML Model, which is also a Transformer. 27 | ''' 28 | @keyword_only 29 | def __init__(self, keras_model_config=None, featuresCol=None, labelCol=None, optimizer_config=None, mode=None, 30 | frequency=None, num_workers=None, nb_epoch=None, batch_size=None, verbose=None, validation_split=None, 31 | categorical=None, nb_classes=None, outputCol=None): 32 | super(ElephasEstimator, self).__init__() 33 | kwargs = self.__init__._input_kwargs 34 | self.set_params(**kwargs) 35 | 36 | @keyword_only 37 | def set_params(self, keras_model_config=None, featuresCol=None, labelCol=None, optimizer_config=None, mode=None, 38 | frequency=None, num_workers=None, nb_epoch=None, batch_size=None, verbose=None, 39 | validation_split=None, categorical=None, nb_classes=None, outputCol=None): 40 | ''' 41 | Set all provided parameters, otherwise set defaults 42 | ''' 43 | kwargs = self.set_params._input_kwargs 44 | return self._set(**kwargs) 45 | 46 | def _fit(self, df): 47 | ''' 48 | Private fit method of the Estimator, which trains the model. 49 | ''' 50 | simple_rdd = df_to_simple_rdd(df, categorical=self.get_categorical_labels(), nb_classes=self.get_nb_classes(), 51 | featuresCol=self.getFeaturesCol(), labelCol=self.getLabelCol()) 52 | simple_rdd = simple_rdd.repartition(self.get_num_workers()) 53 | optimizer = None 54 | if self.get_optimizer_config() is not None: 55 | optimizer = get(self.get_optimizer_config()['name'], self.get_optimizer_config()) 56 | 57 | keras_model = model_from_yaml(self.get_keras_model_config()) 58 | 59 | spark_model = SparkModel(simple_rdd.ctx, keras_model, optimizer=optimizer, 60 | mode=self.get_mode(), frequency=self.get_frequency(), 61 | num_workers=self.get_num_workers()) 62 | spark_model.train(simple_rdd, nb_epoch=self.get_nb_epoch(), batch_size=self.get_batch_size(), 63 | verbose=self.get_verbosity(), validation_split=self.get_validation_split()) 64 | 65 | model_weights = spark_model.master_network.get_weights() 66 | weights = simple_rdd.ctx.broadcast(model_weights) 67 | return ElephasTransformer(labelCol=self.getLabelCol(), 68 | outputCol='prediction', # TODO: Set default value 69 | keras_model_config=spark_model.master_network.to_yaml(), 70 | weights=weights) 71 | 72 | 73 | class ElephasTransformer(Model, HasKerasModelConfig, HasLabelCol, HasOutputCol): 74 | ''' 75 | SparkML Transformer implementation. Contains a trained model, 76 | with which new feature data can be transformed into labels. 77 | ''' 78 | @keyword_only 79 | def __init__(self, labelCol=None, outputCol=None, keras_model_config=None, weights=None): 80 | super(ElephasTransformer, self).__init__() 81 | kwargs = self.__init__._input_kwargs 82 | self.weights = kwargs.pop('weights') # Strip model weights from parameters to init Transformer 83 | self.set_params(**kwargs) 84 | 85 | @keyword_only 86 | def set_params(self, labelCol=None, outputCol=None, keras_model_config=None): 87 | ''' 88 | Set all provided parameters, otherwise set defaults 89 | ''' 90 | kwargs = self.set_params._input_kwargs 91 | return self._set(**kwargs) 92 | 93 | def get_model(self): 94 | return model_from_yaml(self.get_keras_model_config()) 95 | 96 | def _transform(self, df): 97 | ''' 98 | Private transform method of a Transformer. This serves as batch-prediction method for our purposes. 99 | ''' 100 | outputCol = self.getOutputCol() 101 | labelCol = self.getLabelCol() 102 | new_schema = df.schema 103 | new_schema.add(StructField(outputCol, StringType(), True)) 104 | 105 | rdd = df.rdd.coalesce(1) 106 | features = np.asarray(rdd.map(lambda x: from_vector(x.features)).collect()) 107 | # Note that we collect, since executing this on the rdd would require model serialization once again 108 | model = model_from_yaml(self.get_keras_model_config()) 109 | model.set_weights(self.weights.value) 110 | predictions = rdd.ctx.parallelize(model.predict_classes(features)).coalesce(1) 111 | predictions = predictions.map(lambda x: tuple(str(x))) 112 | 113 | results_rdd = rdd.zip(predictions).map(lambda x: x[0] + x[1]) 114 | # TODO: Zipping like this is very likely wrong 115 | # results_rdd = rdd.zip(predictions).map(lambda pair: Row(features=to_vector(pair[0].features), 116 | # label=pair[0].label, prediction=float(pair[1]))) 117 | results_df = df.sql_ctx.createDataFrame(results_rdd, new_schema) 118 | results_df = results_df.withColumn(outputCol, results_df[outputCol].cast(DoubleType())) 119 | results_df = results_df.withColumn(labelCol, results_df[labelCol].cast(DoubleType())) 120 | 121 | return results_df 122 | -------------------------------------------------------------------------------- /elephas/optimizers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This is essentially a copy of keras' optimizers.py. 3 | We have to modify the base class 'Optimizer' here, 4 | as the gradients will be provided by the Spark workers, 5 | not by one of the backends (Theano or Tensorflow). 6 | ''' 7 | from __future__ import absolute_import 8 | from keras import backend as K 9 | from keras.utils.generic_utils import get_from_module 10 | import numpy as np 11 | 12 | from six.moves import zip 13 | 14 | 15 | def clip_norm(g, c, n): 16 | ''' Clip gradients ''' 17 | if c > 0: 18 | g = K.switch(K.ge(n, c), g * c / n, g) 19 | return g 20 | 21 | 22 | def kl_divergence(p, p_hat): 23 | ''' Kullbach-Leibler divergence ''' 24 | return p_hat - p + p * K.log(p / p_hat) 25 | 26 | 27 | class Optimizer(object): 28 | ''' 29 | Optimizer for elephas models, adapted from 30 | respective Keras module. 31 | ''' 32 | def __init__(self, **kwargs): 33 | self.__dict__.update(kwargs) 34 | self.updates = [] 35 | 36 | def get_state(self): 37 | ''' Get latest status of optimizer updates ''' 38 | return [u[0].get_value() for u in self.updates] 39 | 40 | def set_state(self, value_list): 41 | ''' Set current status of optimizer ''' 42 | assert len(self.updates) == len(value_list) 43 | for u, v in zip(self.updates, value_list): 44 | u[0].set_value(v) 45 | 46 | def get_updates(self, params, constraints, grads): 47 | ''' Compute updates from gradients and constraints ''' 48 | raise NotImplementedError 49 | 50 | def get_gradients(self, grads, params): 51 | 52 | if hasattr(self, 'clipnorm') and self.clipnorm > 0: 53 | norm = K.sqrt(sum([K.sum(g ** 2) for g in grads])) 54 | grads = [clip_norm(g, self.clipnorm, norm) for g in grads] 55 | 56 | if hasattr(self, 'clipvalue') and self.clipvalue > 0: 57 | grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads] 58 | 59 | return K.shared(grads) 60 | 61 | def get_config(self): 62 | ''' Get configuration dictionary ''' 63 | return {"name": self.__class__.__name__} 64 | 65 | 66 | class SGD(Optimizer): 67 | ''' SGD, optionally with nesterov momentum ''' 68 | def __init__(self, lr=0.01, momentum=0., decay=0., 69 | nesterov=False, *args, **kwargs): 70 | super(SGD, self).__init__(**kwargs) 71 | self.__dict__.update(locals()) 72 | self.iterations = 0 73 | self.lr = lr 74 | self.momentum = momentum 75 | self.decay = decay 76 | 77 | def get_updates(self, params, constraints, grads): 78 | lr = self.lr * (1.0 / (1.0 + self.decay * self.iterations)) 79 | self.updates = [(self.iterations, self.iterations + 1.)] 80 | new_weights = [] 81 | 82 | for p, g, c in zip(params, grads, constraints): 83 | m = np.zeros_like(p) # momentum 84 | v = self.momentum * m - lr * g # velocity 85 | if self.nesterov: 86 | new_p = p + self.momentum * v - lr * g 87 | else: 88 | new_p = p + v 89 | new_weights.append(c(new_p)) 90 | 91 | return new_weights 92 | 93 | def get_config(self): 94 | return {"name": self.__class__.__name__, 95 | "lr": float(self.lr), 96 | "momentum": float(self.momentum), 97 | "decay": float(self.decay), 98 | "nesterov": self.nesterov} 99 | 100 | 101 | class RMSprop(Optimizer): 102 | ''' 103 | Reference: www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf 104 | ''' 105 | def __init__(self, lr=0.001, rho=0.9, epsilon=1e-6, *args, **kwargs): 106 | super(RMSprop, self).__init__(**kwargs) 107 | self.__dict__.update(locals()) 108 | self.lr = lr 109 | self.rho = rho 110 | 111 | def get_updates(self, params, constraints, grads): 112 | accumulators = [np.zeros_like(p) for p in params] 113 | new_weights = [] 114 | 115 | for p, g, a, c in zip(params, grads, accumulators, constraints): 116 | new_a = self.rho * a + (1 - self.rho) * g ** 2 117 | self.updates.append((a, new_a)) 118 | 119 | new_p = p - self.lr * g / np.sqrt(new_a + self.epsilon) 120 | new_weights.append(c(new_p)) 121 | 122 | return new_weights 123 | 124 | def get_config(self): 125 | return {"name": self.__class__.__name__, 126 | "lr": float(self.lr), 127 | "rho": float(self.rho), 128 | "epsilon": self.epsilon} 129 | 130 | 131 | class Adagrad(Optimizer): 132 | ''' 133 | Reference: http://www.magicbroom.info/Papers/DuchiHaSi10.pdf 134 | ''' 135 | def __init__(self, lr=0.01, epsilon=1e-6, *args, **kwargs): 136 | super(Adagrad, self).__init__(**kwargs) 137 | self.__dict__.update(locals()) 138 | self.lr = lr 139 | 140 | def get_updates(self, params, constraints, grads): 141 | accumulators = [np.zeros_like(p) for p in params] 142 | new_weights = [] 143 | for p, g, a, c in zip(params, grads, accumulators, constraints): 144 | new_a = a + g ** 2 145 | new_p = p - self.lr * g / np.sqrt(new_a + self.epsilon) 146 | new_weights.append(new_p) 147 | 148 | return new_weights 149 | 150 | def get_config(self): 151 | return {"name": self.__class__.__name__, 152 | "lr": float(self.lr), 153 | "epsilon": self.epsilon} 154 | 155 | 156 | class Adadelta(Optimizer): 157 | ''' 158 | Reference: http://arxiv.org/abs/1212.5701 159 | ''' 160 | def __init__(self, lr=1.0, rho=0.95, epsilon=1e-6, *args, **kwargs): 161 | super(Adadelta, self).__init__(**kwargs) 162 | self.__dict__.update(locals()) 163 | self.lr = lr 164 | 165 | def get_updates(self, params, constraints, grads): 166 | accumulators = [np.zeros_like(p) for p in params] 167 | delta_accumulators = [np.zeros_like(p) for p in params] 168 | new_weights = [] 169 | 170 | for p, g, a, d_a, c in zip(params, grads, accumulators, 171 | delta_accumulators, constraints): 172 | new_a = self.rho * a + (1 - self.rho) * g ** 2 173 | self.updates.append((a, new_a)) 174 | # use the new accumulator and the *old* delta_accumulator 175 | div = np.sqrt(new_a + self.epsilon) 176 | update = g * np.sqrt(d_a + self.epsilon) / div 177 | new_p = p - self.lr * update 178 | self.updates.append((p, c(new_p))) # apply constraints 179 | 180 | new_weights.append(new_p) 181 | return new_weights 182 | 183 | def get_config(self): 184 | return {"name": self.__class__.__name__, 185 | "lr": float(self.lr), 186 | "rho": self.rho, 187 | "epsilon": self.epsilon} 188 | 189 | 190 | class Adam(Optimizer): 191 | ''' 192 | Reference: http://arxiv.org/abs/1412.6980v8 193 | Default parameters follow those provided in the original paper. 194 | ''' 195 | def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, 196 | epsilon=1e-8, *args, **kwargs): 197 | super(Adam, self).__init__(**kwargs) 198 | self.__dict__.update(locals()) 199 | self.iterations = 0 200 | self.lr = lr 201 | 202 | def get_updates(self, params, constraints, grads): 203 | new_weights = [] 204 | 205 | t = self.iterations + 1 206 | lr_t = self.lr * np.sqrt(1-self.beta_2**t)/(1-self.beta_1**t) 207 | 208 | for p, g, c in zip(params, grads, constraints): 209 | m = np.zeros_like(p) # zero init of moment 210 | v = np.zeros_like(p) # zero init of velocity 211 | 212 | m_t = (self.beta_1 * m) + (1 - self.beta_1) * g 213 | v_t = (self.beta_2 * v) + (1 - self.beta_2) * (g**2) 214 | p_t = p - lr_t * m_t / (np.sqrt(v_t) + self.epsilon) 215 | new_weights.append(c(p_t)) 216 | 217 | return new_weights 218 | 219 | def get_config(self): 220 | return {"name": self.__class__.__name__, 221 | "lr": float(self.lr), 222 | "beta_1": self.beta_1, 223 | "beta_2": self.beta_2, 224 | "epsilon": self.epsilon} 225 | 226 | # aliases 227 | sgd = SGD 228 | rmsprop = RMSprop 229 | adagrad = Adagrad 230 | adadelta = Adadelta 231 | adam = Adam 232 | 233 | 234 | def get(identifier, kwargs=None): 235 | return get_from_module(identifier, globals(), 'optimizer', 236 | instantiate=True, kwargs=kwargs) 237 | -------------------------------------------------------------------------------- /elephas/spark_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import numpy as np 5 | from itertools import tee 6 | import socket 7 | from multiprocessing import Process 8 | import six.moves.cPickle as pickle 9 | from six.moves import range 10 | from flask import Flask, request 11 | try: 12 | import urllib.request as urllib2 13 | except ImportError: 14 | import urllib2 15 | 16 | from pyspark.mllib.linalg import Matrix, Vector 17 | 18 | from .utils.rwlock import RWLock 19 | from .utils.functional_utils import subtract_params 20 | from .utils.rdd_utils import lp_to_simple_rdd 21 | from .mllib.adapter import to_matrix, from_matrix, to_vector, from_vector 22 | from .optimizers import SGD as default_optimizer 23 | 24 | from keras.models import model_from_yaml 25 | 26 | def get_server_weights(master_url='localhost:5000'): 27 | ''' 28 | Retrieve master weights from parameter server 29 | ''' 30 | request = urllib2.Request('http://{0}/parameters'.format(master_url), 31 | headers={'Content-Type': 'application/elephas'}) 32 | ret = urllib2.urlopen(request).read() 33 | weights = pickle.loads(ret) 34 | return weights 35 | 36 | 37 | def put_deltas_to_server(delta, master_url='localhost:5000'): 38 | ''' 39 | Update master parameters with deltas from training process 40 | ''' 41 | request = urllib2.Request('http://{0}/update'.format(master_url), 42 | pickle.dumps(delta, -1), headers={'Content-Type': 'application/elephas'}) 43 | return urllib2.urlopen(request).read() 44 | 45 | 46 | class SparkModel(object): 47 | ''' 48 | SparkModel is the main abstraction of elephas. Every other model 49 | should inherit from it. 50 | ''' 51 | def __init__(self, sc, master_network, optimizer=None, mode='asynchronous', frequency='epoch', 52 | num_workers=4, 53 | master_optimizer="adam", 54 | master_loss="categorical_crossentropy", 55 | master_metrics=None, 56 | custom_objects=None, 57 | *args, **kwargs): 58 | 59 | 60 | self.spark_context = sc 61 | self._master_network = master_network 62 | if custom_objects is None: 63 | custom_objects = {} 64 | if master_metrics is None: 65 | master_metrics = ["accuracy"] 66 | if optimizer is None: 67 | self.optimizer = default_optimizer() 68 | else: 69 | self.optimizer = optimizer 70 | self.mode = mode 71 | self.frequency = frequency 72 | self.num_workers = num_workers 73 | self.weights = master_network.get_weights() 74 | self.pickled_weights = None 75 | self.lock = RWLock() 76 | self.master_optimizer = master_optimizer 77 | self.master_loss = master_loss 78 | self.master_metrics = master_metrics 79 | self.custom_objects = custom_objects 80 | 81 | @staticmethod 82 | def determine_master(): 83 | ''' 84 | Get URL of parameter server, running on master 85 | ''' 86 | master_url = socket.gethostbyname(socket.gethostname()) + ':5000' 87 | return master_url 88 | 89 | def get_train_config(self, nb_epoch, batch_size, 90 | verbose, validation_split): 91 | ''' 92 | Get configuration of training parameters 93 | ''' 94 | train_config = {} 95 | train_config['nb_epoch'] = nb_epoch 96 | train_config['batch_size'] = batch_size 97 | train_config['verbose'] = verbose 98 | train_config['validation_split'] = validation_split 99 | return train_config 100 | 101 | def get_config(self): 102 | ''' 103 | Get configuration of model parameters 104 | ''' 105 | model_config = {} 106 | model_config['model'] = self.master_network.get_config() 107 | model_config['optimizer'] = self.optimizer.get_config() 108 | model_config['mode'] = self.mode 109 | return model_config 110 | 111 | @property 112 | def master_network(self): 113 | ''' Get master network ''' 114 | return self._master_network 115 | 116 | @master_network.setter 117 | def master_network(self, network): 118 | ''' Set master network ''' 119 | self._master_network = network 120 | 121 | def start_server(self): 122 | ''' Start parameter server''' 123 | self.server = Process(target=self.start_service) 124 | self.server.start() 125 | 126 | def stop_server(self): 127 | ''' Terminate parameter server''' 128 | self.server.terminate() 129 | self.server.join() 130 | 131 | def start_service(self): 132 | ''' Define service and run flask app''' 133 | app = Flask(__name__) 134 | self.app = app 135 | 136 | @app.route('/') 137 | def home(): 138 | return 'Elephas' 139 | 140 | @app.route('/parameters', methods=['GET']) 141 | def get_parameters(): 142 | if self.mode == 'asynchronous': 143 | self.lock.acquire_read() 144 | self.pickled_weights = pickle.dumps(self.weights, -1) 145 | pickled_weights = self.pickled_weights 146 | if self.mode == 'asynchronous': 147 | self.lock.release() 148 | return pickled_weights 149 | 150 | @app.route('/update', methods=['POST']) 151 | def update_parameters(): 152 | delta = pickle.loads(request.data) 153 | if self.mode == 'asynchronous': 154 | self.lock.acquire_write() 155 | constraints = self.master_network.constraints 156 | if len(constraints) == 0: 157 | def empty(a): return a 158 | constraints = [empty for x in self.weights] 159 | self.weights = self.optimizer.get_updates(self.weights, constraints, delta) 160 | if self.mode == 'asynchronous': 161 | self.lock.release() 162 | return 'Update done' 163 | 164 | self.app.run(host='0.0.0.0', debug=True, 165 | threaded=True, use_reloader=False) 166 | 167 | def predict(self, data): 168 | ''' 169 | Get prediction probabilities for a numpy array of features 170 | ''' 171 | return self.master_network.predict(data) 172 | 173 | def predict_classes(self, data): 174 | ''' 175 | Predict classes for a numpy array of features 176 | ''' 177 | return self.master_network.predict_classes(data) 178 | 179 | def train(self, rdd, nb_epoch=10, batch_size=32, 180 | verbose=0, validation_split=0.1): 181 | ''' 182 | Train an elephas model. 183 | ''' 184 | rdd = rdd.repartition(self.num_workers) 185 | master_url = self.determine_master() 186 | 187 | if self.mode in ['asynchronous', 'synchronous', 'hogwild']: 188 | self._train(rdd, nb_epoch, batch_size, verbose, validation_split, master_url) 189 | else: 190 | print("""Choose from one of the modes: asynchronous, synchronous or hogwild""") 191 | 192 | def _train(self, rdd, nb_epoch=10, batch_size=32, verbose=0, 193 | validation_split=0.1, master_url='localhost:5000'): 194 | ''' 195 | Protected train method to make wrapping of modes easier 196 | ''' 197 | self.master_network.compile(optimizer=self.master_optimizer, loss=self.master_loss, metrics=self.master_metrics) 198 | if self.mode in ['asynchronous', 'hogwild']: 199 | self.start_server() 200 | yaml = self.master_network.to_yaml() 201 | train_config = self.get_train_config(nb_epoch, batch_size, 202 | verbose, validation_split) 203 | if self.mode in ['asynchronous', 'hogwild']: 204 | worker = AsynchronousSparkWorker( 205 | yaml, train_config, self.frequency, master_url, 206 | self.master_optimizer, self.master_loss, self.master_metrics, self.custom_objects 207 | ) 208 | rdd.mapPartitions(worker.train).collect() 209 | new_parameters = get_server_weights(master_url) 210 | elif self.mode == 'synchronous': 211 | init = self.master_network.get_weights() 212 | parameters = self.spark_context.broadcast(init) 213 | worker = SparkWorker(yaml, parameters, train_config) 214 | deltas = rdd.mapPartitions(worker.train).collect() 215 | new_parameters = self.master_network.get_weights() 216 | for delta in deltas: 217 | constraints = self.master_network.constraints 218 | new_parameters = self.optimizer.get_updates(self.weights, constraints, delta) 219 | self.master_network.set_weights(new_parameters) 220 | if self.mode in ['asynchronous', 'hogwild']: 221 | self.stop_server() 222 | 223 | 224 | class SparkWorker(object): 225 | ''' 226 | Synchronous Spark worker. This code will be executed on workers. 227 | ''' 228 | def __init__(self, yaml, parameters, train_config, master_optimizer, master_loss, master_metrics, custom_objects): 229 | self.yaml = yaml 230 | self.parameters = parameters 231 | self.train_config = train_config 232 | self.master_optimizer = master_optimizer 233 | self.master_loss = master_loss 234 | self.master_metrics = master_metrics 235 | self.custom_objects = custom_objects 236 | 237 | def train(self, data_iterator): 238 | ''' 239 | Train a keras model on a worker 240 | ''' 241 | feature_iterator, label_iterator = tee(data_iterator, 2) 242 | x_train = np.asarray([x for x, y in feature_iterator]) 243 | y_train = np.asarray([y for x, y in label_iterator]) 244 | 245 | model = model_from_yaml(self.yaml, self.custom_objects) 246 | model.compile(optimizer=self.master_optimizer, loss=self.master_loss, metrics=self.master_metrics) 247 | model.set_weights(self.parameters.value) 248 | weights_before_training = model.get_weights() 249 | if x_train.shape[0] > self.train_config.get('batch_size'): 250 | model.fit(x_train, y_train, **self.train_config) 251 | weights_after_training = model.get_weights() 252 | deltas = subtract_params(weights_before_training, weights_after_training) 253 | yield deltas 254 | 255 | 256 | class AsynchronousSparkWorker(object): 257 | ''' 258 | Asynchronous Spark worker. This code will be executed on workers. 259 | ''' 260 | def __init__(self, yaml, train_config, frequency, master_url, master_optimizer, master_loss, master_metrics, custom_objects): 261 | self.yaml = yaml 262 | self.train_config = train_config 263 | self.frequency = frequency 264 | self.master_url = master_url 265 | self.master_optimizer = master_optimizer 266 | self.master_loss = master_loss 267 | self.master_metrics = master_metrics 268 | self.custom_objects = custom_objects 269 | 270 | 271 | def train(self, data_iterator): 272 | ''' 273 | Train a keras model on a worker and send asynchronous updates 274 | to parameter server 275 | ''' 276 | feature_iterator, label_iterator = tee(data_iterator, 2) 277 | x_train = np.asarray([x for x, y in feature_iterator]) 278 | y_train = np.asarray([y for x, y in label_iterator]) 279 | 280 | if x_train.size == 0: 281 | return 282 | 283 | model = model_from_yaml(self.yaml, self.custom_objects) 284 | model.compile(optimizer=self.master_optimizer, loss=self.master_loss, metrics=self.master_metrics) 285 | 286 | nb_epoch = self.train_config['nb_epoch'] 287 | batch_size = self.train_config.get('batch_size') 288 | nb_train_sample = len(x_train[0]) 289 | nb_batch = int(np.ceil(nb_train_sample/float(batch_size))) 290 | index_array = np.arange(nb_train_sample) 291 | batches = [(i*batch_size, min(nb_train_sample, (i+1)*batch_size)) for i in range(0, nb_batch)] 292 | 293 | if self.frequency == 'epoch': 294 | for epoch in range(nb_epoch): 295 | weights_before_training = get_server_weights(self.master_url) 296 | model.set_weights(weights_before_training) 297 | self.train_config['nb_epoch'] = 1 298 | if x_train.shape[0] > batch_size: 299 | model.fit(x_train, y_train, **self.train_config) 300 | weights_after_training = model.get_weights() 301 | deltas = subtract_params(weights_before_training, weights_after_training) 302 | put_deltas_to_server(deltas, self.master_url) 303 | elif self.frequency == 'batch': 304 | from keras.engine.training import slice_X 305 | for epoch in range(nb_epoch): 306 | if x_train.shape[0] > batch_size: 307 | for (batch_start, batch_end) in batches: 308 | weights_before_training = get_server_weights(self.master_url) 309 | model.set_weights(weights_before_training) 310 | batch_ids = index_array[batch_start:batch_end] 311 | X = slice_X(x_train, batch_ids) 312 | y = slice_X(y_train, batch_ids) 313 | model.train_on_batch(X, y) 314 | weights_after_training = model.get_weights() 315 | deltas = subtract_params(weights_before_training, weights_after_training) 316 | put_deltas_to_server(deltas, self.master_url) 317 | else: 318 | print('Choose frequency to be either batch or epoch') 319 | yield [] 320 | 321 | 322 | class SparkMLlibModel(SparkModel): 323 | ''' 324 | MLlib model takes RDDs of LabeledPoints. Internally we just convert 325 | back to plain old pair RDDs and continue as in SparkModel 326 | ''' 327 | def __init__(self, sc, master_network, optimizer=None, mode='asynchronous', frequency='epoch', num_workers=4, 328 | master_optimizer="adam", 329 | master_loss="categorical_crossentropy", 330 | master_metrics=None, 331 | custom_objects=None): 332 | SparkModel.__init__(self, sc, master_network, optimizer, mode, frequency, num_workers, 333 | master_optimizer=master_optimizer, master_loss=master_loss, master_metrics=master_metrics, 334 | custom_objects=custom_objects) 335 | 336 | def train(self, labeled_points, nb_epoch=10, batch_size=32, verbose=0, validation_split=0.1, 337 | categorical=False, nb_classes=None): 338 | ''' 339 | Train an elephas model on an RDD of LabeledPoints 340 | ''' 341 | rdd = lp_to_simple_rdd(labeled_points, categorical, nb_classes) 342 | rdd = rdd.repartition(self.num_workers) 343 | self._train(rdd, nb_epoch, batch_size, verbose, validation_split) 344 | 345 | def predict(self, mllib_data): 346 | ''' 347 | Predict probabilities for an RDD of features 348 | ''' 349 | if isinstance(mllib_data, Matrix): 350 | return to_matrix(self.master_network.predict(from_matrix(mllib_data))) 351 | elif isinstance(mllib_data, Vector): 352 | return to_vector(self.master_network.predict(from_vector(mllib_data))) 353 | else: 354 | print('Provide either an MLLib matrix or vector') 355 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Elephas: Distributed Deep Learning with Keras & Spark [![Build Status](https://travis-ci.org/maxpumperla/elephas.svg?branch=master)](https://travis-ci.org/maxpumperla/elephas) 2 | 3 | Elephas is an extension of [Keras](http://keras.io), which allows you to run distributed deep learning models at scale with [Spark](http://spark.apache.org). Elephas currently supports a number of applications, including: 4 | 5 | - [Data-parallel training of deep learning models](#usage-of-data-parallel-models) 6 | - [Distributed hyper-parameter optimization](#distributed-hyper-parameter-optimization) 7 | - [Distributed training of ensemble models](#distributed-training-of-ensemble-models) 8 | 9 | 10 | Schematically, elephas works as follows. 11 | 12 | ![Elephas](elephas.gif) 13 | 14 | Table of content: 15 | - [Elephas: Distributed Deep Learning with Keras & Spark](#elephas-distributed-deep-learning-with-keras-&-spark-) 16 | - [Introduction](#introduction) 17 | - [Getting started](#getting-started) 18 | - [Installation](#installation) 19 | - [Basic example](#basic-example) 20 | - [Spark ML example](#spark-ml-example) 21 | - [Usage of data-parallel models](#usage-of-data-parallel-models) 22 | - [Model updates (optimizers)](#model-updates-optimizers) 23 | - [Update frequency](#update-frequency) 24 | - [Update mode](#update-mode) 25 | - [Asynchronous updates with read and write locks (`mode='asynchronous'`)](#asynchronous-updates-with-read-and-write-locks-modeasynchronous) 26 | - [Asynchronous updates without locks (`mode='hogwild'`)](#asynchronous-updates-without-locks-modehogwild) 27 | - [Synchronous updates (`mode='synchronous'`)](#synchronous-updates-modesynchronous) 28 | - [Degree of parallelization (number of workers)](#degree-of-parallelization-number-of-workers) 29 | - [Distributed hyper-parameter optimization](#distributed-hyper-parameter-optimization) 30 | - [Distributed training of ensemble models](#distributed-training-of-ensemble-models) 31 | - [Discussion](#discussion) 32 | - [Future work & contributions](#future-work-&-contributions) 33 | - [Literature](#literature) 34 | 35 | ## Introduction 36 | Elephas brings deep learning with [Keras](http://keras.io) to [Spark](http://spark.apache.org). Elephas intends to keep the simplicity and high usability of Keras, thereby allowing for fast prototyping of distributed models, which can be run on massive data sets. For an introductory example, see the following [iPython notebook](https://github.com/maxpumperla/elephas/blob/master/examples/Spark_ML_Pipeline.ipynb). 37 | 38 | ἐλέφας is Greek for _ivory_ and an accompanying project to κέρας, meaning _horn_. If this seems weird mentioning, like a bad dream, you should confirm it actually is at the [Keras documentation](https://github.com/fchollet/keras/blob/master/README.md). Elephas also means _elephant_, as in stuffed yellow elephant. 39 | 40 | Elephas implements a class of data-parallel algorithms on top of Keras, using Spark's RDDs and data frames. Keras Models are initialized on the driver, then serialized and shipped to workers, alongside with data and broadcasted model parameters. Spark workers deserialize the model, train their chunk of data and send their gradients back to the driver. The "master" model on the driver is updated by an optimizer, which takes gradients either synchronously or asynchronously. 41 | 42 | ## Getting started 43 | 44 | ### Installation 45 | Install elephas from PyPI with 46 | ``` 47 | pip install elephas 48 | ``` 49 | Depending on what OS you are using, you may need to install some prerequisite modules (LAPACK, BLAS, fortran compiler) first. 50 | 51 | For example, on Ubuntu Linux: 52 | ``` 53 | sudo apt-get install liblapack-dev libblas-dev gfortran 54 | ``` 55 | 56 | A quick way to install Spark locally is to use homebrew on Mac 57 | ``` 58 | brew install spark 59 | ``` 60 | or linuxbrew on linux. 61 | ``` 62 | brew install apache-spark 63 | ``` 64 | The brew version of Spark may be outdated at times. To build from source, simply follow the instructions at the [Spark download section](http://spark.apache.org/downloads.html) or use the following commands. 65 | ``` 66 | wget http://apache.mirrors.tds.net/spark/spark-1.5.2/spark-1.5.2-bin-hadoop2.6.tgz -P ~ 67 | sudo tar zxvf ~/spark-* -C /usr/local 68 | sudo mv /usr/local/spark-* /usr/local/spark 69 | ``` 70 | After that, make sure to put these path variables to your shell profile (e.g. `~/.zshrc`): 71 | ``` 72 | export SPARK_HOME=/usr/local/spark 73 | export PATH=$PATH:$SPARK_HOME/bin 74 | ``` 75 | 76 | ### Basic example 77 | After installing both Elephas and Spark, training a model is done schematically as follows: 78 | 79 | - Create a local pyspark context 80 | ```python 81 | from pyspark import SparkContext, SparkConf 82 | conf = SparkConf().setAppName('Elephas_App').setMaster('local[8]') 83 | sc = SparkContext(conf=conf) 84 | ``` 85 | 86 | - Define and compile a Keras model 87 | ```python 88 | model = Sequential() 89 | model.add(Dense(128, input_dim=784)) 90 | model.add(Activation('relu')) 91 | model.add(Dropout(0.2)) 92 | model.add(Dense(128)) 93 | model.add(Activation('relu')) 94 | model.add(Dropout(0.2)) 95 | model.add(Dense(10)) 96 | model.add(Activation('softmax')) 97 | model.compile(loss='categorical_crossentropy', optimizer=SGD()) 98 | ``` 99 | 100 | - Create an RDD from numpy arrays 101 | ```python 102 | from elephas.utils.rdd_utils import to_simple_rdd 103 | rdd = to_simple_rdd(sc, X_train, Y_train) 104 | ``` 105 | 106 | - A SparkModel is defined by passing Spark context and Keras model. Additionally, one has choose an optimizer used for updating the elephas model, an update frequency, a parallelization mode and the degree of parallelism, i.e. the number of workers. 107 | ```python 108 | from elephas.spark_model import SparkModel 109 | from elephas import optimizers as elephas_optimizers 110 | 111 | adagrad = elephas_optimizers.Adagrad() 112 | spark_model = SparkModel(sc,model, optimizer=adagrad, frequency='epoch', mode='asynchronous', num_workers=2) 113 | spark_model.train(rdd, nb_epoch=20, batch_size=32, verbose=0, validation_split=0.1, num_workers=8) 114 | ``` 115 | 116 | - Run your script using spark-submit 117 | ``` 118 | spark-submit --driver-memory 1G ./your_script.py 119 | ``` 120 | Increasing the driver memory even further may be necessary, as the set of parameters in a network may be very large and collecting them on the driver eats up a lot of resources. See the examples folder for a few working examples. 121 | 122 | ### Spark MLlib example 123 | Following up on the last example, to create an RDD of LabeledPoints for supervised training from pairs of numpy arrays, use 124 | ```python 125 | from elephas.utils.rdd_utils import to_labeled_point 126 | lp_rdd = to_labeled_point(sc, X_train, Y_train, categorical=True) 127 | ``` 128 | Training a given LabeledPoint-RDD is very similar to what we've seen already 129 | ```python 130 | from elephas.spark_model import SparkMLlibModel 131 | adadelta = elephas_optimizers.Adadelta() 132 | spark_model = SparkMLlibModel(sc,model, optimizer=adadelta, frequency='batch', mode='hogwild', num_workers=2) 133 | spark_model.train(lp_rdd, nb_epoch=20, batch_size=32, verbose=0, validation_split=0.1, categorical=True, nb_classes=nb_classes) 134 | ``` 135 | 136 | ### Spark ML example 137 | To train a model with a SparkML estimator on a data frame, use the following syntax. 138 | ```python 139 | df = to_data_frame(sc, X_train, Y_train, categorical=True) 140 | test_df = to_data_frame(sc, X_test, Y_test, categorical=True) 141 | 142 | adadelta = elephas_optimizers.Adadelta() 143 | estimator = ElephasEstimator(sc,model, 144 | nb_epoch=nb_epoch, batch_size=batch_size, optimizer=adadelta, frequency='batch', mode='asynchronous', num_workers=2, 145 | verbose=0, validation_split=0.1, categorical=True, nb_classes=nb_classes) 146 | 147 | fitted_model = estimator.fit(df) 148 | ``` 149 | 150 | Fitting an estimator results in a SparkML transformer, which we can use for predictions and other evaluations by calling the transform method on it. 151 | 152 | ``` python 153 | prediction = fitted_model.transform(test_df) 154 | pnl = prediction.select("label", "prediction") 155 | pnl.show(100) 156 | 157 | prediction_and_label= pnl.map(lambda row: (row.label, row.prediction)) 158 | metrics = MulticlassMetrics(prediction_and_label) 159 | print(metrics.precision()) 160 | print(metrics.recall()) 161 | ``` 162 | 163 | ## Usage of data-parallel models 164 | 165 | In the first example above we have seen that an elephas model is instantiated like this 166 | 167 | ```python 168 | spark_model = SparkModel(sc,model, optimizer=adagrad, frequency='epoch', mode='asynchronous', num_workers=2) 169 | ``` 170 | So, apart from the canonical Spark context and Keras model, Elephas models have four parameters to tune and we will describe each of them next. 171 | 172 | ### Model updates (optimizers) 173 | 174 | `optimizer`: The optimizers module in elephas is an adaption of the same module in keras, i.e. it provides the user with the following list of optimizers: 175 | 176 | - `SGD` 177 | - `RMSprop` 178 | - `Adagrad` 179 | - `Adadelta` 180 | - `Adam` 181 | 182 | Once constructed, each of these can be passed to the *optimizer* parameter of the model. Updates in keras are computed with the help of theano, so most of the data structures in keras optimizers stem from theano. In elephas, gradients have already been computed by the respective workers, so it makes sense to entirely work with numpy arrays internally. 183 | 184 | Note that in order to set up an elephas model, you have to specify two optimizers, one for elephas and one for the underlying keras model. Individual workers produce updates according to keras optimizers and the "master" model on the driver uses elephas optimizers to aggregate them. For starters, we recommend keras models with SGD and elephas models with Adagrad or Adadelta. 185 | 186 | ### Update frequency 187 | 188 | `frequency`: The user can decide how often updates are passed to the master model by controlling the *frequency* parameter. To update every batch, choose 'batch' and to update only after every epoch, choose 'epoch'. 189 | 190 | ### Update mode 191 | 192 | `mode`: Currently, there's three different modes available in elephas, each corresponding to a different heuristic or parallelization scheme adopted, which is controlled by the *mode* parameter. The default property is 'asynchronous'. 193 | 194 | #### Asynchronous updates with read and write locks (`mode='asynchronous'`) 195 | 196 | This mode implements the algorithm described as *downpour* in [1], i.e. each worker can send updates whenever they are ready. The master model makes sure that no update gets lost, i.e. multiple updates get applied at the "same" time, by locking the master parameters while reading and writing parameters. This idea has been used in Google's DistBelief framework. 197 | 198 | #### Asynchronous updates without locks (`mode='hogwild'`) 199 | Essentially the same procedure as above, but without requiring the locks. This heuristic assumes that we still fare well enough, even if we loose an update here or there. Updating parameters lock-free in a non-distributed setting for SGD goes by the name 'Hogwild!' [2], it's distributed extension is called 'Dogwild!' [3]. 200 | 201 | #### Synchronous updates (`mode='synchronous'`) 202 | 203 | In this mode each worker sends a new batch of parameter updates at the same time, which are then processed on the master. Accordingly, this algorithm is sometimes called *batch synchronous parallel* or just BSP. 204 | 205 | ### Degree of parallelization (number of workers) 206 | 207 | `num_workers`: Lastly, the degree to which we parallelize our training data is controlled by the parameter *num_workers*. 208 | 209 | ## Distributed hyper-parameter optimization 210 | 211 | Hyper-parameter optimization with elephas is based on [hyperas](https://github.com/maxpumperla/hyperas), a convenience wrapper for hyperopt and keras. Make sure to have at least version ```0.1.2``` of hyperas installed. Each Spark worker executes a number of trials, the results get collected and the best model is returned. As the distributed mode in hyperopt (using MongoDB), is somewhat difficult to configure and error prone at the time of writing, we chose to implement parallelization ourselves. Right now, the only available optimization algorithm is random search. 212 | 213 | The first part of this example is more or less directly taken from the hyperas documentation. We define data and model as functions, hyper-parameter ranges are defined through braces. See the hyperas documentation for more on how this works. 214 | 215 | ```{python} 216 | from __future__ import print_function 217 | from hyperopt import Trials, STATUS_OK, tpe 218 | from hyperas.distributions import choice, uniform 219 | 220 | def data(): 221 | ''' 222 | Data providing function: 223 | 224 | Make sure to have every relevant import statement included here and return data as 225 | used in model function below. This function is separated from model() so that hyperopt 226 | won't reload data for each evaluation run. 227 | ''' 228 | from keras.datasets import mnist 229 | from keras.utils import np_utils 230 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 231 | X_train = X_train.reshape(60000, 784) 232 | X_test = X_test.reshape(10000, 784) 233 | X_train = X_train.astype('float32') 234 | X_test = X_test.astype('float32') 235 | X_train /= 255 236 | X_test /= 255 237 | nb_classes = 10 238 | Y_train = np_utils.to_categorical(y_train, nb_classes) 239 | Y_test = np_utils.to_categorical(y_test, nb_classes) 240 | return X_train, Y_train, X_test, Y_test 241 | 242 | 243 | def model(X_train, Y_train, X_test, Y_test): 244 | ''' 245 | Model providing function: 246 | 247 | Create Keras model with double curly brackets dropped-in as needed. 248 | Return value has to be a valid python dictionary with two customary keys: 249 | - loss: Specify a numeric evaluation metric to be minimized 250 | - status: Just use STATUS_OK and see hyperopt documentation if not feasible 251 | The last one is optional, though recommended, namely: 252 | - model: specify the model just created so that we can later use it again. 253 | ''' 254 | from keras.models import Sequential 255 | from keras.layers.core import Dense, Dropout, Activation 256 | from keras.optimizers import RMSprop 257 | 258 | model = Sequential() 259 | model.add(Dense(512, input_shape=(784,))) 260 | model.add(Activation('relu')) 261 | model.add(Dropout({{uniform(0, 1)}})) 262 | model.add(Dense({{choice([256, 512, 1024])}})) 263 | model.add(Activation('relu')) 264 | model.add(Dropout({{uniform(0, 1)}})) 265 | model.add(Dense(10)) 266 | model.add(Activation('softmax')) 267 | 268 | rms = RMSprop() 269 | model.compile(loss='categorical_crossentropy', optimizer=rms) 270 | 271 | model.fit(X_train, Y_train, 272 | batch_size={{choice([64, 128])}}, 273 | nb_epoch=1, 274 | show_accuracy=True, 275 | verbose=2, 276 | validation_data=(X_test, Y_test)) 277 | score, acc = model.evaluate(X_test, Y_test, show_accuracy=True, verbose=0) 278 | print('Test accuracy:', acc) 279 | return {'loss': -acc, 'status': STATUS_OK, 'model': model.to_yaml(), 'weights': pickle.dumps(model.get_weights())} 280 | ``` 281 | 282 | Once the basic setup is defined, running the minimization is done in just a few lines of code: 283 | 284 | ```{python} 285 | from hyperas import optim 286 | from elephas.hyperparam import HyperParamModel 287 | from pyspark import SparkContext, SparkConf 288 | 289 | # Create Spark context 290 | conf = SparkConf().setAppName('Elephas_Hyperparameter_Optimization').setMaster('local[8]') 291 | sc = SparkContext(conf=conf) 292 | 293 | # Define hyper-parameter model and run optimization 294 | hyperparam_model = HyperParamModel(sc) 295 | hyperparam_model.minimize(model=model, data=data, max_evals=5) 296 | ``` 297 | 298 | ## Distributed training of ensemble models 299 | 300 | Building on the last section, it is possible to train ensemble models with elephas by means of running hyper-parameter optimization on large search spaces and defining a resulting voting classifier on the top-n performing models. With ```data``` and ```model```` defined as above, this is a simple as running 301 | 302 | ```{python} 303 | result = hyperparam_model.best_ensemble(nb_ensemble_models=10, model=model, data=data, max_evals=5) 304 | ``` 305 | In this example an ensemble of 10 models is built, based on optimization of at most 5 runs on each of the Spark workers. 306 | 307 | ## Discussion 308 | 309 | Premature parallelization may not be the root of all evil, but it may not always be the best idea to do so. Keep in mind that more workers mean less data per worker and parallelizing a model is not an excuse for actual learning. So, if you can perfectly well fit your data into memory *and* you're happy with training speed of the model consider just using keras. 310 | 311 | One exception to this rule may be that you're already working within the Spark ecosystem and want to leverage what's there. The above SparkML example shows how to use evaluation modules from Spark and maybe you wish to further process the outcome of an elephas model down the road. In this case, we recommend to use elephas as a simple wrapper by setting num_workers=1. 312 | 313 | Note that right now elephas restricts itself to data-parallel algorithms for two reasons. First, Spark simply makes it very easy to distribute data. Second, neither Spark nor Theano make it particularly easy to split up the actual model in parts, thus making model-parallelism practically impossible to realize. 314 | 315 | Having said all that, we hope you learn to appreciate elephas as a pretty easy to setup and use playground for data-parallel deep-learning algorithms. 316 | 317 | 318 | ## Future work & contributions 319 | 320 | Constructive feedback and pull requests for elephas are very welcome. Here's a few things we're having in mind for future development 321 | 322 | - Benchmarks for training speed and accuracy. 323 | - Some real-world tests on EC2 instances with large data sets like imagenet. 324 | 325 | ## Literature 326 | [1] J. Dean, G.S. Corrado, R. Monga, K. Chen, M. Devin, QV. Le, MZ. Mao, M’A. Ranzato, A. Senior, P. Tucker, K. Yang, and AY. Ng. [Large Scale Distributed Deep Networks](http://research.google.com/archive/large_deep_networks_nips2012.html). 327 | 328 | [2] F. Niu, B. Recht, C. Re, S.J. Wright [HOGWILD!: A Lock-Free Approach to Parallelizing Stochastic Gradient Descent](http://arxiv.org/abs/1106.5730) 329 | 330 | [3] C. Noel, S. Osindero. [Dogwild! — Distributed Hogwild for CPU & GPU](http://stanford.edu/~rezab/nips2014workshop/submits/dogwild.pdf) 331 | -------------------------------------------------------------------------------- /examples/Spark_ML_Pipeline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# Spark ML model pipelines on Distributed Deep Neural Nets\n", 10 | "\n", 11 | "This notebook describes how to build machine learning [pipelines with Spark ML](http://spark.apache.org/docs/latest/ml-guide.html) for distributed versions of Keras deep learning models. As data set we use the Otto Product Classification challenge from Kaggle. The reason we chose this data is that it is small and very structured. This way, we can focus more on technical components rather than prepcrocessing intricacies. Also, users with slow hardware or without a full-blown Spark cluster should be able to run this example locally, and still learn a lot about the distributed mode.\n", 12 | "\n", 13 | "Often, the need to distribute computation is not imposed by model training, but rather by building the data pipeline, i.e. ingestion, transformation etc. In training, deep neural networks tend to do fairly well on one or more GPUs on one machine. Most of the time, using gradient descent methods, you will process one batch after another anyway. Even so, it may still be beneficial to use frameworks like Spark to integrate your models with your surrounding infrastructure. On top of that, the convenience provided by Spark ML pipelines can be very valuable (being syntactically very close to what you might know from [```scikit-learn```](http://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html)).\n", 14 | "\n", 15 | "**TL;DR:** We will show how to tackle a classification problem using distributed deep neural nets and Spark ML pipelines in an example that is essentially a distributed version of the one found [here](https://github.com/fchollet/keras/blob/master/examples/kaggle_otto_nn.py)." 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "## Using this notebook\n", 23 | "As we are going to use elephas, you will need access to a running Spark context to run this notebook. If you don't have it already, install Spark locally by following the [instructions provided here](https://github.com/maxpumperla/elephas/blob/master/README.md). Make sure to also export ```SPARK_HOME``` to your path and start your ipython/jupyter notebook as follows:" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "```\n", 31 | "IPYTHON_OPTS=\"notebook\" ${SPARK_HOME}/bin/pyspark --driver-memory 4G elephas/examples/Spark_ML_Pipeline.ipynb\n", 32 | "```" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "To test your environment, try to print the Spark context (provided as ```sc```), i.e. execute the following cell." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 1, 45 | "metadata": { 46 | "collapsed": false 47 | }, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "from __future__ import print_function\n", 59 | "print(sc)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "## Otto Product Classification Data" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "Training and test data is available [here](https://www.kaggle.com/c/otto-group-product-classification-challenge/data). Go ahead and download the data. Inspecting it, you will see that the provided csv files consist of an id column, 93 integer feature columns. ```train.csv``` has an additional column for labels, which ```test.csv``` is missing. The challenge is to accurately predict test labels. For the rest of this notebook, we will assume data is stored at ```data_path```, which you should modify below as needed." 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 3, 79 | "metadata": { 80 | "collapsed": true 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "data_path = \"./\" # <-- Make sure to adapt this to where your csv files are." 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "Loading data is relatively simple, but we have to take care of a few things. First, while you can shuffle rows of an RDD, it is generally not very efficient. But since data in ```train.csv``` is sorted by category, we'll have to shuffle in order to make the model perform well. This is what the function ```shuffle_csv``` below is for. Next, we read in plain text in ```load_data_rdd```, split lines by comma and convert features to float vector type. Also, note that the last column in ```train.csv``` represents the category, which has a ```Class_``` prefix. " 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "### Defining Data Frames\n", 99 | "\n", 100 | "Spark has a few core data structures, among them is the ```data frame```, which is a distributed version of the named columnar data structure many will now from either [R](https://stat.ethz.ch/R-manual/R-devel/library/base/html/data.frame.html) or [Pandas](http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html). We need a so called ```SQLContext``` and an optional column-to-names mapping to create a data frame from scratch. " 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 6, 106 | "metadata": { 107 | "collapsed": false 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "from pyspark.sql import SQLContext\n", 112 | "from pyspark.mllib.linalg import Vectors\n", 113 | "import numpy as np\n", 114 | "import random\n", 115 | "\n", 116 | "sql_context = SQLContext(sc)\n", 117 | "\n", 118 | "def shuffle_csv(csv_file):\n", 119 | " lines = open(csv_file).readlines()\n", 120 | " random.shuffle(lines)\n", 121 | " open(csv_file, 'w').writelines(lines)\n", 122 | "\n", 123 | "def load_data_frame(csv_file, shuffle=True, train=True):\n", 124 | " if shuffle:\n", 125 | " shuffle_csv(csv_file)\n", 126 | " data = sc.textFile(data_path + csv_file) # This is an RDD, which will later be transformed to a data frame\n", 127 | " data = data.filter(lambda x:x.split(',')[0] != 'id').map(lambda line: line.split(','))\n", 128 | " if train:\n", 129 | " data = data.map(\n", 130 | " lambda line: (Vectors.dense(np.asarray(line[1:-1]).astype(np.float32)),\n", 131 | " str(line[-1])) )\n", 132 | " else:\n", 133 | " # Test data gets dummy labels. We need the same structure as in Train data\n", 134 | " data = data.map( lambda line: (Vectors.dense(np.asarray(line[1:]).astype(np.float32)),\"Class_1\") ) \n", 135 | " return sqlContext.createDataFrame(data, ['features', 'category'])\n", 136 | " " 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "Let's load both train and test data and print a few rows of data using the convenient ```show``` method." 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 7, 149 | "metadata": { 150 | "collapsed": false 151 | }, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "Train data frame:\n", 158 | "+--------------------+--------+\n", 159 | "| features|category|\n", 160 | "+--------------------+--------+\n", 161 | "|[0.0,0.0,0.0,0.0,...| Class_8|\n", 162 | "|[0.0,0.0,0.0,0.0,...| Class_8|\n", 163 | "|[0.0,0.0,0.0,0.0,...| Class_2|\n", 164 | "|[0.0,1.0,0.0,1.0,...| Class_6|\n", 165 | "|[0.0,0.0,0.0,0.0,...| Class_9|\n", 166 | "|[0.0,0.0,0.0,0.0,...| Class_2|\n", 167 | "|[0.0,0.0,0.0,0.0,...| Class_2|\n", 168 | "|[0.0,0.0,0.0,0.0,...| Class_3|\n", 169 | "|[0.0,0.0,4.0,0.0,...| Class_8|\n", 170 | "|[0.0,0.0,0.0,0.0,...| Class_7|\n", 171 | "+--------------------+--------+\n", 172 | "only showing top 10 rows\n", 173 | "\n", 174 | "Test data frame (note the dummy category):\n", 175 | "+--------------------+--------+\n", 176 | "| features|category|\n", 177 | "+--------------------+--------+\n", 178 | "|[1.0,0.0,0.0,1.0,...| Class_1|\n", 179 | "|[0.0,1.0,13.0,1.0...| Class_1|\n", 180 | "|[0.0,0.0,1.0,1.0,...| Class_1|\n", 181 | "|[0.0,0.0,0.0,0.0,...| Class_1|\n", 182 | "|[2.0,0.0,5.0,1.0,...| Class_1|\n", 183 | "|[0.0,0.0,0.0,0.0,...| Class_1|\n", 184 | "|[0.0,0.0,0.0,0.0,...| Class_1|\n", 185 | "|[0.0,0.0,0.0,1.0,...| Class_1|\n", 186 | "|[0.0,0.0,0.0,0.0,...| Class_1|\n", 187 | "|[0.0,0.0,0.0,0.0,...| Class_1|\n", 188 | "+--------------------+--------+\n", 189 | "only showing top 10 rows\n", 190 | "\n" 191 | ] 192 | } 193 | ], 194 | "source": [ 195 | "train_df = load_data_frame(\"train.csv\")\n", 196 | "test_df = load_data_frame(\"test.csv\", shuffle=False, train=False) # No need to shuffle test data\n", 197 | "\n", 198 | "print(\"Train data frame:\")\n", 199 | "train_df.show(10)\n", 200 | "\n", 201 | "print(\"Test data frame (note the dummy category):\")\n", 202 | "test_df.show(10)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "## Preprocessing: Defining Transformers\n", 210 | "\n", 211 | "Up until now, we basically just read in raw data. Luckily, ```Spark ML``` has quite a few preprocessing features available, so the only thing we will ever have to do is define transformations of data frames.\n", 212 | "\n", 213 | "To proceed, we will first transform category strings to double values. This is done by a so called ```StringIndexer```. Note that we carry out the actual transformation here already, but that is just for demonstration purposes. All we really need is too define ```string_indexer``` to put it into a pipeline later on." 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 8, 219 | "metadata": { 220 | "collapsed": false 221 | }, 222 | "outputs": [], 223 | "source": [ 224 | "from pyspark.ml.feature import StringIndexer\n", 225 | "\n", 226 | "string_indexer = StringIndexer(inputCol=\"category\", outputCol=\"index_category\")\n", 227 | "fitted_indexer = string_indexer.fit(train_df)\n", 228 | "indexed_df = fitted_indexer.transform(train_df)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "Next, it's good practice to normalize the features, which is done with a ```StandardScaler```." 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 9, 241 | "metadata": { 242 | "collapsed": false 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "from pyspark.ml.feature import StandardScaler\n", 247 | "\n", 248 | "scaler = StandardScaler(inputCol=\"features\", outputCol=\"scaled_features\", withStd=True, withMean=True)\n", 249 | "fitted_scaler = scaler.fit(indexed_df)\n", 250 | "scaled_df = fitted_scaler.transform(indexed_df)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 10, 256 | "metadata": { 257 | "collapsed": false 258 | }, 259 | "outputs": [ 260 | { 261 | "name": "stdout", 262 | "output_type": "stream", 263 | "text": [ 264 | "The result of indexing and scaling. Each transformation adds new columns to the data frame:\n", 265 | "+--------------------+--------+--------------+--------------------+\n", 266 | "| features|category|index_category| scaled_features|\n", 267 | "+--------------------+--------+--------------+--------------------+\n", 268 | "|[0.0,0.0,0.0,0.0,...| Class_8| 2.0|[-0.2535060296260...|\n", 269 | "|[0.0,0.0,0.0,0.0,...| Class_8| 2.0|[-0.2535060296260...|\n", 270 | "|[0.0,0.0,0.0,0.0,...| Class_2| 0.0|[-0.2535060296260...|\n", 271 | "|[0.0,1.0,0.0,1.0,...| Class_6| 1.0|[-0.2535060296260...|\n", 272 | "|[0.0,0.0,0.0,0.0,...| Class_9| 4.0|[-0.2535060296260...|\n", 273 | "|[0.0,0.0,0.0,0.0,...| Class_2| 0.0|[-0.2535060296260...|\n", 274 | "|[0.0,0.0,0.0,0.0,...| Class_2| 0.0|[-0.2535060296260...|\n", 275 | "|[0.0,0.0,0.0,0.0,...| Class_3| 3.0|[-0.2535060296260...|\n", 276 | "|[0.0,0.0,4.0,0.0,...| Class_8| 2.0|[-0.2535060296260...|\n", 277 | "|[0.0,0.0,0.0,0.0,...| Class_7| 5.0|[-0.2535060296260...|\n", 278 | "+--------------------+--------+--------------+--------------------+\n", 279 | "only showing top 10 rows\n", 280 | "\n" 281 | ] 282 | } 283 | ], 284 | "source": [ 285 | "print(\"The result of indexing and scaling. Each transformation adds new columns to the data frame:\")\n", 286 | "scaled_df.show(10)" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": {}, 292 | "source": [ 293 | "## Keras Deep Learning model\n", 294 | "\n", 295 | "Now that we have a data frame with processed features and labels, let's define a deep neural net that we can use to address the classification problem. Chances are you came here because you know a thing or two about deep learning. If so, the model below will look very straightforward to you. We build a keras model by choosing a set of three consecutive Dense layers with dropout and ReLU activations. There are certainly much better architectures for the problem out there, but we really just want to demonstrate the general flow here." 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 12, 301 | "metadata": { 302 | "collapsed": false 303 | }, 304 | "outputs": [], 305 | "source": [ 306 | "from keras.models import Sequential\n", 307 | "from keras.layers.core import Dense, Dropout, Activation\n", 308 | "from keras.utils import np_utils, generic_utils\n", 309 | "\n", 310 | "nb_classes = train_df.select(\"category\").distinct().count()\n", 311 | "input_dim = len(train_df.select(\"features\").first()[0])\n", 312 | "\n", 313 | "model = Sequential()\n", 314 | "model.add(Dense(512, input_shape=(input_dim,)))\n", 315 | "model.add(Activation('relu'))\n", 316 | "model.add(Dropout(0.5))\n", 317 | "model.add(Dense(512))\n", 318 | "model.add(Activation('relu'))\n", 319 | "model.add(Dropout(0.5))\n", 320 | "model.add(Dense(512))\n", 321 | "model.add(Activation('relu'))\n", 322 | "model.add(Dropout(0.5))\n", 323 | "model.add(Dense(nb_classes))\n", 324 | "model.add(Activation('softmax'))\n", 325 | "\n", 326 | "model.compile(loss='categorical_crossentropy', optimizer='adam')" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "## Distributed Elephas model\n", 334 | "\n", 335 | "To lift the above Keras ```model``` to Spark, we define an ```Estimator``` on top of it. An ```Estimator``` is Spark's incarnation of a model that still has to be trained. It essentially only comes with only a single (required) method, namely ```fit```. Once we call ```fit``` on a data frame, we get back a ```Model```, which is a trained model with a ```transform``` method to predict labels.\n", 336 | "\n", 337 | "We do this by initializing an ```ElephasEstimator``` and setting a few properties. As by now our input data frame will have many columns, we have to tell the model where to find features and labels by column name. Then we provide serialized versions of Keras model and Elephas optimizer. We can not plug in keras models into the ```Estimator``` directly, as Spark will have to serialize them anyway for communication with workers, so it's better to provide the serialization ourselves. In fact, while pyspark knows how to serialize ```model```, it is extremely inefficient and can break if models become too large. Spark ML is especially picky (and rightly so) about parameters and more or less prohibits you from providing non-atomic types and arrays of the latter. Most of the remaining parameters are optional and rather self explainatory. Plus, many of them you know if you have ever run a keras model before. We just include them here to show the full set of training configuration." 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 14, 343 | "metadata": { 344 | "collapsed": false 345 | }, 346 | "outputs": [ 347 | { 348 | "data": { 349 | "text/plain": [ 350 | "ElephasEstimator_415398ab22cb1699f794" 351 | ] 352 | }, 353 | "execution_count": 14, 354 | "metadata": {}, 355 | "output_type": "execute_result" 356 | } 357 | ], 358 | "source": [ 359 | "from elephas.ml_model import ElephasEstimator\n", 360 | "from elephas import optimizers as elephas_optimizers\n", 361 | "\n", 362 | "# Define elephas optimizer (which tells the model how to aggregate updates on the Spark master)\n", 363 | "adadelta = elephas_optimizers.Adadelta()\n", 364 | "\n", 365 | "# Initialize SparkML Estimator and set all relevant properties\n", 366 | "estimator = ElephasEstimator()\n", 367 | "estimator.setFeaturesCol(\"scaled_features\") # These two come directly from pyspark,\n", 368 | "estimator.setLabelCol(\"index_category\") # hence the camel case. Sorry :)\n", 369 | "estimator.set_keras_model_config(model.to_yaml()) # Provide serialized Keras model\n", 370 | "estimator.set_optimizer_config(adadelta.get_config()) # Provide serialized Elephas optimizer\n", 371 | "estimator.set_categorical_labels(True)\n", 372 | "estimator.set_nb_classes(nb_classes)\n", 373 | "estimator.set_num_workers(1) # We just use one worker here. Feel free to adapt it.\n", 374 | "estimator.set_nb_epoch(20) \n", 375 | "estimator.set_batch_size(128)\n", 376 | "estimator.set_verbosity(1)\n", 377 | "estimator.set_validation_split(0.15)" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "metadata": {}, 383 | "source": [ 384 | "## SparkML Pipelines\n", 385 | "\n", 386 | "Now for the easy part: Defining pipelines is really as easy as listing pipeline stages. We can provide any configuration of ```Transformers``` and ```Estimators``` really, but here we simply take the three components defined earlier. Note that ```string_indexer``` and ```scaler``` and interchangable, while ```estimator``` somewhat obviously has to come last in the pipeline." 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 15, 392 | "metadata": { 393 | "collapsed": true 394 | }, 395 | "outputs": [], 396 | "source": [ 397 | "from pyspark.ml import Pipeline\n", 398 | "\n", 399 | "pipeline = Pipeline(stages=[string_indexer, scaler, estimator])" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "metadata": {}, 405 | "source": [ 406 | "### Fitting and evaluating the pipeline\n", 407 | "\n", 408 | "The last step now is to fit the pipeline on training data and evaluate it. We evaluate, i.e. transform, on _training data_, since only in that case do we have labels to check accuracy of the model. If you like, you could transform the ```test_df``` as well." 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 17, 414 | "metadata": { 415 | "collapsed": false 416 | }, 417 | "outputs": [ 418 | { 419 | "name": "stdout", 420 | "output_type": "stream", 421 | "text": [ 422 | "61878/61878 [==============================] - 0s \n", 423 | "+--------------+----------+\n", 424 | "|index_category|prediction|\n", 425 | "+--------------+----------+\n", 426 | "| 2.0| 2.0|\n", 427 | "| 2.0| 2.0|\n", 428 | "| 0.0| 0.0|\n", 429 | "| 1.0| 1.0|\n", 430 | "| 4.0| 4.0|\n", 431 | "| 0.0| 0.0|\n", 432 | "| 0.0| 0.0|\n", 433 | "| 3.0| 3.0|\n", 434 | "| 2.0| 2.0|\n", 435 | "| 5.0| 0.0|\n", 436 | "| 0.0| 0.0|\n", 437 | "| 4.0| 4.0|\n", 438 | "| 0.0| 0.0|\n", 439 | "| 4.0| 1.0|\n", 440 | "| 2.0| 2.0|\n", 441 | "| 1.0| 1.0|\n", 442 | "| 0.0| 0.0|\n", 443 | "| 6.0| 0.0|\n", 444 | "| 2.0| 2.0|\n", 445 | "| 1.0| 1.0|\n", 446 | "| 2.0| 2.0|\n", 447 | "| 8.0| 8.0|\n", 448 | "| 1.0| 1.0|\n", 449 | "| 5.0| 0.0|\n", 450 | "| 0.0| 0.0|\n", 451 | "| 0.0| 3.0|\n", 452 | "| 0.0| 0.0|\n", 453 | "| 1.0| 1.0|\n", 454 | "| 4.0| 4.0|\n", 455 | "| 2.0| 2.0|\n", 456 | "| 0.0| 3.0|\n", 457 | "| 3.0| 3.0|\n", 458 | "| 0.0| 0.0|\n", 459 | "| 3.0| 0.0|\n", 460 | "| 1.0| 5.0|\n", 461 | "| 3.0| 3.0|\n", 462 | "| 2.0| 2.0|\n", 463 | "| 1.0| 1.0|\n", 464 | "| 0.0| 0.0|\n", 465 | "| 2.0| 2.0|\n", 466 | "| 2.0| 2.0|\n", 467 | "| 1.0| 1.0|\n", 468 | "| 6.0| 6.0|\n", 469 | "| 1.0| 1.0|\n", 470 | "| 0.0| 3.0|\n", 471 | "| 7.0| 0.0|\n", 472 | "| 0.0| 0.0|\n", 473 | "| 0.0| 0.0|\n", 474 | "| 1.0| 1.0|\n", 475 | "| 1.0| 1.0|\n", 476 | "| 6.0| 6.0|\n", 477 | "| 0.0| 0.0|\n", 478 | "| 0.0| 3.0|\n", 479 | "| 2.0| 2.0|\n", 480 | "| 0.0| 0.0|\n", 481 | "| 2.0| 2.0|\n", 482 | "| 0.0| 0.0|\n", 483 | "| 4.0| 4.0|\n", 484 | "| 0.0| 0.0|\n", 485 | "| 6.0| 6.0|\n", 486 | "| 2.0| 5.0|\n", 487 | "| 0.0| 3.0|\n", 488 | "| 3.0| 0.0|\n", 489 | "| 0.0| 0.0|\n", 490 | "| 3.0| 3.0|\n", 491 | "| 4.0| 4.0|\n", 492 | "| 0.0| 3.0|\n", 493 | "| 0.0| 0.0|\n", 494 | "| 0.0| 0.0|\n", 495 | "| 4.0| 4.0|\n", 496 | "| 3.0| 0.0|\n", 497 | "| 2.0| 2.0|\n", 498 | "| 1.0| 1.0|\n", 499 | "| 7.0| 7.0|\n", 500 | "| 0.0| 0.0|\n", 501 | "| 0.0| 0.0|\n", 502 | "| 0.0| 3.0|\n", 503 | "| 1.0| 1.0|\n", 504 | "| 1.0| 1.0|\n", 505 | "| 5.0| 4.0|\n", 506 | "| 1.0| 1.0|\n", 507 | "| 1.0| 1.0|\n", 508 | "| 4.0| 4.0|\n", 509 | "| 3.0| 3.0|\n", 510 | "| 0.0| 0.0|\n", 511 | "| 2.0| 2.0|\n", 512 | "| 4.0| 4.0|\n", 513 | "| 7.0| 7.0|\n", 514 | "| 2.0| 2.0|\n", 515 | "| 0.0| 0.0|\n", 516 | "| 1.0| 1.0|\n", 517 | "| 0.0| 0.0|\n", 518 | "| 4.0| 4.0|\n", 519 | "| 1.0| 1.0|\n", 520 | "| 0.0| 0.0|\n", 521 | "| 0.0| 0.0|\n", 522 | "| 0.0| 0.0|\n", 523 | "| 0.0| 3.0|\n", 524 | "| 0.0| 3.0|\n", 525 | "| 0.0| 0.0|\n", 526 | "+--------------+----------+\n", 527 | "only showing top 100 rows\n", 528 | "\n", 529 | "0.764132648114\n" 530 | ] 531 | } 532 | ], 533 | "source": [ 534 | "from pyspark.mllib.evaluation import MulticlassMetrics\n", 535 | "\n", 536 | "fitted_pipeline = pipeline.fit(train_df) # Fit model to data\n", 537 | "\n", 538 | "prediction = fitted_pipeline.transform(train_df) # Evaluate on train data.\n", 539 | "# prediction = fitted_pipeline.transform(test_df) # <-- The same code evaluates test data.\n", 540 | "pnl = prediction.select(\"index_category\", \"prediction\")\n", 541 | "pnl.show(100)\n", 542 | "\n", 543 | "prediction_and_label = pnl.map(lambda row: (row.index_category, row.prediction))\n", 544 | "metrics = MulticlassMetrics(prediction_and_label)\n", 545 | "print(metrics.precision())" 546 | ] 547 | }, 548 | { 549 | "cell_type": "markdown", 550 | "metadata": { 551 | "collapsed": true 552 | }, 553 | "source": [ 554 | "## Conclusion\n", 555 | "\n", 556 | "It may certainly take some time to master the principles and syntax of both Keras and Spark, depending where you come from, of course. However, we also hope you come to the conclusion that once you get beyond the stage of struggeling with defining your models and preprocessing your data, the business of building and using SparkML pipelines is quite an elegant and useful one. \n", 557 | "\n", 558 | "If you like what you see, consider helping further improve elephas or contributing to Keras or Spark. Do you have any constructive remarks on this notebook? Is there something you want me to clarify? In any case, feel free to contact me." 559 | ] 560 | } 561 | ], 562 | "metadata": { 563 | "kernelspec": { 564 | "display_name": "Python 2", 565 | "language": "python", 566 | "name": "python2" 567 | }, 568 | "language_info": { 569 | "codemirror_mode": { 570 | "name": "ipython", 571 | "version": 2 572 | }, 573 | "file_extension": ".py", 574 | "mimetype": "text/x-python", 575 | "name": "python", 576 | "nbconvert_exporter": "python", 577 | "pygments_lexer": "ipython2", 578 | "version": "2.7.10" 579 | } 580 | }, 581 | "nbformat": 4, 582 | "nbformat_minor": 0 583 | } 584 | --------------------------------------------------------------------------------