├── .gitignore ├── README.md ├── app.py ├── model_assets ├── model_V0.pkl └── vectorizer_V0.pkl ├── model_dev ├── BaseModel.py ├── data │ └── cyber_data.json └── model_dev.ipynb ├── requirements.txt ├── static ├── screen-shot-ui.png └── troll-guy.png ├── templates └── index.html └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/python,pycharm 3 | # Edit at https://www.gitignore.io/?templates=python,pycharm 4 | 5 | ### PyCharm ### 6 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 7 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 8 | 9 | # User-specific stuff 10 | .idea/**/workspace.xml 11 | .idea/**/tasks.xml 12 | .idea/**/usage.statistics.xml 13 | .idea/**/dictionaries 14 | .idea/**/shelf 15 | 16 | # Generated files 17 | .idea/**/contentModel.xml 18 | 19 | # Sensitive or high-churn files 20 | .idea/**/dataSources/ 21 | .idea/**/dataSources.ids 22 | .idea/**/dataSources.local.xml 23 | .idea/**/sqlDataSources.xml 24 | .idea/**/dynamic.xml 25 | .idea/**/uiDesigner.xml 26 | .idea/**/dbnavigator.xml 27 | 28 | # Gradle 29 | .idea/**/gradle.xml 30 | .idea/**/libraries 31 | 32 | # Gradle and Maven with auto-import 33 | # When using Gradle or Maven with auto-import, you should exclude module files, 34 | # since they will be recreated, and may cause churn. Uncomment if using 35 | # auto-import. 36 | # .idea/modules.xml 37 | # .idea/*.iml 38 | # .idea/modules 39 | # *.iml 40 | # *.ipr 41 | 42 | # CMake 43 | cmake-build-*/ 44 | 45 | # Mongo Explorer plugin 46 | .idea/**/mongoSettings.xml 47 | 48 | # File-based project format 49 | *.iws 50 | 51 | # IntelliJ 52 | out/ 53 | 54 | # mpeltonen/sbt-idea plugin 55 | .idea_modules/ 56 | 57 | # JIRA plugin 58 | atlassian-ide-plugin.xml 59 | 60 | # Cursive Clojure plugin 61 | .idea/replstate.xml 62 | 63 | # Crashlytics plugin (for Android Studio and IntelliJ) 64 | com_crashlytics_export_strings.xml 65 | crashlytics.properties 66 | crashlytics-build.properties 67 | fabric.properties 68 | 69 | # Editor-based Rest Client 70 | .idea/httpRequests 71 | 72 | # Android studio 3.1+ serialized cache file 73 | .idea/caches/build_file_checksums.ser 74 | 75 | ### PyCharm Patch ### 76 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 77 | 78 | # *.iml 79 | # modules.xml 80 | # .idea/misc.xml 81 | # *.ipr 82 | 83 | # Sonarlint plugin 84 | .idea/sonarlint 85 | 86 | ### Python ### 87 | # Byte-compiled / optimized / DLL files 88 | __pycache__/ 89 | *.py[cod] 90 | *$py.class 91 | 92 | # C extensions 93 | *.so 94 | 95 | # Distribution / packaging 96 | .Python 97 | build/ 98 | develop-eggs/ 99 | dist/ 100 | downloads/ 101 | eggs/ 102 | .eggs/ 103 | lib/ 104 | lib64/ 105 | parts/ 106 | sdist/ 107 | var/ 108 | wheels/ 109 | pip-wheel-metadata/ 110 | share/python-wheels/ 111 | *.egg-info/ 112 | .installed.cfg 113 | *.egg 114 | MANIFEST 115 | 116 | # PyInstaller 117 | # Usually these files are written by a python script from a template 118 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 119 | *.manifest 120 | *.spec 121 | 122 | # Installer logs 123 | pip-log.txt 124 | pip-delete-this-directory.txt 125 | 126 | # Unit test / coverage reports 127 | htmlcov/ 128 | .tox/ 129 | .nox/ 130 | .coverage 131 | .coverage.* 132 | .cache 133 | nosetests.xml 134 | coverage.xml 135 | *.cover 136 | .hypothesis/ 137 | .pytest_cache/ 138 | 139 | # Translations 140 | *.mo 141 | *.pot 142 | 143 | # Django stuff: 144 | *.log 145 | local_settings.py 146 | db.sqlite3 147 | db.sqlite3-journal 148 | 149 | # Flask stuff: 150 | instance/ 151 | .webassets-cache 152 | 153 | # Scrapy stuff: 154 | .scrapy 155 | 156 | # Sphinx documentation 157 | docs/_build/ 158 | 159 | # PyBuilder 160 | target/ 161 | 162 | # Jupyter Notebook 163 | .ipynb_checkpoints 164 | 165 | # IPython 166 | profile_default/ 167 | ipython_config.py 168 | 169 | # pyenv 170 | .python-version 171 | 172 | # pipenv 173 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 174 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 175 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 176 | # install all needed dependencies. 177 | #Pipfile.lock 178 | 179 | # celery beat schedule file 180 | celerybeat-schedule 181 | 182 | # SageMath parsed files 183 | *.sage.py 184 | 185 | # Environments 186 | .env 187 | .venv 188 | env/ 189 | venv/ 190 | ENV/ 191 | env.bak/ 192 | venv.bak/ 193 | 194 | # Spyder project settings 195 | .spyderproject 196 | .spyproject 197 | 198 | # Rope project settings 199 | .ropeproject 200 | 201 | # mkdocs documentation 202 | /site 203 | 204 | # mypy 205 | .mypy_cache/ 206 | .dmypy.json 207 | dmypy.json 208 | 209 | # Pyre type checker 210 | .pyre/ 211 | 212 | # End of https://www.gitignore.io/api/python,pycharm -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ml-flask-web-app 2 | 3 | This is a web application designed to show the project structure for a machine learning model deployed using flask. This project features a machine learning model that has been trained to detect whether or not an online comment is a `Cyber-Troll` or `Non Cyber-Troll`. This application acts as an interface for a user to submit new queries. The machine learning model was built using various features of scikit learn: 4 | 5 | * Support Vector Machine (SVM) 6 | * Bag-of-Words text representation (BoW) 7 | * Grid Search + Cross Validation 8 | 9 | Each of these components are developed within the project in an offline setting inside `/model_dev`. The SVM and BoW models will still be needed in a production or testing setting in order to be able to predict user-submitted queries, so they can be serialized via python's pickle functionality and stored within the `/model_assets` folder. 10 | 11 | In order to detect whether or not an online comment is from a cyber troll, you can deploy this application locally and submit queries to the machine learning model to recieve predictions through a simple user interface. The model was trained using the 12 | Dataset for Detection of Cyber-Trolls ([see here](https://dataturks.com/projects/abhishek.narayanan/Dataset%20for%20Detection%20of%20Cyber-Trolls/)). This project emphasizes more the development process of creating deploy-friendly machine learning projects, rather than the creating of the predictive model itself. 13 | 14 | The model development notebook is located [here](https://github.com/wgopar/ml-flask-web-app/blob/master/model_dev/model_dev.ipynb). 15 | 16 | You can also find a blog post that accompanies this repo [here](http://www.wmendozagopar.com/creating-and-deploying-a-machine-learning-project-with-flask.html#creating-and-deploying-a-machine-learning-project-with-flask). 17 | 18 | Note that this project is still *in progress* 19 | 20 | ## Installation 21 | 22 | First clone the repo locally. 23 | ~~~bash 24 | git clone https://github.com/wgopar/ml-flask-web-app.git 25 | ~~~ 26 | 27 | Create a new virtual environment in the project directory. 28 | ~~~bash 29 | python3 -m venv ./venv 30 | ~~~ 31 | 32 | Activate the virtual environment. 33 | ~~~bash 34 | source venv/bin/activate 35 | ~~~ 36 | 37 | While in the virtual environment, install required dependencies from `requirements.txt`. 38 | 39 | ~~~bash 40 | pip install -r ./requirements.txt 41 | ~~~ 42 | 43 | Now we can deploy the web application via 44 | ~~~bash 45 | python app.py 46 | ~~~ 47 | 48 | and navigate to `http://127.0.0.1:5000/` to see it live. On this page, a user can then submit text into the text 49 | field and receive predictions from the trained model and determine if the text most likely came from a `Cyber Troll` or 50 | `Non Cyber-Troll`. 51 | 52 | ![Screen shot](/static/screen-shot-ui.png "User Interface") 53 | 54 | 55 | The application may then be terminated with the following commands. 56 | ~~~bash 57 | $ ^C # exit flask application (ctrl-c) 58 | $ deactivate # exit virtual environment 59 | ~~~ 60 | 61 | ## Project Structure 62 | 63 | ~~~ 64 | ml-flask-web-app 65 | ├── model_assets 66 | │ ├── model.pkl 67 | │ └── vectorizer.pkl 68 | ├── model_dev 69 | │ ├── data 70 | │ | └── data.json 71 | │ └── model_dev.ipynb 72 | ├── templates 73 | │ └── index.html 74 | ├── app.py 75 | ├── utils.py 76 | ├── requirements.txt 77 | └── README.md 78 | ~~~ 79 | 80 | ### detailed 81 | 82 | `/model_assets` is used to store persisted states of the predictive model and learned feature extractors from scikit-learn. 83 | 84 | `/model_dev` is used as the model development playground where an `.ipynb` is used to develop the model and save new versions of persisted states. 85 | 86 | Storing new persisted states of the model can be done within the jupyter notebook. As an example, within `model_dev.ipynb` 87 | I can create a new model/retrain and include in into the `./model_assets` folder when I am satisfied. A simple example: 88 | 89 | ~~~~python 90 | import utils 91 | 92 | clf = LogisticRegression() 93 | clf.fit(X_train, y_train) 94 | utils.persist_model(clf, description='clf_v.0.0') # creates clf_v.0.0.pkl in /model_assets folder 95 | ~~~~ 96 | 97 | Selecting the version of models to use during run time is chosen within the POST request function inside 98 | in `app.py`. 99 | 100 | `/templates` holds the html templates for the application. 101 | 102 | 103 | 104 | []: './static/screen-shot-ui.png' 105 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, render_template, jsonify, url_for 2 | from utils import clean_text 3 | import pickle 4 | import time 5 | import os 6 | 7 | app = Flask(__name__) 8 | 9 | MODEL_VERSION = 'model_V0.pkl' 10 | VECTORIZER_VERSION = 'vectorizer_V0.pkl' 11 | 12 | # load model assets 13 | vectorizer_path = os.path.join(os.getcwd(), 'model_assets', VECTORIZER_VERSION) 14 | model_path = os.path.join(os.getcwd(), 'model_assets', MODEL_VERSION) 15 | vectorizer = pickle.load(open(vectorizer_path, 'rb')) 16 | model = pickle.load(open(model_path, 'rb')) 17 | 18 | # TODO: add versioning to url 19 | @app.route('/', methods=['GET', 'POST']) 20 | def predict(): 21 | """ Main webpage with user input through form and prediction displayed 22 | 23 | :return: main webpage host, displays prediction if user submitted in text field 24 | """ 25 | 26 | if request.method == 'POST': 27 | 28 | response = request.form['text'] 29 | input_text = clean_text(response) 30 | input_text = vectorizer.transform([input_text]) 31 | prediction = model.predict(input_text) 32 | prediction = 'Cyber-Troll' if prediction[0] == 1 else 'Non Cyber-Troll' 33 | return render_template('index.html', text=prediction, submission=response) 34 | 35 | if request.method == 'GET': 36 | return render_template('index.html') 37 | 38 | # TODO: add versioning to api 39 | @app.route('/predict', methods=['POST']) 40 | def predict_api(): 41 | """ endpoint for model queries (non gui) 42 | 43 | :return: json, model prediction and response time 44 | """ 45 | start_time = time.time() 46 | 47 | request_data = request.json 48 | input_text = request_data['data'] 49 | input_text = clean_text(input_text) 50 | input_text = vectorizer.transform([input_text]) 51 | prediction = model.predict(input_text) 52 | prediction = 'Cyber-Troll' if prediction[0] == 1 else "Non Cyber-Troll" # post processing 53 | 54 | response = {'prediction': prediction, 'response_time': time.time() - start_time} 55 | return jsonify(response) 56 | 57 | 58 | if __name__ == '__main__': 59 | app.run(debug=True) 60 | -------------------------------------------------------------------------------- /model_assets/model_V0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wgopar/ml-flask-web-app/df90234a2aa2e3b0009292fdc356dfb6a5c05bcb/model_assets/model_V0.pkl -------------------------------------------------------------------------------- /model_assets/vectorizer_V0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wgopar/ml-flask-web-app/df90234a2aa2e3b0009292fdc356dfb6a5c05bcb/model_assets/vectorizer_V0.pkl -------------------------------------------------------------------------------- /model_dev/BaseModel.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import confusion_matrix 2 | from sklearn.model_selection import GridSearchCV 3 | from sklearn.svm import LinearSVC 4 | 5 | 6 | class SVM: 7 | """ SVM model to support Online Troll comment detection 8 | 9 | Attributes 10 | ---------- 11 | description : string, model description for referencing parameters of object 12 | 13 | clf : sklearn svm model object 14 | 15 | """ 16 | 17 | def __init__(self, description): 18 | """ Initialize model configuration 19 | 20 | Parameters 21 | -------------- 22 | data: (dict) dictionary of training and testings sets of data 23 | description: (str) description of model being trained 24 | """ 25 | self.description = description 26 | self.clf = LinearSVC() 27 | 28 | def train(self, data, **params): 29 | """ Trains model with grid search over user defined parameters 30 | 31 | Parameters 32 | ------------- 33 | parameters: (dict) key value pairs specifying sklearn parameter search 34 | 35 | """ 36 | 37 | self.clf = GridSearchCV(self.clf, params, cv=5) 38 | self.clf.fit(data['X_train'], data['y_train']) 39 | 40 | def display_results(self, data): 41 | """ Prints testing and training accuracies along with other model 42 | validation metrics. 43 | 44 | Parameters 45 | --------------- 46 | clf: (scikit-learn model) predictive model to test 47 | data: (dict) training and testing data for model 48 | """ 49 | train_accuracy = self.clf.score(data['X_train'], data['y_train']) 50 | test_accuracy = self.clf.score(data['X_test'], data['y_test']) 51 | y_pred = self.clf.predict(data['X_test']) 52 | print('{:>20s} {:.2f}'.format('Train Accuracy:', train_accuracy)) 53 | print('{:>20s} {:.2f}'.format('Test Accuracy:', test_accuracy)) 54 | print(confusion_matrix(data['y_test'], y_pred)) -------------------------------------------------------------------------------- /model_dev/model_dev.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 175, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "\n", 12 | "import warnings \n", 13 | "warnings.filterwarnings('ignore')\n", 14 | "\n", 15 | "# include app-wide functions \n", 16 | "sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('utils.py'))))\n", 17 | "import utils\n", 18 | "\n", 19 | "from sklearn.feature_extraction.text import CountVectorizer\n", 20 | "from sklearn.model_selection import train_test_split\n", 21 | "from collections import Counter\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "from bs4 import BeautifulSoup\n", 24 | "import numpy as np\n", 25 | "import pandas as pd\n", 26 | "import pickle \n", 27 | "import random\n", 28 | "import string\n", 29 | "import json\n", 30 | "\n", 31 | "%config InlineBackend.figure_format = 'retina'\n", 32 | "%matplotlib inline" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 87, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "df = utils.load_data() # cleaned dataset" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "#TODO: only allow words occuring at least 3 times + remove stop words (and or i)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "## some exploratory data analysis" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 88, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "image/png": "\n", 68 | "text/plain": [ 69 | "" 70 | ] 71 | }, 72 | "metadata": { 73 | "image/png": { 74 | "height": 261, 75 | "width": 390 76 | }, 77 | "needs_background": "light" 78 | }, 79 | "output_type": "display_data" 80 | } 81 | ], 82 | "source": [ 83 | "# distribution of cyber trolls vs non-cyber trolls\n", 84 | "counter = Counter(df.label)\n", 85 | "plt.title('Distribution of CyberTrolls - Train set')\n", 86 | "plt.bar(list(counter.keys())[0], list(counter.values())[0], align='center', color='g', label='Non Cyber-Agressive')\n", 87 | "plt.bar(list(counter.keys())[1], list(counter.values())[1], align='center', color='r', label='Cyber-Agressive')\n", 88 | "plt.xticks(list(set(df.label)))\n", 89 | "plt.legend()\n", 90 | "plt.show()" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 90, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "Label: 0\n", 103 | "Index: 788\tuh well suck complainjk\n", 104 | "\n", 105 | "Label: 0\n", 106 | "Index: 9969\tmmm\n", 107 | "\n", 108 | "Label: 0\n", 109 | "Index: 5062\tdamn trying number list dont need more p\n", 110 | "\n", 111 | "Label: 0\n", 112 | "Index: 1167\tson cant decide cuz blackberry sooo delish oh yea ill cd find damn blank cd lol\n", 113 | "\n", 114 | "Label: 0\n", 115 | "Index: 9286\tdo you go wit anybody at this moment\n", 116 | "\n", 117 | "Label: 0\n", 118 | "Index: 8664\ti would love follow r absolutely\n", 119 | "\n", 120 | "Label: 0\n", 121 | "Index: 11964\tany new exciting news youud like share\n", 122 | "\n", 123 | "Label: 0\n", 124 | "Index: 10774\thahhahah yeah you\n", 125 | "\n", 126 | "Label: 0\n", 127 | "Index: 190\toh man funny tweet im tear here hahaha i hate mac genuises soooo much\n", 128 | "\n", 129 | "Label: 0\n", 130 | "Index: 3028\thi best fucking friend reply me jaja i love you motherfucker mua or cku\n", 131 | "\n", 132 | "Label: 1\n", 133 | "Index: 3108\tknow bambis as nice place lol\n", 134 | "\n", 135 | "Label: 1\n", 136 | "Index: 951\ti nominate jcroft shorty award downassbitches totally one\n", 137 | "\n", 138 | "Label: 1\n", 139 | "Index: 7185\ttotally the as end jeep life form impacted\n", 140 | "\n", 141 | "Label: 1\n", 142 | "Index: 7085\tnerd\n", 143 | "\n", 144 | "Label: 1\n", 145 | "Index: 7620\tlmao i seen dude mad angry got as beat\n", 146 | "\n", 147 | "Label: 1\n", 148 | "Index: 3562\ti hate wedding u shud grateful male worrying u got wear\n", 149 | "\n", 150 | "Label: 1\n", 151 | "Index: 6945\tgah fuck tainting private account new follower xd\n", 152 | "\n", 153 | "Label: 1\n", 154 | "Index: 976\tbooooooo hot as mushroom stock wanna come get work\n", 155 | "\n", 156 | "Label: 1\n", 157 | "Index: 5431\tlet hate gambit grow embrace it spread it\n", 158 | "\n", 159 | "Label: 1\n", 160 | "Index: 6509\tand boycott fucking bar carry it\n", 161 | "\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "# sample some of the online comments\n", 167 | "utils.sample_data(df, n=10) # feature engineering (create number hashtags col?)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 91, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "name": "stdout", 177 | "output_type": "stream", 178 | "text": [ 179 | "Cyber Trolls\n", 180 | "('i', 2981)\n", 181 | "('hate', 1313)\n", 182 | "('damn', 1059)\n", 183 | "('fuck', 1036)\n", 184 | "('as', 1022)\n", 185 | "\n", 186 | "Non Cyber Trolls\n", 187 | "('i', 3839)\n", 188 | "('hate', 1488)\n", 189 | "('damn', 1307)\n", 190 | "('im', 988)\n", 191 | "('like', 942)\n" 192 | ] 193 | } 194 | ], 195 | "source": [ 196 | "# most common words by label\n", 197 | "trolls = Counter(' '.join(list(df[df.label == 1].text)).split())\n", 198 | "non_trolls = Counter(' '.join(list(df[df.label == 0].text)).split())\n", 199 | "\n", 200 | "print('Cyber Trolls')\n", 201 | "print(*trolls.most_common()[:5], sep='\\n')\n", 202 | "print('\\nNon Cyber Trolls')\n", 203 | "print(*non_trolls.most_common()[:5], sep='\\n')" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 3, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "## Save vectorizer in ./model_assets\n", 213 | "utils.persist_vectorizer(vectorizer, 'test_v.0.0') " 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "## text feature extractors" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 92, 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "name": "stdout", 230 | "output_type": "stream", 231 | "text": [ 232 | "True\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "# bag-of-words encoding\n", 238 | "enc = utils.build_encoder(df.text, count_vectorizer=True)\n", 239 | "count_vectorized = enc.fit_transform(df.text).toarray()\n", 240 | "\n", 241 | "# tf-idf encoding\n", 242 | "enc = utils.build_encoder(df.text, tf_idf=True)\n", 243 | "tf_idf = enc.fit_transform(df.text).toarray()\n", 244 | "\n", 245 | "print(tf_idf.shape == count_vectorized.shape)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": {}, 251 | "source": [ 252 | "# Model dev" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 94, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "X_train, X_test, y_train, y_test = train_test_split(count_vectorized, df.label)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 95, 267 | "metadata": {}, 268 | "outputs": [ 269 | { 270 | "data": { 271 | "image/png": "\n", 272 | "text/plain": [ 273 | "" 274 | ] 275 | }, 276 | "metadata": { 277 | "image/png": { 278 | "height": 316, 279 | "width": 885 280 | }, 281 | "needs_background": "light" 282 | }, 283 | "output_type": "display_data" 284 | } 285 | ], 286 | "source": [ 287 | "# Show of labels train/test sets\n", 288 | "fig, axs = plt.subplots(1, 2, figsize=(15,5))\n", 289 | "\n", 290 | "train_count = Counter(y_train)\n", 291 | "axs[0].set_title('Distribution of CyberTrolls - Train set')\n", 292 | "axs[0].bar(list(train_count.keys())[0], list(train_count.values())[0], align='center', color='g', label='Non Cyber-Agressive')\n", 293 | "axs[0].bar(list(train_count.keys())[1], list(train_count.values())[1], align='center', color='r', label='Cyber-Agressive')\n", 294 | "axs[0].set_xticks(list(set(y_train)))\n", 295 | "axs[0].legend()\n", 296 | "\n", 297 | "test_count = Counter(y_test)\n", 298 | "axs[1].set_title('Distribution of CyberTrolls - Test set')\n", 299 | "axs[1].bar(list(test_count.keys())[0], list(test_count.values())[0], align='center', color='g', label='Non Cyber-Agressive')\n", 300 | "axs[1].bar(list(test_count.keys())[1], list(test_count.values())[1], align='center', color='r', label='Cyber-Agressive')\n", 301 | "axs[1].set_xticks(list(set(y_test)))\n", 302 | "axs[1].legend()\n", 303 | "\n", 304 | "plt.show()" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 186, 310 | "metadata": {}, 311 | "outputs": [ 312 | { 313 | "name": "stdout", 314 | "output_type": "stream", 315 | "text": [ 316 | " Train Accuracy: 0.97\n", 317 | " Test Accuracy: 0.85\n", 318 | "[[2456 593]\n", 319 | " [ 166 1786]]\n" 320 | ] 321 | } 322 | ], 323 | "source": [ 324 | "from BaseModel import SVM\n", 325 | "\n", 326 | "# svc params\n", 327 | "params = {'C': np.logspace(-5, 5, 5)}\n", 328 | "data = {'X_train': X_train, 'X_test': X_test, 'y_train': y_train, 'y_test': y_test}\n", 329 | "clf = SVM(description='dev')\n", 330 | "clf.train(data=data, **params)\n", 331 | "clf.display_results(data)" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 193, 337 | "metadata": {}, 338 | "outputs": [ 339 | { 340 | "ename": "AttributeError", 341 | "evalue": "'SVM' object has no attribute 'display_results_'", 342 | "output_type": "error", 343 | "traceback": [ 344 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 345 | "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", 346 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisplay_results_\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 347 | "\u001b[0;31mAttributeError\u001b[0m: 'SVM' object has no attribute 'display_results_'" 348 | ] 349 | } 350 | ], 351 | "source": [ 352 | "clf.display_results_" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "metadata": {}, 358 | "source": [ 359 | "## save model and vectorizer" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 10, 365 | "metadata": {}, 366 | "outputs": [ 367 | { 368 | "name": "stdout", 369 | "output_type": "stream", 370 | "text": [ 371 | "Model Saved.\n" 372 | ] 373 | } 374 | ], 375 | "source": [ 376 | "persist_model(clf, 'test_v.0.0')\n", 377 | "persist_model(clf, 'test_v.0.0')" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 165, 383 | "metadata": {}, 384 | "outputs": [ 385 | { 386 | "data": { 387 | "text/plain": [ 388 | "(15000, 16673)" 389 | ] 390 | }, 391 | "execution_count": 165, 392 | "metadata": {}, 393 | "output_type": "execute_result" 394 | } 395 | ], 396 | "source": [ 397 | "data['X_train'].shape" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": null, 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [] 406 | } 407 | ], 408 | "metadata": { 409 | "kernelspec": { 410 | "display_name": "Python 3", 411 | "language": "python", 412 | "name": "python3" 413 | }, 414 | "language_info": { 415 | "codemirror_mode": { 416 | "name": "ipython", 417 | "version": 3 418 | }, 419 | "file_extension": ".py", 420 | "mimetype": "text/x-python", 421 | "name": "python", 422 | "nbconvert_exporter": "python", 423 | "pygments_lexer": "ipython3", 424 | "version": "3.5.6" 425 | } 426 | }, 427 | "nbformat": 4, 428 | "nbformat_minor": 2 429 | } 430 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | beautifulsoup4==4.8.0 2 | bs4==0.0.1 3 | certifi==2019.6.16 4 | chardet==3.0.4 5 | Click==7.0 6 | Flask==1.1.1 7 | idna==2.8 8 | itsdangerous==1.1.0 9 | Jinja2==2.10.1 10 | joblib==0.13.2 11 | MarkupSafe==1.1.1 12 | nltk==3.4.5 13 | numpy==1.17.0 14 | pandas==0.25.1 15 | python-dateutil==2.8.0 16 | pytz==2019.2 17 | scikit-learn==0.21.3 18 | scipy==1.3.1 19 | six==1.12.0 20 | sklearn==0.0 21 | soupsieve==1.9.3 22 | urllib3==1.25.3 23 | Werkzeug==0.15.5 24 | -------------------------------------------------------------------------------- /static/screen-shot-ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wgopar/ml-flask-web-app/df90234a2aa2e3b0009292fdc356dfb6a5c05bcb/static/screen-shot-ui.png -------------------------------------------------------------------------------- /static/troll-guy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wgopar/ml-flask-web-app/df90234a2aa2e3b0009292fdc356dfb6a5c05bcb/static/troll-guy.png -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 |
7 |

