├── images ├── 0 │ └── in.txt ├── 1 │ └── in.txt ├── 2 │ └── in.txt ├── 3 │ └── in.txt ├── 4 │ └── in.txt ├── 5 │ └── in.txt ├── 6 │ └── in.txt ├── 7 │ └── in.txt ├── 8 │ └── in.txt ├── 9 │ └── in.txt └── in.txt ├── LICENSE ├── utils.py ├── .gitignore ├── main.py └── README.md /images/in.txt: -------------------------------------------------------------------------------- 1 | fdsfs 2 | -------------------------------------------------------------------------------- /images/5/in.txt: -------------------------------------------------------------------------------- 1 | All image of 5. 2 | -------------------------------------------------------------------------------- /images/8/in.txt: -------------------------------------------------------------------------------- 1 | All image of 8. 2 | -------------------------------------------------------------------------------- /images/9/in.txt: -------------------------------------------------------------------------------- 1 | All image of 9. 2 | -------------------------------------------------------------------------------- /images/0/in.txt: -------------------------------------------------------------------------------- 1 | All image of Zeros. 2 | -------------------------------------------------------------------------------- /images/1/in.txt: -------------------------------------------------------------------------------- 1 | All image of One. 2 | 3 | -------------------------------------------------------------------------------- /images/2/in.txt: -------------------------------------------------------------------------------- 1 | All image of 2. 2 | 3 | -------------------------------------------------------------------------------- /images/3/in.txt: -------------------------------------------------------------------------------- 1 | All image of 3 2 | 3 | -------------------------------------------------------------------------------- /images/4/in.txt: -------------------------------------------------------------------------------- 1 | All image of 4. 2 | 3 | -------------------------------------------------------------------------------- /images/6/in.txt: -------------------------------------------------------------------------------- 1 | All image of 6. 2 | 3 | -------------------------------------------------------------------------------- /images/7/in.txt: -------------------------------------------------------------------------------- 1 | All image of 7. 2 | 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Atomic Cluster 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import seaborn as sns 5 | 6 | 7 | def multiclass_roc_auc_score(y_test, y_pred, average="macro"): 8 | lb = LabelBinarizer() 9 | lb.fit(y_test) 10 | y_test = lb.transform(y_test) 11 | y_pred = lb.transform(y_pred) 12 | return roc_auc_score(y_test, y_pred, average=average) 13 | 14 | 15 | def plot_confusion_matrix(cm, classes, normalize=False, 16 | title='Confusion matrix', 17 | cmap=plt.cm.GnBu): 18 | 19 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 20 | plt.title(title) 21 | tick_marks = np.arange(len(classes)) 22 | plt.xticks(tick_marks, classes, rotation=45) 23 | plt.yticks(tick_marks, classes) 24 | 25 | fmt = '.2f' if normalize else 'd' 26 | thresh = cm.max() / 2. 27 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 28 | plt.text(j, i, format(cm[i, j], fmt), 29 | horizontalalignment="center", 30 | color="white" if cm[i, j] > thresh else "black") 31 | 32 | plt.tight_layout() 33 | plt.ylabel('True label') 34 | plt.xlabel('Predicted label') 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | from utils import plot_confusion_matrix, multiclass_roc_auc_score 3 | from pyspark.ml.evaluation import MulticlassClassificationEvaluator 4 | from pyspark.ml.evaluation import MulticlassClassificationEvaluator 5 | from pyspark.ml.classification import LogisticRegression 6 | from pyspark.ml.image import ImageSchema 7 | from pyspark.sql.functions import lit 8 | from pyspark.sql import SparkSession 9 | from pyspark.ml import Pipeline 10 | from sparkdl import DeepImageFeaturizer 11 | 12 | from sklearn.metrics import roc_curve, auc, roc_auc_score 13 | from sklearn.metrics import classification_report 14 | from sklearn.preprocessing import LabelBinarizer 15 | from sklearn.metrics import confusion_matrix 16 | 17 | import matplotlib.pyplot as plt 18 | from functools import reduce 19 | import seaborn as sns 20 | import numpy as np 21 | import itertools 22 | 23 | # create spark session 24 | spark = SparkSession.builder.appName('BD Recognizer').getOrCreate() 25 | 26 | # loaded image 27 | zero_df = ImageSchema.readImages("images/0").withColumn("label", lit(0)) 28 | one_df = ImageSchema.readImages("images/1").withColumn("label", lit(1)) 29 | two_df = ImageSchema.readImages("images/2").withColumn("label", lit(2)) 30 | three_df = ImageSchema.readImages("images/3").withColumn("label", lit(3)) 31 | four_df = ImageSchema.readImages("images/4").withColumn("label", lit(4)) 32 | five_df = ImageSchema.readImages("images/5").withColumn("label", lit(5)) 33 | six_df = ImageSchema.readImages("images/6").withColumn("label", lit(6)) 34 | seven_df = ImageSchema.readImages("images/7").withColumn("label", lit(7)) 35 | eight_df = ImageSchema.readImages("images/8").withColumn("label", lit(8)) 36 | nine_df = ImageSchema.readImages("images/9").withColumn("label", lit(9)) 37 | 38 | 39 | # merge data frame 40 | dataframes = [zero_df, one_df, two_df, three_df, 41 | four_df,five_df,six_df,seven_df,eight_df,nine_df] 42 | 43 | df = reduce(lambda first, second: first.union(second), dataframes) 44 | 45 | # repartition dataframe 46 | df = df.repartition(200) 47 | 48 | # split the data-frame 49 | train, test = df.randomSplit([0.8, 0.2], 42) 50 | 51 | print(df.toPandas().size) 52 | print(df.printSchema()) 53 | 54 | 55 | ''' 56 | --------------------------------- Model Building & Training ----------------------------------- 57 | ''' 58 | 59 | 60 | # model: InceptionV3 61 | # extracting feature from images 62 | featurizer = DeepImageFeaturizer(inputCol="image", outputCol="features", 63 | modelName="InceptionV3") 64 | 65 | # used as a multi class classifier 66 | lr = LogisticRegression(maxIter=5, regParam=0.03, 67 | elasticNetParam=0.5, labelCol="label") 68 | 69 | # define a pipeline model 70 | sparkdn = Pipeline(stages=[featurizer, lr]) 71 | spark_model = sparkdn.fit(train) 72 | 73 | 74 | ''' 75 | --------------------------------- Model Evaluation ----------------------------------- 76 | ''' 77 | 78 | # Evaluation Matrix 79 | evaluator = MulticlassClassificationEvaluator() 80 | transform_test = spark_model.transform(test) 81 | 82 | print('F1-Score ', evaluator.evaluate(transform_test, 83 | {evaluator.metricName: 'f1'})) 84 | print('Precision ', evaluator.evaluate(transform_test, 85 | {evaluator.metricName: 'weightedPrecision'})) 86 | print('Recall ', evaluator.evaluate(transform_test, 87 | {evaluator.metricName: 'weightedRecall'})) 88 | print('Accuracy ', evaluator.evaluate(transform_test, 89 | {evaluator.metricName: 'accuracy'})) 90 | 91 | # 92 | # ----------------------------------------------------- 93 | # Confusion Matrix 94 | 95 | ''' 96 | - Convert Spark-DataFrame to Pnadas-DataFrame 97 | - Call Confusion Matrix With 'True' and 'Predicted' Label 98 | ''' 99 | 100 | y_true = transform_test.select("label") 101 | y_true = y_true.toPandas() # convert to pandas dataframe from spark dataframe 102 | 103 | y_pred = transform_test.select("prediction") 104 | y_pred = y_pred.toPandas() # convert to pandas dataframe from spark dataframe 105 | 106 | cnf_matrix = confusion_matrix(y_true, y_pred,labels=range(10)) 107 | 108 | sns.set_style("darkgrid") 109 | plt.figure(figsize=(7,7)) 110 | plt.grid(False) 111 | 112 | # call pre defined function 113 | plot_confusion_matrix(cnf_matrix, classes=range(10)) 114 | 115 | # 116 | # ----------------------------------------------------- 117 | # Classification Report 118 | 119 | ''' 120 | - Classification Report of each class group 121 | ''' 122 | target_names = ["Class {}".format(i) for i in range(10)] 123 | print(classification_report(y_true, y_pred, target_names = target_names)) 124 | 125 | 126 | # 127 | # ----------------------------------------------------- 128 | # ROC AUC Score 129 | 130 | ''' 131 | - A custom ROC AUC score function for multi-class classification problem 132 | ''' 133 | 134 | 135 | print('ROC AUC score:', multiclass_roc_auc_score(y_true,y_pred)) 136 | 137 | 138 | # 139 | # ----------------------------------------------------- 140 | # Sample Prediction 141 | 142 | ''' 143 | - Comparing true vs predicted samples 144 | ''' 145 | # all columns after transformations 146 | print(transform_test.columns) 147 | 148 | # see some predicted output 149 | transform_test.select('image', "prediction", "label").show() 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Multi-Class Image Classification Using Transfer Learning With PySpark 2 | 3 | A promising solution for a **Computer Vision** problem with the power to combine state-of-the-art technologies: **Deep Learning** with **Apache Spark**. We will leverage the power of **Deep Learning Pipelines** for a Multi-Class image classification problem. 4 | 5 | **Deep Learning Pipelines** is a high-level Deep Learning framework that facilitates common Deep Learning workflows via the **Spark MLlib** Pipelines API. It currently supports TensorFlow and Keras with the TensorFlow-backend. The library comes from Databricks. 6 | 7 | **Blog Articles:** Read technical description from below sources. 8 | 9 | - [Towards Data Science | Transfer Learning In PySpark](https://towardsdatascience.com/transfer-learning-with-pyspark-729d49604d45) 10 | 11 | - [LinedIn-Article | Transfer Learning In PySpark](https://www.linkedin.com/pulse/transfer-learning-pyspark-mohammed-innat/) 12 | 13 | --- 14 | 15 | ### Installation 16 | 17 | Installation is bit pesky right now. So, I decided to write it separately. Check this [GitHub-Gist](https://gist.github.com/iphton/b0ab252c954eb2a28a984774e3ee1f2d) to install necessary packages and technical stuffs. 18 | 19 | ### Data Set 20 | 21 | We choose [**NumtaDB**](https://arxiv.org/abs/1806.02452) as a source of our datasets. It's a collection of Bengali Handwritten Digit data. The dataset contains more than **85,000** digits from over **2,700** contributors. But here we're not planning to work on the whole data set rather than choose randomly 50 images of each class. You can get the sample of some images from [here.](https://drive.google.com/open?id=1AbTGJIfD2lhGe-stNIymGaowy7vyVovn) It Contains 500 in total. 22 | 23 | ### Model Training 24 | 25 | Here we combine the **InceptionV3** model and **logistic regression** in Spark. The **DeepImageFeaturizer** automatically peels off the last layer of a pre-trained neural network and uses the output from all the previous layers as features for the logistic regression algorithm. 26 | 27 | 28 | ```python 29 | from pyspark.ml.evaluation import MulticlassClassificationEvaluator 30 | from pyspark.ml.classification import LogisticRegression 31 | from pyspark.ml import Pipeline 32 | from sparkdl import DeepImageFeaturizer 33 | 34 | # model: InceptionV3 35 | # extracting feature from images 36 | featurizer = DeepImageFeaturizer(inputCol="image", 37 | outputCol="features", 38 | modelName="InceptionV3") 39 | 40 | # used as a multi class classifier 41 | lr = LogisticRegression(maxIter=5, regParam=0.03, 42 | elasticNetParam=0.5, labelCol="label") 43 | 44 | # define a pipeline model 45 | sparkdn = Pipeline(stages=[featurizer, lr]) 46 | spark_model = sparkdn.fit(train) # start fitting or training 47 | ``` 48 | 49 | 50 | ## Evaluation 51 | 52 | **evaluation matrix** 53 | ``` 54 | F1-Score 0.81117 55 | Precision 0.84220 56 | Recall 0.80909 57 | Accuracy 0.80909 58 | ``` 59 | 60 | **confusion Metrix** 61 | ![Screenshot from 2019-07-23 00-40-15](https://user-images.githubusercontent.com/17668390/61664640-00afd880-acf5-11e9-8544-91b3e05fbbf4.png) 62 | 63 | 64 | **classification Report** 65 | ``` 66 | precision recall f1-score support 67 | 68 | Class 0 1.00 0.92 0.96 13 69 | Class 1 0.57 1.00 0.73 8 70 | Class 2 0.64 1.00 0.78 7 71 | Class 3 0.88 0.70 0.78 10 72 | Class 4 0.90 1.00 0.95 9 73 | Class 5 0.67 0.83 0.74 12 74 | Class 6 0.83 0.62 0.71 8 75 | Class 7 1.00 0.80 0.89 10 76 | Class 8 1.00 0.80 0.89 20 77 | Class 9 0.70 0.54 0.61 13 78 | 79 | micro avg 0.81 0.81 0.81 110 80 | macro avg 0.82 0.82 0.80 110 81 | weighted avg 0.84 0.81 0.81 110 82 | ``` 83 | 84 | 85 | **Predicted Samples** 86 | 87 | ``` 88 | ['image', 'label', 'features', 'rawPrediction', 'probability', 'prediction'] 89 | +--------------------+----------+-----+ 90 | | image|prediction|label| 91 | +--------------------+----------+-----+ 92 | |[file:/home/i...| 1.0| 1| 93 | |[file:/home/i...| 8.0| 8| 94 | |[file:/home/i...| 9.0| 9| 95 | |[file:/home/i...| 1.0| 8| 96 | |[file:/home/i...| 1.0| 1| 97 | |[file:/home/i...| 1.0| 9| 98 | |[file:/home/i...| 0.0| 0| 99 | |[file:/home/i...| 2.0| 9| 100 | |[file:/home/i...| 8.0| 8| 101 | |[file:/home/i...| 9.0| 9| 102 | |[file:/home/i...| 0.0| 0| 103 | |[file:/home/i...| 4.0| 0| 104 | |[file:/home/i...| 5.0| 9| 105 | |[file:/home/i...| 1.0| 1| 106 | |[file:/home/i...| 9.0| 9| 107 | |[file:/home/i...| 9.0| 9| 108 | |[file:/home/i...| 1.0| 1| 109 | |[file:/home/i...| 1.0| 1| 110 | |[file:/home/i...| 9.0| 9| 111 | |[file:/home/i...| 3.0| 6| 112 | +--------------------+----------+-----+ 113 | only showing top 20 rows 114 | ``` 115 | 116 | --- 117 | 118 | **Related Technologies** 119 | - Distributed DL with Keras & PySpark — [Elephas](https://github.com/maxpumperla/elephas?source=post_page---------------------------) 120 | - Distributed Deep Learning Library for Apache Spark — [BigDL](https://github.com/intel-analytics/BigDL?source=post_page---------------------------) 121 | - TensorFlow to Apache Spark clusters — [TensorFlowOnSpark](https://github.com/yahoo/TensorFlowOnSpark?source=post_page---------------------------) 122 | 123 | **References** 124 | - Databricks: [Deep Learning Guide](https://docs.databricks.com/applications/deep-learning/index.html?source=post_page---------------------------) 125 | - Apache Spark: [PySpark Machine Learning](https://spark.apache.org/docs/latest/api/python/index.html?source=post_page---------------------------) 126 | --------------------------------------------------------------------------------