├── APIProjectFolder
├── APIProject
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── urls.cpython-37.pyc
│ │ ├── wsgi.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ └── settings.cpython-37.pyc
│ ├── wsgi.py
│ ├── urls.py
│ └── settings.py
├── Prediction
│ ├── __init__.py
│ ├── migrations
│ │ ├── __init__.py
│ │ └── __pycache__
│ │ │ └── __init__.cpython-37.pyc
│ ├── models.py
│ ├── tests.py
│ ├── admin.py
│ ├── __pycache__
│ │ ├── admin.cpython-37.pyc
│ │ ├── apps.cpython-37.pyc
│ │ ├── urls.cpython-37.pyc
│ │ ├── views.cpython-37.pyc
│ │ ├── models.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ └── throttles.cpython-37.pyc
│ ├── classifier
│ │ └── IRISRandomForestClassifier.joblib
│ ├── throttles.py
│ ├── urls.py
│ ├── apps.py
│ └── views.py
├── .vscode
│ ├── settings.json
│ └── launch.json
├── db.sqlite3
└── manage.py
├── .gitattributes
├── .vs
├── PythonSettings.json
├── slnx.sqlite
├── DjangoRestAPI
│ └── v16
│ │ └── .suo
├── DjangoRestAPIDemo
│ └── v16
│ │ └── .suo
└── VSWorkspaceState.json
├── Iris.pptx
├── .ipynb_checkpoints
└── CreateModelWithJupyter-checkpoint.ipynb
├── Steps.txt
├── README.md
└── CreateModelWithJupyter.ipynb
/APIProjectFolder/APIProject/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/migrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | * linguist-vendored
2 | *.py linguist-vendored=false
--------------------------------------------------------------------------------
/.vs/PythonSettings.json:
--------------------------------------------------------------------------------
1 | {
2 | "Interpreter": "Global|VisualStudio|Env64 (Python 3.7 (64-bit))"
3 | }
--------------------------------------------------------------------------------
/Iris.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/Iris.pptx
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 | # Create your models here.
4 |
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/tests.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase
2 |
3 | # Create your tests here.
4 |
--------------------------------------------------------------------------------
/.vs/slnx.sqlite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/.vs/slnx.sqlite
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/admin.py:
--------------------------------------------------------------------------------
1 | from django.contrib import admin
2 |
3 | # Register your models here.
4 |
--------------------------------------------------------------------------------
/APIProjectFolder/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.pythonPath": "Z:\\MachineLearning\\Env64\\Scripts\\python.exe"
3 | }
--------------------------------------------------------------------------------
/.vs/DjangoRestAPI/v16/.suo:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/.vs/DjangoRestAPI/v16/.suo
--------------------------------------------------------------------------------
/APIProjectFolder/db.sqlite3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/db.sqlite3
--------------------------------------------------------------------------------
/.vs/DjangoRestAPIDemo/v16/.suo:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/.vs/DjangoRestAPIDemo/v16/.suo
--------------------------------------------------------------------------------
/.ipynb_checkpoints/CreateModelWithJupyter-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [],
3 | "metadata": {},
4 | "nbformat": 4,
5 | "nbformat_minor": 2
6 | }
7 |
--------------------------------------------------------------------------------
/.vs/VSWorkspaceState.json:
--------------------------------------------------------------------------------
1 | {
2 | "ExpandedNodes": [
3 | "",
4 | "\\APIProject"
5 | ],
6 | "SelectedNode": "\\Steps.txt",
7 | "PreviewInSolutionExplorer": false
8 | }
--------------------------------------------------------------------------------
/APIProjectFolder/APIProject/__pycache__/urls.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/APIProject/__pycache__/urls.cpython-37.pyc
--------------------------------------------------------------------------------
/APIProjectFolder/APIProject/__pycache__/wsgi.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/APIProject/__pycache__/wsgi.cpython-37.pyc
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/__pycache__/admin.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/admin.cpython-37.pyc
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/__pycache__/apps.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/apps.cpython-37.pyc
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/__pycache__/urls.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/urls.cpython-37.pyc
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/__pycache__/views.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/views.cpython-37.pyc
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/__pycache__/models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/models.cpython-37.pyc
--------------------------------------------------------------------------------
/APIProjectFolder/APIProject/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/APIProject/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/APIProjectFolder/APIProject/__pycache__/settings.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/APIProject/__pycache__/settings.cpython-37.pyc
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/__pycache__/throttles.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/throttles.cpython-37.pyc
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/classifier/IRISRandomForestClassifier.joblib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/classifier/IRISRandomForestClassifier.joblib
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/migrations/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/migrations/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/throttles.py:
--------------------------------------------------------------------------------
1 | from rest_framework.throttling import UserRateThrottle
2 |
3 | # Custom Throttle classes
4 | class LimitedRateThrottle(UserRateThrottle):
5 | scope = 'limited'
6 |
7 | class BurstRateThrottle(UserRateThrottle):
8 | scope = 'burst'
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/urls.py:
--------------------------------------------------------------------------------
1 | from django.urls import path
2 | import Prediction.views as views
3 |
4 | urlpatterns = [
5 | path('add/', views.api_add, name = 'api_add'),
6 | path('add_values/', views.Add_Values.as_view(), name = 'api_add_values'),
7 | path('predict/', views.IRIS_Model_Predict.as_view(), name = 'predict'),
8 | ]
--------------------------------------------------------------------------------
/Steps.txt:
--------------------------------------------------------------------------------
1 | 1) Install Django version 2.2 with command:
2 | python -m pip install Django==2.2.0
3 | 2) Check Django is installed by command:
4 | python -m django --version
5 | This should produce the output below:
6 | 2.2
7 | 3) From the command line, cd into a directory where you’d like to store your code, then run the following command:
8 | django-admin startproject ModelAPI
9 |
--------------------------------------------------------------------------------
/APIProjectFolder/APIProject/wsgi.py:
--------------------------------------------------------------------------------
1 | """
2 | WSGI config for APIProject project.
3 |
4 | It exposes the WSGI callable as a module-level variable named ``application``.
5 |
6 | For more information on this file, see
7 | https://docs.djangoproject.com/en/2.2/howto/deployment/wsgi/
8 | """
9 |
10 | import os
11 |
12 | from django.core.wsgi import get_wsgi_application
13 |
14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'APIProject.settings')
15 |
16 | application = get_wsgi_application()
17 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PredictionAPIWithDjangoRESTDemoApp
2 |
3 | Full instructions available on https://www.datagraphi.com/blog/post/2019/12/19/rest-api-guide-productionizing-a-machine-learning-model-by-creating-a-rest-api-with-python-django-and-django-rest-framework
4 |
5 | The article explains everything from setting up Django, creating REST APIs, Testing REST APIs, Serializing a Machine Learning Model, Integrating a Machine Learning Model into an API, Authentication and Throttling with a hands-on approach in a clear and concise manner.
6 |
--------------------------------------------------------------------------------
/APIProjectFolder/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: Django",
9 | "type": "python",
10 | "request": "launch",
11 | "program": "${workspaceFolder}\\manage.py",
12 | "args": [
13 | "runserver",
14 | "--noreload"
15 | ],
16 | "django": true
17 | }
18 | ]
19 | }
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/apps.py:
--------------------------------------------------------------------------------
1 | from django.apps import AppConfig
2 | import pandas as pd
3 | from joblib import load
4 | import os
5 |
6 |
7 | class PredictionConfig(AppConfig):
8 | name = 'Prediction'
9 | #CLASSIFIER_FOLDER = Path("classifier")
10 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11 | CLASSIFIER_FOLDER = os.path.join(BASE_DIR, 'Prediction/classifier/')
12 | #CLASSIFIER_FILE = CLASSIFIER_FOLDER / "IRISRandomForestClassifier.joblib"
13 | CLASSIFIER_FILE = os.path.join(CLASSIFIER_FOLDER, "IRISRandomForestClassifier.joblib")
14 | classifier = load(CLASSIFIER_FILE)
--------------------------------------------------------------------------------
/APIProjectFolder/manage.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Django's command-line utility for administrative tasks."""
3 | import os
4 | import sys
5 |
6 |
7 | def main():
8 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'APIProject.settings')
9 | try:
10 | from django.core.management import execute_from_command_line
11 | except ImportError as exc:
12 | raise ImportError(
13 | "Couldn't import Django. Are you sure it's installed and "
14 | "available on your PYTHONPATH environment variable? Did you "
15 | "forget to activate a virtual environment?"
16 | ) from exc
17 | execute_from_command_line(sys.argv)
18 |
19 |
20 | if __name__ == '__main__':
21 | main()
22 |
--------------------------------------------------------------------------------
/APIProjectFolder/APIProject/urls.py:
--------------------------------------------------------------------------------
1 | """APIProject URL Configuration
2 |
3 | The `urlpatterns` list routes URLs to views. For more information please see:
4 | https://docs.djangoproject.com/en/2.2/topics/http/urls/
5 | Examples:
6 | Function views
7 | 1. Add an import: from my_app import views
8 | 2. Add a URL to urlpatterns: path('', views.home, name='home')
9 | Class-based views
10 | 1. Add an import: from other_app.views import Home
11 | 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
12 | Including another URLconf
13 | 1. Import the include() function: from django.urls import include, path
14 | 2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
15 | """
16 | from django.contrib import admin
17 | from django.urls import path, include
18 |
19 | urlpatterns = [
20 | path('admin/', admin.site.urls),
21 | path('api-auth/', include('rest_framework.urls', namespace='rest_framework')),
22 | path('api/', include('Prediction.urls')),
23 | ]
24 |
25 |
26 |
--------------------------------------------------------------------------------
/APIProjectFolder/Prediction/views.py:
--------------------------------------------------------------------------------
1 | from rest_framework import status
2 | from rest_framework.decorators import api_view
3 | from rest_framework.response import Response
4 | from rest_framework.views import APIView
5 | from .apps import PredictionConfig
6 | import pandas as pd
7 |
8 | from rest_framework.permissions import IsAuthenticated
9 | from .throttles import LimitedRateThrottle, BurstRateThrottle
10 |
11 | # Create your views here.
12 | @api_view(['GET', 'POST'])
13 | def api_add(request):
14 | sum = 0
15 | response_dict = {}
16 | if request.method == 'GET':
17 | # Do nothing
18 | pass
19 | elif request.method == 'POST':
20 | # Add the numbers
21 | data = request.data
22 | for key in data:
23 | sum += data[key]
24 | response_dict = {"sum": sum}
25 | return Response(response_dict, status=status.HTTP_201_CREATED)
26 |
27 | # Class based view to add numbers
28 | class Add_Values(APIView):
29 | permission_classes = [IsAuthenticated]
30 | def post(self, request, format=None):
31 | sum = 0
32 | # Add the numbers
33 | data = request.data
34 | for key in data:
35 | sum += data[key]
36 | response_dict = {"sum": sum}
37 | return Response(response_dict, status=status.HTTP_201_CREATED)
38 |
39 | # Class based view to predict based on IRIS model
40 | class IRIS_Model_Predict(APIView):
41 | permission_classes = [IsAuthenticated]
42 | throttle_classes = [LimitedRateThrottle]
43 | def post(self, request, format=None):
44 | data = request.data
45 | keys = []
46 | values = []
47 | for key in data:
48 | keys.append(key)
49 | values.append(data[key])
50 | X = pd.Series(values).to_numpy().reshape(1, -1)
51 | loaded_classifier = PredictionConfig.classifier
52 | y_pred = loaded_classifier.predict(X)
53 | y_pred = pd.Series(y_pred)
54 | target_map = {0: 'setosa', 1: 'versicolor', 2: 'virginica'}
55 | y_pred = y_pred.map(target_map).to_numpy()
56 | response_dict = {"Prediced Iris Species": y_pred[0]}
57 | return Response(response_dict, status=200)
--------------------------------------------------------------------------------
/APIProjectFolder/APIProject/settings.py:
--------------------------------------------------------------------------------
1 | """
2 | Django settings for APIProject project.
3 |
4 | Generated by 'django-admin startproject' using Django 2.2.
5 |
6 | For more information on this file, see
7 | https://docs.djangoproject.com/en/2.2/topics/settings/
8 |
9 | For the full list of settings and their values, see
10 | https://docs.djangoproject.com/en/2.2/ref/settings/
11 | """
12 |
13 | import os
14 |
15 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...)
16 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17 |
18 |
19 | # Quick-start development settings - unsuitable for production
20 | # See https://docs.djangoproject.com/en/2.2/howto/deployment/checklist/
21 |
22 | # SECURITY WARNING: keep the secret key used in production secret!
23 | SECRET_KEY = '67sc$))b_ykz&ze3py@=)9j-repx%1b%lkl672k2oiw&qjj&hy'
24 |
25 | # SECURITY WARNING: don't run with debug turned on in production!
26 | DEBUG = True
27 |
28 | ALLOWED_HOSTS = []
29 |
30 |
31 | # Application definition
32 |
33 | INSTALLED_APPS = [
34 | 'django.contrib.admin',
35 | 'django.contrib.auth',
36 | 'django.contrib.contenttypes',
37 | 'django.contrib.sessions',
38 | 'django.contrib.messages',
39 | 'django.contrib.staticfiles',
40 | 'Prediction',
41 | 'rest_framework',
42 | ]
43 |
44 | MIDDLEWARE = [
45 | 'django.middleware.security.SecurityMiddleware',
46 | 'django.contrib.sessions.middleware.SessionMiddleware',
47 | 'django.middleware.common.CommonMiddleware',
48 | 'django.middleware.csrf.CsrfViewMiddleware',
49 | 'django.contrib.auth.middleware.AuthenticationMiddleware',
50 | 'django.contrib.messages.middleware.MessageMiddleware',
51 | 'django.middleware.clickjacking.XFrameOptionsMiddleware',
52 | ]
53 |
54 | ROOT_URLCONF = 'APIProject.urls'
55 |
56 | TEMPLATES = [
57 | {
58 | 'BACKEND': 'django.template.backends.django.DjangoTemplates',
59 | 'DIRS': [],
60 | 'APP_DIRS': True,
61 | 'OPTIONS': {
62 | 'context_processors': [
63 | 'django.template.context_processors.debug',
64 | 'django.template.context_processors.request',
65 | 'django.contrib.auth.context_processors.auth',
66 | 'django.contrib.messages.context_processors.messages',
67 | ],
68 | },
69 | },
70 | ]
71 |
72 | WSGI_APPLICATION = 'APIProject.wsgi.application'
73 |
74 |
75 | # Database
76 | # https://docs.djangoproject.com/en/2.2/ref/settings/#databases
77 |
78 | DATABASES = {
79 | 'default': {
80 | 'ENGINE': 'django.db.backends.sqlite3',
81 | 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
82 | }
83 | }
84 |
85 |
86 | # Password validation
87 | # https://docs.djangoproject.com/en/2.2/ref/settings/#auth-password-validators
88 |
89 | AUTH_PASSWORD_VALIDATORS = [
90 | {
91 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
92 | },
93 | {
94 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
95 | },
96 | {
97 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
98 | },
99 | {
100 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
101 | },
102 | ]
103 |
104 |
105 | # Internationalization
106 | # https://docs.djangoproject.com/en/2.2/topics/i18n/
107 |
108 | LANGUAGE_CODE = 'en-us'
109 |
110 | TIME_ZONE = 'UTC'
111 |
112 | USE_I18N = True
113 |
114 | USE_L10N = True
115 |
116 | USE_TZ = True
117 |
118 |
119 | # Static files (CSS, JavaScript, Images)
120 | # https://docs.djangoproject.com/en/2.2/howto/static-files/
121 |
122 | STATIC_URL = '/static/'
123 |
124 | REST_FRAMEWORK = {
125 | 'DEFAULT_THROTTLE_CLASSES': [
126 | 'Prediction.throttles.LimitedRateThrottle',
127 | 'Prediction.throttles.BurstRateThrottle'
128 | ],
129 | 'DEFAULT_THROTTLE_RATES': {
130 | 'limited': '2/min',
131 | 'burst': '10/min'
132 | }
133 | }
134 |
--------------------------------------------------------------------------------
/CreateModelWithJupyter.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 13,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "### Load the iris dataset from sklearn "
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "from sklearn.datasets import load_iris"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 65,
31 | "metadata": {},
32 | "outputs": [
33 | {
34 | "data": {
35 | "text/plain": [
36 | "sklearn.utils.Bunch"
37 | ]
38 | },
39 | "execution_count": 65,
40 | "metadata": {},
41 | "output_type": "execute_result"
42 | }
43 | ],
44 | "source": [
45 | "iris = load_iris()\n",
46 | "type(iris)"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 6,
52 | "metadata": {},
53 | "outputs": [
54 | {
55 | "name": "stdout",
56 | "output_type": "stream",
57 | "text": [
58 | "['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']\n"
59 | ]
60 | }
61 | ],
62 | "source": [
63 | "print(iris.feature_names)"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 11,
69 | "metadata": {},
70 | "outputs": [
71 | {
72 | "name": "stdout",
73 | "output_type": "stream",
74 | "text": [
75 | "['setosa' 'versicolor' 'virginica']\n"
76 | ]
77 | }
78 | ],
79 | "source": [
80 | "print(iris.target_names)"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": 8,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "# Define feature matrix in \"X\"\n",
90 | "X = iris.data\n",
91 | "\n",
92 | "# Define target response vector in \"y\"\n",
93 | "y = iris.target"
94 | ]
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "metadata": {},
99 | "source": [
100 | "### Get the basic statistics for the features"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 15,
106 | "metadata": {},
107 | "outputs": [
108 | {
109 | "data": {
110 | "text/html": [
111 | "
\n",
112 | "\n",
125 | "
\n",
126 | " \n",
127 | " \n",
128 | " | \n",
129 | " sepal length (cm) | \n",
130 | " sepal width (cm) | \n",
131 | " petal length (cm) | \n",
132 | " petal width (cm) | \n",
133 | "
\n",
134 | " \n",
135 | " \n",
136 | " \n",
137 | " | 0 | \n",
138 | " 5.1 | \n",
139 | " 3.5 | \n",
140 | " 1.4 | \n",
141 | " 0.2 | \n",
142 | "
\n",
143 | " \n",
144 | " | 1 | \n",
145 | " 4.9 | \n",
146 | " 3.0 | \n",
147 | " 1.4 | \n",
148 | " 0.2 | \n",
149 | "
\n",
150 | " \n",
151 | " | 2 | \n",
152 | " 4.7 | \n",
153 | " 3.2 | \n",
154 | " 1.3 | \n",
155 | " 0.2 | \n",
156 | "
\n",
157 | " \n",
158 | "
\n",
159 | "
"
160 | ],
161 | "text/plain": [
162 | " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n",
163 | "0 5.1 3.5 1.4 0.2\n",
164 | "1 4.9 3.0 1.4 0.2\n",
165 | "2 4.7 3.2 1.3 0.2"
166 | ]
167 | },
168 | "execution_count": 15,
169 | "metadata": {},
170 | "output_type": "execute_result"
171 | }
172 | ],
173 | "source": [
174 | "X_df = pd.DataFrame(data=X, columns = iris.feature_names)\n",
175 | "X_df.head(3)"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 16,
181 | "metadata": {},
182 | "outputs": [
183 | {
184 | "data": {
185 | "text/html": [
186 | "\n",
187 | "\n",
200 | "
\n",
201 | " \n",
202 | " \n",
203 | " | \n",
204 | " sepal length (cm) | \n",
205 | " sepal width (cm) | \n",
206 | " petal length (cm) | \n",
207 | " petal width (cm) | \n",
208 | "
\n",
209 | " \n",
210 | " \n",
211 | " \n",
212 | " | count | \n",
213 | " 150.000000 | \n",
214 | " 150.000000 | \n",
215 | " 150.000000 | \n",
216 | " 150.000000 | \n",
217 | "
\n",
218 | " \n",
219 | " | mean | \n",
220 | " 5.843333 | \n",
221 | " 3.057333 | \n",
222 | " 3.758000 | \n",
223 | " 1.199333 | \n",
224 | "
\n",
225 | " \n",
226 | " | std | \n",
227 | " 0.828066 | \n",
228 | " 0.435866 | \n",
229 | " 1.765298 | \n",
230 | " 0.762238 | \n",
231 | "
\n",
232 | " \n",
233 | " | min | \n",
234 | " 4.300000 | \n",
235 | " 2.000000 | \n",
236 | " 1.000000 | \n",
237 | " 0.100000 | \n",
238 | "
\n",
239 | " \n",
240 | " | 25% | \n",
241 | " 5.100000 | \n",
242 | " 2.800000 | \n",
243 | " 1.600000 | \n",
244 | " 0.300000 | \n",
245 | "
\n",
246 | " \n",
247 | " | 50% | \n",
248 | " 5.800000 | \n",
249 | " 3.000000 | \n",
250 | " 4.350000 | \n",
251 | " 1.300000 | \n",
252 | "
\n",
253 | " \n",
254 | " | 75% | \n",
255 | " 6.400000 | \n",
256 | " 3.300000 | \n",
257 | " 5.100000 | \n",
258 | " 1.800000 | \n",
259 | "
\n",
260 | " \n",
261 | " | max | \n",
262 | " 7.900000 | \n",
263 | " 4.400000 | \n",
264 | " 6.900000 | \n",
265 | " 2.500000 | \n",
266 | "
\n",
267 | " \n",
268 | "
\n",
269 | "
"
270 | ],
271 | "text/plain": [
272 | " sepal length (cm) sepal width (cm) petal length (cm) \\\n",
273 | "count 150.000000 150.000000 150.000000 \n",
274 | "mean 5.843333 3.057333 3.758000 \n",
275 | "std 0.828066 0.435866 1.765298 \n",
276 | "min 4.300000 2.000000 1.000000 \n",
277 | "25% 5.100000 2.800000 1.600000 \n",
278 | "50% 5.800000 3.000000 4.350000 \n",
279 | "75% 6.400000 3.300000 5.100000 \n",
280 | "max 7.900000 4.400000 6.900000 \n",
281 | "\n",
282 | " petal width (cm) \n",
283 | "count 150.000000 \n",
284 | "mean 1.199333 \n",
285 | "std 0.762238 \n",
286 | "min 0.100000 \n",
287 | "25% 0.300000 \n",
288 | "50% 1.300000 \n",
289 | "75% 1.800000 \n",
290 | "max 2.500000 "
291 | ]
292 | },
293 | "execution_count": 16,
294 | "metadata": {},
295 | "output_type": "execute_result"
296 | }
297 | ],
298 | "source": [
299 | "X_df.describe()"
300 | ]
301 | },
302 | {
303 | "cell_type": "markdown",
304 | "metadata": {},
305 | "source": [
306 | "### Split the dataset into training and testing dataset"
307 | ]
308 | },
309 | {
310 | "cell_type": "code",
311 | "execution_count": 17,
312 | "metadata": {},
313 | "outputs": [],
314 | "source": [
315 | "from sklearn.model_selection import train_test_split"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": 19,
321 | "metadata": {},
322 | "outputs": [],
323 | "source": [
324 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=1, stratify=y)"
325 | ]
326 | },
327 | {
328 | "cell_type": "markdown",
329 | "metadata": {},
330 | "source": [
331 | "### Use the Random Forest Classifier"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 20,
337 | "metadata": {},
338 | "outputs": [],
339 | "source": [
340 | "from sklearn.ensemble import RandomForestClassifier"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": 21,
346 | "metadata": {},
347 | "outputs": [],
348 | "source": [
349 | "clf_rf = RandomForestClassifier(random_state = 1, n_estimators = 10, n_jobs = -1)\n",
350 | "estimator_rf = clf_rf"
351 | ]
352 | },
353 | {
354 | "cell_type": "code",
355 | "execution_count": 22,
356 | "metadata": {},
357 | "outputs": [
358 | {
359 | "data": {
360 | "text/plain": [
361 | "RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n",
362 | " max_depth=None, max_features='auto', max_leaf_nodes=None,\n",
363 | " min_impurity_decrease=0.0, min_impurity_split=None,\n",
364 | " min_samples_leaf=1, min_samples_split=2,\n",
365 | " min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=-1,\n",
366 | " oob_score=False, random_state=1, verbose=0,\n",
367 | " warm_start=False)"
368 | ]
369 | },
370 | "execution_count": 22,
371 | "metadata": {},
372 | "output_type": "execute_result"
373 | }
374 | ],
375 | "source": [
376 | "estimator_rf.fit(X=X_train, y=y_train)"
377 | ]
378 | },
379 | {
380 | "cell_type": "markdown",
381 | "metadata": {},
382 | "source": [
383 | "### Testing Accuracy"
384 | ]
385 | },
386 | {
387 | "cell_type": "code",
388 | "execution_count": 24,
389 | "metadata": {},
390 | "outputs": [
391 | {
392 | "data": {
393 | "text/plain": [
394 | "0.9666666666666667"
395 | ]
396 | },
397 | "execution_count": 24,
398 | "metadata": {},
399 | "output_type": "execute_result"
400 | }
401 | ],
402 | "source": [
403 | "estimator_rf.score(X_test,y_test)"
404 | ]
405 | },
406 | {
407 | "cell_type": "markdown",
408 | "metadata": {},
409 | "source": [
410 | "### Cross Validated Accuracy"
411 | ]
412 | },
413 | {
414 | "cell_type": "code",
415 | "execution_count": 28,
416 | "metadata": {},
417 | "outputs": [],
418 | "source": [
419 | "from sklearn.model_selection import cross_val_score"
420 | ]
421 | },
422 | {
423 | "cell_type": "code",
424 | "execution_count": 29,
425 | "metadata": {},
426 | "outputs": [
427 | {
428 | "data": {
429 | "text/plain": [
430 | "0.9666666666666668"
431 | ]
432 | },
433 | "execution_count": 29,
434 | "metadata": {},
435 | "output_type": "execute_result"
436 | }
437 | ],
438 | "source": [
439 | "estimator_cv = clf_rf\n",
440 | "scores = cross_val_score(estimator_cv, X, y, cv = 5, scoring = 'accuracy')\n",
441 | "scores.mean()"
442 | ]
443 | },
444 | {
445 | "cell_type": "markdown",
446 | "metadata": {},
447 | "source": [
448 | "## Serialize the model"
449 | ]
450 | },
451 | {
452 | "cell_type": "code",
453 | "execution_count": 30,
454 | "metadata": {},
455 | "outputs": [],
456 | "source": [
457 | "from joblib import dump, load"
458 | ]
459 | },
460 | {
461 | "cell_type": "code",
462 | "execution_count": 31,
463 | "metadata": {},
464 | "outputs": [
465 | {
466 | "data": {
467 | "text/plain": [
468 | "['IRISRandomForestClassifier.joblib']"
469 | ]
470 | },
471 | "execution_count": 31,
472 | "metadata": {},
473 | "output_type": "execute_result"
474 | }
475 | ],
476 | "source": [
477 | "dump(estimator_rf, 'IRISRandomForestClassifier.joblib') "
478 | ]
479 | },
480 | {
481 | "cell_type": "markdown",
482 | "metadata": {},
483 | "source": [
484 | "### Load back the saved model and check accuracy"
485 | ]
486 | },
487 | {
488 | "cell_type": "code",
489 | "execution_count": 32,
490 | "metadata": {},
491 | "outputs": [],
492 | "source": [
493 | "loaded_classifier = load('IRISRandomForestClassifier.joblib') "
494 | ]
495 | },
496 | {
497 | "cell_type": "code",
498 | "execution_count": 33,
499 | "metadata": {},
500 | "outputs": [
501 | {
502 | "data": {
503 | "text/plain": [
504 | "0.9666666666666667"
505 | ]
506 | },
507 | "execution_count": 33,
508 | "metadata": {},
509 | "output_type": "execute_result"
510 | }
511 | ],
512 | "source": [
513 | "loaded_classifier.score(X_test,y_test)"
514 | ]
515 | },
516 | {
517 | "cell_type": "code",
518 | "execution_count": 59,
519 | "metadata": {},
520 | "outputs": [],
521 | "source": [
522 | "target_map = {}\n",
523 | "i = 0\n",
524 | "for target in iris.target_names:\n",
525 | " target_map[i]= target\n",
526 | " i+=1\n",
527 | "\n",
528 | "def iris_predictor(X):\n",
529 | " y_pred = loaded_classifier.predict(X)\n",
530 | " y_pred = pd.Series(y_pred)\n",
531 | " y_pred = y_pred.map(target_map).to_numpy()\n",
532 | " return(y_pred)"
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "execution_count": 60,
538 | "metadata": {},
539 | "outputs": [],
540 | "source": [
541 | "y_pred = iris_predictor(pd.DataFrame(X_test[0].reshape(1, -1)))"
542 | ]
543 | },
544 | {
545 | "cell_type": "code",
546 | "execution_count": 62,
547 | "metadata": {},
548 | "outputs": [
549 | {
550 | "name": "stdout",
551 | "output_type": "stream",
552 | "text": [
553 | "virginica\n"
554 | ]
555 | }
556 | ],
557 | "source": [
558 | "print(y_pred[0])"
559 | ]
560 | },
561 | {
562 | "cell_type": "code",
563 | "execution_count": 63,
564 | "metadata": {},
565 | "outputs": [
566 | {
567 | "data": {
568 | "text/plain": [
569 | "array([[1. , 0. , 0. ],\n",
570 | " [1. , 0. , 0. ],\n",
571 | " [1. , 0. , 0. ],\n",
572 | " [1. , 0. , 0. ],\n",
573 | " [1. , 0. , 0. ],\n",
574 | " [1. , 0. , 0. ],\n",
575 | " [1. , 0. , 0. ],\n",
576 | " [1. , 0. , 0. ],\n",
577 | " [1. , 0. , 0. ],\n",
578 | " [1. , 0. , 0. ],\n",
579 | " [1. , 0. , 0. ],\n",
580 | " [1. , 0. , 0. ],\n",
581 | " [1. , 0. , 0. ],\n",
582 | " [1. , 0. , 0. ],\n",
583 | " [0.9, 0.1, 0. ],\n",
584 | " [1. , 0. , 0. ],\n",
585 | " [1. , 0. , 0. ],\n",
586 | " [1. , 0. , 0. ],\n",
587 | " [1. , 0. , 0. ],\n",
588 | " [1. , 0. , 0. ],\n",
589 | " [1. , 0. , 0. ],\n",
590 | " [1. , 0. , 0. ],\n",
591 | " [1. , 0. , 0. ],\n",
592 | " [1. , 0. , 0. ],\n",
593 | " [1. , 0. , 0. ],\n",
594 | " [1. , 0. , 0. ],\n",
595 | " [1. , 0. , 0. ],\n",
596 | " [1. , 0. , 0. ],\n",
597 | " [1. , 0. , 0. ],\n",
598 | " [1. , 0. , 0. ],\n",
599 | " [1. , 0. , 0. ],\n",
600 | " [1. , 0. , 0. ],\n",
601 | " [1. , 0. , 0. ],\n",
602 | " [1. , 0. , 0. ],\n",
603 | " [1. , 0. , 0. ],\n",
604 | " [1. , 0. , 0. ],\n",
605 | " [1. , 0. , 0. ],\n",
606 | " [1. , 0. , 0. ],\n",
607 | " [1. , 0. , 0. ],\n",
608 | " [1. , 0. , 0. ],\n",
609 | " [1. , 0. , 0. ],\n",
610 | " [1. , 0. , 0. ],\n",
611 | " [1. , 0. , 0. ],\n",
612 | " [1. , 0. , 0. ],\n",
613 | " [1. , 0. , 0. ],\n",
614 | " [1. , 0. , 0. ],\n",
615 | " [1. , 0. , 0. ],\n",
616 | " [1. , 0. , 0. ],\n",
617 | " [1. , 0. , 0. ],\n",
618 | " [1. , 0. , 0. ],\n",
619 | " [0. , 1. , 0. ],\n",
620 | " [0. , 1. , 0. ],\n",
621 | " [0. , 1. , 0. ],\n",
622 | " [0. , 1. , 0. ],\n",
623 | " [0. , 1. , 0. ],\n",
624 | " [0. , 1. , 0. ],\n",
625 | " [0. , 1. , 0. ],\n",
626 | " [0. , 1. , 0. ],\n",
627 | " [0. , 1. , 0. ],\n",
628 | " [0.1, 0.9, 0. ],\n",
629 | " [0. , 1. , 0. ],\n",
630 | " [0. , 1. , 0. ],\n",
631 | " [0. , 1. , 0. ],\n",
632 | " [0. , 1. , 0. ],\n",
633 | " [0. , 1. , 0. ],\n",
634 | " [0. , 1. , 0. ],\n",
635 | " [0. , 1. , 0. ],\n",
636 | " [0. , 1. , 0. ],\n",
637 | " [0. , 1. , 0. ],\n",
638 | " [0. , 1. , 0. ],\n",
639 | " [0. , 0.5, 0.5],\n",
640 | " [0. , 1. , 0. ],\n",
641 | " [0. , 0.9, 0.1],\n",
642 | " [0. , 1. , 0. ],\n",
643 | " [0. , 1. , 0. ],\n",
644 | " [0. , 1. , 0. ],\n",
645 | " [0. , 1. , 0. ],\n",
646 | " [0. , 0.8, 0.2],\n",
647 | " [0. , 1. , 0. ],\n",
648 | " [0. , 1. , 0. ],\n",
649 | " [0. , 1. , 0. ],\n",
650 | " [0. , 1. , 0. ],\n",
651 | " [0. , 1. , 0. ],\n",
652 | " [0. , 0.7, 0.3],\n",
653 | " [0. , 1. , 0. ],\n",
654 | " [0. , 1. , 0. ],\n",
655 | " [0. , 1. , 0. ],\n",
656 | " [0. , 1. , 0. ],\n",
657 | " [0. , 1. , 0. ],\n",
658 | " [0. , 1. , 0. ],\n",
659 | " [0. , 1. , 0. ],\n",
660 | " [0. , 1. , 0. ],\n",
661 | " [0. , 1. , 0. ],\n",
662 | " [0. , 1. , 0. ],\n",
663 | " [0. , 1. , 0. ],\n",
664 | " [0. , 1. , 0. ],\n",
665 | " [0. , 1. , 0. ],\n",
666 | " [0. , 1. , 0. ],\n",
667 | " [0. , 1. , 0. ],\n",
668 | " [0. , 1. , 0. ],\n",
669 | " [0. , 0. , 1. ],\n",
670 | " [0. , 0. , 1. ],\n",
671 | " [0. , 0. , 1. ],\n",
672 | " [0. , 0. , 1. ],\n",
673 | " [0. , 0. , 1. ],\n",
674 | " [0. , 0. , 1. ],\n",
675 | " [0. , 1. , 0. ],\n",
676 | " [0. , 0. , 1. ],\n",
677 | " [0. , 0. , 1. ],\n",
678 | " [0. , 0. , 1. ],\n",
679 | " [0. , 0. , 1. ],\n",
680 | " [0. , 0. , 1. ],\n",
681 | " [0. , 0. , 1. ],\n",
682 | " [0. , 0.1, 0.9],\n",
683 | " [0. , 0. , 1. ],\n",
684 | " [0. , 0. , 1. ],\n",
685 | " [0. , 0. , 1. ],\n",
686 | " [0. , 0. , 1. ],\n",
687 | " [0. , 0. , 1. ],\n",
688 | " [0. , 0.4, 0.6],\n",
689 | " [0. , 0. , 1. ],\n",
690 | " [0. , 0.2, 0.8],\n",
691 | " [0. , 0. , 1. ],\n",
692 | " [0. , 0.1, 0.9],\n",
693 | " [0. , 0. , 1. ],\n",
694 | " [0. , 0. , 1. ],\n",
695 | " [0. , 0.3, 0.7],\n",
696 | " [0. , 0.2, 0.8],\n",
697 | " [0. , 0. , 1. ],\n",
698 | " [0. , 0.2, 0.8],\n",
699 | " [0. , 0. , 1. ],\n",
700 | " [0. , 0. , 1. ],\n",
701 | " [0. , 0. , 1. ],\n",
702 | " [0. , 0.2, 0.8],\n",
703 | " [0. , 0.3, 0.7],\n",
704 | " [0. , 0. , 1. ],\n",
705 | " [0. , 0. , 1. ],\n",
706 | " [0. , 0. , 1. ],\n",
707 | " [0. , 0.3, 0.7],\n",
708 | " [0. , 0. , 1. ],\n",
709 | " [0. , 0. , 1. ],\n",
710 | " [0. , 0. , 1. ],\n",
711 | " [0. , 0. , 1. ],\n",
712 | " [0. , 0. , 1. ],\n",
713 | " [0. , 0. , 1. ],\n",
714 | " [0. , 0. , 1. ],\n",
715 | " [0. , 0. , 1. ],\n",
716 | " [0. , 0. , 1. ],\n",
717 | " [0. , 0.1, 0.9],\n",
718 | " [0. , 0. , 1. ]])"
719 | ]
720 | },
721 | "execution_count": 63,
722 | "metadata": {},
723 | "output_type": "execute_result"
724 | }
725 | ],
726 | "source": [
727 | "loaded_classifier.predict_proba(X)"
728 | ]
729 | },
730 | {
731 | "cell_type": "code",
732 | "execution_count": 64,
733 | "metadata": {},
734 | "outputs": [
735 | {
736 | "data": {
737 | "text/html": [
738 | "\n",
739 | "\n",
752 | "
\n",
753 | " \n",
754 | " \n",
755 | " | \n",
756 | " 0 | \n",
757 | " 1 | \n",
758 | " 2 | \n",
759 | " 3 | \n",
760 | "
\n",
761 | " \n",
762 | " \n",
763 | " \n",
764 | " | 0 | \n",
765 | " 7.3 | \n",
766 | " 2.9 | \n",
767 | " 6.3 | \n",
768 | " 1.8 | \n",
769 | "
\n",
770 | " \n",
771 | "
\n",
772 | "
"
773 | ],
774 | "text/plain": [
775 | " 0 1 2 3\n",
776 | "0 7.3 2.9 6.3 1.8"
777 | ]
778 | },
779 | "execution_count": 64,
780 | "metadata": {},
781 | "output_type": "execute_result"
782 | }
783 | ],
784 | "source": [
785 | "pd.DataFrame(X_test[0].reshape(1, -1))"
786 | ]
787 | },
788 | {
789 | "cell_type": "code",
790 | "execution_count": null,
791 | "metadata": {},
792 | "outputs": [],
793 | "source": []
794 | }
795 | ],
796 | "metadata": {
797 | "kernelspec": {
798 | "display_name": "Python 3",
799 | "language": "python",
800 | "name": "python3"
801 | },
802 | "language_info": {
803 | "codemirror_mode": {
804 | "name": "ipython",
805 | "version": 3
806 | },
807 | "file_extension": ".py",
808 | "mimetype": "text/x-python",
809 | "name": "python",
810 | "nbconvert_exporter": "python",
811 | "pygments_lexer": "ipython3",
812 | "version": "3.7.4"
813 | }
814 | },
815 | "nbformat": 4,
816 | "nbformat_minor": 2
817 | }
818 |
--------------------------------------------------------------------------------