├── iris ├── __init__.py ├── migrations │ ├── __init__.py │ └── 0001_initial.py ├── tests.py ├── apps.py ├── admin.py ├── templates │ ├── home.html │ └── base.html ├── forms.py ├── models.py └── views.py ├── mysite ├── __init__.py ├── asgi.py ├── wsgi.py ├── urls.py └── settings.py ├── ml_model ├── iris_model.pkl └── iris_model.py ├── tutorial_imgs ├── admin-site.png ├── hello-world.png ├── landing-page.png ├── model-prediction.png ├── model-prediction-1.png └── model-prediction-images.png ├── requirements.txt ├── manage.py ├── README.md ├── LICENSE └── .gitignore /iris/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mysite/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /iris/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /iris/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | # Create your tests here. 4 | -------------------------------------------------------------------------------- /iris/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class IrisConfig(AppConfig): 5 | name = 'iris' 6 | -------------------------------------------------------------------------------- /ml_model/iris_model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katiehouse/django-scikit-learn-tutorial/HEAD/ml_model/iris_model.pkl -------------------------------------------------------------------------------- /tutorial_imgs/admin-site.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katiehouse/django-scikit-learn-tutorial/HEAD/tutorial_imgs/admin-site.png -------------------------------------------------------------------------------- /tutorial_imgs/hello-world.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katiehouse/django-scikit-learn-tutorial/HEAD/tutorial_imgs/hello-world.png -------------------------------------------------------------------------------- /tutorial_imgs/landing-page.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katiehouse/django-scikit-learn-tutorial/HEAD/tutorial_imgs/landing-page.png -------------------------------------------------------------------------------- /tutorial_imgs/model-prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katiehouse/django-scikit-learn-tutorial/HEAD/tutorial_imgs/model-prediction.png -------------------------------------------------------------------------------- /tutorial_imgs/model-prediction-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katiehouse/django-scikit-learn-tutorial/HEAD/tutorial_imgs/model-prediction-1.png -------------------------------------------------------------------------------- /tutorial_imgs/model-prediction-images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katiehouse/django-scikit-learn-tutorial/HEAD/tutorial_imgs/model-prediction-images.png -------------------------------------------------------------------------------- /iris/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | from .models import Predictions 3 | 4 | 5 | @admin.register(Predictions) 6 | class PredictionsAdmin(admin.ModelAdmin): 7 | pass 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asgiref==3.6.0 2 | dj-database-url==1.2.0 3 | Django==4.1.7 4 | django-heroku==0.3.1 5 | django-on-heroku==1.1.2 6 | gunicorn==20.1.0 7 | joblib==1.2.0 8 | numpy==1.24.2 9 | psycopg2==2.9.5 10 | psycopg2-binary==2.9.5 11 | scikit-learn==1.2.1 12 | scipy==1.10.1 13 | sqlparse==0.4.3 14 | threadpoolctl==3.1.0 15 | whitenoise==6.4.0 16 | -------------------------------------------------------------------------------- /iris/templates/home.html: -------------------------------------------------------------------------------- 1 | {% extends 'base.html' %} 2 | 3 | {% block content %} 4 |
5 | {% csrf_token %} 6 | {{ form }} 7 | 8 |
9 | 10 | {% if form.is_valid %} 11 |

The model predicted: {{ prediction_name }}

