├── .gitignore ├── .prettierignore ├── LICENSE ├── README.md ├── backend ├── classifier │ ├── __init__.py │ ├── admin.py │ ├── apps.py │ ├── migrations │ │ ├── 0001_initial.py │ │ └── __init__.py │ ├── models.py │ ├── serializers.py │ ├── tests.py │ ├── urls.py │ └── views.py ├── config │ ├── __init__.py │ ├── asgi.py │ ├── settings.py │ ├── urls.py │ └── wsgi.py ├── manage.py ├── ml_model │ ├── cnn_mnist.py │ └── cnn_model.h5 ├── requirements.txt └── setup.cfg └── frontend ├── package-lock.json ├── package.json ├── public ├── favicon.ico └── index.html └── src ├── App.js ├── assets └── images │ ├── drawing_editor.png │ └── img1.jpg ├── components ├── CustomAlert.js ├── CustomDivider.js ├── Description.js ├── DescriptionItem.js ├── EditorButtons.js ├── EditorHeader.js ├── Hero.js ├── HeroButtons.js ├── SketchCanvas.js └── Spacer.js ├── index.js ├── layout ├── Footer.js ├── Header.js ├── Layout.js └── Sidebar.js ├── pages ├── DrawingEditor.js └── Home.js └── theme └── theme.js /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .DS_Store 3 | db.sqlite3 4 | build/ 5 | ipynb_checkpoints/ 6 | node_modules/ 7 | venv/ 8 | backend/media/images -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | *.html -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Bob's Programming Academy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Classification MNIST 2 | 3 | This is an image classification app built using **TensorFlow 2**, **Django 3**, **Django REST Framework 3**, **React 17**, and **Material UI 5**. The app uses a machine learning model built in TensorFlow and trained on the MNIST dataset to recognize handwritten digits. 4 | 5 | ![plot](https://github.com/BobsProgrammingAcademy/Image-Classification-MNIST/blob/master/frontend/src/assets/images/drawing_editor.png?raw=true) 6 | 7 | 8 | ## Table of Contents 9 | - [Prerequisites](#prerequisites) 10 | - [Installation](#installation) 11 | - [Running the application](#run-the-application) 12 | - [Customizing the application](#customize-the-application) 13 | - [Copyright and License](#copyright-and-license) 14 | 15 | 16 | ## Prerequisites 17 | 18 | Install the following prerequisites: 19 | 20 | 1. [Python 3.7-3.9](https://www.python.org/downloads/) 21 |
This project uses **TensorFlow v2.7.0**. For TensorFlow to work, you must install a correct version of Python on your machine. More information [here](https://www.tensorflow.org/install/source#tested_build_configurations). 22 | 2. [Node.js](https://nodejs.org/en/) 23 | 3. [Visual Studio Code](https://code.visualstudio.com/download) 24 | 25 | 26 | ## Installation 27 | 28 | ### Backend 29 | 30 | #### 1. Create a virtual environment 31 | 32 | From the **root** directory, run: 33 | 34 | ```bash 35 | cd backend 36 | ``` 37 | ```bash 38 | python -m venv venv 39 | ``` 40 | 41 | #### 2. Activate the virtual environment 42 | 43 | From the **backend** directory, run: 44 | 45 | On macOS: 46 | 47 | ```bash 48 | source venv/bin/activate 49 | ``` 50 | 51 | On Windows: 52 | 53 | ```bash 54 | venv\scripts\activate 55 | ``` 56 | 57 | #### 3. Install required backend dependencies 58 | 59 | From the **backend** directory, run: 60 | 61 | ```bash 62 | pip install -r requirements.txt 63 | ``` 64 | 65 | #### 4. Run migrations 66 | 67 | From the **backend** directory, run: 68 | 69 | ```bash 70 | python manage.py makemigrations 71 | ``` 72 | 73 | ```bash 74 | python manage.py migrate 75 | ``` 76 | 77 | ### Frontend 78 | 79 | #### 1. Install required frontend dependencies 80 | 81 | From the **root** directory, run: 82 | 83 | ```bash 84 | cd frontend 85 | ``` 86 | ```bash 87 | npm install 88 | ``` 89 | 90 | ## Run the application 91 | 92 | To run the application, you need to have both the backend and the frontend up and running. 93 | 94 | ### 1. Run backend 95 | 96 | From the **backend** directory, run: 97 | 98 | ```bash 99 | python manage.py runserver 100 | ``` 101 | 102 | ### 2. Run frontend 103 | 104 | From the **frontend** directory, run: 105 | 106 | ```bash 107 | npm start 108 | ``` 109 | 110 | ## View the application 111 | 112 | Go to http://localhost:3000/ to view the application. 113 | 114 | ## Customize the application 115 | 116 | This section describes how to customize the application. 117 | 118 | ### 1. Changing Colors 119 | 120 | To modify the colors in the application, make changes in the ```frontend/src/theme/theme.js``` file. 121 | 122 | ### 2. Changing Fonts 123 | 124 | To modify the fonts in the application, first, add a new font to the ```frontend/public/index.html``` file, and then make changes in the ```frontend/src/theme/theme.js``` file. 125 | 126 | ### 3. Changing Logo 127 | 128 | To modify the logo in the application, make changes in the ```frontend/src/layout/Header.js``` and ```frontend/src/layout/Sidebar.js``` files. 129 | 130 | ### 4. Changing the Image in the Hero Section 131 | 132 | To modify the image in the Hero section, make changes in the ```frontend/src/components/Hero.js``` and ```frontend/src/layout/Footer.js``` files. 133 | 134 | ### 5. Changing the Text in the Hero Section 135 | 136 | To modify the text in the Hero section, make changes in the ```frontend/src/components/Hero.js``` file. 137 | 138 | ### 6. Changing Buttons in the Hero Section 139 | 140 | To modify the two buttons in the Hero section, make changes in the ```frontend/src/components/HeroButtons.js``` file. 141 | 142 | ### 7. Changing the App Description 143 | 144 | To modify the app's description on the home page, make changes in the ```frontend/src/components/Description.js``` file. 145 | 146 | ## Copyright and License 147 | 148 | Copyright © 2022 Bob's Programming Academy. Code released under the MIT license. 149 | -------------------------------------------------------------------------------- /backend/classifier/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/backend/classifier/__init__.py -------------------------------------------------------------------------------- /backend/classifier/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | from .models import Classifier 3 | 4 | admin.site.register(Classifier) 5 | -------------------------------------------------------------------------------- /backend/classifier/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class ClassifierConfig(AppConfig): 5 | default_auto_field = 'django.db.models.BigAutoField' 6 | name = 'classifier' 7 | -------------------------------------------------------------------------------- /backend/classifier/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 3.2.11 on 2022-01-28 05:01 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | initial = True 9 | 10 | dependencies = [ 11 | ] 12 | 13 | operations = [ 14 | migrations.CreateModel( 15 | name='Classifier', 16 | fields=[ 17 | ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 18 | ('image', models.ImageField(upload_to='images')), 19 | ('result', models.CharField(blank=True, max_length=3)), 20 | ('date_created', models.DateTimeField(auto_now_add=True)), 21 | ('date_updated', models.DateTimeField(auto_now=True)), 22 | ], 23 | ), 24 | ] 25 | -------------------------------------------------------------------------------- /backend/classifier/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/backend/classifier/migrations/__init__.py -------------------------------------------------------------------------------- /backend/classifier/models.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import tensorflow as tf 5 | from django.conf import settings 6 | from django.db import models 7 | from PIL import Image 8 | 9 | 10 | class Classifier(models.Model): 11 | image = models.ImageField(upload_to='images') 12 | result = models.CharField(max_length=3, blank=True) 13 | date_created = models.DateTimeField(auto_now_add=True) 14 | date_updated = models.DateTimeField(auto_now=True) 15 | 16 | def save(self, *args, **kwargs): 17 | img = Image.open(self.image) 18 | img_array = tf.keras.preprocessing.image.img_to_array(img) 19 | 20 | new_img = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY) 21 | 22 | dimensions = (28, 28) 23 | 24 | # Interpolation - a method of constructing new data points within the range 25 | # of a discrete set of known data points. 26 | resized_image = cv2.resize(new_img, dimensions, interpolation=cv2.INTER_AREA) 27 | 28 | ready_image = np.expand_dims(resized_image, axis=2) 29 | ready_image = np.expand_dims(ready_image, axis=0) 30 | 31 | try: 32 | file_model = os.path.join(settings.BASE_DIR, 'ml_model/cnn_model.h5') 33 | graph = tf.compat.v1.get_default_graph() 34 | 35 | with graph.as_default(): 36 | model = tf.keras.models.load_model(file_model) 37 | prediction = np.argmax(model.predict(ready_image)) 38 | self.result = str(prediction) 39 | print(f'Classified as {prediction}') 40 | except Exception: 41 | print('Failed to classify') 42 | self.result = 'F' 43 | 44 | return super().save(*args, **kwargs) 45 | -------------------------------------------------------------------------------- /backend/classifier/serializers.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import uuid 3 | from django.core.files.base import ContentFile 4 | from rest_framework import serializers 5 | 6 | from .models import Classifier 7 | 8 | 9 | class Base64ImageField(serializers.ImageField): 10 | def to_internal_value(self, data): 11 | _format, str_img = data.split(";base64,") 12 | decoded_file = base64.b64decode(str_img) 13 | file_name = f'{str(uuid.uuid4())[:10]}.png' 14 | data = ContentFile(decoded_file, name=file_name) 15 | return super().to_internal_value(data) 16 | 17 | 18 | class ClassifierSerializer(serializers.ModelSerializer): 19 | image = Base64ImageField() 20 | 21 | class Meta: 22 | model = Classifier 23 | fields = '__all__' 24 | -------------------------------------------------------------------------------- /backend/classifier/tests.py: -------------------------------------------------------------------------------- 1 | # from django.test import TestCase 2 | 3 | # Create your tests here. 4 | -------------------------------------------------------------------------------- /backend/classifier/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import path, include 2 | from rest_framework import routers 3 | 4 | from .views import ClassifierViewSet 5 | 6 | router = routers.DefaultRouter() 7 | router.register(r'classifier', ClassifierViewSet) 8 | 9 | urlpatterns = [ 10 | path('', include(router.urls)), 11 | ] 12 | -------------------------------------------------------------------------------- /backend/classifier/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import viewsets 2 | 3 | from .serializers import ClassifierSerializer 4 | from .models import Classifier 5 | 6 | 7 | class ClassifierViewSet(viewsets.ModelViewSet): 8 | queryset = Classifier.objects.all() 9 | serializer_class = ClassifierSerializer 10 | -------------------------------------------------------------------------------- /backend/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/backend/config/__init__.py -------------------------------------------------------------------------------- /backend/config/asgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ASGI config for config project. 3 | 4 | It exposes the ASGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/3.2/howto/deployment/asgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.asgi import get_asgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') 15 | 16 | application = get_asgi_application() 17 | -------------------------------------------------------------------------------- /backend/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | # Build paths inside the project like this: BASE_DIR / 'subdir'. 5 | BASE_DIR = Path(__file__).resolve().parent.parent 6 | 7 | 8 | # Quick-start development settings - unsuitable for production 9 | # See https://docs.djangoproject.com/en/3.2/howto/deployment/checklist/ 10 | 11 | # SECURITY WARNING: keep the secret key used in production secret! 12 | SECRET_KEY = 'django-insecure-qq!ws+!t@^1n9yq#07o$goah_-_#4#@m#k&-s)ta8%hd*$3sr1' 13 | 14 | # SECURITY WARNING: don't run with debug turned on in production! 15 | DEBUG = True 16 | 17 | ALLOWED_HOSTS = [] 18 | 19 | 20 | # Application definition 21 | 22 | INSTALLED_APPS = [ 23 | 'django.contrib.admin', 24 | 'django.contrib.auth', 25 | 'django.contrib.contenttypes', 26 | 'django.contrib.sessions', 27 | 'django.contrib.messages', 28 | 'django.contrib.staticfiles', 29 | 30 | # 3rd party 31 | 'rest_framework', 32 | 'corsheaders', 33 | 34 | # Local 35 | 'classifier', 36 | ] 37 | 38 | MIDDLEWARE = [ 39 | 'corsheaders.middleware.CorsMiddleware', 40 | 'django.middleware.security.SecurityMiddleware', 41 | 'django.contrib.sessions.middleware.SessionMiddleware', 42 | 'django.middleware.common.CommonMiddleware', 43 | 'django.middleware.csrf.CsrfViewMiddleware', 44 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 45 | 'django.contrib.messages.middleware.MessageMiddleware', 46 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 47 | ] 48 | 49 | ROOT_URLCONF = 'config.urls' 50 | 51 | TEMPLATES = [ 52 | { 53 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 54 | 'DIRS': [], 55 | 'APP_DIRS': True, 56 | 'OPTIONS': { 57 | 'context_processors': [ 58 | 'django.template.context_processors.debug', 59 | 'django.template.context_processors.request', 60 | 'django.contrib.auth.context_processors.auth', 61 | 'django.contrib.messages.context_processors.messages', 62 | ], 63 | }, 64 | }, 65 | ] 66 | 67 | WSGI_APPLICATION = 'config.wsgi.application' 68 | 69 | 70 | # Database 71 | # https://docs.djangoproject.com/en/3.2/ref/settings/#databases 72 | 73 | DATABASES = { 74 | 'default': { 75 | 'ENGINE': 'django.db.backends.sqlite3', 76 | 'NAME': BASE_DIR / 'db.sqlite3', 77 | } 78 | } 79 | 80 | 81 | # Password validation 82 | # https://docs.djangoproject.com/en/3.2/ref/settings/#auth-password-validators 83 | 84 | AUTH_PASSWORD_VALIDATORS = [ 85 | { 86 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 87 | }, 88 | { 89 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 90 | }, 91 | { 92 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 93 | }, 94 | { 95 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 96 | }, 97 | ] 98 | 99 | 100 | # Internationalization 101 | # https://docs.djangoproject.com/en/3.2/topics/i18n/ 102 | 103 | LANGUAGE_CODE = 'en-us' 104 | 105 | TIME_ZONE = 'UTC' 106 | 107 | USE_I18N = True 108 | 109 | USE_L10N = True 110 | 111 | USE_TZ = True 112 | 113 | 114 | # Static files (CSS, JavaScript, Images) 115 | # https://docs.djangoproject.com/en/3.2/howto/static-files/ 116 | 117 | STATIC_URL = '/static/' 118 | STATICFILES_DIRS = [os.path.join(BASE_DIR, 'build/static')] 119 | STATIC_ROOT = os.path.join(BASE_DIR, 'static') 120 | 121 | MEDIA_URL = '/media/' 122 | MEDIA_ROOT = os.path.join(BASE_DIR, 'media') 123 | 124 | 125 | # Default primary key field type 126 | # https://docs.djangoproject.com/en/3.2/ref/settings/#default-auto-field 127 | 128 | DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' 129 | 130 | CORS_ORIGIN_ALLOW_ALL = True 131 | 132 | FILE_UPLOAD_PERMISSIONS = 0o640 133 | -------------------------------------------------------------------------------- /backend/config/urls.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | from django.conf.urls.static import static 3 | from django.contrib import admin 4 | from django.urls import path, include 5 | 6 | urlpatterns = [ 7 | path('admin/', admin.site.urls), 8 | path('api/', include('classifier.urls')) 9 | ] 10 | 11 | urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) 12 | urlpatterns += static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) 13 | -------------------------------------------------------------------------------- /backend/config/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for config project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/3.2/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /backend/manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Django's command-line utility for administrative tasks.""" 3 | import os 4 | import sys 5 | 6 | 7 | def main(): 8 | """Run administrative tasks.""" 9 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') 10 | try: 11 | from django.core.management import execute_from_command_line 12 | except ImportError as exc: 13 | raise ImportError( 14 | "Couldn't import Django. Are you sure it's installed and " 15 | "available on your PYTHONPATH environment variable? Did you " 16 | "forget to activate a virtual environment?" 17 | ) from exc 18 | execute_from_command_line(sys.argv) 19 | 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /backend/ml_model/cnn_mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def load_dataset(): 6 | # load the dataset 7 | (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data() 8 | 9 | # reshape the dataset to have a single channel 10 | X_train = np.expand_dims(X_train, axis=3) 11 | X_test = np.expand_dims(X_test, axis=3) 12 | 13 | # one-hot encode target values 14 | y_train = tf.keras.utils.to_categorical(y_train, 10) 15 | y_test = tf.keras.utils.to_categorical(y_test, 10) 16 | 17 | return X_train, y_train, X_test, y_test 18 | 19 | 20 | def prepare_dataset(train, test): 21 | # convert from integers to floats 22 | train_norm = train.astype('float32') 23 | test_norm = test.astype('float32') 24 | 25 | # normalize to range 0-1 26 | train_norm = train_norm / 255.0 27 | test_norm = test_norm / 255.0 28 | 29 | # return normalized images 30 | return train_norm, test_norm 31 | 32 | 33 | # define the CNN model 34 | def define_model(): 35 | model = tf.keras.models.Sequential() 36 | model.add(tf.keras.layers.Conv2D( 37 | filters=32, 38 | kernel_size=(5, 5), 39 | strides=1, 40 | activation='relu', 41 | kernel_regularizer=tf.keras.regularizers.l2(0.0005), 42 | input_shape=(28, 28, 1), 43 | name='convolution_1' 44 | )) 45 | model.add(tf.keras.layers.Conv2D( 46 | filters=32, 47 | kernel_size=(5, 5), 48 | strides=1, 49 | activation='relu', 50 | use_bias=False, 51 | name='convolution_2' 52 | )) 53 | model.add(tf.keras.layers.BatchNormalization(name='batchnorm_1')) 54 | model.add(tf.keras.layers.Activation('relu')) 55 | model.add(tf.keras.layers.MaxPooling2D( 56 | pool_size=(2, 2), 57 | strides=2, 58 | name='max_pool_1' 59 | )) 60 | model.add(tf.keras.layers.Dropout(0.25, name='dropout_1')) 61 | model.add(tf.keras.layers.Conv2D( 62 | filters=64, 63 | kernel_size=(3, 3), 64 | strides=1, 65 | activation='relu', 66 | kernel_regularizer=tf.keras.regularizers.l2(0.0005), 67 | name='convolution_3' 68 | )) 69 | model.add(tf.keras.layers.Conv2D( 70 | filters=64, 71 | kernel_size=(3, 3), 72 | strides=1, 73 | activation='relu', 74 | use_bias=False, 75 | name='convolution_4' 76 | )) 77 | model.add(tf.keras.layers.BatchNormalization(name='batchnorm_2')) 78 | model.add(tf.keras.layers.Activation('relu')) 79 | model.add(tf.keras.layers.MaxPooling2D( 80 | pool_size=(2, 2), 81 | strides=2, 82 | name='max_pool_2' 83 | )) 84 | model.add(tf.keras.layers.Dropout(0.25, name='dropout_2')) 85 | model.add(tf.keras.layers.Flatten(name='flatten')) 86 | model.add(tf.keras.layers.Dense( 87 | units=256, 88 | activation='relu', 89 | use_bias=False, 90 | name='fully_connected_1' 91 | )) 92 | model.add(tf.keras.layers.BatchNormalization(name='batchnorm_3')) 93 | model.add(tf.keras.layers.Activation('relu')) 94 | model.add(tf.keras.layers.Dense( 95 | units=128, 96 | activation='relu', 97 | use_bias=False, 98 | name='fully_connected_2' 99 | )) 100 | model.add(tf.keras.layers.BatchNormalization(name='batchnorm_4')) 101 | model.add(tf.keras.layers.Activation('relu')) 102 | model.add(tf.keras.layers.Dense( 103 | units=84, 104 | activation='relu', 105 | use_bias=False, 106 | name='fully_connected_3' 107 | )) 108 | model.add(tf.keras.layers.BatchNormalization(name='batchnorm_5')) 109 | model.add(tf.keras.layers.Activation('relu')) 110 | model.add(tf.keras.layers.Dropout(0.25, name='dropout_3')) 111 | model.add(tf.keras.layers.Dense( 112 | units=10, 113 | activation='softmax', 114 | name='output' 115 | )) 116 | 117 | # compile the model 118 | model.compile( 119 | optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), 120 | loss='categorical_crossentropy', 121 | metrics=['accuracy'] 122 | ) 123 | model.summary() 124 | 125 | return model 126 | 127 | 128 | # train, evaluate and save the model 129 | def train_and_save_model(X_train, y_train, X_test, y_test): 130 | # define the model 131 | model = define_model() 132 | 133 | # train the model 134 | model.fit( 135 | X_train, 136 | y_train, 137 | epochs=30, 138 | batch_size=64, 139 | validation_split=0.2, 140 | verbose=1, 141 | shuffle=True, 142 | callbacks=[tf.keras.callbacks.ReduceLROnPlateau( 143 | monitor='val_loss', 144 | factor=0.2, 145 | patience=2 146 | )] 147 | ) 148 | 149 | # evaluate the model on test data 150 | score = model.evaluate(X_test, y_test, verbose=0) 151 | print('Error on test data: ', score[0]) # printed: Error on test data: 0.022791262716054916 152 | print('Accuracy on test data: {0:.2f}%'.format(score[1] * 100)) # printed: Accuracy on test data: 99.45% 153 | 154 | # save the model 155 | model.save('ml_model/cnn_model.h5') 156 | print('Saved model to disk') 157 | 158 | 159 | # run the test harness for evaluating the model 160 | def run_test_harness(): 161 | # load the dataset 162 | X_train, y_train, X_test, y_test = load_dataset() 163 | 164 | # prepare the dataset 165 | X_train, X_test = prepare_dataset(X_train, X_test) 166 | 167 | # train, evaluate and save the model 168 | train_and_save_model(X_train, y_train, X_test, y_test) 169 | 170 | 171 | # entry point, run the test harness 172 | run_test_harness() 173 | -------------------------------------------------------------------------------- /backend/ml_model/cnn_model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/backend/ml_model/cnn_model.h5 -------------------------------------------------------------------------------- /backend/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/backend/requirements.txt -------------------------------------------------------------------------------- /backend/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .git,*migrations*,*venv* 3 | max-line-length = 119 4 | indent-size = 2 -------------------------------------------------------------------------------- /frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "image-classification-mnist", 3 | "version": "1.0.0", 4 | "description": "Image Classification MNIST", 5 | "scripts": { 6 | "start": "react-scripts start", 7 | "build": "react-scripts build", 8 | "test": "react-scripts test", 9 | "eject": "react-scripts eject" 10 | }, 11 | "author": "MG", 12 | "license": "MIT", 13 | "dependencies": { 14 | "@emotion/react": "^11.7.1", 15 | "@emotion/styled": "^11.6.0", 16 | "@fortawesome/fontawesome-svg-core": "^1.2.36", 17 | "@fortawesome/free-regular-svg-icons": "^5.15.4", 18 | "@fortawesome/free-solid-svg-icons": "^5.15.4", 19 | "@fortawesome/react-fontawesome": "^0.1.16", 20 | "@mui/icons-material": "^5.3.1", 21 | "@mui/material": "^5.3.1", 22 | "aos": "^2.3.4", 23 | "axios": "^0.25.0", 24 | "file-saver": "^2.0.5", 25 | "react": "^17.0.2", 26 | "react-dom": "^17.0.2", 27 | "react-helmet-async": "^1.2.2", 28 | "react-lazy-load-image-component": "^1.5.1", 29 | "react-router-dom": "^6.2.1", 30 | "react-scripts": "^5.0.1", 31 | "react-sketch-canvas": "^6.1.0" 32 | }, 33 | "browserslist": { 34 | "production": [ 35 | ">0.2%", 36 | "not dead", 37 | "not op_mini all" 38 | ], 39 | "development": [ 40 | "last 1 chrome version", 41 | "last 1 firefox version", 42 | "last 1 safari version" 43 | ] 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /frontend/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/frontend/public/favicon.ico -------------------------------------------------------------------------------- /frontend/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Image Classification MNIST 7 | 8 | 9 | 10 | 14 | 15 | 16 | 20 | 21 | 22 |
23 | 24 | -------------------------------------------------------------------------------- /frontend/src/App.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { BrowserRouter, Routes, Route } from 'react-router-dom'; 3 | import { HelmetProvider, Helmet } from 'react-helmet-async'; 4 | import 'aos/dist/aos.css'; 5 | 6 | import Layout from './layout/Layout'; 7 | import Home from './pages/Home'; 8 | import DrawingEditor from './pages/DrawingEditor'; 9 | 10 | const App = () => { 11 | return ( 12 | 13 | 17 | 18 | 19 | 20 | } /> 21 | } /> 22 | 23 | 24 | 25 | 26 | ); 27 | }; 28 | 29 | export default App; 30 | -------------------------------------------------------------------------------- /frontend/src/assets/images/drawing_editor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/frontend/src/assets/images/drawing_editor.png -------------------------------------------------------------------------------- /frontend/src/assets/images/img1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/frontend/src/assets/images/img1.jpg -------------------------------------------------------------------------------- /frontend/src/components/CustomAlert.js: -------------------------------------------------------------------------------- 1 | // Material UI 2 | import Alert from '@mui/material/Alert'; 3 | import AlertTitle from '@mui/material/AlertTitle'; 4 | import Box from '@mui/material/Box'; 5 | 6 | const CustomAlert = ({ variant, severity, title, children }) => { 7 | return ( 8 | 9 | 10 | {title} 11 | {children} 12 | 13 | 14 | ); 15 | }; 16 | 17 | export default CustomAlert; 18 | -------------------------------------------------------------------------------- /frontend/src/components/CustomDivider.js: -------------------------------------------------------------------------------- 1 | import Box from '@mui/material/Box'; 2 | import Divider from '@mui/material/Divider'; 3 | 4 | const CustomDivider = () => { 5 | return ( 6 | 12 | 13 | 14 | ); 15 | }; 16 | 17 | export default CustomDivider; 18 | -------------------------------------------------------------------------------- /frontend/src/components/Description.js: -------------------------------------------------------------------------------- 1 | import Box from '@mui/material/Box'; 2 | import Container from '@mui/material/Container'; 3 | import Grid from '@mui/material/Grid'; 4 | import Typography from '@mui/material/Typography'; 5 | import { useTheme } from '@mui/material'; 6 | 7 | // Font Awesome Icons 8 | import { library } from '@fortawesome/fontawesome-svg-core'; 9 | import { faEdit as EditIcon } from '@fortawesome/free-regular-svg-icons'; 10 | import { faDownload as DownloadIcon } from '@fortawesome/free-solid-svg-icons'; 11 | import { faShareSquare as ShareSquareIcon } from '@fortawesome/free-regular-svg-icons'; 12 | import { faLaptopCode as LaptopCodeIcon } from '@fortawesome/free-solid-svg-icons'; 13 | library.add(EditIcon, DownloadIcon, ShareSquareIcon, LaptopCodeIcon); 14 | 15 | import DescriptionItem from './DescriptionItem'; 16 | 17 | const Description = () => { 18 | const theme = useTheme(); 19 | 20 | return ( 21 | 28 | 33 | 42 | 49 | How Does It Work? 50 | 51 | 60 | A step-by-step guide on how to use the app 61 | 62 | 63 | 69 | 75 | 81 | 87 | 88 | 89 | 90 | 91 | ); 92 | }; 93 | 94 | export default Description; 95 | -------------------------------------------------------------------------------- /frontend/src/components/DescriptionItem.js: -------------------------------------------------------------------------------- 1 | import Box from '@mui/material/Box'; 2 | import Grid from '@mui/material/Grid'; 3 | import ListItem from '@mui/material/ListItem'; 4 | import ListItemAvatar from '@mui/material/ListItemAvatar'; 5 | import ListItemText from '@mui/material/ListItemText'; 6 | 7 | // Font Awesome Icons 8 | import { library } from '@fortawesome/fontawesome-svg-core'; 9 | import { faEdit as EditIcon } from '@fortawesome/free-regular-svg-icons'; 10 | import { faDownload as DownloadIcon } from '@fortawesome/free-solid-svg-icons'; 11 | import { faShareSquare as ShareSquareIcon } from '@fortawesome/free-regular-svg-icons'; 12 | import { FontAwesomeIcon } from '@fortawesome/react-fontawesome'; 13 | library.add(EditIcon, DownloadIcon, ShareSquareIcon); 14 | 15 | const DescriptionItem = ({ color, icon, title, subtitle }) => { 16 | return ( 17 | <> 18 | 19 | 25 | 26 | 27 | 28 | 29 | 30 | 43 | 44 | 45 | 46 | ); 47 | }; 48 | 49 | export default DescriptionItem; 50 | -------------------------------------------------------------------------------- /frontend/src/components/EditorButtons.js: -------------------------------------------------------------------------------- 1 | import Button from '@mui/material/Button'; 2 | import Box from '@mui/material/Box'; 3 | import { useTheme, useMediaQuery } from '@mui/material'; 4 | 5 | // Material Icons 6 | import ResetIcon from '@mui/icons-material/RotateLeft'; 7 | import SendIcon from '@mui/icons-material/SendToMobile'; 8 | import DownloadIcon from '@mui/icons-material/Download'; 9 | 10 | const EditorButtons = ({ submitOnClick, resetOnClick, downloadOnClick }) => { 11 | const theme = useTheme(); 12 | const isMd = useMediaQuery(theme.breakpoints.up('md'), { 13 | defaultMatches: true, 14 | }); 15 | 16 | return ( 17 | <> 18 | 25 | 47 | 52 | 74 | 75 | 80 | 102 | 103 | 104 | 105 | ); 106 | }; 107 | 108 | export default EditorButtons; 109 | -------------------------------------------------------------------------------- /frontend/src/components/EditorHeader.js: -------------------------------------------------------------------------------- 1 | import Box from '@mui/material/Box'; 2 | import Typography from '@mui/material/Typography'; 3 | import { useTheme, useMediaQuery } from '@mui/material'; 4 | import { green } from '@mui/material/colors'; 5 | 6 | const EditorHeader = () => { 7 | const theme = useTheme(); 8 | const isMd = useMediaQuery(theme.breakpoints.up('md'), { 9 | defaultMatches: true, 10 | }); 11 | 12 | return ( 13 | <> 14 | 15 | Drawing Editor 16 | 17 | 18 | 19 | Draw a digit{' '} 20 | between 0 and 9 in the 21 | window below 22 | 23 | 24 | 25 | ); 26 | }; 27 | 28 | export default EditorHeader; 29 | -------------------------------------------------------------------------------- /frontend/src/components/Hero.js: -------------------------------------------------------------------------------- 1 | import { useEffect } from 'react'; 2 | import { LazyLoadImage } from 'react-lazy-load-image-component'; 3 | import AOS from 'aos'; 4 | import Box from '@mui/material/Box'; 5 | import Grid from '@mui/material/Grid'; 6 | import Typography from '@mui/material/Typography'; 7 | import { useTheme, useMediaQuery } from '@mui/material'; 8 | 9 | import bgImage from '../assets/images/img1.jpg'; 10 | import HeroButtons from '../components/HeroButtons'; 11 | 12 | const Hero = () => { 13 | const theme = useTheme(); 14 | const isMd = useMediaQuery(theme.breakpoints.up('md'), { 15 | defaultMatches: true, 16 | }); 17 | 18 | useEffect(() => { 19 | AOS.init({ 20 | once: true, 21 | delay: 50, 22 | duration: 600, 23 | easing: 'ease-in-out', 24 | }); 25 | }, []); 26 | 27 | return ( 28 | 35 | 36 | 37 | 38 | 39 | 45 | Draw a Digit 46 | 47 | 48 | 49 | 57 | The App Will Tell You What Digit You Have Drawn 58 | 59 | 60 | 61 | 62 | 63 | 71 | 83 | 94 | 95 | 96 | 97 | 98 | ); 99 | }; 100 | 101 | export default Hero; 102 | -------------------------------------------------------------------------------- /frontend/src/components/HeroButtons.js: -------------------------------------------------------------------------------- 1 | import Box from '@mui/material/Box'; 2 | import Button from '@mui/material/Button'; 3 | import { useTheme, useMediaQuery } from '@mui/material'; 4 | 5 | // Material Icons 6 | import InfoIcon from '@mui/icons-material/HelpOutline'; 7 | import PlayIcon from '@mui/icons-material/PlayCircleOutlineOutlined'; 8 | 9 | const HeroButtons = () => { 10 | const theme = useTheme(); 11 | const isMd = useMediaQuery(theme.breakpoints.up('md'), { 12 | defaultMatches: true, 13 | }); 14 | 15 | return ( 16 | <> 17 | 24 | 47 | 52 | 75 | 76 | 77 | 78 | ); 79 | }; 80 | 81 | export default HeroButtons; 82 | -------------------------------------------------------------------------------- /frontend/src/components/SketchCanvas.js: -------------------------------------------------------------------------------- 1 | import { ReactSketchCanvas } from 'react-sketch-canvas'; 2 | import Box from '@mui/material/Box'; 3 | import { useTheme } from '@mui/material'; 4 | 5 | const SketchCanvas = ({ inputRef }) => { 6 | const theme = useTheme(); 7 | 8 | return ( 9 | 16 | 30 | 31 | ); 32 | }; 33 | 34 | export default SketchCanvas; 35 | -------------------------------------------------------------------------------- /frontend/src/components/Spacer.js: -------------------------------------------------------------------------------- 1 | import Box from '@mui/material/Box'; 2 | import { useTheme } from '@mui/material'; 3 | 4 | const Spacer = ({ sx = [] }) => { 5 | const theme = useTheme(); 6 | 7 | return ( 8 | 16 | ); 17 | }; 18 | 19 | export default Spacer; 20 | -------------------------------------------------------------------------------- /frontend/src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom'; 3 | import { ThemeProvider } from '@mui/material/styles'; 4 | import { CssBaseline } from '@mui/material'; 5 | import 'aos/dist/aos'; 6 | 7 | import theme from './theme/theme'; 8 | import App from './App'; 9 | 10 | ReactDOM.render( 11 | 12 | 13 | 14 | , 15 | document.getElementById('root') 16 | ); 17 | -------------------------------------------------------------------------------- /frontend/src/layout/Footer.js: -------------------------------------------------------------------------------- 1 | import Box from '@mui/material/Box'; 2 | import Divider from '@mui/material/Divider'; 3 | import Grid from '@mui/material/Grid'; 4 | import Hidden from '@mui/material/Hidden'; 5 | import Link from '@mui/material/Link'; 6 | import List from '@mui/material/List'; 7 | import ListItemButton from '@mui/material/ListItemButton'; 8 | import ListItemText from '@mui/material/ListItemText'; 9 | import Typography from '@mui/material/Typography'; 10 | import { useTheme } from '@mui/material'; 11 | 12 | const Footer = () => { 13 | const theme = useTheme(); 14 | 15 | return ( 16 | <> 17 | 27 | 28 | 33 | 34 | 35 | 36 | 43 | 44 | 50 | Privacy Policy 51 | 52 | } 53 | /> 54 | 55 | 56 | 62 | Terms of Use 63 | 64 | } 65 | /> 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 79 | Copyright © {new Date().getFullYear()} Bob's 80 | Programming Academy. 81 | 82 | } 83 | /> 84 | 85 | 86 | 87 | 88 | 89 | 90 | 96 | Photo by{' '} 97 | 104 | Karim Manjra 105 | {' '} 106 | on{' '} 107 | 114 | Unsplash 115 | 116 | . 117 | 118 | } 119 | /> 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | ); 128 | }; 129 | 130 | export default Footer; 131 | -------------------------------------------------------------------------------- /frontend/src/layout/Header.js: -------------------------------------------------------------------------------- 1 | import { Link } from 'react-router-dom'; 2 | import AppBar from '@mui/material/AppBar'; 3 | import Avatar from '@mui/material/Avatar'; 4 | import Button from '@mui/material/Button'; 5 | import Box from '@mui/material/Box'; 6 | import Divider from '@mui/material/Divider'; 7 | import IconButton from '@mui/material/IconButton'; 8 | import Toolbar from '@mui/material/Toolbar'; 9 | import Typography from '@mui/material/Typography'; 10 | import { useTheme } from '@mui/material'; 11 | import { green } from '@mui/material/colors'; 12 | 13 | // Material Icons 14 | import MenuIcon from '@mui/icons-material/Menu'; 15 | import HomeIcon from '@mui/icons-material/Home'; 16 | 17 | // Font Awesome Icons 18 | import { library } from '@fortawesome/fontawesome-svg-core'; 19 | import { faPencilAlt } from '@fortawesome/free-solid-svg-icons'; 20 | import { faEdit } from '@fortawesome/free-solid-svg-icons'; 21 | import { FontAwesomeIcon } from '@fortawesome/react-fontawesome'; 22 | library.add(faPencilAlt, faEdit); 23 | 24 | const Header = ({ onSidebarMobileOpen }) => { 25 | const theme = useTheme(); 26 | 27 | return ( 28 | <> 29 | 36 | 37 | 42 | 43 | 44 | 45 | 46 | 47 | 56 | 64 | 65 | 75 | Image Classification MNIST 76 | 77 | 78 | 79 | 80 | 81 | 99 | 118 | 119 | 120 | 121 | 122 | ); 123 | }; 124 | 125 | export default Header; 126 | -------------------------------------------------------------------------------- /frontend/src/layout/Layout.js: -------------------------------------------------------------------------------- 1 | import Box from '@mui/material/Box'; 2 | import { useTheme } from '@mui/material'; 3 | 4 | import Header from './Header'; 5 | import Footer from './Footer'; 6 | import Sidebar from './Sidebar'; 7 | 8 | const Layout = ({ children }) => { 9 | const theme = useTheme(); 10 | const [isSidebarMobileOpen, setIsSidebarMobileOpen] = React.useState(false); 11 | 12 | return ( 13 | 19 |
setIsSidebarMobileOpen(true)} /> 20 | setIsSidebarMobileOpen(false)} 22 | openMobile={isSidebarMobileOpen} 23 | /> 24 |
{children}
25 |