├── APIProjectFolder ├── APIProject │ ├── __init__.py │ ├── __pycache__ │ │ ├── urls.cpython-37.pyc │ │ ├── wsgi.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── settings.cpython-37.pyc │ ├── wsgi.py │ ├── urls.py │ └── settings.py ├── Prediction │ ├── __init__.py │ ├── migrations │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ └── __init__.cpython-37.pyc │ ├── models.py │ ├── tests.py │ ├── admin.py │ ├── __pycache__ │ │ ├── admin.cpython-37.pyc │ │ ├── apps.cpython-37.pyc │ │ ├── urls.cpython-37.pyc │ │ ├── views.cpython-37.pyc │ │ ├── models.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── throttles.cpython-37.pyc │ ├── classifier │ │ └── IRISRandomForestClassifier.joblib │ ├── throttles.py │ ├── urls.py │ ├── apps.py │ └── views.py ├── .vscode │ ├── settings.json │ └── launch.json ├── db.sqlite3 └── manage.py ├── .gitattributes ├── .vs ├── PythonSettings.json ├── slnx.sqlite ├── DjangoRestAPI │ └── v16 │ │ └── .suo ├── DjangoRestAPIDemo │ └── v16 │ │ └── .suo └── VSWorkspaceState.json ├── Iris.pptx ├── .ipynb_checkpoints └── CreateModelWithJupyter-checkpoint.ipynb ├── Steps.txt ├── README.md └── CreateModelWithJupyter.ipynb /APIProjectFolder/APIProject/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * linguist-vendored 2 | *.py linguist-vendored=false -------------------------------------------------------------------------------- /.vs/PythonSettings.json: -------------------------------------------------------------------------------- 1 | { 2 | "Interpreter": "Global|VisualStudio|Env64 (Python 3.7 (64-bit))" 3 | } -------------------------------------------------------------------------------- /Iris.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/Iris.pptx -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | # Create your models here. 4 | -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | # Create your tests here. 4 | -------------------------------------------------------------------------------- /.vs/slnx.sqlite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/.vs/slnx.sqlite -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | # Register your models here. 4 | -------------------------------------------------------------------------------- /APIProjectFolder/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "Z:\\MachineLearning\\Env64\\Scripts\\python.exe" 3 | } -------------------------------------------------------------------------------- /.vs/DjangoRestAPI/v16/.suo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/.vs/DjangoRestAPI/v16/.suo -------------------------------------------------------------------------------- /APIProjectFolder/db.sqlite3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/db.sqlite3 -------------------------------------------------------------------------------- /.vs/DjangoRestAPIDemo/v16/.suo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/.vs/DjangoRestAPIDemo/v16/.suo -------------------------------------------------------------------------------- /.ipynb_checkpoints/CreateModelWithJupyter-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /.vs/VSWorkspaceState.json: -------------------------------------------------------------------------------- 1 | { 2 | "ExpandedNodes": [ 3 | "", 4 | "\\APIProject" 5 | ], 6 | "SelectedNode": "\\Steps.txt", 7 | "PreviewInSolutionExplorer": false 8 | } -------------------------------------------------------------------------------- /APIProjectFolder/APIProject/__pycache__/urls.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/APIProject/__pycache__/urls.cpython-37.pyc -------------------------------------------------------------------------------- /APIProjectFolder/APIProject/__pycache__/wsgi.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/APIProject/__pycache__/wsgi.cpython-37.pyc -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/__pycache__/admin.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/admin.cpython-37.pyc -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/__pycache__/apps.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/apps.cpython-37.pyc -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/__pycache__/urls.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/urls.cpython-37.pyc -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/__pycache__/views.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/views.cpython-37.pyc -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /APIProjectFolder/APIProject/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/APIProject/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /APIProjectFolder/APIProject/__pycache__/settings.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/APIProject/__pycache__/settings.cpython-37.pyc -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/__pycache__/throttles.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/__pycache__/throttles.cpython-37.pyc -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/classifier/IRISRandomForestClassifier.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/classifier/IRISRandomForestClassifier.joblib -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/migrations/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MausamGaurav/PredictionAPIWithDjangoRESTDemoApp/HEAD/APIProjectFolder/Prediction/migrations/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/throttles.py: -------------------------------------------------------------------------------- 1 | from rest_framework.throttling import UserRateThrottle 2 | 3 | # Custom Throttle classes 4 | class LimitedRateThrottle(UserRateThrottle): 5 | scope = 'limited' 6 | 7 | class BurstRateThrottle(UserRateThrottle): 8 | scope = 'burst' -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import path 2 | import Prediction.views as views 3 | 4 | urlpatterns = [ 5 | path('add/', views.api_add, name = 'api_add'), 6 | path('add_values/', views.Add_Values.as_view(), name = 'api_add_values'), 7 | path('predict/', views.IRIS_Model_Predict.as_view(), name = 'predict'), 8 | ] -------------------------------------------------------------------------------- /Steps.txt: -------------------------------------------------------------------------------- 1 | 1) Install Django version 2.2 with command: 2 | python -m pip install Django==2.2.0 3 | 2) Check Django is installed by command: 4 | python -m django --version 5 | This should produce the output below: 6 | 2.2 7 | 3) From the command line, cd into a directory where you’d like to store your code, then run the following command: 8 | django-admin startproject ModelAPI 9 | -------------------------------------------------------------------------------- /APIProjectFolder/APIProject/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for APIProject project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/2.2/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'APIProject.settings') 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PredictionAPIWithDjangoRESTDemoApp 2 | 3 | Full instructions available on https://www.datagraphi.com/blog/post/2019/12/19/rest-api-guide-productionizing-a-machine-learning-model-by-creating-a-rest-api-with-python-django-and-django-rest-framework 4 | 5 | The article explains everything from setting up Django, creating REST APIs, Testing REST APIs, Serializing a Machine Learning Model, Integrating a Machine Learning Model into an API, Authentication and Throttling with a hands-on approach in a clear and concise manner. 6 | -------------------------------------------------------------------------------- /APIProjectFolder/.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Django", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${workspaceFolder}\\manage.py", 12 | "args": [ 13 | "runserver", 14 | "--noreload" 15 | ], 16 | "django": true 17 | } 18 | ] 19 | } -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | import pandas as pd 3 | from joblib import load 4 | import os 5 | 6 | 7 | class PredictionConfig(AppConfig): 8 | name = 'Prediction' 9 | #CLASSIFIER_FOLDER = Path("classifier") 10 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | CLASSIFIER_FOLDER = os.path.join(BASE_DIR, 'Prediction/classifier/') 12 | #CLASSIFIER_FILE = CLASSIFIER_FOLDER / "IRISRandomForestClassifier.joblib" 13 | CLASSIFIER_FILE = os.path.join(CLASSIFIER_FOLDER, "IRISRandomForestClassifier.joblib") 14 | classifier = load(CLASSIFIER_FILE) -------------------------------------------------------------------------------- /APIProjectFolder/manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Django's command-line utility for administrative tasks.""" 3 | import os 4 | import sys 5 | 6 | 7 | def main(): 8 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'APIProject.settings') 9 | try: 10 | from django.core.management import execute_from_command_line 11 | except ImportError as exc: 12 | raise ImportError( 13 | "Couldn't import Django. Are you sure it's installed and " 14 | "available on your PYTHONPATH environment variable? Did you " 15 | "forget to activate a virtual environment?" 16 | ) from exc 17 | execute_from_command_line(sys.argv) 18 | 19 | 20 | if __name__ == '__main__': 21 | main() 22 | -------------------------------------------------------------------------------- /APIProjectFolder/APIProject/urls.py: -------------------------------------------------------------------------------- 1 | """APIProject URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/2.2/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: path('', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.urls import include, path 14 | 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 15 | """ 16 | from django.contrib import admin 17 | from django.urls import path, include 18 | 19 | urlpatterns = [ 20 | path('admin/', admin.site.urls), 21 | path('api-auth/', include('rest_framework.urls', namespace='rest_framework')), 22 | path('api/', include('Prediction.urls')), 23 | ] 24 | 25 | 26 | -------------------------------------------------------------------------------- /APIProjectFolder/Prediction/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import status 2 | from rest_framework.decorators import api_view 3 | from rest_framework.response import Response 4 | from rest_framework.views import APIView 5 | from .apps import PredictionConfig 6 | import pandas as pd 7 | 8 | from rest_framework.permissions import IsAuthenticated 9 | from .throttles import LimitedRateThrottle, BurstRateThrottle 10 | 11 | # Create your views here. 12 | @api_view(['GET', 'POST']) 13 | def api_add(request): 14 | sum = 0 15 | response_dict = {} 16 | if request.method == 'GET': 17 | # Do nothing 18 | pass 19 | elif request.method == 'POST': 20 | # Add the numbers 21 | data = request.data 22 | for key in data: 23 | sum += data[key] 24 | response_dict = {"sum": sum} 25 | return Response(response_dict, status=status.HTTP_201_CREATED) 26 | 27 | # Class based view to add numbers 28 | class Add_Values(APIView): 29 | permission_classes = [IsAuthenticated] 30 | def post(self, request, format=None): 31 | sum = 0 32 | # Add the numbers 33 | data = request.data 34 | for key in data: 35 | sum += data[key] 36 | response_dict = {"sum": sum} 37 | return Response(response_dict, status=status.HTTP_201_CREATED) 38 | 39 | # Class based view to predict based on IRIS model 40 | class IRIS_Model_Predict(APIView): 41 | permission_classes = [IsAuthenticated] 42 | throttle_classes = [LimitedRateThrottle] 43 | def post(self, request, format=None): 44 | data = request.data 45 | keys = [] 46 | values = [] 47 | for key in data: 48 | keys.append(key) 49 | values.append(data[key]) 50 | X = pd.Series(values).to_numpy().reshape(1, -1) 51 | loaded_classifier = PredictionConfig.classifier 52 | y_pred = loaded_classifier.predict(X) 53 | y_pred = pd.Series(y_pred) 54 | target_map = {0: 'setosa', 1: 'versicolor', 2: 'virginica'} 55 | y_pred = y_pred.map(target_map).to_numpy() 56 | response_dict = {"Prediced Iris Species": y_pred[0]} 57 | return Response(response_dict, status=200) -------------------------------------------------------------------------------- /APIProjectFolder/APIProject/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for APIProject project. 3 | 4 | Generated by 'django-admin startproject' using Django 2.2. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/2.2/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/2.2/ref/settings/ 11 | """ 12 | 13 | import os 14 | 15 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 16 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 17 | 18 | 19 | # Quick-start development settings - unsuitable for production 20 | # See https://docs.djangoproject.com/en/2.2/howto/deployment/checklist/ 21 | 22 | # SECURITY WARNING: keep the secret key used in production secret! 23 | SECRET_KEY = '67sc$))b_ykz&ze3py@=)9j-repx%1b%lkl672k2oiw&qjj&hy' 24 | 25 | # SECURITY WARNING: don't run with debug turned on in production! 26 | DEBUG = True 27 | 28 | ALLOWED_HOSTS = [] 29 | 30 | 31 | # Application definition 32 | 33 | INSTALLED_APPS = [ 34 | 'django.contrib.admin', 35 | 'django.contrib.auth', 36 | 'django.contrib.contenttypes', 37 | 'django.contrib.sessions', 38 | 'django.contrib.messages', 39 | 'django.contrib.staticfiles', 40 | 'Prediction', 41 | 'rest_framework', 42 | ] 43 | 44 | MIDDLEWARE = [ 45 | 'django.middleware.security.SecurityMiddleware', 46 | 'django.contrib.sessions.middleware.SessionMiddleware', 47 | 'django.middleware.common.CommonMiddleware', 48 | 'django.middleware.csrf.CsrfViewMiddleware', 49 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 50 | 'django.contrib.messages.middleware.MessageMiddleware', 51 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 52 | ] 53 | 54 | ROOT_URLCONF = 'APIProject.urls' 55 | 56 | TEMPLATES = [ 57 | { 58 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 59 | 'DIRS': [], 60 | 'APP_DIRS': True, 61 | 'OPTIONS': { 62 | 'context_processors': [ 63 | 'django.template.context_processors.debug', 64 | 'django.template.context_processors.request', 65 | 'django.contrib.auth.context_processors.auth', 66 | 'django.contrib.messages.context_processors.messages', 67 | ], 68 | }, 69 | }, 70 | ] 71 | 72 | WSGI_APPLICATION = 'APIProject.wsgi.application' 73 | 74 | 75 | # Database 76 | # https://docs.djangoproject.com/en/2.2/ref/settings/#databases 77 | 78 | DATABASES = { 79 | 'default': { 80 | 'ENGINE': 'django.db.backends.sqlite3', 81 | 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), 82 | } 83 | } 84 | 85 | 86 | # Password validation 87 | # https://docs.djangoproject.com/en/2.2/ref/settings/#auth-password-validators 88 | 89 | AUTH_PASSWORD_VALIDATORS = [ 90 | { 91 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 92 | }, 93 | { 94 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 95 | }, 96 | { 97 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 98 | }, 99 | { 100 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 101 | }, 102 | ] 103 | 104 | 105 | # Internationalization 106 | # https://docs.djangoproject.com/en/2.2/topics/i18n/ 107 | 108 | LANGUAGE_CODE = 'en-us' 109 | 110 | TIME_ZONE = 'UTC' 111 | 112 | USE_I18N = True 113 | 114 | USE_L10N = True 115 | 116 | USE_TZ = True 117 | 118 | 119 | # Static files (CSS, JavaScript, Images) 120 | # https://docs.djangoproject.com/en/2.2/howto/static-files/ 121 | 122 | STATIC_URL = '/static/' 123 | 124 | REST_FRAMEWORK = { 125 | 'DEFAULT_THROTTLE_CLASSES': [ 126 | 'Prediction.throttles.LimitedRateThrottle', 127 | 'Prediction.throttles.BurstRateThrottle' 128 | ], 129 | 'DEFAULT_THROTTLE_RATES': { 130 | 'limited': '2/min', 131 | 'burst': '10/min' 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /CreateModelWithJupyter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 13, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "### Load the iris dataset from sklearn " 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from sklearn.datasets import load_iris" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 65, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "data": { 35 | "text/plain": [ 36 | "sklearn.utils.Bunch" 37 | ] 38 | }, 39 | "execution_count": 65, 40 | "metadata": {}, 41 | "output_type": "execute_result" 42 | } 43 | ], 44 | "source": [ 45 | "iris = load_iris()\n", 46 | "type(iris)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 6, 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | "['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "print(iris.feature_names)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 11, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "['setosa' 'versicolor' 'virginica']\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "print(iris.target_names)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 8, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "# Define feature matrix in \"X\"\n", 90 | "X = iris.data\n", 91 | "\n", 92 | "# Define target response vector in \"y\"\n", 93 | "y = iris.target" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "### Get the basic statistics for the features" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 15, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "text/html": [ 111 | "
\n", 112 | "\n", 125 | "\n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)
05.13.51.40.2
14.93.01.40.2
24.73.21.30.2
\n", 159 | "
" 160 | ], 161 | "text/plain": [ 162 | " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n", 163 | "0 5.1 3.5 1.4 0.2\n", 164 | "1 4.9 3.0 1.4 0.2\n", 165 | "2 4.7 3.2 1.3 0.2" 166 | ] 167 | }, 168 | "execution_count": 15, 169 | "metadata": {}, 170 | "output_type": "execute_result" 171 | } 172 | ], 173 | "source": [ 174 | "X_df = pd.DataFrame(data=X, columns = iris.feature_names)\n", 175 | "X_df.head(3)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 16, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "text/html": [ 186 | "
\n", 187 | "\n", 200 | "\n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)
count150.000000150.000000150.000000150.000000
mean5.8433333.0573333.7580001.199333
std0.8280660.4358661.7652980.762238
min4.3000002.0000001.0000000.100000
25%5.1000002.8000001.6000000.300000
50%5.8000003.0000004.3500001.300000
75%6.4000003.3000005.1000001.800000
max7.9000004.4000006.9000002.500000
\n", 269 | "
" 270 | ], 271 | "text/plain": [ 272 | " sepal length (cm) sepal width (cm) petal length (cm) \\\n", 273 | "count 150.000000 150.000000 150.000000 \n", 274 | "mean 5.843333 3.057333 3.758000 \n", 275 | "std 0.828066 0.435866 1.765298 \n", 276 | "min 4.300000 2.000000 1.000000 \n", 277 | "25% 5.100000 2.800000 1.600000 \n", 278 | "50% 5.800000 3.000000 4.350000 \n", 279 | "75% 6.400000 3.300000 5.100000 \n", 280 | "max 7.900000 4.400000 6.900000 \n", 281 | "\n", 282 | " petal width (cm) \n", 283 | "count 150.000000 \n", 284 | "mean 1.199333 \n", 285 | "std 0.762238 \n", 286 | "min 0.100000 \n", 287 | "25% 0.300000 \n", 288 | "50% 1.300000 \n", 289 | "75% 1.800000 \n", 290 | "max 2.500000 " 291 | ] 292 | }, 293 | "execution_count": 16, 294 | "metadata": {}, 295 | "output_type": "execute_result" 296 | } 297 | ], 298 | "source": [ 299 | "X_df.describe()" 300 | ] 301 | }, 302 | { 303 | "cell_type": "markdown", 304 | "metadata": {}, 305 | "source": [ 306 | "### Split the dataset into training and testing dataset" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 17, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "from sklearn.model_selection import train_test_split" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 19, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=1, stratify=y)" 325 | ] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "metadata": {}, 330 | "source": [ 331 | "### Use the Random Forest Classifier" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 20, 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [ 340 | "from sklearn.ensemble import RandomForestClassifier" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 21, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "clf_rf = RandomForestClassifier(random_state = 1, n_estimators = 10, n_jobs = -1)\n", 350 | "estimator_rf = clf_rf" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 22, 356 | "metadata": {}, 357 | "outputs": [ 358 | { 359 | "data": { 360 | "text/plain": [ 361 | "RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", 362 | " max_depth=None, max_features='auto', max_leaf_nodes=None,\n", 363 | " min_impurity_decrease=0.0, min_impurity_split=None,\n", 364 | " min_samples_leaf=1, min_samples_split=2,\n", 365 | " min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=-1,\n", 366 | " oob_score=False, random_state=1, verbose=0,\n", 367 | " warm_start=False)" 368 | ] 369 | }, 370 | "execution_count": 22, 371 | "metadata": {}, 372 | "output_type": "execute_result" 373 | } 374 | ], 375 | "source": [ 376 | "estimator_rf.fit(X=X_train, y=y_train)" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": {}, 382 | "source": [ 383 | "### Testing Accuracy" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 24, 389 | "metadata": {}, 390 | "outputs": [ 391 | { 392 | "data": { 393 | "text/plain": [ 394 | "0.9666666666666667" 395 | ] 396 | }, 397 | "execution_count": 24, 398 | "metadata": {}, 399 | "output_type": "execute_result" 400 | } 401 | ], 402 | "source": [ 403 | "estimator_rf.score(X_test,y_test)" 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "metadata": {}, 409 | "source": [ 410 | "### Cross Validated Accuracy" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": 28, 416 | "metadata": {}, 417 | "outputs": [], 418 | "source": [ 419 | "from sklearn.model_selection import cross_val_score" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 29, 425 | "metadata": {}, 426 | "outputs": [ 427 | { 428 | "data": { 429 | "text/plain": [ 430 | "0.9666666666666668" 431 | ] 432 | }, 433 | "execution_count": 29, 434 | "metadata": {}, 435 | "output_type": "execute_result" 436 | } 437 | ], 438 | "source": [ 439 | "estimator_cv = clf_rf\n", 440 | "scores = cross_val_score(estimator_cv, X, y, cv = 5, scoring = 'accuracy')\n", 441 | "scores.mean()" 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": {}, 447 | "source": [ 448 | "## Serialize the model" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 30, 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [ 457 | "from joblib import dump, load" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 31, 463 | "metadata": {}, 464 | "outputs": [ 465 | { 466 | "data": { 467 | "text/plain": [ 468 | "['IRISRandomForestClassifier.joblib']" 469 | ] 470 | }, 471 | "execution_count": 31, 472 | "metadata": {}, 473 | "output_type": "execute_result" 474 | } 475 | ], 476 | "source": [ 477 | "dump(estimator_rf, 'IRISRandomForestClassifier.joblib') " 478 | ] 479 | }, 480 | { 481 | "cell_type": "markdown", 482 | "metadata": {}, 483 | "source": [ 484 | "### Load back the saved model and check accuracy" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 32, 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [ 493 | "loaded_classifier = load('IRISRandomForestClassifier.joblib') " 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 33, 499 | "metadata": {}, 500 | "outputs": [ 501 | { 502 | "data": { 503 | "text/plain": [ 504 | "0.9666666666666667" 505 | ] 506 | }, 507 | "execution_count": 33, 508 | "metadata": {}, 509 | "output_type": "execute_result" 510 | } 511 | ], 512 | "source": [ 513 | "loaded_classifier.score(X_test,y_test)" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": 59, 519 | "metadata": {}, 520 | "outputs": [], 521 | "source": [ 522 | "target_map = {}\n", 523 | "i = 0\n", 524 | "for target in iris.target_names:\n", 525 | " target_map[i]= target\n", 526 | " i+=1\n", 527 | "\n", 528 | "def iris_predictor(X):\n", 529 | " y_pred = loaded_classifier.predict(X)\n", 530 | " y_pred = pd.Series(y_pred)\n", 531 | " y_pred = y_pred.map(target_map).to_numpy()\n", 532 | " return(y_pred)" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": 60, 538 | "metadata": {}, 539 | "outputs": [], 540 | "source": [ 541 | "y_pred = iris_predictor(pd.DataFrame(X_test[0].reshape(1, -1)))" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": 62, 547 | "metadata": {}, 548 | "outputs": [ 549 | { 550 | "name": "stdout", 551 | "output_type": "stream", 552 | "text": [ 553 | "virginica\n" 554 | ] 555 | } 556 | ], 557 | "source": [ 558 | "print(y_pred[0])" 559 | ] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "execution_count": 63, 564 | "metadata": {}, 565 | "outputs": [ 566 | { 567 | "data": { 568 | "text/plain": [ 569 | "array([[1. , 0. , 0. ],\n", 570 | " [1. , 0. , 0. ],\n", 571 | " [1. , 0. , 0. ],\n", 572 | " [1. , 0. , 0. ],\n", 573 | " [1. , 0. , 0. ],\n", 574 | " [1. , 0. , 0. ],\n", 575 | " [1. , 0. , 0. ],\n", 576 | " [1. , 0. , 0. ],\n", 577 | " [1. , 0. , 0. ],\n", 578 | " [1. , 0. , 0. ],\n", 579 | " [1. , 0. , 0. ],\n", 580 | " [1. , 0. , 0. ],\n", 581 | " [1. , 0. , 0. ],\n", 582 | " [1. , 0. , 0. ],\n", 583 | " [0.9, 0.1, 0. ],\n", 584 | " [1. , 0. , 0. ],\n", 585 | " [1. , 0. , 0. ],\n", 586 | " [1. , 0. , 0. ],\n", 587 | " [1. , 0. , 0. ],\n", 588 | " [1. , 0. , 0. ],\n", 589 | " [1. , 0. , 0. ],\n", 590 | " [1. , 0. , 0. ],\n", 591 | " [1. , 0. , 0. ],\n", 592 | " [1. , 0. , 0. ],\n", 593 | " [1. , 0. , 0. ],\n", 594 | " [1. , 0. , 0. ],\n", 595 | " [1. , 0. , 0. ],\n", 596 | " [1. , 0. , 0. ],\n", 597 | " [1. , 0. , 0. ],\n", 598 | " [1. , 0. , 0. ],\n", 599 | " [1. , 0. , 0. ],\n", 600 | " [1. , 0. , 0. ],\n", 601 | " [1. , 0. , 0. ],\n", 602 | " [1. , 0. , 0. ],\n", 603 | " [1. , 0. , 0. ],\n", 604 | " [1. , 0. , 0. ],\n", 605 | " [1. , 0. , 0. ],\n", 606 | " [1. , 0. , 0. ],\n", 607 | " [1. , 0. , 0. ],\n", 608 | " [1. , 0. , 0. ],\n", 609 | " [1. , 0. , 0. ],\n", 610 | " [1. , 0. , 0. ],\n", 611 | " [1. , 0. , 0. ],\n", 612 | " [1. , 0. , 0. ],\n", 613 | " [1. , 0. , 0. ],\n", 614 | " [1. , 0. , 0. ],\n", 615 | " [1. , 0. , 0. ],\n", 616 | " [1. , 0. , 0. ],\n", 617 | " [1. , 0. , 0. ],\n", 618 | " [1. , 0. , 0. ],\n", 619 | " [0. , 1. , 0. ],\n", 620 | " [0. , 1. , 0. ],\n", 621 | " [0. , 1. , 0. ],\n", 622 | " [0. , 1. , 0. ],\n", 623 | " [0. , 1. , 0. ],\n", 624 | " [0. , 1. , 0. ],\n", 625 | " [0. , 1. , 0. ],\n", 626 | " [0. , 1. , 0. ],\n", 627 | " [0. , 1. , 0. ],\n", 628 | " [0.1, 0.9, 0. ],\n", 629 | " [0. , 1. , 0. ],\n", 630 | " [0. , 1. , 0. ],\n", 631 | " [0. , 1. , 0. ],\n", 632 | " [0. , 1. , 0. ],\n", 633 | " [0. , 1. , 0. ],\n", 634 | " [0. , 1. , 0. ],\n", 635 | " [0. , 1. , 0. ],\n", 636 | " [0. , 1. , 0. ],\n", 637 | " [0. , 1. , 0. ],\n", 638 | " [0. , 1. , 0. ],\n", 639 | " [0. , 0.5, 0.5],\n", 640 | " [0. , 1. , 0. ],\n", 641 | " [0. , 0.9, 0.1],\n", 642 | " [0. , 1. , 0. ],\n", 643 | " [0. , 1. , 0. ],\n", 644 | " [0. , 1. , 0. ],\n", 645 | " [0. , 1. , 0. ],\n", 646 | " [0. , 0.8, 0.2],\n", 647 | " [0. , 1. , 0. ],\n", 648 | " [0. , 1. , 0. ],\n", 649 | " [0. , 1. , 0. ],\n", 650 | " [0. , 1. , 0. ],\n", 651 | " [0. , 1. , 0. ],\n", 652 | " [0. , 0.7, 0.3],\n", 653 | " [0. , 1. , 0. ],\n", 654 | " [0. , 1. , 0. ],\n", 655 | " [0. , 1. , 0. ],\n", 656 | " [0. , 1. , 0. ],\n", 657 | " [0. , 1. , 0. ],\n", 658 | " [0. , 1. , 0. ],\n", 659 | " [0. , 1. , 0. ],\n", 660 | " [0. , 1. , 0. ],\n", 661 | " [0. , 1. , 0. ],\n", 662 | " [0. , 1. , 0. ],\n", 663 | " [0. , 1. , 0. ],\n", 664 | " [0. , 1. , 0. ],\n", 665 | " [0. , 1. , 0. ],\n", 666 | " [0. , 1. , 0. ],\n", 667 | " [0. , 1. , 0. ],\n", 668 | " [0. , 1. , 0. ],\n", 669 | " [0. , 0. , 1. ],\n", 670 | " [0. , 0. , 1. ],\n", 671 | " [0. , 0. , 1. ],\n", 672 | " [0. , 0. , 1. ],\n", 673 | " [0. , 0. , 1. ],\n", 674 | " [0. , 0. , 1. ],\n", 675 | " [0. , 1. , 0. ],\n", 676 | " [0. , 0. , 1. ],\n", 677 | " [0. , 0. , 1. ],\n", 678 | " [0. , 0. , 1. ],\n", 679 | " [0. , 0. , 1. ],\n", 680 | " [0. , 0. , 1. ],\n", 681 | " [0. , 0. , 1. ],\n", 682 | " [0. , 0.1, 0.9],\n", 683 | " [0. , 0. , 1. ],\n", 684 | " [0. , 0. , 1. ],\n", 685 | " [0. , 0. , 1. ],\n", 686 | " [0. , 0. , 1. ],\n", 687 | " [0. , 0. , 1. ],\n", 688 | " [0. , 0.4, 0.6],\n", 689 | " [0. , 0. , 1. ],\n", 690 | " [0. , 0.2, 0.8],\n", 691 | " [0. , 0. , 1. ],\n", 692 | " [0. , 0.1, 0.9],\n", 693 | " [0. , 0. , 1. ],\n", 694 | " [0. , 0. , 1. ],\n", 695 | " [0. , 0.3, 0.7],\n", 696 | " [0. , 0.2, 0.8],\n", 697 | " [0. , 0. , 1. ],\n", 698 | " [0. , 0.2, 0.8],\n", 699 | " [0. , 0. , 1. ],\n", 700 | " [0. , 0. , 1. ],\n", 701 | " [0. , 0. , 1. ],\n", 702 | " [0. , 0.2, 0.8],\n", 703 | " [0. , 0.3, 0.7],\n", 704 | " [0. , 0. , 1. ],\n", 705 | " [0. , 0. , 1. ],\n", 706 | " [0. , 0. , 1. ],\n", 707 | " [0. , 0.3, 0.7],\n", 708 | " [0. , 0. , 1. ],\n", 709 | " [0. , 0. , 1. ],\n", 710 | " [0. , 0. , 1. ],\n", 711 | " [0. , 0. , 1. ],\n", 712 | " [0. , 0. , 1. ],\n", 713 | " [0. , 0. , 1. ],\n", 714 | " [0. , 0. , 1. ],\n", 715 | " [0. , 0. , 1. ],\n", 716 | " [0. , 0. , 1. ],\n", 717 | " [0. , 0.1, 0.9],\n", 718 | " [0. , 0. , 1. ]])" 719 | ] 720 | }, 721 | "execution_count": 63, 722 | "metadata": {}, 723 | "output_type": "execute_result" 724 | } 725 | ], 726 | "source": [ 727 | "loaded_classifier.predict_proba(X)" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": 64, 733 | "metadata": {}, 734 | "outputs": [ 735 | { 736 | "data": { 737 | "text/html": [ 738 | "
\n", 739 | "\n", 752 | "\n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | "
0123
07.32.96.31.8
\n", 772 | "
" 773 | ], 774 | "text/plain": [ 775 | " 0 1 2 3\n", 776 | "0 7.3 2.9 6.3 1.8" 777 | ] 778 | }, 779 | "execution_count": 64, 780 | "metadata": {}, 781 | "output_type": "execute_result" 782 | } 783 | ], 784 | "source": [ 785 | "pd.DataFrame(X_test[0].reshape(1, -1))" 786 | ] 787 | }, 788 | { 789 | "cell_type": "code", 790 | "execution_count": null, 791 | "metadata": {}, 792 | "outputs": [], 793 | "source": [] 794 | } 795 | ], 796 | "metadata": { 797 | "kernelspec": { 798 | "display_name": "Python 3", 799 | "language": "python", 800 | "name": "python3" 801 | }, 802 | "language_info": { 803 | "codemirror_mode": { 804 | "name": "ipython", 805 | "version": 3 806 | }, 807 | "file_extension": ".py", 808 | "mimetype": "text/x-python", 809 | "name": "python", 810 | "nbconvert_exporter": "python", 811 | "pygments_lexer": "ipython3", 812 | "version": "3.7.4" 813 | } 814 | }, 815 | "nbformat": 4, 816 | "nbformat_minor": 2 817 | } 818 | --------------------------------------------------------------------------------