12 | 13 | 14 | {% endif %} 15 | 16 | {% endblock %} -------------------------------------------------------------------------------- /ml_model/iris_model.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_iris 2 | from sklearn.tree import DecisionTreeClassifier 3 | import pickle 4 | 5 | # Load the Iris dataset 6 | iris = load_iris() 7 | X = iris.data 8 | y = iris.target 9 | 10 | # Train a Decision Tree Classifier 11 | clf = DecisionTreeClassifier(random_state=0) 12 | clf.fit(X, y) 13 | 14 | # Save the model as a pkl file 15 | filename = 'ml_model/iris_model.pkl' 16 | pickle.dump(clf, open(filename, 'wb')) 17 | -------------------------------------------------------------------------------- /mysite/asgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ASGI config for mysite project. 3 | 4 | It exposes the ASGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/3.1/howto/deployment/asgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.asgi import get_asgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'mysite.settings') 15 | 16 | application = get_asgi_application() 17 | -------------------------------------------------------------------------------- /mysite/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for mysite project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/3.1/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'mysite.settings') 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /iris/forms.py: -------------------------------------------------------------------------------- 1 | from django import forms 2 | 3 | 4 | class ModelForm(forms.Form): 5 | sepal_length = forms.DecimalField( 6 | label='Sepal Length (cm)', decimal_places=2, max_digits=3) 7 | sepal_width = forms.DecimalField( 8 | label='Sepal Width (cm)', decimal_places=2, max_digits=3) 9 | petal_length = forms.DecimalField( 10 | label='Pedal Length (cm)', decimal_places=2, max_digits=3) 11 | petal_width = forms.DecimalField( 12 | label='Pedal Width (cm)', decimal_places=2, max_digits=3) 13 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Django's command-line utility for administrative tasks.""" 3 | import os 4 | import sys 5 | 6 | 7 | def main(): 8 | """Run administrative tasks.""" 9 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'mysite.settings') 10 | try: 11 | from django.core.management import execute_from_command_line 12 | except ImportError as exc: 13 | raise ImportError( 14 | "Couldn't import Django. Are you sure it's installed and " 15 | "available on your PYTHONPATH environment variable? Did you " 16 | "forget to activate a virtual environment?" 17 | ) from exc 18 | execute_from_command_line(sys.argv) 19 | 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Django with Scikit-Learn Tutorial 3 | This tutorial creates a Django web app that tests a simple classification model with the `iris` dataset. This tutorial is performed on Mac OS, so some commands may be different for a PC. 4 | 5 | Deployed App: https://iris-django.herokuapp.com/ 6 | 7 | [Create your own app with this tutorial](https://github.com/katiehouse/django-scikit-learn-tutorial/wiki) 8 | 9 | #### To run locally 10 | ``` 11 | pip3 install -r requirements.txt 12 | python3 manage.py migrate 13 | python3 manage.py runserver 14 | ``` 15 | 16 | In this tutorial, you will integrate a Decision Tree Classifier on the Iris dataset with a Django web app! 17 | 18 | -------------------------------------------------------------------------------- /iris/templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 9 | 10 | 11 | 12 |
13 |
14 |
15 |

Iris Model Predictions

