├── preprocessing ├── utils.py ├── 003_hog.py ├── 002_data_augmentation.py └── 001_load_data.py ├── data └── wgetgdrive.sh ├── requirements.txt ├── docs ├── Makefile ├── make.bat └── source │ ├── index.rst │ ├── result.rst │ ├── install.rst │ ├── conf.py │ └── module.rst ├── setup.py ├── models ├── svm.py ├── random_forest.py ├── vgg16.py ├── ensemble_majority_voting.py ├── cnn_custom.py └── ensemble_stacked_prediction.py ├── LICENSE ├── .gitignore ├── app.py ├── README.md └── visualization.py /preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def make_folder(path): 5 | """Check if the folder exists, if it doesn't exist create one in the given path. 6 | 7 | Args: 8 | path [str]: path where the folder needs to be created. 9 | 10 | """ 11 | if not os.path.exists(os.path.join(path)): 12 | print('[INFO] Creating new folder...') 13 | 14 | os.makedirs(os.path.join(path)) 15 | -------------------------------------------------------------------------------- /data/wgetgdrive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get files from Google Drive 4 | 5 | # $1 = file ID 6 | # $2 = file name 7 | 8 | URL="https://docs.google.com/uc?export=download&id=$1" 9 | 10 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate $URL -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$1" -O $2 && rm -rf /tmp/cookies.txt 11 | unzip $2 -d raw 12 | rm $2 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flake8==3.8.2 2 | h5py==2.10.0 3 | ipython==6.5.0 4 | joblib==0.13.2 5 | jupyter==1.0.0 6 | Keras==2.3.1 7 | Keras-Applications==1.0.8 8 | Keras-Preprocessing==1.1.0 9 | Markdown==3.1.1 10 | matplotlib==3.1.1 11 | notebook==6.0.1 12 | numpy==1.18.1 13 | opencv-python==4.1.0.25 14 | pandas==0.25.1 15 | pickleshare==0.7.5 16 | scikit-image==0.15.0 17 | scikit-learn==0.22.2.post1 18 | scipy==1.4.1 19 | seaborn==0.10.0 20 | Sphinx==3.0.4 21 | tensorflow-gpu==1.14.0 22 | tqdm==4.41.1 23 | yellowbrick==1.1 24 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | 4 | 5 | # Utility function to read the README file. 6 | # Used for the long_description. 7 | def read(file_name): 8 | return open(os.path.join(os.path.dirname(__file__), file_name)).read() 9 | 10 | 11 | def get_packages(): 12 | requirementPath = 'requirements.txt' 13 | packages = [] 14 | if os.path.isfile(requirementPath): 15 | with open(requirementPath) as f: 16 | packages = f.read().splitlines() 17 | return packages 18 | 19 | 20 | setup( 21 | name="grape_disease_classification", 22 | version="1.0.0", 23 | author="Sanjana Srinivas", 24 | author_email="sanjanasrinivas73@gmail.com", 25 | description=("A demonstration of how to classify diseases in" 26 | " plants(Grape) using various Machine Learning models."), 27 | keywords="disease classification", 28 | url="https://github.com/Sanjana7395/Grape-disease-classification.git", 29 | install_requires=get_packages(), 30 | long_description=read('README.md'), 31 | ) 32 | -------------------------------------------------------------------------------- /models/svm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.svm import LinearSVC 3 | import joblib 4 | from preprocessing.utils import make_folder 5 | 6 | 7 | def main(): 8 | """ Load the data. 9 | Train SVM model using linear kernel. 10 | Print accuracy on test data. 11 | 12 | """ 13 | # Load stored data 14 | X_train = np.load('../data/processed/ImageTrainHOG_input.npy') 15 | y_train = np.load('../data/augment/DiseaseAugment_input.npy') 16 | print("=== TRAIN DATA ===") 17 | print(X_train.shape) 18 | print(y_train.shape) 19 | 20 | X_test = np.load('../data/processed/ImageTestHOG_input.npy') 21 | y_test = np.load('../data/test/DiseaseTest_input.npy') 22 | print("=== TEST DATA ===") 23 | print(X_test.shape) 24 | print(y_test.shape) 25 | 26 | # Classifier 27 | svm_model = LinearSVC(C=0.01) 28 | svm_model.fit(X_train, y_train) 29 | print(svm_model.score(X_test, y_test)) 30 | 31 | make_folder('../results/models') 32 | filename = '../results/models/SVM_model.sav' 33 | joblib.dump(svm_model, filename) 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Sanjana7395 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. grape_disease_classification documentation master file, created by 2 | sphinx-quickstart on Wed May 27 16:16:34 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Grape Disease Classification 7 | **************************** 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Table of Contents 12 | :numbered: 13 | 14 | install 15 | module 16 | result 17 | 18 | 19 | Technology used - Tensorflow (1.14.0) 20 | 21 | This project classifies diseases in grape plant using various Machine Learning classification algorithms. 22 | Grape plants are susceptible to various diseases The diseases that are classified in this project are: 23 | 24 | 1. Black rot 25 | 2. Black Measles (esca) 26 | 3. Powdery mildew 27 | 4. Leaf blight 28 | 5. Healthy 29 | 30 | The Machine learning classification models in this project includes: 31 | 32 | 1. Random forest classification 33 | 2. Support vector machine classification 34 | 3. CNN - VGG16 35 | 4. CNN - Custom 36 | 5. Ensemble model - Majority voting 37 | 6. Ensemble model - Stacked prediction 38 | 39 | 40 | Indices and Tables 41 | ------------------- 42 | 43 | * :ref:`genindex` 44 | * :ref:`modindex` 45 | * :ref:`search` 46 | -------------------------------------------------------------------------------- /models/random_forest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.ensemble import RandomForestClassifier 3 | import joblib 4 | from preprocessing.utils import make_folder 5 | 6 | 7 | def main(): 8 | """ Load data. 9 | Train random forest model. 10 | Print accuracy on test data. 11 | 12 | """ 13 | # Load stored data 14 | X_train = np.load('../data/processed/ImageTrainHOG_input.npy') 15 | y_train = np.load('../data/augment/DiseaseAugment_input.npy') 16 | print("=== TRAIN DATA ===") 17 | print(X_train.shape) 18 | print(y_train.shape) 19 | 20 | X_test = np.load('../data/processed/ImageTestHOG_input.npy') 21 | y_test = np.load('../data/test/DiseaseTest_input.npy') 22 | print("=== TEST DATA ===") 23 | print(X_test.shape) 24 | print(y_test.shape) 25 | 26 | # Classifier 27 | Random_classifier = RandomForestClassifier(n_estimators=500, max_depth=35, 28 | n_jobs=-1, warm_start=True, 29 | oob_score=True, 30 | max_features='sqrt') 31 | 32 | Random_classifier.fit(X_train, y_train) 33 | print(Random_classifier.score(X_test, y_test)) 34 | 35 | make_folder('../results/models') 36 | filename = '../results/models/Random_model.sav' 37 | joblib.dump(Random_classifier, filename) 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /docs/source/result.rst: -------------------------------------------------------------------------------- 1 | Results 2 | ======== 3 | 4 | Below are the results obtained on the test set for various models trained in the project. 5 | 6 | .. note:: The results obtained are system specific. Due to different combinations of the neural 7 | network cudnn library versions and NVIDIA driver library versions, the results can be 8 | slightly different. To the best of my knowledge, upon reproducing the environment, the 9 | ballpark number will be close to the results obtained. 10 | 11 | +----------------------------------+---------------+ 12 | | Models | Accuracy (%) | 13 | +==================================+===============+ 14 | | Random forest | 75.35 | 15 | +----------------------------------+---------------+ 16 | | SVM | 82.89 | 17 | +----------------------------------+---------------+ 18 | | CNN - VGG16 | 93.62 | 19 | +----------------------------------+---------------+ 20 | | Ensemble - Majority voting | 98.05 | 21 | +----------------------------------+---------------+ 22 | | Ensemble - Stacked prediction | 98.23 | 23 | +----------------------------------+---------------+ 24 | | CNN - Custom | 98.76 | 25 | +----------------------------------+---------------+ 26 | 27 | Indices and Tables 28 | ------------------- 29 | 30 | * :ref:`genindex` 31 | * :ref:`modindex` 32 | * :ref:`search` 33 | -------------------------------------------------------------------------------- /preprocessing/003_hog.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.feature import hog 3 | from preprocessing.utils import make_folder 4 | 5 | 6 | def hog_feature(image, multichannel=True): 7 | """ Extract HOG feature descriptors from the image. 8 | 9 | Args: 10 | image (numpy array): Array of image pixels. 11 | 12 | multichannel (bool): True for RGB image, else False. 13 | 14 | Returns: 15 | (numpy array): Feature descriptors. 16 | 17 | """ 18 | hog_feature_var = hog(image) 19 | return hog_feature_var 20 | 21 | 22 | def main(): 23 | """ Load images. 24 | Extract HOG feature descriptors. 25 | 26 | """ 27 | # Load stored data 28 | X_train = np.load('../data/augment/ImageAugment_input.npy') 29 | print('=== TRAIN DATA ===') 30 | print(X_train.shape) 31 | 32 | X_test = np.load('../data/test/ImageTest_input.npy') 33 | print('=== TEST DATA ===') 34 | print(X_test.shape) 35 | 36 | print("Extracting HOG features...") 37 | RF_train = np.zeros([len(X_train), 32400]) 38 | for i in range(len(X_train)): 39 | RF_train[i] = hog_feature(X_train[i]) 40 | print("FEATURE DESCRIPTORS") 41 | print(RF_train.shape) 42 | 43 | RF_test = np.zeros([len(X_test), 32400]) 44 | for i in range(len(X_test)): 45 | RF_test[i] = hog_feature(X_test[i]) 46 | print(RF_test.shape) 47 | 48 | # Save data 49 | make_folder('../data/processed') 50 | np.save('../data/processed/ImageTestHOG_input.npy', RF_test) 51 | np.save('../data/processed/ImageTrainHOG_input.npy', RF_train) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /docs/source/install.rst: -------------------------------------------------------------------------------- 1 | Configuration of Project Environment 2 | ===================================== 3 | 4 | 1. Clone the project. 5 | 2. Install packages required. 6 | 3. Download the data set 7 | 4. Run the project. 8 | 9 | Setup procedure 10 | ---------------- 11 | 1. Clone project from `GitHub `_ 12 | Change to the directory Grape-Disease-Classification. 13 | 2. Install packages 14 | In order to reproduce the code install the packages 15 | A. Manually install packages mentioned in requirements.txt file or use the command. :: 16 | 17 | pip install -r requirements.txt 18 | 19 | B. Install packages using setup.py file. :: 20 | 21 | python setup.py install 22 | 23 | The **---user** option directs setup.py to install the package 24 | in the user site-packages directory for the running Python. 25 | Alternatively, you can use the **---home** or **---prefix** option to install 26 | your package in a different location (where you have the necessary permissions) 27 | 28 | .. note:: The requirements.txt file replicates the virtual environment that I use. It has many packages 29 | that are not relevant to this project. Feel free to edit the packages list. 30 | 31 | 3. Download the required data set. 32 | The data set that is used in this project is available 33 | `here. `_ 34 | The data set includes images from `kaggle `_ 35 | grape disease data set and the images collected online and labelled using the LabelMe tool. 36 | Download the zip file and extract the files in **data/raw** folder. 37 | 38 | [OR] 39 | 40 | Run the below command :: 41 | 42 | ./wgetgdrive.sh .zip 43 | 44 | drive_id is **1gsUyWEkxz9H1-yn2ONx4scHg88kWU-38**. 45 | Provide any zip_name. 46 | 47 | 4. Run the project. 48 | See **Documentation for the code** section for further details. 49 | 50 | Indices and Tables 51 | ------------------- 52 | 53 | * :ref:`genindex` 54 | * :ref:`modindex` 55 | * :ref:`search` 56 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath('../..')) 17 | sys.setrecursionlimit(1500) 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'Overview' 22 | copyright = '2020, Sanjana Srinivas' 23 | author = 'Sanjana Srinivas' 24 | 25 | # The full version, including alpha/beta/rc tags 26 | release = '1.0.0' 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'sphinx.ext.viewcode', 35 | 'sphinx.ext.githubpages', 36 | 'sphinx.ext.autodoc', 37 | 'sphinx.ext.autosummary', 38 | 'sphinx.ext.coverage', 39 | 'sphinx.ext.graphviz', 40 | 'sphinx.ext.doctest', 41 | 'sphinx.ext.intersphinx', 42 | 'sphinx.ext.todo', 43 | 'sphinx.ext.coverage', 44 | 'sphinx.ext.ifconfig', 45 | 'matplotlib.sphinxext.plot_directive', 46 | ] 47 | 48 | # Add any paths that contain templates here, relative to this directory. 49 | templates_path = ['_templates'] 50 | 51 | 52 | # List of patterns, relative to source directory, that match files and 53 | # directories to ignore when looking for source files. 54 | # This pattern also affects html_static_path and html_extra_path. 55 | exclude_patterns = [] 56 | 57 | # -- Options for HTML output ------------------------------------------------- 58 | 59 | # The theme to use for HTML and HTML Help pages. See the documentation for 60 | # a list of builtin themes. 61 | # 62 | html_theme = 'alabaster' 63 | 64 | # Add any paths that contain custom static files (such as style sheets) here, 65 | # relative to this directory. They are copied after the builtin static files, 66 | # so a file named "default.css" will overwrite the builtin "default.css". 67 | html_static_path = ['_static'] 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | augment/ 13 | intermediate/ 14 | processed/ 15 | test/ 16 | raw/ 17 | results/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | 148 | -------------------------------------------------------------------------------- /models/vgg16.py: -------------------------------------------------------------------------------- 1 | from sklearn.preprocessing import LabelEncoder 2 | import numpy as np 3 | import tensorflow as tf 4 | from tensorflow.keras.utils import to_categorical 5 | from tensorflow.keras.models import Sequential 6 | from tensorflow.keras.layers import Dense, Conv2D, Activation 7 | from tensorflow.keras.optimizers import Adam 8 | import pickle 9 | from preprocessing.utils import make_folder 10 | 11 | NUM_CLASSES = 5 12 | 13 | 14 | class Hist: 15 | """ Dummy class 16 | 17 | """ 18 | def __init__(self): 19 | pass 20 | 21 | 22 | def encoder(data, class_count): 23 | """ Transform string data to unique int. 24 | Convert unique int data to one-hot encoded data. 25 | 26 | Args: 27 | data (numpy array): array to be encoded. 28 | 29 | class_count (int): number of classes. 30 | 31 | Returns: 32 | (numpy array): one-hot encoded array. 33 | 34 | """ 35 | labeler = LabelEncoder() 36 | y = labeler.fit_transform(data) 37 | y = to_categorical(y, num_classes=class_count) 38 | return y 39 | 40 | 41 | def main(): 42 | """ Load data. 43 | Normalize and encode. 44 | Train CNN-VGG16 model. 45 | Print accuracy on test data. 46 | 47 | """ 48 | # Load stored data 49 | X_train = np.load('../data/augment/ImageAugment_input.npy') 50 | y_train = np.load('../data/augment/DiseaseAugment_input.npy') 51 | print("=== TRAIN DATA ===") 52 | print(X_train.shape) 53 | print(y_train.shape) 54 | 55 | X_test = np.load('../data/test/ImageTest_input.npy') 56 | y_test = np.load('../data/test/DiseaseTest_input.npy') 57 | print("=== TEST DATA ===") 58 | print(X_test.shape) 59 | print(y_test.shape) 60 | 61 | # hot encoding of labels 62 | y_train = encoder(y_train, NUM_CLASSES) 63 | y_test = encoder(y_test, NUM_CLASSES) 64 | 65 | # Input normalization 66 | X_train = (X_train / 255.0).astype(np.float32) 67 | X_test = (X_test / 255.0).astype(np.float32) 68 | 69 | # VGG16 CNN model 70 | IMG_SHAPE = (180, 180, 3) 71 | VGG16_MODEL = tf.keras.applications.VGG16(input_shape=IMG_SHAPE, 72 | include_top=False, 73 | weights='imagenet') 74 | 75 | VGG16_MODEL.trainable = False 76 | global_average_layer = tf.keras.layers.GlobalAveragePooling2D() 77 | prediction_layer = Dense(NUM_CLASSES, activation='softmax') 78 | 79 | model_vgg16 = Sequential([ 80 | VGG16_MODEL, 81 | Conv2D(512, kernel_size=(3, 3), padding='same'), 82 | Activation('relu'), 83 | Conv2D(1024, kernel_size=(3, 3), padding='same'), 84 | global_average_layer, 85 | prediction_layer 86 | ]) 87 | 88 | model_vgg16.compile(optimizer=Adam(), 89 | loss="categorical_crossentropy", 90 | metrics=['accuracy']) 91 | 92 | # Learning rate decay 93 | reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', 94 | factor=0.2, 95 | patience=5, 96 | min_lr=0.00001) 97 | history_custom = model_vgg16.fit(X_train, y_train, batch_size=8, 98 | epochs=20, verbose=1, 99 | validation_split=.1, 100 | callbacks=[reduce_lr]) 101 | scores = model_vgg16.evaluate(X_test, y_test, verbose=0) 102 | print("========================") 103 | print("TEST SET: %s: %.2f%%" % (model_vgg16.metrics_names[1], 104 | scores[1] * 100)) 105 | print("========================") 106 | 107 | print(model_vgg16.summary()) 108 | print("=== BASE MODEL SUMMARY ===") 109 | print(VGG16_MODEL.summary()) 110 | 111 | # save model 112 | make_folder('../results/models') 113 | model_vgg16.save('../results/models/vgg16.h5') 114 | history = dict() 115 | history['acc'] = history_custom.history['acc'] 116 | history['val_acc'] = history_custom.history['val_acc'] 117 | history['loss'] = history_custom.history['loss'] 118 | history['val_loss'] = history_custom.history['val_loss'] 119 | 120 | hist = Hist() 121 | setattr(hist, 'history', history) 122 | pickle.dump(hist, open('../results/models/vgg16_training_history.pkl', 'wb')) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /docs/source/module.rst: -------------------------------------------------------------------------------- 1 | Documentation for the code 2 | =========================== 3 | 4 | 1. Pre processing 5 | This folder contains 6 | A. Code to load the images and json(contains labelling information) files. This is present in 7 | preprocessing/001_load_data.py. To execute this code, within the 'preprocessing' folder enter the below 8 | command. :: 9 | 10 | python 001_load_data.py 11 | 12 | B. Augment data. The code is present in preproprocessing/002_data_augmentation.py. To execute, run 13 | the below command. :: 14 | 15 | python 002_data_augmentation.py 16 | 17 | C. Extract histograms of feature descriptors. Feature descriptors are used to train only 18 | random forest and SVM. The code is present in preprocessing/003_hog.py. :: 19 | 20 | python 003_hog.py 21 | 22 | 2. Models 23 | This folder contains various models used in this project namely: 24 | A. Random forest 25 | B. Support vector machine 26 | C. CNN - VGG16 27 | D. CNN - Custom 28 | E. Ensemble model - Majority voting 29 | F. Ensemble model - Stacked prediction 30 | 31 | The ensemble models are the aggregation of random forest, SVM, CNN-custom and CNN-VGG16. 32 | The models can be trained by executing the below command within the models folder. :: 33 | 34 | python .py 35 | 36 | 3. visualization.py 37 | This file contains all the visualization techniques used in this project. 38 | A. Confusion matrix, using sns heat map with modifications to display details within each box. 39 | B. Loss and Accuracy curves for Neural networks. 40 | C. Tree representation for Random forest 41 | D. ROC-AUC curves using Yellowbrick. 42 | 43 | Usage is as follows :: 44 | 45 | python visualization.py -m -t 46 | 47 | For help on available models and visualization techniques :: 48 | 49 | python visualization.py --help 50 | 51 | 4. app.py 52 | This file predicts the disease of the input image. Usage is as follows :: 53 | 54 | python app.py -m -i 55 | 56 | for help on usage :: 57 | 58 | python app.py --help 59 | 60 | Classification main (app.py) 61 | ---------------------------- 62 | .. automodule:: app 63 | :members: 64 | 65 | Classification visualization (visualization.py) 66 | ------------------------------------------------ 67 | .. automodule:: visualization 68 | :members: 69 | 70 | Pre processing - load data (001_load_data.py) 71 | ---------------------------------------------- 72 | .. automodule:: preprocessing.001_load_data 73 | :members: 74 | 75 | Pre processing - data augmentation (002_data_augmentation.py) 76 | -------------------------------------------------------------- 77 | 78 | Data augmentation techniques used are - 79 | 80 | 1. Horizontal flip 81 | 2. Vertical flip 82 | 3. Gamma correction 83 | 4. Intensity scaling 84 | 5. Random rotation 85 | 86 | .. automodule:: preprocessing.002_data_augmentation 87 | :members: 88 | 89 | Pre processing - HOG (003_hog.py) 90 | ---------------------------------- 91 | 92 | Feature descriptors are generated using Histograms of Oriented Gradients 93 | 94 | .. automodule:: preprocessing.003_hog 95 | :members: 96 | 97 | Models - CNN-Custom (cnn_custom) 98 | --------------------------------- 99 | .. automodule:: models.cnn_custom 100 | :members: 101 | 102 | Models - Majority Voting (ensemble_majority_voting.py) 103 | ------------------------------------------------------- 104 | 105 | In majority voting technique, output prediction is the one that 106 | receives more than half of the votes or the maximum 107 | number of votes. If none of the predictions get more 108 | than half of the votes or if it is a tie, we may say that 109 | the ensemble method could not make a stable 110 | prediction for this instance. In such a situation the 111 | prediction of the model with the highest accuracy is 112 | taken as the final output. 113 | 114 | .. automodule:: models.ensemble_majority_voting 115 | :members: 116 | 117 | Models - Stacked Prediction (ensemble_stacked_prediction.py) 118 | ------------------------------------------------------------- 119 | 120 | The network is trained with the array of probabilities from all 4 models. 121 | 122 | .. automodule:: models.ensemble_stacked_prediction 123 | :members: 124 | 125 | Models - Random forest (random_forest.py) 126 | ------------------------------------------ 127 | .. automodule:: models.random_forest 128 | :members: 129 | 130 | Models - SVM (svm.py) 131 | ---------------------- 132 | .. automodule:: models.svm 133 | :members: 134 | 135 | Models - VGG16 (vgg16.py) 136 | ------------------------- 137 | .. automodule:: models.vgg16 138 | :members: 139 | 140 | Indices and Tables 141 | ------------------- 142 | 143 | * :ref:`genindex` 144 | * :ref:`modindex` 145 | * :ref:`search` 146 | -------------------------------------------------------------------------------- /preprocessing/002_data_augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from skimage import transform, exposure 4 | from preprocessing.utils import make_folder 5 | 6 | 7 | def random_rotation(img): 8 | """ Randomly rotate the image. 9 | 10 | Pick a random degree of rotation between. 11 | 25% on the left and 25% on the right. 12 | 13 | Args: 14 | img (numpy array): Array of image pixels to rotate. 15 | 16 | Returns: 17 | (numpy array): Rotated image. 18 | 19 | """ 20 | random_degree = random.uniform(-25, 25) 21 | return (transform.rotate(img, random_degree, 22 | preserve_range=True)).astype(np.uint8) 23 | 24 | 25 | def horizontal_flip(img): 26 | """ Flip the image horizontally. 27 | 28 | horizontal flip doesn't need skimage, 29 | it's easy as flipping the image array of pixels! 30 | 31 | Args: 32 | img (numpy array): Array of image pixels. 33 | 34 | Returns: 35 | (numpy array): Rotated image. 36 | 37 | """ 38 | return img[:, ::-1] 39 | 40 | 41 | def intensity(img): 42 | """ Change the intensity of the image. 43 | 44 | Args: 45 | img (numpy array): Array of image pixels. 46 | 47 | Returns: 48 | (numpy array): Rotated image. 49 | 50 | """ 51 | v_min, v_max = np.percentile(img, (0.2, 99.8)) 52 | if np.abs(v_max - v_min) < 1e-3: 53 | v_max += 1e-3 54 | return exposure.rescale_intensity(img, in_range=(v_min, v_max)) 55 | 56 | 57 | def gamma(img): 58 | """ Perform gamma correction of the image. 59 | 60 | Args: 61 | img (numpy array): Array of image pixels. 62 | 63 | Returns: 64 | (numpy array): Rotated image. 65 | 66 | """ 67 | return exposure.adjust_gamma(img, gamma=0.4, gain=0.9) 68 | 69 | 70 | def vertical_flip(img): 71 | """ Flip the image vertically. 72 | 73 | vertical flip doesn't need skimage, 74 | it's easy as flipping the image array of pixels! 75 | 76 | Args: 77 | img (numpy array): Array of image pixels. 78 | 79 | Returns: 80 | (numpy array): Rotated image. 81 | 82 | """ 83 | return img[::-1, :] 84 | 85 | 86 | def data_augment(img, y_label): 87 | """ Perform image augmentation using rotation, 88 | intensity scaling, flip and gamma correction. 89 | 90 | Args: 91 | img (numpy array): Array of image pixels. 92 | 93 | y_label (str): Label of the image. 94 | 95 | Returns: 96 | (numpy array): Augmented images. 97 | 98 | (numpy array): Array of labels corresponding to the images. 99 | 100 | """ 101 | temp = [horizontal_flip(img), vertical_flip(img), 102 | random_rotation(img), gamma(img), intensity(img)] 103 | label = [y_label, y_label, y_label, y_label, y_label] 104 | return temp, label 105 | 106 | 107 | def main(): 108 | """ Load train data. 109 | Augment the data. 110 | 111 | """ 112 | # Load data 113 | X_train = np.load('../data/intermediate/ImageTrain_input.npy') 114 | y_train = np.load('../data/intermediate/DiseaseTrain_input.npy') 115 | print('TO BE AUGMENTED DATA') 116 | print(X_train.shape) 117 | print(y_train.shape) 118 | 119 | br_count = e_count = lb_count = 0 120 | transformed_img = [] 121 | y_array = [] 122 | 123 | for i, name in enumerate(y_train): 124 | if name == 'healthy': 125 | x, y = data_augment(X_train[i], name) 126 | transformed_img.extend(x) 127 | y_array.extend(y) 128 | 129 | elif (name == 'black rot') and (br_count < 450): 130 | x, y = data_augment(X_train[i], name) 131 | transformed_img.extend(x) 132 | y_array.extend(y) 133 | br_count += 1 134 | 135 | elif (name == 'ecsa') and (e_count < 321): 136 | x, y = data_augment(X_train[i], name) 137 | transformed_img.extend(x) 138 | y_array.extend(y) 139 | e_count += 1 140 | 141 | elif (name == 'leaf_blight') and (lb_count < 308): 142 | x, y = data_augment(X_train[i], name) 143 | transformed_img.extend(x) 144 | y_array.extend(y) 145 | lb_count += 1 146 | 147 | elif name == 'powdery mildew': 148 | x, y = data_augment(X_train[i], name) 149 | transformed_img.extend(x) 150 | y_array.extend(y) 151 | 152 | transformed_img = np.array(transformed_img) 153 | y_array = np.array(y_array) 154 | print('AUGMENTED DATA') 155 | print(transformed_img.shape) 156 | print(y_array.shape) 157 | 158 | # Concatenate with initial image_array 159 | X_train = np.concatenate((X_train, transformed_img), axis=0) 160 | y_train = np.concatenate((y_train, y_array), axis=0) 161 | print('TOTAL MODEL INPUT DATA') 162 | print(X_train.shape) 163 | print(y_train.shape) 164 | 165 | # Save data 166 | make_folder('../data/augment') 167 | np.save('../data/augment/ImageAugment_input.npy', X_train) 168 | np.save('../data/augment/DiseaseAugment_input.npy', y_train) 169 | 170 | 171 | if __name__ == "__main__": 172 | main() 173 | -------------------------------------------------------------------------------- /models/ensemble_majority_voting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tensorflow.keras.utils import to_categorical 4 | from tensorflow.keras.models import load_model 5 | from sklearn.preprocessing import LabelEncoder 6 | from sklearn.metrics import accuracy_score 7 | import joblib 8 | 9 | ROOT_DIR = '../results/models/' 10 | 11 | 12 | def majority_voting(rf_prediction, sv_prediction, 13 | custom_prediction, vgg_prediction): 14 | """ Find the majority vote among the predictions of the given models 15 | 16 | Args: 17 | rf_prediction (numpy array): predictions of random forest model. 18 | 19 | sv_prediction (numpy array): predictions of SVM model. 20 | 21 | custom_prediction (numpy array): predictions of custom CNN model. 22 | 23 | vgg_prediction (numpy array): predictions of CNN-VGG16 model. 24 | 25 | Returns: 26 | (numpy array): predictions of ensemble-majority voting model 27 | 28 | """ 29 | # loop over all predictions 30 | final_prediction = list() 31 | for rf, sv, custom, vgg in zip(rf_prediction, 32 | sv_prediction, 33 | custom_prediction, 34 | vgg_prediction): 35 | # Keep track of votes per class 36 | br = e = h = lb = pm = 0 37 | 38 | # Loop over all models 39 | image_predictions = [rf, sv, custom, vgg] 40 | for img_prediction in image_predictions: 41 | # Voting 42 | if img_prediction == 'black rot': 43 | br += 1 44 | elif img_prediction == 'ecsa': 45 | e += 1 46 | elif img_prediction == 'healthy': 47 | h += 1 48 | elif img_prediction == 'leaf_blight': 49 | lb += 1 50 | elif img_prediction == 'powdery mildew': 51 | pm += 1 52 | 53 | # Find max vote 54 | count_dict = {'br': br, 'e': e, 'h': h, 'lb': lb, 'pm': pm} 55 | highest = max(count_dict.values()) 56 | max_values = [k for k, v in count_dict.items() if v == highest] 57 | ensemble_prediction = [] 58 | for max_value in max_values: 59 | if max_value == 'br': 60 | ensemble_prediction.append('black rot') 61 | elif max_value == 'e': 62 | ensemble_prediction.append('ecsa') 63 | elif max_value == 'h': 64 | ensemble_prediction.append('healthy') 65 | elif max_value == 'lb': 66 | ensemble_prediction.append('leaf_blight') 67 | elif max_value == 'pm': 68 | ensemble_prediction.append('powdery mildew') 69 | 70 | predict = '' 71 | if len(ensemble_prediction) > 1: 72 | predict = custom 73 | else: 74 | predict = ensemble_prediction[0] 75 | 76 | # Store max vote 77 | final_prediction.append(predict) 78 | 79 | return np.array(final_prediction) 80 | 81 | 82 | def main(): 83 | """ Load data. 84 | Normalize and encode. 85 | Train ensemble-majority voting model. 86 | Print accuracy of the model. 87 | 88 | """ 89 | X_test1 = np.load('../data/processed/ImageTestHOG_input.npy') 90 | X_test2 = np.load('../data/test/ImageTest_input.npy') 91 | y_test1 = np.load('../data/test/DiseaseTest_input.npy') 92 | print("=== TEST DATA ===") 93 | print(X_test1.shape) 94 | print(X_test2.shape) 95 | print(y_test1.shape) 96 | 97 | # hot encoding of labels 98 | labeler = LabelEncoder() 99 | y_test2 = labeler.fit_transform(y_test1) 100 | y_test2 = to_categorical(y_test2, num_classes=5) 101 | 102 | try: 103 | rf_model = joblib.load(os.path.join(ROOT_DIR, 'Random_model.sav')) 104 | sv_model = joblib.load(os.path.join(ROOT_DIR, 'SVM_model.sav')) 105 | custom_model = load_model(os.path.join(ROOT_DIR, 'custom.h5')) 106 | vgg_model = load_model(os.path.join(ROOT_DIR, 'vgg16.h5')) 107 | 108 | # Normalize image for CNN 109 | X_test2 = (X_test2 / 255.0).astype(np.float32) 110 | 111 | rf_prediction = rf_model.predict(X_test1) 112 | sv_prediction = sv_model.predict(X_test1) 113 | custom_prediction = np.argmax(custom_model.predict(X_test2), axis=-1) 114 | custom_prediction = labeler.inverse_transform(custom_prediction) 115 | vgg_prediction = np.argmax(vgg_model.predict(X_test2), axis=-1) 116 | vgg_prediction = labeler.inverse_transform(vgg_prediction) 117 | 118 | final_prediction = majority_voting(rf_prediction, 119 | sv_prediction, 120 | custom_prediction, 121 | vgg_prediction) 122 | # Compute accuracy 123 | print("ACCURACY:", accuracy_score(y_test1, final_prediction)) 124 | 125 | # Save model 126 | np.save(os.path.join(ROOT_DIR, 'Ensemble.npy'), final_prediction) 127 | 128 | except FileNotFoundError as err: 129 | print('[ERROR] Train random forest, SVM, CNN-custom ' 130 | 'and VGG16 models before executing ensemble model!') 131 | print('[ERROR MESSAGE]', err) 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /models/cnn_custom.py: -------------------------------------------------------------------------------- 1 | from sklearn.preprocessing import LabelEncoder 2 | import numpy as np 3 | import tensorflow as tf 4 | from tensorflow.keras.utils import to_categorical 5 | from tensorflow.keras.models import Sequential 6 | import tensorflow.keras.layers as layers 7 | from tensorflow.keras.optimizers import Adam 8 | import pickle 9 | from preprocessing.utils import make_folder 10 | 11 | NUM_CLASSES = 5 12 | 13 | 14 | class Hist: 15 | """ Dummy class 16 | 17 | """ 18 | def __init__(self): 19 | pass 20 | 21 | 22 | def encoder(data, class_count): 23 | """ Transform string data to unique int. 24 | Convert unique int data to one-hot encoded data. 25 | 26 | Args: 27 | data (numpy array): array to be encoded. 28 | 29 | class_count (int): number of classes. 30 | 31 | Returns: 32 | (numpy array): one-hot encoded array. 33 | 34 | """ 35 | labeler = LabelEncoder() 36 | y = labeler.fit_transform(data) 37 | y = to_categorical(y, num_classes=class_count) 38 | return y 39 | 40 | 41 | def main(): 42 | """ Load data. 43 | Normalize and encode. 44 | Train custom CNN model. 45 | Print accuracy on test data. 46 | 47 | """ 48 | # Load stored data 49 | X_train = np.load('../data/augment/ImageAugment_input.npy') 50 | y_train = np.load('../data/augment/DiseaseAugment_input.npy') 51 | print("=== TRAIN DATA ===") 52 | print(X_train.shape) 53 | print(y_train.shape) 54 | 55 | X_test = np.load('../data/test/ImageTest_input.npy') 56 | y_test = np.load('../data/test/DiseaseTest_input.npy') 57 | print("=== TEST DATA ===") 58 | print(X_test.shape) 59 | print(y_test.shape) 60 | 61 | # hot encoding of labels 62 | y_train = encoder(y_train, NUM_CLASSES) 63 | y_test = encoder(y_test, NUM_CLASSES) 64 | 65 | # Input normalization 66 | X_train = (X_train / 255.0).astype(np.float32) 67 | X_test = (X_test / 255.0).astype(np.float32) 68 | 69 | # Custom CNN model 70 | model_custom = Sequential(( 71 | layers.Conv2D(32, kernel_size=(3, 3), padding='same', 72 | input_shape=(180, 180, 3)), 73 | layers.Activation('relu'), 74 | layers.Conv2D(32, kernel_size=(3, 3), padding='same'), 75 | layers.Activation('relu'), 76 | layers.MaxPooling2D(pool_size=(2, 2)), 77 | 78 | layers.Conv2D(32, kernel_size=(3, 3), padding='same'), 79 | layers.Activation('relu'), 80 | layers.Conv2D(32, kernel_size=(3, 3), padding='same'), 81 | layers.Activation('relu'), 82 | layers.MaxPooling2D(pool_size=(2, 2)), 83 | 84 | layers.Conv2D(32, kernel_size=(3, 3), padding='same'), 85 | layers.Activation('relu'), 86 | layers.Conv2D(32, kernel_size=(3, 3), padding='same'), 87 | layers.Activation('relu'), 88 | layers.MaxPooling2D(pool_size=(2, 2)), 89 | 90 | layers.Conv2D(32, kernel_size=(3, 3), padding='same'), 91 | layers.Activation('relu'), 92 | layers.Conv2D(32, kernel_size=(3, 3), padding='same'), 93 | layers.Activation('relu'), 94 | layers.MaxPooling2D(pool_size=(2, 2)), 95 | 96 | layers.Conv2D(32, kernel_size=(3, 3), padding='same'), 97 | layers.Activation('relu'), 98 | layers.Conv2D(32, kernel_size=(3, 3), padding='same'), 99 | layers.Activation('relu'), 100 | layers.MaxPooling2D(pool_size=(2, 2)), 101 | 102 | layers.Flatten(), 103 | layers.Dropout(0.5), 104 | layers.Dense(128), 105 | layers.Activation('relu'), 106 | layers.Dense(NUM_CLASSES, activation='softmax'))) 107 | 108 | model_custom.compile(optimizer=Adam(), 109 | loss="categorical_crossentropy", 110 | metrics=['accuracy']) 111 | 112 | # Learning rate decay 113 | reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', 114 | factor=0.2, 115 | patience=5, 116 | min_lr=0.00001) 117 | history_custom = model_custom.fit(X_train, y_train, batch_size=8, 118 | epochs=1, verbose=1, 119 | validation_split=.1, 120 | callbacks=[reduce_lr]) 121 | scores = model_custom.evaluate(X_test, y_test, verbose=0) 122 | print("========================") 123 | print("TEST SET: %s: %.2f%%" % (model_custom.metrics_names[1], 124 | scores[1] * 100)) 125 | print("========================") 126 | 127 | print(model_custom.summary()) 128 | 129 | # save model 130 | make_folder('../results/models/') 131 | model_custom.save('../results/models/custom.h5') 132 | 133 | history = dict() 134 | history['acc'] = history_custom.history['acc'] 135 | history['val_acc'] = history_custom.history['val_acc'] 136 | history['loss'] = history_custom.history['loss'] 137 | history['val_loss'] = history_custom.history['val_loss'] 138 | 139 | hist = Hist() 140 | setattr(hist, 'history', history) 141 | pickle.dump(hist, open('../results/models/custom_training_history.pkl', 'wb')) 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import matplotlib.pyplot as plt 5 | from scipy.special import softmax 6 | from tensorflow.keras.models import load_model 7 | from sklearn.preprocessing import LabelEncoder 8 | import joblib 9 | from preprocessing.utils import make_folder 10 | 11 | ROOT_DIR = 'results/models/' 12 | 13 | 14 | def plot(image, label, index, model, x_hog): 15 | """ Display image, true label and predicted label. 16 | Save the result in the output folder 17 | 18 | Args: 19 | image (numpy array): image to predict. 20 | 21 | label (numpy array): true labels of corresponding images. 22 | 23 | index (int): index of the test image entered by the user. 24 | 25 | model (str): model to use. (Entered by user) 26 | 27 | x_hog (numpy array): feature descriptors of the image. 28 | 29 | """ 30 | plt.figure(figsize=(8, 6)) 31 | plt.imshow(image[index]) 32 | plt.axis('off') 33 | plt.title('True label: {}'.format(label[index]), 34 | fontdict={'fontweight': 'bold', 'fontsize': 'x-large'}) 35 | predictions, percent = model_predict(model, image, x_hog, label) 36 | if model == "majority_voting": 37 | if predictions[index] == label[index]: 38 | plt.suptitle('Predicted label: {}'.format(predictions[index]), 39 | color="green") 40 | else: 41 | plt.suptitle('Predicted label: {}'.format(predictions[index]), 42 | color="red") 43 | 44 | else: 45 | if predictions[index] == label[index]: 46 | plt.suptitle('Predicted label: {} ({:.2f} %)'.format(predictions[index], 47 | np.max(percent[index]) * 100), 48 | color="green") 49 | else: 50 | plt.suptitle('Predicted label: {} ({:.2f} %)'.format(predictions[index], 51 | np.max(percent[index]) * 100), 52 | color="red") 53 | 54 | make_folder('results/visualization') 55 | plt.savefig('results/visualization/app.png', bbox_inches='tight') 56 | 57 | 58 | def model_predict(model, x, hog, y): 59 | """ Load the given model and predict the test images 60 | 61 | Args: 62 | model (str): model as entered by the user. 63 | 64 | x (numpy array): images to predict. 65 | 66 | hog (numpy array): feature descriptors of the images. 67 | 68 | y (numpy array): labels of the corresponding image. 69 | 70 | Returns: 71 | predictions (numpy array): predicted labels of the corresponding image. 72 | 73 | percent (numpy array): Accuracy of corresponding predictions. 74 | 75 | """ 76 | predictions = [] 77 | percent = [] 78 | labeler = LabelEncoder() 79 | labeler.fit(y) 80 | if model == "random_forest": 81 | rf_model = joblib.load(os.path.join(ROOT_DIR, 'Random_model.sav')) 82 | 83 | predictions = rf_model.predict(hog) 84 | percent = rf_model.predict_proba(hog) 85 | 86 | elif model == "svm": 87 | sv_model = joblib.load(os.path.join(ROOT_DIR, 'SVM_model.sav')) 88 | 89 | predictions = sv_model.predict(hog) 90 | percent = sv_model.decision_function(hog) 91 | percent = softmax(percent, axis=1) 92 | 93 | elif model == "custom_cnn": 94 | custom_model = load_model(os.path.join(ROOT_DIR, 'custom.h5')) 95 | 96 | percent = custom_model.predict(x) 97 | predictions = np.argmax(percent, axis=-1) 98 | predictions = labeler.inverse_transform(predictions) 99 | 100 | elif model == "vgg": 101 | vgg_model = load_model(os.path.join(ROOT_DIR, 'vgg16.h5')) 102 | 103 | percent = vgg_model.predict(x) 104 | predictions = np.argmax(percent, axis=-1) 105 | predictions = labeler.inverse_transform(predictions) 106 | 107 | elif model == "majority_voting": 108 | predictions = np.load(os.path.join(ROOT_DIR, 'Ensemble.npy')) 109 | 110 | elif model == "stacked_prediction": 111 | en_model = load_model(os.path.join(ROOT_DIR, 'custom_ensemble.h5')) 112 | 113 | percent = en_model.predict(np.load('data/test/X_test_ensemble.npy')) 114 | predictions = np.argmax(percent, 115 | axis=-1) 116 | predictions = labeler.inverse_transform(predictions) 117 | 118 | return predictions, percent 119 | 120 | 121 | def main(): 122 | """ Predict the disease of the given image. 123 | 124 | Usage example: 125 | python app.py -m vgg -i 49 126 | 127 | """ 128 | # construct the argument parser and parse the arguments 129 | ap = argparse.ArgumentParser() 130 | ap.add_argument("-i", "--image", type=int, required=True, 131 | help="index of the test image") 132 | ap.add_argument("-m", "--model", type=str, required=True, 133 | choices=("vgg", "custom_cnn", 134 | "svm", "random_forest", 135 | "majority_voting", 136 | "stacked_prediction"), 137 | help="model to be used") 138 | args = vars(ap.parse_args()) 139 | 140 | X_image = np.load('data/test/ImageTest_input.npy') 141 | X_processed = np.load('data/processed/ImageTestHOG_input.npy') 142 | y = np.load('data/test/DiseaseTest_input.npy') 143 | 144 | plot(X_image, y, args["image"], args["model"], X_processed) 145 | 146 | 147 | if __name__ == "__main__": 148 | main() 149 | -------------------------------------------------------------------------------- /models/ensemble_stacked_prediction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.preprocessing import LabelEncoder 4 | from scipy.special import softmax 5 | import tensorflow as tf 6 | from tensorflow.keras.utils import to_categorical 7 | from tensorflow.keras.models import Sequential 8 | from tensorflow.keras.layers import Dropout, Dense, Activation 9 | from tensorflow.keras.optimizers import Adam 10 | from tensorflow.keras.models import load_model 11 | import pickle 12 | import joblib 13 | 14 | NUM_CLASS = 5 15 | ROOT_DIR = '../results/models/' 16 | 17 | 18 | class Hist: 19 | """ Dummy class 20 | 21 | """ 22 | def __init__(self): 23 | pass 24 | 25 | 26 | def encoder(data, class_count): 27 | """ Transform string data to unique int. 28 | Convert unique int data to one-hot encoded data. 29 | 30 | Args: 31 | data (numpy array): array to be encoded. 32 | 33 | class_count (int): number of classes. 34 | 35 | Returns: 36 | (numpy array): one-hot encoded array. 37 | 38 | """ 39 | labeler = LabelEncoder() 40 | y = labeler.fit_transform(data) 41 | y = to_categorical(y, num_classes=class_count) 42 | return y 43 | 44 | 45 | def get_predictions(x_hog, x, model): 46 | """ Get predictions of the models 47 | 48 | Args: 49 | x_hog (numpy array): array HOG feature descriptor. 50 | 51 | x (numpy array): number of classes. 52 | 53 | model (numpy array): array of models used. 54 | 55 | Returns: 56 | (numpy array): one-hot encoded array. 57 | 58 | """ 59 | rf_prediction = model[0].predict_proba(x_hog) 60 | sv_prediction = model[1].decision_function(x_hog) 61 | sv_prediction = softmax(sv_prediction, axis=1) 62 | custom_prediction = model[2].predict(x) 63 | vgg_prediction = model[3].predict(x) 64 | return np.concatenate([rf_prediction, 65 | sv_prediction, 66 | custom_prediction, 67 | vgg_prediction], axis=-1) 68 | 69 | 70 | def main(): 71 | """ Load data. 72 | Normalize and encode. 73 | Train CNN-VGG16 model. 74 | Print accuracy on test data. 75 | 76 | """ 77 | X_test1 = np.load('../data/processed/ImageTestHOG_input.npy') 78 | X_test2 = np.load('../data/test/ImageTest_input.npy') 79 | y_test1 = np.load('../data/test/DiseaseTest_input.npy') 80 | print("=== TEST DATA ===") 81 | print(X_test1.shape) 82 | print(X_test2.shape) 83 | print(y_test1.shape) 84 | 85 | X_train1 = np.load('../data/processed/ImageTrainHOG_input.npy') 86 | X_train2 = np.load('../data/augment/ImageAugment_input.npy') 87 | y_train1 = np.load('../data/augment/DiseaseAugment_input.npy') 88 | print("=== TRAIN DATA ===") 89 | print(X_train1.shape) 90 | print(X_train2.shape) 91 | print(y_train1.shape) 92 | 93 | # Normalize images 94 | X_train2 = (X_train2 / 255.0).astype(np.float32) 95 | X_test2 = (X_test2 / 255.0).astype(np.float32) 96 | 97 | # hot encoding of labels 98 | y_test2 = encoder(y_test1, NUM_CLASS) 99 | y_train2 = encoder(y_train1, NUM_CLASS) 100 | 101 | try: 102 | rf_model = joblib.load(os.path.join(ROOT_DIR, 'Random_model.sav')) 103 | sv_model = joblib.load(os.path.join(ROOT_DIR, 'SVM_model.sav')) 104 | custom_model = load_model(os.path.join(ROOT_DIR, 'custom.h5')) 105 | vgg_model = load_model(os.path.join(ROOT_DIR, 'vgg16.h5')) 106 | 107 | X_train_f = get_predictions(X_train1, X_train2, 108 | [rf_model, sv_model, custom_model, vgg_model]) 109 | X_test_f = get_predictions(X_test1, X_test2, 110 | [rf_model, sv_model, custom_model, vgg_model]) 111 | np.save('../data/test/X_test_ensemble.npy', X_test_f) 112 | 113 | # Custom CNN model 114 | model_custom = Sequential([ 115 | Dense(128, input_shape=(20,)), 116 | Activation('relu'), 117 | Dropout(0.5), 118 | Dense(256), 119 | Activation('relu'), 120 | Dense(NUM_CLASS, activation='softmax')]) 121 | 122 | model_custom.compile(optimizer=Adam(), 123 | loss="categorical_crossentropy", 124 | metrics=['accuracy']) 125 | 126 | # Learning rate decay 127 | reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', 128 | factor=0.2, 129 | patience=5, 130 | min_lr=0.00001) 131 | history_custom = model_custom.fit(X_train_f, y_train2, 132 | batch_size=64, 133 | epochs=1, verbose=1, 134 | validation_split=.1, 135 | callbacks=[reduce_lr]) 136 | scores = model_custom.evaluate(X_test_f, y_test2, verbose=0) 137 | print("========================") 138 | print("TEST SET: %s: %.2f%%" % (model_custom.metrics_names[1], 139 | scores[1] * 100)) 140 | print("========================") 141 | 142 | print(model_custom.summary()) 143 | 144 | # save model 145 | model_custom.save(os.path.join(ROOT_DIR, 'custom_ensemble.h5')) 146 | 147 | history = dict() 148 | history['acc'] = history_custom.history['acc'] 149 | history['val_acc'] = history_custom.history['val_acc'] 150 | history['loss'] = history_custom.history['loss'] 151 | history['val_loss'] = history_custom.history['val_loss'] 152 | 153 | hist = Hist() 154 | setattr(hist, 'history', history) 155 | pickle.dump(hist, open(os.path.join(ROOT_DIR, 'stacked_training_history.pkl'), 'wb')) 156 | 157 | except FileNotFoundError as err: 158 | print('[ERROR] Train random forest, SVM, CNN-custom ' 159 | 'and VGG16 models before executing ensemble model!') 160 | print('[ERROR MESSAGE]', err) 161 | 162 | 163 | if __name__ == "__main__": 164 | main() 165 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Grape Disease Detection 2 | 3 | This project classifies diseases in grape plant using various Machine Learning classification algorithms. 4 | Grape plants are susceptible to various diseases The diseases that are classified in this project are: 5 | 6 | 1. Black rot 7 | 2. Black Measles (esca) 8 | 3. Powdery mildew 9 | 4. Leaf blight 10 | 5. Healthy 11 | 12 | The Machine learning classification models in this project includes: 13 | 14 | 1. Random forest classification 15 | 2. Support vector machine classification 16 | 3. CNN - VGG16 17 | 4. CNN - Custom 18 | 5. Ensemble model - Majority voting 19 | 6. Ensemble model - Stacked prediction 20 | 21 | Configuration of Project Environment 22 | ===================================== 23 | 24 | 1. Clone the project. 25 | 2. Install packages required. 26 | 3. Download the data set 27 | 4. Run the project. 28 | 29 | Setup procedure 30 | ---------------- 31 | 1. Clone project from [GitHub](https://github.com/Sanjana7395/Grape-disease-classification.git). 32 | Change to the directory Grape-Disease-Classification. 33 | 2. Install packages 34 | In order to reproduce the code install the packages 35 | 36 | 1. Manually install packages mentioned in requirements.txt file or use the command. 37 | 38 | pip install -r requirements.txt 39 | 40 | 2. Install packages using setup.py file. 41 | 42 | python setup.py install 43 | 44 | The **--user** option directs setup.py to install the package 45 | in the user site-packages directory for the running Python. 46 | Alternatively, you can use the **--home** or **--prefix** option to install 47 | your package in a different location (where you have the necessary permissions) 48 | 49 | 3. Download the required data set. 50 | The data set that is used in this project is available 51 | [here](https://drive.google.com/drive/folders/1SFBc-dNzr325jHw434j8LYyCii6djzkC?usp=sharing). 52 | The data set includes images from [kaggle](https://www.kaggle.com/xabdallahali/plantvillage-dataset) 53 | grape disease data set and the images collected online and labelled using the **LabelMe** tool. 54 | Download the zip file and extract the files in **data/raw** folder. 55 | 56 | [OR] 57 | 58 | Run the below command 59 | 60 | ./wgetgdrive.sh .zip 61 | 62 | drive_id is **1gsUyWEkxz9H1-yn2ONx4scHg88kWU-38** 63 | Provide any zip_name. 64 | 65 | 4. Run the project. 66 | See **Documentation for the code** section for further details. 67 | 68 | Documentation for the code 69 | =========================== 70 | 71 | 1. __Pre processing__ 72 | This folder contains 73 | 74 | 1. Code to load the images and json(contains labelling information) files. This is present in 75 | preprocessing/001_load_data.py. To execute this code, within the 'preprocessing' folder enter the below 76 | command 77 | 78 | python 001_load_data.py 79 | 80 | 2. Augment data. The code is present in preproprocessing/002_data_augmentation.py. To execute, run 81 | the below command 82 | 83 | python 002_data_augmentation.py 84 | 85 | The data augmentation techniques used are 86 | - Horizontal flip 87 | - Vertical flip 88 | - Random rotation 89 | - Intensity scaling 90 | - Gamma correction 91 | 92 | 3. Extract histograms of feature descriptors. Feature descriptors are used to train only 93 | random forest and SVM. The code is present in preprocessing/003_hog.py 94 | 95 | python 003_hog.py 96 | 97 | 2. __Models__ 98 | This folder contains various models used in this project namely: 99 | 100 | 1. Random forest 101 | 2. Support vector machine 102 | 3. CNN - VGG16 103 | 4. CNN - Custom 104 | 5. Ensemble model - Majority voting 105 | In majority voting technique, output prediction is the one that 106 | receives more than half of the votes or the maximum 107 | number of votes. If none of the predictions get more 108 | than half of the votes or if it is a tie, we may say that 109 | the ensemble method could not make a stable 110 | prediction for this instance. In such a situation the 111 | prediction of the model with the highest accuracy is 112 | taken as the final output. 113 | 6. Ensemble model - Stacked prediction 114 | The network is trained with the array of probabilities from all 4 models. 115 | 116 | The ensemble models are the aggregation of random forest, SVM, CNN-custom and CNN-VGG16. 117 | 118 | The models can be trained by executing the below command within the models folder 119 | 120 | python .py 121 | 122 | 3. __visualization.py__ 123 | This file contains all the visualization techniques used in this project. 124 | 1. Confusion matrix, using sns heat map with modifications to display details within each box. 125 | 2. Loss and Accuracy curves for Neural networks. 126 | 3. Tree representation for Random forest 127 | 4. ROC-AUC curves using Yellowbrick. 128 | 129 | Usage is as follows 130 | 131 | python visualization.py -m -t 132 | 133 | For help on available models and visualization techniques 134 | 135 | python visualization.py --help 136 | 137 | 4. __app.py__ 138 | This file predicts the disease of the input image. Usage is as follows 139 | 140 | python app.py -m -i 141 | 142 | for help on usage 143 | 144 | python app.py --help 145 | 146 | Results 147 | ======== 148 | 149 | Below are the results obtained on the test set for various models trained in the project. 150 | 151 | > NOTE 152 | The results obtained are system specific. Due to different combinations of the neural 153 | network cudnn library versions and NVIDIA driver library versions, the results can be 154 | slightly different. To the best of my knowledge, upon reproducing the environment, the 155 | ballpark number will be close to the results obtained. 156 | 157 | | Models | Accuracy (%) | 158 | |----------------------------------|:-------------:| 159 | | Random forest | 75.35 | 160 | | SVM | 82.89 | 161 | | CNN - VGG16 | 93.62 | 162 | | Ensemble - Majority voting | 98.05 | 163 | | Ensemble - Stacked prediction | 98.23 | 164 | | CNN - Custom | 98.76 | 165 | 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /preprocessing/001_load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import json 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | from sklearn.model_selection import train_test_split 9 | from sklearn.utils import shuffle 10 | from preprocessing.utils import make_folder 11 | 12 | ROOT_DIR = '../data/raw/' 13 | 14 | 15 | def get_json_data(json_path): 16 | """ Get data from json file and store in a table. 17 | 18 | The images are labelled using LabelMe tool. 19 | The labelled bounding box details are stored as json file. 20 | This function extract details from the json file in the given path. 21 | 22 | Args: 23 | json_path (str): Path of the json file. 24 | 25 | Returns: 26 | (Data frame): Contains path of the image, co-ordinates and label. 27 | 28 | """ 29 | json_files = [pos_json for pos_json in os.listdir(json_path) 30 | if pos_json.endswith('.json')] 31 | # store fields from json file in data frame 32 | table_data = pd.DataFrame(columns=['path', 'col1', 'col2', 33 | 'row1', 'row2', 'label']) 34 | index = 0 35 | 36 | for js in json_files: 37 | with open(os.path.join(json_path, js)) as file: 38 | # load json file 39 | json_text = json.load(file) 40 | 41 | for x in json_text['shapes']: 42 | path = json_text['imagePath'] 43 | points = x['points'] 44 | # extract image section within bounding box 45 | if x['shape_type'] == 'rectangle': 46 | col1 = int(min(points[0][1], points[1][1])) 47 | col2 = int(max(points[0][1], points[1][1])) 48 | row1 = int(min(points[0][0], points[1][0])) 49 | row2 = int(max(points[0][0], points[1][0])) 50 | else: 51 | col1 = int(min(points[0][1], points[3][1])) 52 | col2 = int(max(points[1][1], points[2][1])) 53 | row1 = int(min(points[0][0], points[1][0])) 54 | row2 = int(max(points[2][0], points[3][0])) 55 | label = x['label'] 56 | if label == 'black measles': 57 | label = 'ecsa' 58 | 59 | table_data.loc[index] = [path, col1, col2, row1, row2, label] 60 | index += 1 61 | return table_data 62 | 63 | 64 | def resize_with_aspect_ratio(img, size, interpolation): 65 | """ Resize image to maintain aspect ratio. 66 | 67 | Args: 68 | img (numpy array): Image to resize. 69 | 70 | size (int): Size to which needs to be resized. 71 | 72 | interpolation (str): Interpolation method to use in order to resize. 73 | 74 | Returns: 75 | (array): Resized image. 76 | 77 | """ 78 | h, w = img.shape[:2] 79 | c = None if len(img.shape) < 3 else img.shape[2] 80 | # if h=w no padding 81 | if h == w: 82 | return cv2.resize(img, (size, size), interpolation) 83 | # if h!=w, make h=w by padding 0. 84 | if h > w: 85 | dif = h 86 | else: 87 | dif = w 88 | x_pos = int((dif - w) / 2.) 89 | y_pos = int((dif - h) / 2.) 90 | if c is None: 91 | mask = np.zeros((dif, dif), dtype=img.dtype) 92 | mask[y_pos:y_pos + h, x_pos:x_pos + w] = img[:h, :w] 93 | else: 94 | mask = np.zeros((dif, dif, c), dtype=img.dtype) 95 | mask[y_pos:y_pos + h, x_pos:x_pos + w, :] = img[:h, :w, :] 96 | 97 | return cv2.resize(mask, (size, size), interpolation) 98 | 99 | 100 | def get_images(img_path, js=None, valid=[".jpg", ".jpeg", ".png"], name=None): 101 | """ Get images from the path and store as numpy array. 102 | 103 | There are two sets of data set. 104 | 105 | 1. Kaggle data set that are used as is 106 | 2. Google images that are labelled using LabelMe tool 107 | 108 | Args: 109 | img_path (str): Path of the images folder. 110 | 111 | js (data frame, optional): Labelling info table.Initialized when the image is a Google images. 112 | Else default is None. 113 | 114 | valid (list): list of valid data types. Defaults are .jpeg, .jpg, .png. 115 | 116 | name (str): Image label. Initialized with the image label when data set used is Kaggle images. 117 | Else default is None. Label comes from 'js' table 118 | 119 | Returns: 120 | (numpy array): Array of desired images. 121 | 122 | (numpy array): Array of labels of corresponding images in the above image array. 123 | 124 | """ 125 | images = [] 126 | labels = [] 127 | 128 | for f in os.listdir(img_path): 129 | # check for image files only 130 | ext = os.path.splitext(f)[1] 131 | if ext.lower() not in valid: 132 | continue 133 | 134 | # store original image 135 | img = plt.imread(os.path.join(img_path, f)) 136 | 137 | # kaggle data set 138 | if name: 139 | resized_img = cv2.resize(img, (180, 180), cv2.INTER_AREA) 140 | images.append(resized_img) 141 | labels.append(name) 142 | 143 | # google images 144 | else: 145 | # find corresponding json files 146 | for index, j in enumerate(js.path): 147 | if j == f: 148 | right_file = js.iloc[index] 149 | 150 | cut = img[right_file.col1:right_file.col2, 151 | right_file.row1:right_file.row2] 152 | resized_img = resize_with_aspect_ratio(cut, 153 | 180, 154 | cv2.INTER_AREA) 155 | images.append(resized_img) 156 | labels.append(right_file.label) 157 | image_arr = np.array(images) 158 | label_arr = np.array(labels) 159 | print(image_arr.shape) 160 | return image_arr, label_arr 161 | 162 | 163 | def main(): 164 | """ Load images and json files from all folders and concatenate to form a single array. 165 | Shuffle the array. 166 | Split into test and train data sets. 167 | 168 | """ 169 | print('INFO: Extracting json data...') 170 | # Accumulate data from json files 171 | json_df_images = get_json_data(os.path.join(ROOT_DIR, 'images/')) 172 | json_df_positive = get_json_data(os.path.join(ROOT_DIR, 'positive/')) 173 | json_df_healthy = get_json_data(os.path.join(ROOT_DIR, 'healthy/')) 174 | json_df_team4 = get_json_data(os.path.join(ROOT_DIR, 'team4/')) 175 | json_df_team4_br = get_json_data(os.path.join(ROOT_DIR, 'team4_br/')) 176 | json_df_leaf_blight = get_json_data(os.path.join(ROOT_DIR, 'leaf_blight/')) 177 | 178 | # Accumulate data set from all folders 179 | print('INFO: Extracting images and corresponding labels...') 180 | array1, disease1 = get_images(os.path.join(ROOT_DIR, 'Grape/Black_rot/'), 181 | name='black rot') 182 | array2, disease2 = get_images(os.path.join(ROOT_DIR, 'Grape/Esca/'), 183 | name='ecsa') 184 | array3, disease3 = get_images(os.path.join(ROOT_DIR, 'Grape/Leaf_blight/'), 185 | name='leaf_blight') 186 | array4, disease4 = get_images(os.path.join(ROOT_DIR, 'Grape/healthy/'), 187 | name='healthy') 188 | array5, disease5 = get_images(os.path.join(ROOT_DIR, 'images/'), 189 | js=json_df_images) 190 | array6, disease6 = get_images(os.path.join(ROOT_DIR, 'positive/'), 191 | js=json_df_positive) 192 | array7, disease7 = get_images(os.path.join(ROOT_DIR, 'healthy/'), 193 | js=json_df_healthy) 194 | array8, disease8 = get_images(os.path.join(ROOT_DIR, 'team4/'), 195 | js=json_df_team4) 196 | array9, disease9 = get_images(os.path.join(ROOT_DIR, 'team4_br/'), 197 | js=json_df_team4_br) 198 | array10, disease10 = get_images(os.path.join(ROOT_DIR, 'leaf_blight/'), 199 | js=json_df_leaf_blight) 200 | 201 | # Concatenate data 202 | disease_arr = np.concatenate((disease1, disease2, 203 | disease3, disease4, 204 | disease5, disease6, 205 | disease7, disease8, 206 | disease9, disease10), axis=0) 207 | print('=== TOTAL DATA ===') 208 | print(disease_arr.shape) 209 | img_arr = np.concatenate((array1, array2, 210 | array3, array4, 211 | array5, array6, 212 | array7, array8, 213 | array9, array10), axis=0) 214 | print(img_arr.shape) 215 | 216 | # Shuffle data 217 | img_arr, disease_arr = shuffle(img_arr, disease_arr, random_state=42) 218 | print(np.unique(disease_arr)) 219 | 220 | # split train set and test set 221 | X_train, X_test, y_train, y_test = train_test_split(img_arr, disease_arr, 222 | test_size=0.2, 223 | random_state=42) 224 | print('=== TRAIN TEST SPLIT ===') 225 | print(X_test.shape) 226 | print(X_train.shape) 227 | 228 | # Save data 229 | make_folder('../data/test') 230 | make_folder('../data/intermediate') 231 | np.save('../data/test/ImageTest_input.npy', X_test) 232 | np.save('../data/test/DiseaseTest_input.npy', y_test) 233 | np.save('../data/intermediate/ImageTrain_input.npy', X_train) 234 | np.save('../data/intermediate/DiseaseTrain_input.npy', y_train) 235 | 236 | 237 | if __name__ == "__main__": 238 | main() 239 | -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import seaborn as sns 6 | import yellowbrick 7 | from sklearn.metrics import confusion_matrix 8 | from sklearn.preprocessing import LabelEncoder 9 | from sklearn.tree import export_graphviz 10 | from tensorflow.keras.models import load_model 11 | import joblib 12 | import pickle 13 | from preprocessing.utils import make_folder 14 | 15 | ROOT_DIR = 'results/models/' 16 | 17 | 18 | class Hist: 19 | """ Dummy class 20 | 21 | """ 22 | 23 | def __init__(self): 24 | pass 25 | 26 | 27 | def visualize(visual_type, model, x, y): 28 | """ Execute function depending on the user input 'type' 29 | 30 | Args: 31 | visual_type (str): type of visualization technique to plot. 32 | 33 | model (str): model for which visualization needs to be plot. 34 | 35 | x (numpy array): test images. 36 | 37 | y (numpy array): test labels. 38 | 39 | """ 40 | if visual_type == "confusion_matrix": 41 | con_matrix(model, x, y) 42 | 43 | elif visual_type == "acc_loss": 44 | plot(model) 45 | 46 | elif visual_type == "tree": 47 | tree() 48 | 49 | elif visual_type == "ROC": 50 | roc(x, y) 51 | 52 | 53 | def roc(x, y): 54 | """ Plot ROC-AUC plot for random forest model. 55 | Save the image in output folder. 56 | 57 | Args: 58 | x (numpy array): test images. 59 | 60 | y (numpy array): test labels. 61 | 62 | """ 63 | model = joblib.load(os.path.join(ROOT_DIR, 'Random_model.sav')) 64 | 65 | visualizer = yellowbrick.classifier.ROCAUC(model, 66 | classes=['healthy', 67 | 'leaf_blight', 68 | 'ecsa', 69 | 'black rot', 70 | 'powdery mildew']) 71 | visualizer.score(x, y) 72 | ax = visualizer.show() 73 | make_folder('results/visualization') 74 | ax.figure.savefig('results/visualization/auc_roc.png') 75 | 76 | 77 | def tree(): 78 | """ Plot the tree for random forest model. 79 | Save the dot file in output folder. 80 | Convert dot file to png by using the command: 81 | 'dot -Tpng tree.dot -o tree.png' 82 | 83 | """ 84 | model = joblib.load(os.path.join(ROOT_DIR, 'Random_model.sav')) 85 | tree_num = model.estimators_ 86 | make_folder('results/visualization') 87 | for tree_in_forest in tree_num: 88 | export_graphviz(tree_in_forest, out_file='results/visualization/tree.dot', 89 | filled=True, rounded=True, 90 | precision=2) 91 | 92 | 93 | def plot(model): 94 | """ Plot the accuracy and loss curve for the neural networks. 95 | Save file in the output folder. 96 | 97 | Args: 98 | model (str): model for which visualization needs to be plot 99 | 100 | """ 101 | history_custom = Hist() 102 | if model == "cnn_custom": 103 | history_custom = pickle.load(open(os.path.join(ROOT_DIR, 'custom_training_history.pkl'), 104 | 'rb')) 105 | 106 | elif model == "vgg": 107 | history_custom = pickle.load(open(os.path.join(ROOT_DIR, 'vgg16_training_history.pkl'), 108 | 'rb')) 109 | 110 | # Plot training & validation accuracy values 111 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[15, 8]) 112 | ax1.plot(history_custom.history['acc']) 113 | ax1.plot(history_custom.history['val_acc']) 114 | ax1.set_title('Model accuracy') 115 | ax1.set_ylabel('Accuracy') 116 | ax1.set_xlabel('Epoch') 117 | ax1.legend(['Train', 'Validation'], loc='lower right') 118 | 119 | # Plot training & validation loss values 120 | ax2.plot(history_custom.history['loss']) 121 | ax2.plot(history_custom.history['val_loss']) 122 | ax2.set_title('Model loss') 123 | ax2.set_ylabel('Loss') 124 | ax2.set_xlabel('Epoch') 125 | ax2.legend(['Train', 'Validation'], loc='upper right') 126 | 127 | make_folder('results/visualization') 128 | plt.savefig('results/visualization/acc_loss_{}.png'.format(model)) 129 | 130 | 131 | def con_matrix(model, x, y): 132 | """ Plot confusion matrix for the given model. 133 | Save the png in the output folder. 134 | 135 | Args: 136 | model (str): model for which visualization needs to be plot. 137 | 138 | x (numpy array): test images. 139 | 140 | y (numpy array): test labels. 141 | 142 | """ 143 | corr = [] 144 | if model == "random_forest": 145 | loaded_model = joblib.load(os.path.join(ROOT_DIR, 'Random_model.sav')) 146 | classifier_prediction = loaded_model.predict(x) 147 | corr = confusion_matrix(y, classifier_prediction) 148 | 149 | elif model == "svm": 150 | loaded_model = joblib.load(os.path.join(ROOT_DIR, 'SVM_model.sav')) 151 | classifier_prediction = loaded_model.predict(x) 152 | corr = confusion_matrix(y, classifier_prediction) 153 | 154 | elif model == "majority_voting": 155 | classifier_prediction = np.load(os.path.join(ROOT_DIR, 'Ensemble.npy')) 156 | corr = confusion_matrix(y, classifier_prediction) 157 | 158 | elif model == "stacked_prediction": 159 | labeler = LabelEncoder() 160 | labeler.fit(y) 161 | loaded_model = load_model(os.path.join(ROOT_DIR, 'custom_ensemble.h5')) 162 | y_prediction = loaded_model.predict(np.load('data/test/X_test_ensemble.npy')) 163 | prediction = np.argmax(y_prediction, axis=-1) 164 | prediction = labeler.inverse_transform(prediction) 165 | corr = confusion_matrix(y, prediction) 166 | 167 | make_confusion_matrix(corr, 168 | categories=['blackrot', 'ecsa', 169 | 'healthy', 'leafblight', 170 | 'pmildew'], 171 | count=True, 172 | percent=True, 173 | color_bar=False, 174 | xy_ticks=True, 175 | xy_plot_labels=True, 176 | sum_stats=True, 177 | fig_size=(8, 6), 178 | c_map='OrRd', 179 | title='Confusion matrix') 180 | # error correction - cropped heat map 181 | b, t = plt.ylim() # discover the values for bottom and top 182 | b += 0.5 # Add 0.5 to the bottom 183 | t -= 0.5 # Subtract 0.5 from the top 184 | plt.ylim(b, t) # update the ylim(bottom, top) values 185 | 186 | make_folder('results/visualization') 187 | plt.savefig('results/visualization/confusion_matrix_{}.png'.format(model), 188 | bbox_inches='tight') 189 | 190 | 191 | def make_confusion_matrix(cf, categories, 192 | group_names=None, 193 | count=True, 194 | percent=True, 195 | color_bar=True, 196 | xy_ticks=True, 197 | xy_plot_labels=True, 198 | sum_stats=True, 199 | fig_size=None, 200 | c_map='Blues', 201 | title=None): 202 | """ Code to generate text within each box and beautify confusion matrix. 203 | 204 | Args: 205 | cf (numpy array): Confusion matrix. 206 | 207 | categories (numpy array): array of classes. 208 | 209 | group_names (numpy array): classes in the project. 210 | 211 | count (bool): whether to display the count of each class. 212 | 213 | percent (bool): whether to display percentage for each class. 214 | 215 | color_bar (bool): whether to display color bar for the heat map. 216 | 217 | xy_ticks (bool): whether to display xy labels. 218 | 219 | xy_plot_labels (bool): whether to display xy title. 220 | 221 | sum_stats (bool):whether to display overall accuracy. 222 | 223 | fig_size (tuple): size of the plot. 224 | 225 | c_map (str): color scheme to use. 226 | 227 | title (str): Title of the plot. 228 | 229 | """ 230 | blanks = ['' for i in range(cf.size)] 231 | 232 | if group_names and len(group_names) == cf.size: 233 | group_labels = ["{}\n".format(value) for value in group_names] 234 | else: 235 | group_labels = blanks 236 | 237 | if count: 238 | group_counts = ["{0:0.0f}\n".format(value) for value in cf.flatten()] 239 | else: 240 | group_counts = blanks 241 | 242 | if percent: 243 | row_size = np.size(cf, 0) 244 | col_size = np.size(cf, 1) 245 | group_percentages = [] 246 | for i in range(row_size): 247 | for j in range(col_size): 248 | group_percentages.append(cf[i][j] / cf[i].sum()) 249 | group_percentages = ["{0:.2%}".format(value) 250 | for value in group_percentages] 251 | else: 252 | group_percentages = blanks 253 | 254 | box_labels = [f"{v1}{v2}{v3}".strip() 255 | for v1, v2, v3 in zip(group_labels, 256 | group_counts, 257 | group_percentages)] 258 | box_labels = np.asarray(box_labels).reshape(cf.shape[0], cf.shape[1]) 259 | 260 | # CODE TO GENERATE SUMMARY STATISTICS & TEXT FOR SUMMARY STATS 261 | if sum_stats: 262 | # Accuracy is sum of diagonal divided by total observations 263 | accuracy = np.trace(cf) / float(np.sum(cf)) 264 | stats_text = "\n\nAccuracy={0:0.2%}".format(accuracy) 265 | else: 266 | stats_text = "" 267 | 268 | # SET FIGURE PARAMETERS ACCORDING TO OTHER ARGUMENTS 269 | if fig_size is None: 270 | # Get default figure size if not set 271 | fig_size = plt.rcParams.get('figure.figsize') 272 | 273 | if not xy_ticks: 274 | # Do not show categories if xyticks is False 275 | categories = False 276 | 277 | # MAKE THE HEAT MAP VISUALIZATION 278 | plt.figure(figsize=fig_size) 279 | sns.heatmap(cf, annot=box_labels, fmt="", 280 | cmap=c_map, cbar=color_bar, 281 | xticklabels=categories, 282 | yticklabels=categories) 283 | 284 | if xy_plot_labels: 285 | plt.ylabel('True label') 286 | plt.xlabel('Predicted label' + stats_text) 287 | else: 288 | plt.xlabel(stats_text) 289 | 290 | if title: 291 | plt.title(title) 292 | 293 | 294 | def main(): 295 | """ Accept user input. 296 | Depending on the input plot the required graph. 297 | 298 | Usage example: 299 | python visualization.py -t confusion matrix -m svm 300 | 301 | """ 302 | # construct the argument parser and parse the arguments 303 | ap = argparse.ArgumentParser() 304 | ap.add_argument("-t", "--type", type=str, required=True, 305 | choices=("confusion_matrix", "acc_loss", 306 | "tree", "ROC"), 307 | help="type of visualization") 308 | ap.add_argument("-m", "--model", type=str, required=False, 309 | choices=("random_forest", "svm", 310 | "majority_voting", "stacked_prediction", 311 | "cnn_custom", "vgg"), 312 | help="type of visualization") 313 | args = vars(ap.parse_args()) 314 | 315 | X_test = np.load('data/processed/ImageTestHOG_input.npy') 316 | y_test = np.load('data/test/DiseaseTest_input.npy') 317 | print(X_test.shape) 318 | print(y_test.shape) 319 | 320 | visualize(args["type"], args["model"], X_test, y_test) 321 | 322 | 323 | if __name__ == "__main__": 324 | main() 325 | --------------------------------------------------------------------------------