Welcome to the Cyber-Troll Prediction Service

8 |
9 | 10 |
11 | 12 |
13 | 14 |
15 | 16 | This is an example of a machine learning model being deployed using Flask and 17 | Scikit-Learn! You can interact with the machine learning model that was built 18 | to predict whether or not an online comment is from a cyber troll. Enter a comment 19 | below for a prediction! 20 | 21 |
22 | 23 |
24 |
25 | 26 | 27 |
28 |
29 | 30 |

{{ submission }}

31 |

{{ text }}

32 | 33 | 34 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from bs4 import BeautifulSoup 2 | from nltk.corpus import stopwords 3 | from nltk.stem import WordNetLemmatizer 4 | from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer 5 | import pandas as pd 6 | import pickle 7 | import string 8 | import json 9 | import random 10 | import os 11 | 12 | 13 | def load_data(raw=None): 14 | """ load data to development workspaces 15 | 16 | Parameters 17 | -------------- 18 | raw: (bool) if True, function returns cleaned dataset. 19 | return (df) data frame of data and its labels 20 | """ 21 | raw_data = [] 22 | with open('./data/cyber_data.json') as f: 23 | for line in f: 24 | raw_data.append(json.loads(line)) 25 | 26 | labels = [int(d['annotation']['label'][0]) for d in raw_data] 27 | text = [d['content'] for d in raw_data] 28 | data = {'text': text, 'label': labels} 29 | df = pd.DataFrame(data, columns=['text', 'label']) # raw data frame 30 | 31 | if raw: 32 | return df 33 | else: 34 | df.text = df.text.apply(clean_text) 35 | return df 36 | 37 | 38 | def clean_text(text): 39 | """ clean input text for the prediction model 40 | 41 | Parameters 42 | ------------- 43 | text: (str) text to clean 44 | return (str) post-processed clean text 45 | """ 46 | lemmatizer = WordNetLemmatizer() 47 | punctuation = list(string.punctuation) 48 | punctuation.extend(['.', "’", ',']) 49 | text = BeautifulSoup(text, 'html.parser').text 50 | filtered_text = ' '.join([word.lower() for word in text.split() if word not in stopwords.words('english')]) 51 | filtered_text = ''.join([c for c in filtered_text if c not in punctuation]) 52 | filtered_text = ''.join([c for c in filtered_text if not c.isdigit()]) 53 | filtered_text = filtered_text.replace('-', ' ') 54 | filtered_text = ' '.join([lemmatizer.lemmatize(w) for w in filtered_text.split()]) 55 | return filtered_text 56 | 57 | 58 | def persist_model(clf, description): 59 | """ saves pickled classifier in /model_assets folder with naming convention: model_[description].pkl 60 | 61 | Parameters 62 | ------------- 63 | clf: (obj) scikit-learn trained model 64 | description: (str) model version/descriptor 65 | """ 66 | model_path = open(os.path.join(os.pardir, "model_assets/model_{}.pkl".format(description)), "wb") 67 | pickle.dump(clf, model_path) 68 | print('Model Saved.') 69 | 70 | 71 | def build_encoder(text, count_vectorizer=None, tf_idf=None): 72 | """ builds a text feature extractor given an iterable of text data 73 | 74 | Parameters 75 | --------------- 76 | text: (list or series) of text data to transoform 77 | count_vectorizer: (bool) If `True` transforms into BoW model 78 | tf_idf: (bool) If `True` transforms into TF-IDF representation 79 | 80 | """ 81 | 82 | if count_vectorizer: 83 | vectorizer = CountVectorizer() 84 | vectorizer.fit(text) 85 | return vectorizer 86 | 87 | if tf_idf: 88 | transformer = TfidfVectorizer() 89 | transformer.fit(text) 90 | return transformer 91 | 92 | 93 | def persist_vectorizer(vectorizer, description): 94 | """ saves bag-of-words vectorizer in /model_assets folder with naming convention: vectorizer_[description].pkl 95 | 96 | Parameters 97 | ------------- 98 | vectorizer: (obj) sklearn vectorizer object 99 | description: (str) vectorizer version/descriptor 100 | """ 101 | vectorizer_path = open(os.path.join(os.pardir, "model_assets/vectorizer_{}.pkl".format(description)), "wb") 102 | pickle.dump(vectorizer, vectorizer_path) 103 | print('Vectorizer Saved.') 104 | 105 | 106 | def sample_data(df, n): 107 | """ prints to console n random samples of data in the data frame (e.g online comment and label) 108 | 109 | Parameters 110 | ------------- 111 | df: pandas DataFrame to be sampled 112 | n: number of samples to generate 113 | 114 | """ 115 | for label in set(df.label): 116 | subset = df[df.label == label] 117 | rand_idxs = [random.randint(0, subset.shape[0]) for _ in range(n)] 118 | for idx in rand_idxs: 119 | print('Label: {}\nIndex: {}\t{}\n'.format(subset.iloc[idx]['label'], idx, subset.iloc[idx]['text'])) --------------------------------------------------------------------------------