16 |
17 | {% block content %} 18 | {% endblock %} 19 |
20 |
21 |
22 | 23 | 24 | -------------------------------------------------------------------------------- /mysite/urls.py: -------------------------------------------------------------------------------- 1 | """mysite URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/3.1/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: path('', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.urls import include, path 14 | 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 15 | """ 16 | from django.contrib import admin 17 | from django.urls import path 18 | from iris import views 19 | 20 | urlpatterns = [ 21 | path('admin/', admin.site.urls), 22 | path('', views.predict_model, name='predict_model'), 23 | ] 24 | -------------------------------------------------------------------------------- /iris/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | from django.utils import timezone 3 | 4 | # Create your models here. 5 | 6 | 7 | class Predictions(models.Model): 8 | # The possible predictions the model can make in the 'predictions' field 9 | # defined by: (, ) 10 | PREDICT_OPTIONS = [ 11 | ('setosa', 'Setosa'), 12 | ('versicolor', 'Versicolor'), 13 | ('virginica', 'Virginica') 14 | ] 15 | 16 | # Prediction table fields (or columns) are defined by creating attributes 17 | # and assigning them to field instances such as models.CharField() 18 | predict_datetime = models.DateTimeField(default=timezone.now) 19 | sepal_length = models.DecimalField(decimal_places=2, max_digits=3) 20 | sepal_width = models.DecimalField(decimal_places=2, max_digits=3) 21 | petal_length = models.DecimalField(decimal_places=2, max_digits=3) 22 | petal_width = models.DecimalField(decimal_places=2, max_digits=3) 23 | prediction = models.CharField(choices=PREDICT_OPTIONS, max_length=10) 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Katie House 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /iris/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 3.1.1 on 2020-09-14 17:59 2 | 3 | from django.db import migrations, models 4 | import django.utils.timezone 5 | 6 | 7 | class Migration(migrations.Migration): 8 | 9 | initial = True 10 | 11 | dependencies = [ 12 | ] 13 | 14 | operations = [ 15 | migrations.CreateModel( 16 | name='Predictions', 17 | fields=[ 18 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 19 | ('predict_datetime', models.DateTimeField(default=django.utils.timezone.now)), 20 | ('sepal_length', models.DecimalField(decimal_places=2, max_digits=3)), 21 | ('sepal_width', models.DecimalField(decimal_places=2, max_digits=3)), 22 | ('petal_length', models.DecimalField(decimal_places=2, max_digits=3)), 23 | ('petal_width', models.DecimalField(decimal_places=2, max_digits=3)), 24 | ('prediction', models.CharField(choices=[('setosa', 'Setosa'), ('versicolor', 'Versicolor'), ('virginica', 'Virginica')], max_length=10)), 25 | ], 26 | ), 27 | ] 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | # vscode 133 | .vscode -------------------------------------------------------------------------------- /iris/views.py: -------------------------------------------------------------------------------- 1 | from django.shortcuts import render 2 | from .forms import ModelForm 3 | from .models import Predictions 4 | import pickle 5 | 6 | 7 | def predict_model(request): 8 | # if this is a POST request we need to process the form data 9 | if request.method == 'POST': 10 | # create a form instance and populate it with data from the request: 11 | form = ModelForm(request.POST) 12 | # check whether it's valid: 13 | if form.is_valid(): 14 | # process the data in form.cleaned_data as required 15 | sepal_length = form.cleaned_data['sepal_length'] 16 | sepal_width = form.cleaned_data['sepal_width'] 17 | petal_length = form.cleaned_data['petal_length'] 18 | petal_width = form.cleaned_data['petal_width'] 19 | 20 | # Run new features through ML model 21 | model_features = [ 22 | [sepal_length, sepal_width, petal_length, petal_width]] 23 | loaded_model = pickle.load( 24 | open("ml_model/iris_model.pkl", 'rb')) 25 | prediction = loaded_model.predict(model_features)[0] 26 | 27 | prediction_dict = [{'name': 'setosa', 28 | 'img': 'https://alchetron.com/cdn/iris-setosa-0ab3145a-68f2-41ca-a529-c02fa2f5b02-resize-750.jpeg'}, 29 | {'name': 'versicolor', 30 | 'img': 'https://www.plantmorenatives.com/uploads/4/2/5/3/42539229/s308030450575347893_p39_i3_w1984.jpeg'}, 31 | {'name': 'virginica', 32 | 'img': 'https://s3.amazonaws.com/eit-planttoolbox-prod/media/images/Iris-virginica--Justin-Meissen--CC-BY-SA.jpg'}] 33 | 34 | prediction_name = prediction_dict[prediction]['name'] 35 | prediction_img = prediction_dict[prediction]['img'] 36 | 37 | # Save prediction to database Predictions table 38 | Predictions.objects.create(sepal_length=sepal_length, 39 | sepal_width=sepal_width, 40 | petal_length=petal_length, 41 | petal_width=petal_width, 42 | prediction=prediction_name) 43 | 44 | return render(request, 'home.html', {'form': form, 'prediction': prediction, 45 | 'prediction_name': prediction_name, 46 | 'prediction_img': prediction_img}) 47 | 48 | # if a GET (or any other method) we'll create a blank form 49 | else: 50 | form = ModelForm() 51 | 52 | return render(request, 'home.html', {'form': form}) 53 | -------------------------------------------------------------------------------- /mysite/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for mysite project. 3 | 4 | Generated by 'django-admin startproject' using Django 3.1.1. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/3.1/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/3.1/ref/settings/ 11 | """ 12 | 13 | from pathlib import Path 14 | 15 | # Build paths inside the project like this: BASE_DIR / 'subdir'. 16 | BASE_DIR = Path(__file__).resolve().parent.parent 17 | 18 | 19 | # Quick-start development settings - unsuitable for production 20 | # See https://docs.djangoproject.com/en/3.1/howto/deployment/checklist/ 21 | 22 | # SECURITY WARNING: keep the secret key used in production secret! 23 | SECRET_KEY = '5br5f=sjdc7xtl+)5sj3%jfbvc%bi!-2)mj=rq&2w@=il#!5hh' 24 | 25 | # SECURITY WARNING: don't run with debug turned on in production! 26 | DEBUG = True 27 | 28 | ALLOWED_HOSTS = [] 29 | 30 | 31 | # Application definition 32 | 33 | INSTALLED_APPS = [ 34 | 'django.contrib.admin', 35 | 'django.contrib.auth', 36 | 'django.contrib.contenttypes', 37 | 'django.contrib.sessions', 38 | 'django.contrib.messages', 39 | 'django.contrib.staticfiles', 40 | 'iris', 41 | ] 42 | 43 | MIDDLEWARE = [ 44 | 'django.middleware.security.SecurityMiddleware', 45 | 'django.contrib.sessions.middleware.SessionMiddleware', 46 | 'django.middleware.common.CommonMiddleware', 47 | 'django.middleware.csrf.CsrfViewMiddleware', 48 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 49 | 'django.contrib.messages.middleware.MessageMiddleware', 50 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 51 | ] 52 | 53 | ROOT_URLCONF = 'mysite.urls' 54 | 55 | TEMPLATES = [ 56 | { 57 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 58 | 'DIRS': [], 59 | 'APP_DIRS': True, 60 | 'OPTIONS': { 61 | 'context_processors': [ 62 | 'django.template.context_processors.debug', 63 | 'django.template.context_processors.request', 64 | 'django.contrib.auth.context_processors.auth', 65 | 'django.contrib.messages.context_processors.messages', 66 | ], 67 | }, 68 | }, 69 | ] 70 | 71 | WSGI_APPLICATION = 'mysite.wsgi.application' 72 | 73 | 74 | # Database 75 | # https://docs.djangoproject.com/en/3.1/ref/settings/#databases 76 | 77 | DATABASES = { 78 | 'default': { 79 | 'ENGINE': 'django.db.backends.sqlite3', 80 | 'NAME': BASE_DIR / 'db.sqlite3', 81 | } 82 | } 83 | 84 | 85 | # Password validation 86 | # https://docs.djangoproject.com/en/3.1/ref/settings/#auth-password-validators 87 | 88 | AUTH_PASSWORD_VALIDATORS = [ 89 | { 90 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 91 | }, 92 | { 93 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 94 | }, 95 | { 96 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 97 | }, 98 | { 99 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 100 | }, 101 | ] 102 | 103 | 104 | # Internationalization 105 | # https://docs.djangoproject.com/en/3.1/topics/i18n/ 106 | 107 | LANGUAGE_CODE = 'en-us' 108 | 109 | TIME_ZONE = 'UTC' 110 | 111 | USE_I18N = True 112 | 113 | USE_L10N = True 114 | 115 | USE_TZ = True 116 | 117 | 118 | # Static files (CSS, JavaScript, Images) 119 | # https://docs.djangoproject.com/en/3.1/howto/static-files/ 120 | 121 | STATIC_URL = '/static/' 122 | --------------------------------------------------------------------------------