├── .gitignore
├── .prettierignore
├── LICENSE
├── README.md
├── backend
├── classifier
│ ├── __init__.py
│ ├── admin.py
│ ├── apps.py
│ ├── migrations
│ │ ├── 0001_initial.py
│ │ └── __init__.py
│ ├── models.py
│ ├── serializers.py
│ ├── tests.py
│ ├── urls.py
│ └── views.py
├── config
│ ├── __init__.py
│ ├── asgi.py
│ ├── settings.py
│ ├── urls.py
│ └── wsgi.py
├── manage.py
├── ml_model
│ ├── cnn_mnist.py
│ └── cnn_model.h5
├── requirements.txt
└── setup.cfg
└── frontend
├── package-lock.json
├── package.json
├── public
├── favicon.ico
└── index.html
└── src
├── App.js
├── assets
└── images
│ ├── drawing_editor.png
│ └── img1.jpg
├── components
├── CustomAlert.js
├── CustomDivider.js
├── Description.js
├── DescriptionItem.js
├── EditorButtons.js
├── EditorHeader.js
├── Hero.js
├── HeroButtons.js
├── SketchCanvas.js
└── Spacer.js
├── index.js
├── layout
├── Footer.js
├── Header.js
├── Layout.js
└── Sidebar.js
├── pages
├── DrawingEditor.js
└── Home.js
└── theme
└── theme.js
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .DS_Store
3 | db.sqlite3
4 | build/
5 | ipynb_checkpoints/
6 | node_modules/
7 | venv/
8 | backend/media/images
--------------------------------------------------------------------------------
/.prettierignore:
--------------------------------------------------------------------------------
1 | *.html
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Bob's Programming Academy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Image Classification MNIST
2 |
3 | This is an image classification app built using **TensorFlow 2**, **Django 3**, **Django REST Framework 3**, **React 17**, and **Material UI 5**. The app uses a machine learning model built in TensorFlow and trained on the MNIST dataset to recognize handwritten digits.
4 |
5 | 
6 |
7 |
8 | ## Table of Contents
9 | - [Prerequisites](#prerequisites)
10 | - [Installation](#installation)
11 | - [Running the application](#run-the-application)
12 | - [Customizing the application](#customize-the-application)
13 | - [Copyright and License](#copyright-and-license)
14 |
15 |
16 | ## Prerequisites
17 |
18 | Install the following prerequisites:
19 |
20 | 1. [Python 3.7-3.9](https://www.python.org/downloads/)
21 |
This project uses **TensorFlow v2.7.0**. For TensorFlow to work, you must install a correct version of Python on your machine. More information [here](https://www.tensorflow.org/install/source#tested_build_configurations).
22 | 2. [Node.js](https://nodejs.org/en/)
23 | 3. [Visual Studio Code](https://code.visualstudio.com/download)
24 |
25 |
26 | ## Installation
27 |
28 | ### Backend
29 |
30 | #### 1. Create a virtual environment
31 |
32 | From the **root** directory, run:
33 |
34 | ```bash
35 | cd backend
36 | ```
37 | ```bash
38 | python -m venv venv
39 | ```
40 |
41 | #### 2. Activate the virtual environment
42 |
43 | From the **backend** directory, run:
44 |
45 | On macOS:
46 |
47 | ```bash
48 | source venv/bin/activate
49 | ```
50 |
51 | On Windows:
52 |
53 | ```bash
54 | venv\scripts\activate
55 | ```
56 |
57 | #### 3. Install required backend dependencies
58 |
59 | From the **backend** directory, run:
60 |
61 | ```bash
62 | pip install -r requirements.txt
63 | ```
64 |
65 | #### 4. Run migrations
66 |
67 | From the **backend** directory, run:
68 |
69 | ```bash
70 | python manage.py makemigrations
71 | ```
72 |
73 | ```bash
74 | python manage.py migrate
75 | ```
76 |
77 | ### Frontend
78 |
79 | #### 1. Install required frontend dependencies
80 |
81 | From the **root** directory, run:
82 |
83 | ```bash
84 | cd frontend
85 | ```
86 | ```bash
87 | npm install
88 | ```
89 |
90 | ## Run the application
91 |
92 | To run the application, you need to have both the backend and the frontend up and running.
93 |
94 | ### 1. Run backend
95 |
96 | From the **backend** directory, run:
97 |
98 | ```bash
99 | python manage.py runserver
100 | ```
101 |
102 | ### 2. Run frontend
103 |
104 | From the **frontend** directory, run:
105 |
106 | ```bash
107 | npm start
108 | ```
109 |
110 | ## View the application
111 |
112 | Go to http://localhost:3000/ to view the application.
113 |
114 | ## Customize the application
115 |
116 | This section describes how to customize the application.
117 |
118 | ### 1. Changing Colors
119 |
120 | To modify the colors in the application, make changes in the ```frontend/src/theme/theme.js``` file.
121 |
122 | ### 2. Changing Fonts
123 |
124 | To modify the fonts in the application, first, add a new font to the ```frontend/public/index.html``` file, and then make changes in the ```frontend/src/theme/theme.js``` file.
125 |
126 | ### 3. Changing Logo
127 |
128 | To modify the logo in the application, make changes in the ```frontend/src/layout/Header.js``` and ```frontend/src/layout/Sidebar.js``` files.
129 |
130 | ### 4. Changing the Image in the Hero Section
131 |
132 | To modify the image in the Hero section, make changes in the ```frontend/src/components/Hero.js``` and ```frontend/src/layout/Footer.js``` files.
133 |
134 | ### 5. Changing the Text in the Hero Section
135 |
136 | To modify the text in the Hero section, make changes in the ```frontend/src/components/Hero.js``` file.
137 |
138 | ### 6. Changing Buttons in the Hero Section
139 |
140 | To modify the two buttons in the Hero section, make changes in the ```frontend/src/components/HeroButtons.js``` file.
141 |
142 | ### 7. Changing the App Description
143 |
144 | To modify the app's description on the home page, make changes in the ```frontend/src/components/Description.js``` file.
145 |
146 | ## Copyright and License
147 |
148 | Copyright © 2022 Bob's Programming Academy. Code released under the MIT license.
149 |
--------------------------------------------------------------------------------
/backend/classifier/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/backend/classifier/__init__.py
--------------------------------------------------------------------------------
/backend/classifier/admin.py:
--------------------------------------------------------------------------------
1 | from django.contrib import admin
2 | from .models import Classifier
3 |
4 | admin.site.register(Classifier)
5 |
--------------------------------------------------------------------------------
/backend/classifier/apps.py:
--------------------------------------------------------------------------------
1 | from django.apps import AppConfig
2 |
3 |
4 | class ClassifierConfig(AppConfig):
5 | default_auto_field = 'django.db.models.BigAutoField'
6 | name = 'classifier'
7 |
--------------------------------------------------------------------------------
/backend/classifier/migrations/0001_initial.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 3.2.11 on 2022-01-28 05:01
2 |
3 | from django.db import migrations, models
4 |
5 |
6 | class Migration(migrations.Migration):
7 |
8 | initial = True
9 |
10 | dependencies = [
11 | ]
12 |
13 | operations = [
14 | migrations.CreateModel(
15 | name='Classifier',
16 | fields=[
17 | ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
18 | ('image', models.ImageField(upload_to='images')),
19 | ('result', models.CharField(blank=True, max_length=3)),
20 | ('date_created', models.DateTimeField(auto_now_add=True)),
21 | ('date_updated', models.DateTimeField(auto_now=True)),
22 | ],
23 | ),
24 | ]
25 |
--------------------------------------------------------------------------------
/backend/classifier/migrations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/backend/classifier/migrations/__init__.py
--------------------------------------------------------------------------------
/backend/classifier/models.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import numpy as np
4 | import tensorflow as tf
5 | from django.conf import settings
6 | from django.db import models
7 | from PIL import Image
8 |
9 |
10 | class Classifier(models.Model):
11 | image = models.ImageField(upload_to='images')
12 | result = models.CharField(max_length=3, blank=True)
13 | date_created = models.DateTimeField(auto_now_add=True)
14 | date_updated = models.DateTimeField(auto_now=True)
15 |
16 | def save(self, *args, **kwargs):
17 | img = Image.open(self.image)
18 | img_array = tf.keras.preprocessing.image.img_to_array(img)
19 |
20 | new_img = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
21 |
22 | dimensions = (28, 28)
23 |
24 | # Interpolation - a method of constructing new data points within the range
25 | # of a discrete set of known data points.
26 | resized_image = cv2.resize(new_img, dimensions, interpolation=cv2.INTER_AREA)
27 |
28 | ready_image = np.expand_dims(resized_image, axis=2)
29 | ready_image = np.expand_dims(ready_image, axis=0)
30 |
31 | try:
32 | file_model = os.path.join(settings.BASE_DIR, 'ml_model/cnn_model.h5')
33 | graph = tf.compat.v1.get_default_graph()
34 |
35 | with graph.as_default():
36 | model = tf.keras.models.load_model(file_model)
37 | prediction = np.argmax(model.predict(ready_image))
38 | self.result = str(prediction)
39 | print(f'Classified as {prediction}')
40 | except Exception:
41 | print('Failed to classify')
42 | self.result = 'F'
43 |
44 | return super().save(*args, **kwargs)
45 |
--------------------------------------------------------------------------------
/backend/classifier/serializers.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import uuid
3 | from django.core.files.base import ContentFile
4 | from rest_framework import serializers
5 |
6 | from .models import Classifier
7 |
8 |
9 | class Base64ImageField(serializers.ImageField):
10 | def to_internal_value(self, data):
11 | _format, str_img = data.split(";base64,")
12 | decoded_file = base64.b64decode(str_img)
13 | file_name = f'{str(uuid.uuid4())[:10]}.png'
14 | data = ContentFile(decoded_file, name=file_name)
15 | return super().to_internal_value(data)
16 |
17 |
18 | class ClassifierSerializer(serializers.ModelSerializer):
19 | image = Base64ImageField()
20 |
21 | class Meta:
22 | model = Classifier
23 | fields = '__all__'
24 |
--------------------------------------------------------------------------------
/backend/classifier/tests.py:
--------------------------------------------------------------------------------
1 | # from django.test import TestCase
2 |
3 | # Create your tests here.
4 |
--------------------------------------------------------------------------------
/backend/classifier/urls.py:
--------------------------------------------------------------------------------
1 | from django.urls import path, include
2 | from rest_framework import routers
3 |
4 | from .views import ClassifierViewSet
5 |
6 | router = routers.DefaultRouter()
7 | router.register(r'classifier', ClassifierViewSet)
8 |
9 | urlpatterns = [
10 | path('', include(router.urls)),
11 | ]
12 |
--------------------------------------------------------------------------------
/backend/classifier/views.py:
--------------------------------------------------------------------------------
1 | from rest_framework import viewsets
2 |
3 | from .serializers import ClassifierSerializer
4 | from .models import Classifier
5 |
6 |
7 | class ClassifierViewSet(viewsets.ModelViewSet):
8 | queryset = Classifier.objects.all()
9 | serializer_class = ClassifierSerializer
10 |
--------------------------------------------------------------------------------
/backend/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/backend/config/__init__.py
--------------------------------------------------------------------------------
/backend/config/asgi.py:
--------------------------------------------------------------------------------
1 | """
2 | ASGI config for config project.
3 |
4 | It exposes the ASGI callable as a module-level variable named ``application``.
5 |
6 | For more information on this file, see
7 | https://docs.djangoproject.com/en/3.2/howto/deployment/asgi/
8 | """
9 |
10 | import os
11 |
12 | from django.core.asgi import get_asgi_application
13 |
14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
15 |
16 | application = get_asgi_application()
17 |
--------------------------------------------------------------------------------
/backend/config/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | # Build paths inside the project like this: BASE_DIR / 'subdir'.
5 | BASE_DIR = Path(__file__).resolve().parent.parent
6 |
7 |
8 | # Quick-start development settings - unsuitable for production
9 | # See https://docs.djangoproject.com/en/3.2/howto/deployment/checklist/
10 |
11 | # SECURITY WARNING: keep the secret key used in production secret!
12 | SECRET_KEY = 'django-insecure-qq!ws+!t@^1n9yq#07o$goah_-_#4#@m#k&-s)ta8%hd*$3sr1'
13 |
14 | # SECURITY WARNING: don't run with debug turned on in production!
15 | DEBUG = True
16 |
17 | ALLOWED_HOSTS = []
18 |
19 |
20 | # Application definition
21 |
22 | INSTALLED_APPS = [
23 | 'django.contrib.admin',
24 | 'django.contrib.auth',
25 | 'django.contrib.contenttypes',
26 | 'django.contrib.sessions',
27 | 'django.contrib.messages',
28 | 'django.contrib.staticfiles',
29 |
30 | # 3rd party
31 | 'rest_framework',
32 | 'corsheaders',
33 |
34 | # Local
35 | 'classifier',
36 | ]
37 |
38 | MIDDLEWARE = [
39 | 'corsheaders.middleware.CorsMiddleware',
40 | 'django.middleware.security.SecurityMiddleware',
41 | 'django.contrib.sessions.middleware.SessionMiddleware',
42 | 'django.middleware.common.CommonMiddleware',
43 | 'django.middleware.csrf.CsrfViewMiddleware',
44 | 'django.contrib.auth.middleware.AuthenticationMiddleware',
45 | 'django.contrib.messages.middleware.MessageMiddleware',
46 | 'django.middleware.clickjacking.XFrameOptionsMiddleware',
47 | ]
48 |
49 | ROOT_URLCONF = 'config.urls'
50 |
51 | TEMPLATES = [
52 | {
53 | 'BACKEND': 'django.template.backends.django.DjangoTemplates',
54 | 'DIRS': [],
55 | 'APP_DIRS': True,
56 | 'OPTIONS': {
57 | 'context_processors': [
58 | 'django.template.context_processors.debug',
59 | 'django.template.context_processors.request',
60 | 'django.contrib.auth.context_processors.auth',
61 | 'django.contrib.messages.context_processors.messages',
62 | ],
63 | },
64 | },
65 | ]
66 |
67 | WSGI_APPLICATION = 'config.wsgi.application'
68 |
69 |
70 | # Database
71 | # https://docs.djangoproject.com/en/3.2/ref/settings/#databases
72 |
73 | DATABASES = {
74 | 'default': {
75 | 'ENGINE': 'django.db.backends.sqlite3',
76 | 'NAME': BASE_DIR / 'db.sqlite3',
77 | }
78 | }
79 |
80 |
81 | # Password validation
82 | # https://docs.djangoproject.com/en/3.2/ref/settings/#auth-password-validators
83 |
84 | AUTH_PASSWORD_VALIDATORS = [
85 | {
86 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
87 | },
88 | {
89 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
90 | },
91 | {
92 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
93 | },
94 | {
95 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
96 | },
97 | ]
98 |
99 |
100 | # Internationalization
101 | # https://docs.djangoproject.com/en/3.2/topics/i18n/
102 |
103 | LANGUAGE_CODE = 'en-us'
104 |
105 | TIME_ZONE = 'UTC'
106 |
107 | USE_I18N = True
108 |
109 | USE_L10N = True
110 |
111 | USE_TZ = True
112 |
113 |
114 | # Static files (CSS, JavaScript, Images)
115 | # https://docs.djangoproject.com/en/3.2/howto/static-files/
116 |
117 | STATIC_URL = '/static/'
118 | STATICFILES_DIRS = [os.path.join(BASE_DIR, 'build/static')]
119 | STATIC_ROOT = os.path.join(BASE_DIR, 'static')
120 |
121 | MEDIA_URL = '/media/'
122 | MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
123 |
124 |
125 | # Default primary key field type
126 | # https://docs.djangoproject.com/en/3.2/ref/settings/#default-auto-field
127 |
128 | DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
129 |
130 | CORS_ORIGIN_ALLOW_ALL = True
131 |
132 | FILE_UPLOAD_PERMISSIONS = 0o640
133 |
--------------------------------------------------------------------------------
/backend/config/urls.py:
--------------------------------------------------------------------------------
1 | from django.conf import settings
2 | from django.conf.urls.static import static
3 | from django.contrib import admin
4 | from django.urls import path, include
5 |
6 | urlpatterns = [
7 | path('admin/', admin.site.urls),
8 | path('api/', include('classifier.urls'))
9 | ]
10 |
11 | urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
12 | urlpatterns += static(settings.STATIC_URL, document_root=settings.STATIC_ROOT)
13 |
--------------------------------------------------------------------------------
/backend/config/wsgi.py:
--------------------------------------------------------------------------------
1 | """
2 | WSGI config for config project.
3 |
4 | It exposes the WSGI callable as a module-level variable named ``application``.
5 |
6 | For more information on this file, see
7 | https://docs.djangoproject.com/en/3.2/howto/deployment/wsgi/
8 | """
9 |
10 | import os
11 |
12 | from django.core.wsgi import get_wsgi_application
13 |
14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
15 |
16 | application = get_wsgi_application()
17 |
--------------------------------------------------------------------------------
/backend/manage.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Django's command-line utility for administrative tasks."""
3 | import os
4 | import sys
5 |
6 |
7 | def main():
8 | """Run administrative tasks."""
9 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
10 | try:
11 | from django.core.management import execute_from_command_line
12 | except ImportError as exc:
13 | raise ImportError(
14 | "Couldn't import Django. Are you sure it's installed and "
15 | "available on your PYTHONPATH environment variable? Did you "
16 | "forget to activate a virtual environment?"
17 | ) from exc
18 | execute_from_command_line(sys.argv)
19 |
20 |
21 | if __name__ == '__main__':
22 | main()
23 |
--------------------------------------------------------------------------------
/backend/ml_model/cnn_mnist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 |
5 | def load_dataset():
6 | # load the dataset
7 | (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
8 |
9 | # reshape the dataset to have a single channel
10 | X_train = np.expand_dims(X_train, axis=3)
11 | X_test = np.expand_dims(X_test, axis=3)
12 |
13 | # one-hot encode target values
14 | y_train = tf.keras.utils.to_categorical(y_train, 10)
15 | y_test = tf.keras.utils.to_categorical(y_test, 10)
16 |
17 | return X_train, y_train, X_test, y_test
18 |
19 |
20 | def prepare_dataset(train, test):
21 | # convert from integers to floats
22 | train_norm = train.astype('float32')
23 | test_norm = test.astype('float32')
24 |
25 | # normalize to range 0-1
26 | train_norm = train_norm / 255.0
27 | test_norm = test_norm / 255.0
28 |
29 | # return normalized images
30 | return train_norm, test_norm
31 |
32 |
33 | # define the CNN model
34 | def define_model():
35 | model = tf.keras.models.Sequential()
36 | model.add(tf.keras.layers.Conv2D(
37 | filters=32,
38 | kernel_size=(5, 5),
39 | strides=1,
40 | activation='relu',
41 | kernel_regularizer=tf.keras.regularizers.l2(0.0005),
42 | input_shape=(28, 28, 1),
43 | name='convolution_1'
44 | ))
45 | model.add(tf.keras.layers.Conv2D(
46 | filters=32,
47 | kernel_size=(5, 5),
48 | strides=1,
49 | activation='relu',
50 | use_bias=False,
51 | name='convolution_2'
52 | ))
53 | model.add(tf.keras.layers.BatchNormalization(name='batchnorm_1'))
54 | model.add(tf.keras.layers.Activation('relu'))
55 | model.add(tf.keras.layers.MaxPooling2D(
56 | pool_size=(2, 2),
57 | strides=2,
58 | name='max_pool_1'
59 | ))
60 | model.add(tf.keras.layers.Dropout(0.25, name='dropout_1'))
61 | model.add(tf.keras.layers.Conv2D(
62 | filters=64,
63 | kernel_size=(3, 3),
64 | strides=1,
65 | activation='relu',
66 | kernel_regularizer=tf.keras.regularizers.l2(0.0005),
67 | name='convolution_3'
68 | ))
69 | model.add(tf.keras.layers.Conv2D(
70 | filters=64,
71 | kernel_size=(3, 3),
72 | strides=1,
73 | activation='relu',
74 | use_bias=False,
75 | name='convolution_4'
76 | ))
77 | model.add(tf.keras.layers.BatchNormalization(name='batchnorm_2'))
78 | model.add(tf.keras.layers.Activation('relu'))
79 | model.add(tf.keras.layers.MaxPooling2D(
80 | pool_size=(2, 2),
81 | strides=2,
82 | name='max_pool_2'
83 | ))
84 | model.add(tf.keras.layers.Dropout(0.25, name='dropout_2'))
85 | model.add(tf.keras.layers.Flatten(name='flatten'))
86 | model.add(tf.keras.layers.Dense(
87 | units=256,
88 | activation='relu',
89 | use_bias=False,
90 | name='fully_connected_1'
91 | ))
92 | model.add(tf.keras.layers.BatchNormalization(name='batchnorm_3'))
93 | model.add(tf.keras.layers.Activation('relu'))
94 | model.add(tf.keras.layers.Dense(
95 | units=128,
96 | activation='relu',
97 | use_bias=False,
98 | name='fully_connected_2'
99 | ))
100 | model.add(tf.keras.layers.BatchNormalization(name='batchnorm_4'))
101 | model.add(tf.keras.layers.Activation('relu'))
102 | model.add(tf.keras.layers.Dense(
103 | units=84,
104 | activation='relu',
105 | use_bias=False,
106 | name='fully_connected_3'
107 | ))
108 | model.add(tf.keras.layers.BatchNormalization(name='batchnorm_5'))
109 | model.add(tf.keras.layers.Activation('relu'))
110 | model.add(tf.keras.layers.Dropout(0.25, name='dropout_3'))
111 | model.add(tf.keras.layers.Dense(
112 | units=10,
113 | activation='softmax',
114 | name='output'
115 | ))
116 |
117 | # compile the model
118 | model.compile(
119 | optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
120 | loss='categorical_crossentropy',
121 | metrics=['accuracy']
122 | )
123 | model.summary()
124 |
125 | return model
126 |
127 |
128 | # train, evaluate and save the model
129 | def train_and_save_model(X_train, y_train, X_test, y_test):
130 | # define the model
131 | model = define_model()
132 |
133 | # train the model
134 | model.fit(
135 | X_train,
136 | y_train,
137 | epochs=30,
138 | batch_size=64,
139 | validation_split=0.2,
140 | verbose=1,
141 | shuffle=True,
142 | callbacks=[tf.keras.callbacks.ReduceLROnPlateau(
143 | monitor='val_loss',
144 | factor=0.2,
145 | patience=2
146 | )]
147 | )
148 |
149 | # evaluate the model on test data
150 | score = model.evaluate(X_test, y_test, verbose=0)
151 | print('Error on test data: ', score[0]) # printed: Error on test data: 0.022791262716054916
152 | print('Accuracy on test data: {0:.2f}%'.format(score[1] * 100)) # printed: Accuracy on test data: 99.45%
153 |
154 | # save the model
155 | model.save('ml_model/cnn_model.h5')
156 | print('Saved model to disk')
157 |
158 |
159 | # run the test harness for evaluating the model
160 | def run_test_harness():
161 | # load the dataset
162 | X_train, y_train, X_test, y_test = load_dataset()
163 |
164 | # prepare the dataset
165 | X_train, X_test = prepare_dataset(X_train, X_test)
166 |
167 | # train, evaluate and save the model
168 | train_and_save_model(X_train, y_train, X_test, y_test)
169 |
170 |
171 | # entry point, run the test harness
172 | run_test_harness()
173 |
--------------------------------------------------------------------------------
/backend/ml_model/cnn_model.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/backend/ml_model/cnn_model.h5
--------------------------------------------------------------------------------
/backend/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/backend/requirements.txt
--------------------------------------------------------------------------------
/backend/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | exclude = .git,*migrations*,*venv*
3 | max-line-length = 119
4 | indent-size = 2
--------------------------------------------------------------------------------
/frontend/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "image-classification-mnist",
3 | "version": "1.0.0",
4 | "description": "Image Classification MNIST",
5 | "scripts": {
6 | "start": "react-scripts start",
7 | "build": "react-scripts build",
8 | "test": "react-scripts test",
9 | "eject": "react-scripts eject"
10 | },
11 | "author": "MG",
12 | "license": "MIT",
13 | "dependencies": {
14 | "@emotion/react": "^11.7.1",
15 | "@emotion/styled": "^11.6.0",
16 | "@fortawesome/fontawesome-svg-core": "^1.2.36",
17 | "@fortawesome/free-regular-svg-icons": "^5.15.4",
18 | "@fortawesome/free-solid-svg-icons": "^5.15.4",
19 | "@fortawesome/react-fontawesome": "^0.1.16",
20 | "@mui/icons-material": "^5.3.1",
21 | "@mui/material": "^5.3.1",
22 | "aos": "^2.3.4",
23 | "axios": "^0.25.0",
24 | "file-saver": "^2.0.5",
25 | "react": "^17.0.2",
26 | "react-dom": "^17.0.2",
27 | "react-helmet-async": "^1.2.2",
28 | "react-lazy-load-image-component": "^1.5.1",
29 | "react-router-dom": "^6.2.1",
30 | "react-scripts": "^5.0.1",
31 | "react-sketch-canvas": "^6.1.0"
32 | },
33 | "browserslist": {
34 | "production": [
35 | ">0.2%",
36 | "not dead",
37 | "not op_mini all"
38 | ],
39 | "development": [
40 | "last 1 chrome version",
41 | "last 1 firefox version",
42 | "last 1 safari version"
43 | ]
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/frontend/public/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobsProgrammingAcademy/image-classification-mnist/c19dd50615f3fe2348640ae7a99d380f1f3c66d6/frontend/public/favicon.ico
--------------------------------------------------------------------------------
/frontend/public/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |