├── README.md ├── datasets └── housing │ └── housing.tgz ├── LICENSE ├── NOTES.md ├── .gitignore ├── ML_Project_Checklist.md └── Ch3_classification.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Hands-On-ML -------------------------------------------------------------------------------- /datasets/housing/housing.tgz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bexxmodd/Hands-On-ML/main/datasets/housing/housing.tgz -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Beka Modebadze 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NOTES.md: -------------------------------------------------------------------------------- 1 | # Chapter 1 2 | In a famous 1996 paper, 11 David Wolpert demonstrated that if you 3 | make absolutely no assumption about the data, then there is no reason 4 | to prefer one model over any other. This is called the **No Free Lunch 5 | (NFL)** theorem. 6 | 7 | ----- 8 | # Chapter 2 9 | A sequence of data processing components is called a _data pipeline_. 10 | Pipelines are very common in Machine Learning systems, since there is a 11 | lot of data to manipulate and many data transformations to apply. 12 | Components typically run asynchronously. Each component pulls in a 13 | large amount of data, processes it, and spits out the result in another data 14 | store. Then, some time later, the next component in the pipeline pulls this 15 | data and spits out its own output. Each component is fairly self-contained: 16 | the interface between components is simply the data store. This makes the 17 | system simple to grasp (with the help of a data flow graph), and different 18 | teams can focus on different components. Moreover, if a component 19 | breaks down, the downstream components can often continue to run 20 | normally (at least for a while) by just using the last output from the 21 | broken component. This makes the architecture quite robust. 22 | On the other hand, a broken component can go unnoticed for some time if 23 | proper monitoring is not implemented. The data gets stale and the overall 24 | system’s performance drops. 25 | 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /ML_Project_Checklist.md: -------------------------------------------------------------------------------- 1 | This checklist can guide you through your Machine Learning projects. 2 | There are eight main steps: 3 | 4 | 1. Frame the problem and look at the big picture. 5 | 2. Get the data. 6 | 3. Explore the data to gain insights. 7 | 4. Prepare the data to better expose the underlying data patterns to Machine Learning algorithms. 8 | 5. Explore many different models and shortlist the best ones. 9 | 6. Fine-tune your models and combine them into a great solution. 10 | 7. Present your solution. 11 | 8. Launch, monitor, and maintain your system. 12 | 13 | Obviously, you should feel free to adapt this checklist to your needs. 14 | 15 | 16 | ## Frame the Problem and Look at the Big Picture 17 | 18 | 1. Define the objective in business terms. 19 | 2. How will your solution be used? 20 | 3. What are the current solutions/workarounds (if any)? 21 | 4. How should you frame this problem (supervised/unsupervised, online/offline, etc.)? 22 | 5. How should performance be measured? 23 | 6. Is the performance measure aligned with the business objective? 24 | 7. What would be the minimum performance needed to reach the business objective? 25 | 8. What are comparable problems? Can you reuse experience or tools? 26 | 9. Is human expertise available? 27 | 10. How would you solve the problem manually? 28 | 11. List the assumptions you (or others) have made so far. 29 | 12. Verify assumptions if possible. 30 | 31 | 32 | ## Get the Data 33 | Note: automate as much as possible so you can easily get fresh data. 34 | 35 | 1. List the data you need and how much you need. 36 | 2. Find and document where you can get that data. 37 | 3. Check how much space it will take. 38 | 4. Check legal obligations, and get authorization if necessary. 39 | 5. Get access authorizations. 40 | 6. Create a workspace (with enough storage space). 41 | 7. Get the data. 42 | 8. Convert the data to a format you can easily manipulate (without changing the data itself). 43 | 9. Ensure sensitive information is deleted or protected (e.g., anonymized). 44 | 10. Check the size and type of data (time series, sample, geographical, etc.). 45 | 11. Sample a test set, put it aside, and never look at it (no data snooping!). 46 | 47 | 48 | ## Explore the Data 49 | Note: try to get insights from a field expert for these steps. 50 | 51 | 1. Create a copy of the data for exploration (sampling it down to a manageable size if necessary). 52 | 2. Create a Jupyter notebook to keep a record of your data exploration. 53 | 3. Study each attribute and its characteristics: 54 | * Name 55 | * Type (categorical, int/float, bounded/unbounded, text, structured, etc.) 56 | * % of missing values 57 | * Noisiness and type of noise (stochastic, outliers, rounding errors, etc.) 58 | * Usefulness for the task 59 | * Type of distribution (Gaussian, uniform, logarithmic, etc.) 60 | 4. For supervised learning tasks, identify the target attribute(s). 61 | 5. Visualize the data. 62 | 6. Study the correlations between attributes. 63 | 7. Study how you would solve the problem manually.8. Identify the promising transformations you may want to apply. 64 | 9. Identify extra data that would be useful (go back to “Get the Data”). 65 | 10. Document what you have learned. 66 | 67 | 68 | ## Prepare the Data 69 | Notes: 70 | * Work on copies of the data (keep the original dataset intact). 71 | * Write functions for all data transformations you apply, for five reasons: 72 | * So you can easily prepare the data the next time you get a 73 | * fresh dataset 74 | * So you can apply these transformations in future projects 75 | * To clean and prepare the test set 76 | * To clean and prepare new data instances once your solution is live 77 | * To make it easy to treat your preparation choices as hyperparameters 78 | 79 | 1. Data cleaning: 80 | * Fix or remove outliers (optional). 81 | * Fill in missing values (e.g., with zero, mean, median...) or drop their rows (or columns). 82 | 2. Feature selection (optional): 83 | * Drop the attributes that provide no useful information for the task. 84 | 3. Feature engineering, where appropriate: 85 | * Discretize continuous features. 86 | * Decompose features (e.g., categorical, date/time, etc.). 87 | * Add promising transformations of features (e.g., log(x), sqrt(x), x 2 , etc.). 88 | * Aggregate features into promising new features. 89 | 4. Feature scaling: 90 | * Standardize or normalize features. 91 | 92 | 93 | ## Shortlist Promising Models 94 | Notes: 95 | * If the data is huge, you may want to sample smaller training sets so you can train many different models in a reasonable time (be aware that this penalizes complex models such as large neural nets or Random Forests). 96 | * Once again, try to automate these steps as much as possible. 97 | 98 | 1. Train many quick-and-dirty models from different categories (e.g., linear, naive Bayes, SVM, Random Forest, neural net, etc.) using standard parameters. 99 | 2. Measure and compare their performance. For each model, use N-fold cross-validation and compute the mean and standard deviation of the performance measure on the N folds. 100 | 3. Analyze the most significant variables for each algorithm. 101 | 4. Analyze the types of errors the models make.What data would a human have used to avoid these errors? 102 | 5. Perform a quick round of feature selection and engineering. 103 | 6. Perform one or two more quick iterations of the five previous steps. 104 | 7. Shortlist the top three to five most promising models, preferring models that make different types of errors. 105 | 106 | 107 | ## Fine-Tune the System 108 | Notes: 109 | * You will want to use as much data as possible for this step, especially as you move toward the end of fine-tuning. 110 | * As always, automate what you can. 111 | 1. Fine-tune the hyperparameters using cross-validation: 112 | * Treat your data transformation choices as hyperparameters, especially when you are not sure about them (e.g., if you’re not sure whether to replace missing values with zeros or with the median value, or to just drop the rows). 113 | * Unless there are very few hyperparameter values to explore, prefer random search over grid search. If training is very long, you may prefer a Bayesian optimization approach (e.g., using Gaussian process priors, as described by Jasper Snoek et al.). 1 114 | 2. Try Ensemble methods. Combining your best models will often produce better performance than running them individually. 115 | 3. Once you are confident about your final model, measure its performance on the test set to estimate the generalization error. 116 | 117 | _WARNING_ 118 | Don’t tweak your model after measuring the generalization error: you would just 119 | start overfitting the test set. 120 | 121 | 122 | ## Present Your Solution 123 | 124 | 1. Document what you have done. 125 | 2. Create a nice presentation. Make sure you highlight the big picture first. 126 | 3. Explain why your solution achieves the business objective. 127 | 4. Don’t forget to present interesting points you noticed along the way. 128 | * Describe what worked and what did not. 129 | * List your assumptions and your system’s limitations. 130 | 5. Ensure your key findings are communicated through beautiful visualizations or easy-to-remember statements (e.g., “the median income is the number-one predictor of housing prices”). 131 | 132 | 133 | ## Launch! 134 | 1. Get your solution ready for production (plug into production data inputs, write unit tests, etc.). 135 | 2. Write monitoring code to check your system’s live performance at regular intervals and trigger alerts when it drops. 136 | * Beware of slow degradation: models tend to “rot” as data evolves. 137 | * Measuring performance may require a human pipeline (e.g., via a crowdsourcing service). 138 | * Also monitor your inputs’ quality (e.g., a malfunctioning sensor sending random values, or another team’s output becoming stale). This is particularly important for online learning systems. 139 | 3. Retrain your models on a regular basis on fresh data (automate as much as possible). -------------------------------------------------------------------------------- /Ch3_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Chapter 3. Classification\n", 8 | "\n", 9 | "## MNIST\n", 10 | "Scikit-Learn provides many helper functions to download popular datasets. MNIST is one of them, which is a set of 70,000 small images of digits handwritten by high school students and employees of the US Census Bureau. Each image is labeled with the digit it represents." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [ 18 | { 19 | "data": { 20 | "text/plain": [ 21 | "dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])" 22 | ] 23 | }, 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "output_type": "execute_result" 27 | } 28 | ], 29 | "source": [ 30 | "from sklearn.datasets import fetch_openml\n", 31 | "\n", 32 | "mnist = fetch_openml('mnist_784', version=1)\n", 33 | "mnist.keys()" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "X size (70000, 784)\n", 46 | "label size: (70000,)\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "X, y = mnist['data'], mnist['target']\n", 52 | "print('X size', X.shape)\n", 53 | "print('label size:', y.shape)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "There are 70,000 images, and each image has 784 features. This is because each image is 28 × 28 pixels, and each feature simply represents one pixel’s intensity, from 0 (white) to 255 (black). Let’s take a peek at one digit from the dataset. All you need to do is grab an instance’s feature vector, reshape it to a 28 × 28 array, and display it using Matplotlib’s `imshow()` function:" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 3, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "data": { 70 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAGaElEQVR4nO3dPUiWfR/G8dveSyprs2gOXHqhcAh6hZqsNRqiJoPKRYnAoTGorWyLpqhFcmgpEmqIIByKXiAHIaKhFrGghiJ81ucBr991Z/Z4XPr5jB6cXSfVtxP6c2rb9PT0P0CeJfN9A8DMxAmhxAmhxAmhxAmhljXZ/Vcu/H1tM33RkxNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCiRNCLZvvG+B//fr1q9y/fPnyVz9/aGio4fb9+/fy2vHx8XK/ceNGuQ8MDDTc7t69W167atWqcr948WK5X7p0qdzngycnhBInhBInhBInhBInhBInhBInhHLOOYMPHz6U+48fP8r92bNn5f706dOG29TUVHnt8PBwuc+nLVu2lPv58+fLfWRkpOG2du3a8tpt27aV+759+8o9kScnhBInhBInhBInhBInhBInhGqbnp6u9nJsVS9evCj3gwcPlvvffm0r1dKlS8v91q1b5d7e3j7rz960aVO5b9iwody3bt0668/+P2ib6YuenBBKnBBKnBBKnBBKnBBKnBBKnBBqUZ5zTk5Olnt3d3e5T0xMzOXtzKlm997sPPDx48cNtxUrVpTXLtbz3zngnBNaiTghlDghlDghlDghlDghlDgh1KL81pgbN24s96tXr5b7/fv3y33Hjh3l3tfXV+6V7du3l/vo6Gi5N3un8s2bNw23a9euldcytzw5IZQ4IZQ4IZQ4IZQ4IZQ4IZQ4IdSifJ/zT339+rXcm/24ut7e3obbzZs3y2tv375d7idOnCh3InmfE1qJOCGUOCGUOCGUOCGUOCGUOCHUonyf80+tW7fuj65fv379rK9tdg56/Pjxcl+yxL/HrcKfFIQSJ4QSJ4QSJ4QSJ4QSJ4Tyytg8+PbtW8Otp6envPbJkyfl/uDBg3I/fPhwuTMvvDIGrUScEEqcEEqcEEqcEEqcEEqcEMo5Z5iJiYly37lzZ7l3dHSU+4EDB8p9165dDbezZ8+W17a1zXhcR3POOaGViBNCiRNCiRNCiRNCiRNCiRNCOedsMSMjI+V++vTpcm/24wsrly9fLveTJ0+We2dn56w/e4FzzgmtRJwQSpwQSpwQSpwQSpwQSpwQyjnnAvP69ety7+/vL/fR0dFZf/aZM2fKfXBwsNw3b948689ucc45oZWIE0KJE0KJE0KJE0KJE0KJE0I551xkpqamyv3+/fsNt1OnTpXXNvm79M+hQ4fK/dGjR+W+gDnnhFYiTgglTgglTgglTgglTgjlKIV/beXKleX+8+fPcl++fHm5P3z4sOG2f//+8toW5ygFWok4IZQ4IZQ4IZQ4IZQ4IZQ4IdSy+b4B5tarV6/KfXh4uNzHxsYabs3OMZvp6uoq97179/7Rr7/QeHJCKHFCKHFCKHFCKHFCKHFCKHFCKOecYcbHx8v9+vXr5X7v3r1y//Tp02/f07+1bFn916mzs7PclyzxrPhvfjcglDghlDghlDghlDghlDghlDghlHPOv6DZWeKdO3cabkNDQ+W179+/n80tzYndu3eX++DgYLkfPXp0Lm9nwfPkhFDihFDihFDihFDihFDihFCOUmbw+fPncn/79m25nzt3rtzfvXv32/c0V7q7u8v9woULDbdjx46V13rla2753YRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQC/acc3JysuHW29tbXvvy5ctyn5iYmM0tzYk9e/aUe39/f7kfOXKk3FevXv3b98Tf4ckJocQJocQJocQJocQJocQJocQJoWLPOZ8/f17uV65cKfexsbGG28ePH2d1T3NlzZo1Dbe+vr7y2mbffrK9vX1W90QeT04IJU4IJU4IJU4IJU4IJU4IJU4IFXvOOTIy8kf7n+jq6ir3np6ecl+6dGm5DwwMNNw6OjrKa1k8PDkhlDghlDghlDghlDghlDghlDghVNv09HS1lyMwJ9pm+qInJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4Rq9iMAZ/yWfcDf58kJocQJocQJocQJocQJocQJof4DO14Dh4wBfawAAAAASUVORK5CYII=\n", 71 | "text/plain": [ 72 | "
" 73 | ] 74 | }, 75 | "metadata": { 76 | "needs_background": "light" 77 | }, 78 | "output_type": "display_data" 79 | } 80 | ], 81 | "source": [ 82 | "import matplotlib as mpl\n", 83 | "import matplotlib.pyplot as plt\n", 84 | "\n", 85 | "some_digit = X.iloc[0].values\n", 86 | "some_digits_image = some_digit.reshape(28, 28)\n", 87 | "\n", 88 | "plt.imshow(some_digits_image, cmap='binary')\n", 89 | "plt.axis(\"off\")\n", 90 | "plt.show()" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 4, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "5\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "print(y[0])" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 5, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# Let's cast labels as integers\n", 117 | "import numpy as np\n", 118 | "\n", 119 | "y = y.astype(np.uint8)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "The MNIST dataset is actually already split into a training set (the first 60,000 images) and a test set (the last 10,000 images):" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 6, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "## Training a Binary Classifier\n", 143 | "\n", 144 | "Let's try to identify only one digit \"5\". This will be an example of _binary classifier_, capable of distibuishing between just two classes, 5 and not-5." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 7, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "# Set for all 5s\n", 154 | "y_train_5 = (y_train == 5)\n", 155 | "y_test_5 = (y_test == 5)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "Now let’s pick a classifier and train it. A good place to start is with a _Stochastic Gradient Descent (SGD)_ classifier. This classifier deals with large datasets efficiently, because it deals with training instances independetly, one at a time." 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 8, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "data": { 172 | "text/plain": [ 173 | "array([ True])" 174 | ] 175 | }, 176 | "execution_count": 8, 177 | "metadata": {}, 178 | "output_type": "execute_result" 179 | } 180 | ], 181 | "source": [ 182 | "from sklearn.linear_model import SGDClassifier\n", 183 | "\n", 184 | "sgd_clf = SGDClassifier(random_state=42)\n", 185 | "sgd_clf.fit(X_train, y_train_5)\n", 186 | "\n", 187 | "# Predicting if the digit is 5\n", 188 | "sgd_clf.predict([some_digit])" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "## Performance Measures\n", 196 | "\n", 197 | "### Measuring Accuracy Using Cross-Validation\n", 198 | "Remember that K-fold cross-validation means splitting the training set into K folds (in this case, three), then making predictions and evaluatingthem on each fold using a model trained on the remaining folds." 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 9, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "data": { 208 | "text/plain": [ 209 | "array([0.95035, 0.96035, 0.9604 ])" 210 | ] 211 | }, 212 | "execution_count": 9, 213 | "metadata": {}, 214 | "output_type": "execute_result" 215 | } 216 | ], 217 | "source": [ 218 | "from sklearn.model_selection import cross_val_score\n", 219 | "\n", 220 | "cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring='accuracy')" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "before we get too excited, let’s look at a very dumb classifier that just classifies every single image in the “not-5” class:" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 10, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "data": { 237 | "text/plain": [ 238 | "array([0.91125, 0.90855, 0.90915])" 239 | ] 240 | }, 241 | "execution_count": 10, 242 | "metadata": {}, 243 | "output_type": "execute_result" 244 | } 245 | ], 246 | "source": [ 247 | "from sklearn.base import BaseEstimator\n", 248 | "\n", 249 | "class Never5Classifier(BaseEstimator):\n", 250 | " \n", 251 | " def fit(self, X, y=None):\n", 252 | " pass\n", 253 | " \n", 254 | " def predict(self, X):\n", 255 | " return np.zeros((len(X), 1), dtype=bool)\n", 256 | "\n", 257 | "never_5_clf = Never5Classifier()\n", 258 | "cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring='accuracy')" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": {}, 264 | "source": [ 265 | "That’s right, it has over 90% accuracy! This is simply because only about 10% of the images are 5s, so if you always guess that an image is not a 5, you will be right about 90% of the time.\n", 266 | "\n", 267 | "This demonstrates why accuracy is generally not the preferred performance measure for classifiers, especially when you are dealing with _skewed datasets_.\n", 268 | "\n", 269 | "### Confusion Matrix\n", 270 | "\n", 271 | "The general idea is to count the number of times instances of class A are classified as class B. For example, to know the number of times the classifier confused images of 5s with 3s, you would look in the fifth row and third column of the confusion matrix.\n", 272 | "\n", 273 | "To compute the confusion matrix, you first need to have a set of predictions so that they can be compared to the actual targets. The `cross_val_predict()`\n", 274 | "performs K-fold cross-validation, but instead of returning the evaluation scores, it returns the predictions made on each test fold. This means that you get a clean prediction for each instance in the training set." 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 11, 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [ 283 | "from sklearn.model_selection import cross_val_predict\n", 284 | "\n", 285 | "y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "Let's get the confusion matrix and pass the target classes and the predicted classes" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 12, 298 | "metadata": {}, 299 | "outputs": [ 300 | { 301 | "data": { 302 | "text/plain": [ 303 | "array([[53892, 687],\n", 304 | " [ 1891, 3530]])" 305 | ] 306 | }, 307 | "execution_count": 12, 308 | "metadata": {}, 309 | "output_type": "execute_result" 310 | } 311 | ], 312 | "source": [ 313 | "from sklearn.metrics import confusion_matrix\n", 314 | "\n", 315 | "confusion_matrix(y_train_5, y_train_pred)" 316 | ] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "metadata": {}, 321 | "source": [ 322 | "Each row in a confusion matrix represents an _actual class_, while each column represents a _predicted class_. The first row of this matrix considers non-5 images (the negative class): 53,892 of them were correctly classified as non-5s (they are called true negatives), while the remaining 687 were wrongly classified as 5s (false positives). The second row considers the images of 5s (the positive class): 1,891 were wrongly classified as non-5s (false negatives), while the remaining 3,530 were correctly classified as 5s (true positives).\n", 323 | "\n", 324 | "An interesting one to look at the accuracy of the positive predictions; this is called the precision of the classifier. _precision_= $\\frac{TP}{TP+FP}$. The $TP$ is the number of true positives, and $FP$ is the number of false positives.\n", 325 | "\n", 326 | "precision is typically used along with another metric named _recall_, also called sensitivity or the true positive rate (TPR): this is the ratio of positive instances that are correctly detected by the classifier. _recall_=$\\frac{TP}{TP+FN}$. $FN$ is the number of false negatives.\n", 327 | "\n", 328 | "### Precision and Recall" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 13, 334 | "metadata": {}, 335 | "outputs": [ 336 | { 337 | "data": { 338 | "text/plain": [ 339 | "0.8370879772350012" 340 | ] 341 | }, 342 | "execution_count": 13, 343 | "metadata": {}, 344 | "output_type": "execute_result" 345 | } 346 | ], 347 | "source": [ 348 | "from sklearn.metrics import precision_score, recall_score\n", 349 | "\n", 350 | "precision_score(y_train_5, y_train_pred)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 14, 356 | "metadata": {}, 357 | "outputs": [ 358 | { 359 | "data": { 360 | "text/plain": [ 361 | "0.6511713705958311" 362 | ] 363 | }, 364 | "execution_count": 14, 365 | "metadata": {}, 366 | "output_type": "execute_result" 367 | } 368 | ], 369 | "source": [ 370 | "recall_score(y_train_5, y_train_pred)" 371 | ] 372 | }, 373 | { 374 | "cell_type": "markdown", 375 | "metadata": {}, 376 | "source": [ 377 | "It is often convenient to combine precision and recall into a single metric called the $F_1$ score, in particular if you need a simple way to compare two \n", 378 | "classifiers. The $F_1$ score is the harmonic mean of precision and recall. Whereas the regular mean treats all values equally, the harmonic mean gives much more weight to low values.\n", 379 | "\n", 380 | "$F_1 = \\frac{2}{\\frac{1}{precision} + \\frac{1}{recall}}=\\frac{TP}{TP+\\frac{FN+FP}{2}}$" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 15, 386 | "metadata": {}, 387 | "outputs": [ 388 | { 389 | "data": { 390 | "text/plain": [ 391 | "0.7325171197343846" 392 | ] 393 | }, 394 | "execution_count": 15, 395 | "metadata": {}, 396 | "output_type": "execute_result" 397 | } 398 | ], 399 | "source": [ 400 | "from sklearn.metrics import f1_score\n", 401 | "\n", 402 | "f1_score(y_train_5, y_train_pred)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "metadata": {}, 408 | "source": [ 409 | "The $F_1$ score favors classifiers that have similar precision and recall. This\n", 410 | "is not always what you want: in some contexts you mostly care about\n", 411 | "precision, and in other contexts you really care about recall. For example,\n", 412 | "if you trained a classifier to detect videos that are safe for kids, you would\n", 413 | "probably prefer a classifier that rejects many good videos (low recall) but\n", 414 | "keeps only safe ones (high precision), rather than a classifier that has a\n", 415 | "much higher recall but lets a few really bad videos show up in your\n", 416 | "product (in such cases, you may even want to add a human pipeline to\n", 417 | "check the classifier’s video selection). On the other hand, suppose you\n", 418 | "train a classifier to detect shoplifters in surveillance images: it is probably\n", 419 | "fine if your classifier has only 30% precision as long as it has 99% recall\n", 420 | "(sure, the security guards will get a few false alerts, but almost all\n", 421 | "shoplifters will get caught).\n", 422 | "\n", 423 | "Unfortunately, you can’t have it both ways: increasing precision reduces\n", 424 | "recall, and vice versa. This is called the precision/recall trade-off.\n", 425 | "\n", 426 | "How to decide which threshold to use for **SGDClassifier**? First, use the `cross_val_predict()` function to get the scores of all instances in the training set, but this time specify that you want to return decision scores instead of predictions:" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 16, 432 | "metadata": {}, 433 | "outputs": [], 434 | "source": [ 435 | "y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, \n", 436 | " cv=3, method='decision_function')" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": {}, 442 | "source": [ 443 | "compute precision and recall for all possible thresholds:" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 18, 449 | "metadata": {}, 450 | "outputs": [ 451 | { 452 | "data": { 453 | "image/png": "\n", 454 | "text/plain": [ 455 | "
" 456 | ] 457 | }, 458 | "metadata": { 459 | "needs_background": "light" 460 | }, 461 | "output_type": "display_data" 462 | } 463 | ], 464 | "source": [ 465 | "from sklearn.metrics import precision_recall_curve\n", 466 | "\n", 467 | "precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)\n", 468 | "\n", 469 | "# Plot precision and recall as functions of the threshold value\n", 470 | "def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):\n", 471 | " plt.figure(figsize=(8,4))\n", 472 | " plt.plot(thresholds, precisions[:-1], 'b--', label='Precision')\n", 473 | " plt.plot(thresholds, recalls[:-1], 'g-', label='Recall')\n", 474 | " plt.grid(axis='both')\n", 475 | " plt.legend()\n", 476 | " \n", 477 | "plot_precision_recall_vs_threshold(precisions, recalls, thresholds)\n", 478 | "plt.show()" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "Suppose you decide to aim for 90% precision. You look up the first plot\n", 486 | "and find that you need to use a threshold of about 8,000. To be more\n", 487 | "precise you can search for the lowest threshold that gives you at least 90%\n", 488 | "precision ( `np.argmax()` will give you the first index of the maximum\n", 489 | "value, which in this case means the first `True` value)." 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 19, 495 | "metadata": {}, 496 | "outputs": [ 497 | { 498 | "name": "stdout", 499 | "output_type": "stream", 500 | "text": [ 501 | "Precision: 0.9000345901072293\n", 502 | "Recall: 0.4799852425751706\n" 503 | ] 504 | } 505 | ], 506 | "source": [ 507 | "threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]\n", 508 | "\n", 509 | "# Now let's make prediction with this threshold\n", 510 | "y_train_pred_90 = (y_scores >= threshold_90_precision)\n", 511 | "\n", 512 | "print(\"Precision:\", precision_score(y_train_5, y_train_pred_90))\n", 513 | "print(\"Recall:\", recall_score(y_train_5, y_train_pred_90))" 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": {}, 519 | "source": [ 520 | "A high-precision classifier is not very useful if its recall is too low!\n", 521 | "\n", 522 | "### The ROC Curve\n", 523 | "The _receiver operating characteristic (ROC)_ curve is another common\n", 524 | "tool used with binary classifiers. It is very similar to the precision/recall\n", 525 | "curve, but instead of plotting precision versus recall, the ROC curve plots\n", 526 | "the true positive rate (another name for recall) against the false positive\n", 527 | "rate (FPR). The FPR is the ratio of negative instances that are incorrectly\n", 528 | "classified as positive. It is equal to 1 – the true negative rate (TNR), which\n", 529 | "is the ratio of negative instances that are correctly classified as negative.\n", 530 | "The TNR is also called specificity. Hence, the ROC curve plots sensitivity\n", 531 | "(recall) versus 1 – specificity" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": 20, 537 | "metadata": {}, 538 | "outputs": [ 539 | { 540 | "data": { 541 | "image/png": "\n", 542 | "text/plain": [ 543 | "
" 544 | ] 545 | }, 546 | "metadata": { 547 | "needs_background": "light" 548 | }, 549 | "output_type": "display_data" 550 | } 551 | ], 552 | "source": [ 553 | "from sklearn.metrics import roc_curve\n", 554 | "\n", 555 | "fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)\n", 556 | "\n", 557 | "# Plot FPR against the TPR\n", 558 | "def plot_roc_curve(fpr, tpr, label=None):\n", 559 | " \"\"\"Plots ROC Curve (FPR agains TPR)\"\"\"\n", 560 | " plt.figure(figsize=(5, 5))\n", 561 | " plt.plot(fpr, tpr, linewidth=3, label=label)\n", 562 | " plt.plot([0, 1], [0, 1], 'k--')\n", 563 | " plt.grid(axis='both')\n", 564 | " plt.ylabel('True Positive Rate (Recall)')\n", 565 | " plt.xlabel('False Positive Rate')\n", 566 | " plt.ylim(0, 1)\n", 567 | " plt.xlim(0, 1)\n", 568 | " \n", 569 | "plot_roc_curve(fpr, tpr)\n", 570 | "plt.show()" 571 | ] 572 | }, 573 | { 574 | "cell_type": "markdown", 575 | "metadata": {}, 576 | "source": [ 577 | "One way to compare classifiers is to measure the area under the curve\n", 578 | "(AUC). A perfect classifier will have a ROC AUC equal to 1, whereas a\n", 579 | "purely random classifier will have a ROC AUC equal to 0.5. Scikit-Learn\n", 580 | "provides a function to compute the ROC AUC:" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": 21, 586 | "metadata": {}, 587 | "outputs": [ 588 | { 589 | "name": "stdout", 590 | "output_type": "stream", 591 | "text": [ 592 | "ROC AUC Score: 0.9604938554008616\n" 593 | ] 594 | } 595 | ], 596 | "source": [ 597 | "from sklearn.metrics import roc_auc_score\n", 598 | "\n", 599 | "print('ROC AUC Score:', roc_auc_score(y_train_5, y_scores))" 600 | ] 601 | }, 602 | { 603 | "cell_type": "markdown", 604 | "metadata": {}, 605 | "source": [ 606 | "As a rule of thumb, you should prefer the PR curve\n", 607 | "whenever the positive class is rare or when you care more about the false positives\n", 608 | "than the false negatives. Otherwise, use the ROC curve." 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": 22, 614 | "metadata": {}, 615 | "outputs": [], 616 | "source": [ 617 | "# Let's train random forest model\n", 618 | "from sklearn.ensemble import RandomForestClassifier\n", 619 | "\n", 620 | "forest_clf = RandomForestClassifier(random_state=42)\n", 621 | "y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, \n", 622 | " method='predict_proba')" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 23, 628 | "metadata": {}, 629 | "outputs": [ 630 | { 631 | "data": { 632 | "image/png": "\n", 633 | "text/plain": [ 634 | "
" 635 | ] 636 | }, 637 | "metadata": { 638 | "needs_background": "light" 639 | }, 640 | "output_type": "display_data" 641 | } 642 | ], 643 | "source": [ 644 | "# The roc_curve() function expects labels and scores, \n", 645 | "# but instead of scores you can give it class probabilities\n", 646 | "y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n", 647 | "fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)\n", 648 | "\n", 649 | "plot_roc_curve(fpr_forest, tpr_forest, \"Random Forest\")" 650 | ] 651 | }, 652 | { 653 | "cell_type": "code", 654 | "execution_count": 24, 655 | "metadata": {}, 656 | "outputs": [ 657 | { 658 | "name": "stdout", 659 | "output_type": "stream", 660 | "text": [ 661 | "ROC AUC score: 0.9983436731328145\n" 662 | ] 663 | } 664 | ], 665 | "source": [ 666 | "print('ROC AUC score:', roc_auc_score(y_train_5, y_scores_forest))" 667 | ] 668 | }, 669 | { 670 | "cell_type": "markdown", 671 | "metadata": {}, 672 | "source": [ 673 | "You now know how to train binary classifiers, choose the appropriate\n", 674 | "metric for your task, evaluate your classifiers using cross-validation,\n", 675 | "select the precision/recall trade-off that fits your needs, and use ROC\n", 676 | "curves and ROC AUC scores to compare various models.\n", 677 | "\n", 678 | "## Multiclass Classification\n", 679 | "_Multinomial classifier_ distinguishes between more than two classes. Even though some algorithms only perform binary classifications, there are ways to still use them for multiple classes classification. One way is to create a system that classifies 10 different type images by taking one class at a time and classifying it against the rest. This is called _one-versus-rest_ (OvR) strategy. Another strategy is to train a binary classifier for every pair of classes: distinguish between 0 and 1 class, then between 0 and 2, then between 0 and 3 and so on. This is called _one-versus-one_ (OvO) strategy. For $N$ classes you need to train $\\frac{N\\times (N-1)}{2}$ classifiers.\n", 680 | "\n", 681 | "Scikit-Learn detects when you try to use a binary classification algorithm for a multiclass classification task, and it automatically runs OvR or OvO, depending on the algorithm:" 682 | ] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "execution_count": 25, 687 | "metadata": {}, 688 | "outputs": [ 689 | { 690 | "data": { 691 | "text/plain": [ 692 | "array([5], dtype=uint8)" 693 | ] 694 | }, 695 | "execution_count": 25, 696 | "metadata": {}, 697 | "output_type": "execute_result" 698 | } 699 | ], 700 | "source": [ 701 | "from sklearn.svm import SVC\n", 702 | "\n", 703 | "svm_clf = SVC()\n", 704 | "svm_clf.fit(X_train, y_train)\n", 705 | "svm_clf.predict([some_digit])" 706 | ] 707 | }, 708 | { 709 | "cell_type": "markdown", 710 | "metadata": {}, 711 | "source": [ 712 | "This code trains the SVC on the training set using the original target classes from 0 to 9 ( y_train ), instead of the 5-versus-the-rest target classes ( y_train_5 )" 713 | ] 714 | }, 715 | { 716 | "cell_type": "code", 717 | "execution_count": 26, 718 | "metadata": {}, 719 | "outputs": [ 720 | { 721 | "data": { 722 | "text/plain": [ 723 | "array([[ 1.72501977, 2.72809088, 7.2510018 , 8.3076379 , -0.31087254,\n", 724 | " 9.3132482 , 1.70975103, 2.76765202, 6.23049537, 4.84771048]])" 725 | ] 726 | }, 727 | "execution_count": 26, 728 | "metadata": {}, 729 | "output_type": "execute_result" 730 | } 731 | ], 732 | "source": [ 733 | "# This will return 10 scores per instance\n", 734 | "some_digit_scores = svm_clf.decision_function([some_digit])\n", 735 | "some_digit_scores" 736 | ] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "execution_count": 27, 741 | "metadata": {}, 742 | "outputs": [ 743 | { 744 | "data": { 745 | "text/plain": [ 746 | "5" 747 | ] 748 | }, 749 | "execution_count": 27, 750 | "metadata": {}, 751 | "output_type": "execute_result" 752 | } 753 | ], 754 | "source": [ 755 | "# Highest score is the one corresponding to class 5\n", 756 | "np.argmax(some_digit_scores)" 757 | ] 758 | }, 759 | { 760 | "cell_type": "code", 761 | "execution_count": 28, 762 | "metadata": {}, 763 | "outputs": [ 764 | { 765 | "data": { 766 | "text/plain": [ 767 | "array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=uint8)" 768 | ] 769 | }, 770 | "execution_count": 28, 771 | "metadata": {}, 772 | "output_type": "execute_result" 773 | } 774 | ], 775 | "source": [ 776 | "svm_clf.classes_" 777 | ] 778 | }, 779 | { 780 | "cell_type": "code", 781 | "execution_count": 29, 782 | "metadata": {}, 783 | "outputs": [ 784 | { 785 | "data": { 786 | "text/plain": [ 787 | "5" 788 | ] 789 | }, 790 | "execution_count": 29, 791 | "metadata": {}, 792 | "output_type": "execute_result" 793 | } 794 | ], 795 | "source": [ 796 | "svm_clf.classes_[5]" 797 | ] 798 | }, 799 | { 800 | "cell_type": "markdown", 801 | "metadata": {}, 802 | "source": [ 803 | "**WARNING**: When a classifier is trained, it stores the list of target classes in its `classes_` attribute, ordered by value.\n", 804 | "\n", 805 | "If you want to force Scikit-Learn to use _one-versus-one_ or _one-versus-the-rest_, you can use the `OneVsOneClassifier` or `OneVsRestClassifier` classes" 806 | ] 807 | }, 808 | { 809 | "cell_type": "code", 810 | "execution_count": 30, 811 | "metadata": {}, 812 | "outputs": [ 813 | { 814 | "data": { 815 | "text/plain": [ 816 | "array([5], dtype=uint8)" 817 | ] 818 | }, 819 | "execution_count": 30, 820 | "metadata": {}, 821 | "output_type": "execute_result" 822 | } 823 | ], 824 | "source": [ 825 | "# Classify using OvR\n", 826 | "from sklearn.multiclass import OneVsRestClassifier\n", 827 | "\n", 828 | "ovr_clf = OneVsRestClassifier(SVC())\n", 829 | "ovr_clf.fit(X_train, y_train)\n", 830 | "ovr_clf.predict([some_digit])" 831 | ] 832 | }, 833 | { 834 | "cell_type": "code", 835 | "execution_count": 31, 836 | "metadata": {}, 837 | "outputs": [ 838 | { 839 | "data": { 840 | "text/plain": [ 841 | "10" 842 | ] 843 | }, 844 | "execution_count": 31, 845 | "metadata": {}, 846 | "output_type": "execute_result" 847 | } 848 | ], 849 | "source": [ 850 | "len(ovr_clf.estimators_)" 851 | ] 852 | }, 853 | { 854 | "cell_type": "code", 855 | "execution_count": 32, 856 | "metadata": {}, 857 | "outputs": [ 858 | { 859 | "data": { 860 | "text/plain": [ 861 | "array([3], dtype=uint8)" 862 | ] 863 | }, 864 | "execution_count": 32, 865 | "metadata": {}, 866 | "output_type": "execute_result" 867 | } 868 | ], 869 | "source": [ 870 | "# Trying SGD Classifier. No need for OvR or OvO as\n", 871 | "# SGD directly classifies into multiple classes.\n", 872 | "sgd_clf.fit(X_train, y_train)\n", 873 | "sgd_clf.predict([some_digit])" 874 | ] 875 | }, 876 | { 877 | "cell_type": "code", 878 | "execution_count": 33, 879 | "metadata": {}, 880 | "outputs": [ 881 | { 882 | "data": { 883 | "text/plain": [ 884 | "array([[-31893.03095419, -34419.69069632, -9530.63950739,\n", 885 | " 1823.73154031, -22320.14822878, -1385.80478895,\n", 886 | " -26188.91070951, -16147.51323997, -4604.35491274,\n", 887 | " -12050.767298 ]])" 888 | ] 889 | }, 890 | "execution_count": 33, 891 | "metadata": {}, 892 | "output_type": "execute_result" 893 | } 894 | ], 895 | "source": [ 896 | "sgd_clf.decision_function([some_digit])" 897 | ] 898 | }, 899 | { 900 | "cell_type": "markdown", 901 | "metadata": {}, 902 | "source": [ 903 | "You can see that the classifier is fairly confident about its prediction: almost all scores are largely negative, while class 5 has a score of ...\n", 904 | "\n", 905 | "You can evaluate this classifier using `cros_val_score()`." 906 | ] 907 | }, 908 | { 909 | "cell_type": "code", 910 | "execution_count": 34, 911 | "metadata": {}, 912 | "outputs": [ 913 | { 914 | "data": { 915 | "text/plain": [ 916 | "array([0.87365, 0.85835, 0.8689 ])" 917 | ] 918 | }, 919 | "execution_count": 34, 920 | "metadata": {}, 921 | "output_type": "execute_result" 922 | } 923 | ], 924 | "source": [ 925 | "cross_val_score(sgd_clf, X_train, y_train,\n", 926 | " cv=3, scoring=\"accuracy\")" 927 | ] 928 | }, 929 | { 930 | "cell_type": "markdown", 931 | "metadata": {}, 932 | "source": [ 933 | "We get 86% on all test folds. If we used a random classifier we would get 10%. We still can get a better score by scaling the inputs:" 934 | ] 935 | }, 936 | { 937 | "cell_type": "code", 938 | "execution_count": 35, 939 | "metadata": {}, 940 | "outputs": [ 941 | { 942 | "data": { 943 | "text/plain": [ 944 | "array([0.8983, 0.891 , 0.9018])" 945 | ] 946 | }, 947 | "execution_count": 35, 948 | "metadata": {}, 949 | "output_type": "execute_result" 950 | } 951 | ], 952 | "source": [ 953 | "from sklearn.preprocessing import StandardScaler\n", 954 | "\n", 955 | "scaler = StandardScaler()\n", 956 | "X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))\n", 957 | "cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring='accuracy')" 958 | ] 959 | }, 960 | { 961 | "cell_type": "markdown", 962 | "metadata": {}, 963 | "source": [ 964 | "## Error Analysis\n", 965 | "\n", 966 | "One way to improve selected model is to analyze the types of errors it makes." 967 | ] 968 | }, 969 | { 970 | "cell_type": "code", 971 | "execution_count": 37, 972 | "metadata": {}, 973 | "outputs": [ 974 | { 975 | "name": "stdout", 976 | "output_type": "stream", 977 | "text": [ 978 | "[[5577 0 22 5 8 43 36 6 225 1]\n", 979 | " [ 0 6400 37 24 4 44 4 7 212 10]\n", 980 | " [ 27 27 5220 92 73 27 67 36 378 11]\n", 981 | " [ 22 17 117 5227 2 203 27 40 403 73]\n", 982 | " [ 12 14 41 9 5182 12 34 27 347 164]\n", 983 | " [ 27 15 30 168 53 4444 75 14 535 60]\n", 984 | " [ 30 15 42 3 44 97 5552 3 131 1]\n", 985 | " [ 21 10 51 30 49 12 3 5684 195 210]\n", 986 | " [ 17 63 48 86 3 126 25 10 5429 44]\n", 987 | " [ 25 18 30 64 118 36 1 179 371 5107]]\n" 988 | ] 989 | } 990 | ], 991 | "source": [ 992 | "# Let's look at the confusion matrix\n", 993 | "y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)\n", 994 | "conf_mx = confusion_matrix(y_train, y_train_pred)\n", 995 | "print(conf_mx)" 996 | ] 997 | }, 998 | { 999 | "cell_type": "code", 1000 | "execution_count": 61, 1001 | "metadata": {}, 1002 | "outputs": [ 1003 | { 1004 | "data": { 1005 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPoAAAECCAYAAADXWsr9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAALDUlEQVR4nO3dz4vc9R3H8dcr2XXXJCX+aC5mpVmxGEQIq2tRAx6Mh7aKIvRgwUC97KXVKIJoL/4DInoowhLrxaCHGEGkWAvqoZeQTVaIyRoUfyTRiOlCjQgmu867hxlhk9063zHfz35nfD8fIGTHbz55M9lnvjOz3/mMI0IAft7WND0AgPIIHUiA0IEECB1IgNCBBAgdSKCx0G3/1vYx2x/ZfqKpOaqyfbXtd2wftX3E9q6mZ6rC9lrbs7bfaHqWKmxfZnuv7Q9sz9m+temZurH9aOd74n3bL9sebXqmCzUSuu21kv4m6XeSrpf0R9vXNzFLDxYlPRYR10u6RdKfB2BmSdolaa7pIXrwnKQ3I2KrpG3q89ltb5b0sKTJiLhB0lpJ9zc71XJNndF/I+mjiPg4Is5JekXSvQ3NUklEnIqIQ51ff6P2N+DmZqf6cbbHJN0laXfTs1Rhe6Ok2yW9IEkRcS4i/tvoUNUMSbrU9pCkdZK+aHieZZoKfbOkE0u+Pqk+j2Yp21skTUja3/Ao3Twr6XFJrYbnqGpc0mlJL3aebuy2vb7poX5MRHwu6WlJxyWdkvR1RLzV7FTL8WJcj2xvkPSqpEci4kzT8/w/tu+W9FVEHGx6lh4MSbpR0vMRMSHpW0l9/fqN7cvVfjQ6LukqSettP9DsVMs1Ffrnkq5e8vVY57a+ZntY7cj3RMS+pufpYruke2x/qvZToztsv9TsSF2dlHQyIn54pLRX7fD72Z2SPomI0xGxIGmfpNsanmmZpkI/IOnXtsdtX6L2ixevNzRLJbat9nPHuYh4pul5uomIJyNiLCK2qH3/vh0RfXemWSoivpR0wvZ1nZt2SDra4EhVHJd0i+11ne+RHerDFxCHmvhDI2LR9l8k/VPtVyn/HhFHmpilB9sl7ZR02PZ7ndv+GhH/aG6kn6WHJO3pnAA+lvRgw/P8qIjYb3uvpENq/2RmVtJ0s1MtZ96mCvz88WIckAChAwkQOpAAoQMJEDqQQOOh255qeoZeDNq8EjOvhn6ft/HQJfX1HbSCQZtXYubV0Nfz9kPoAAorcsHMFVdcEWNjY5WOnZ+f15VXXlnp2MOHD1/MWEDP2le1dhcRlY9d+ntKiIhlgxS5BHZsbEyvv17/pevj4+O1r4nlev2G7QelohkZGSmyriR99913xda+EA/dgQQIHUiA0IEECB1IgNCBBCqFPmh7sAM4X9fQB3QPdgBLVDmjD9we7ADOVyX0gd6DHUCNL8bZnrI9Y3tmfn6+rmUB1KBK6JX2YI+I6YiYjIjJqteuA1gdVUIfuD3YAZyv65taBnQPdgBLVHr3WudDCvigAmBAcWUckAChAwkQOpAAoQMJEDqQQJHNIW0X2cCr5Ce/rllT5t+8Qfy02lJ7xg3ifTE6Olps7VJ7xq20OSRndCABQgcSIHQgAUIHEiB0IAFCBxIgdCABQgcSIHQgAUIHEiB0IAFCBxIgdCABQgcSIHQgAUIHEiB0IAFCBxIgdCABQgcSIHQgAUIHEqj0IYs/RYktg0ttySxJs7OzRda96aabiqwrlds+udS6Jf/+Ss08MjJSZF2p3HbPK+GMDiRA6EAChA4kQOhAAoQOJEDoQAKEDiTQNXTbV9t+x/ZR20ds71qNwQDUp8oFM4uSHouIQ7Z/Iemg7X9FxNHCswGoSdczekSciohDnV9/I2lO0ubSgwGoT0/P0W1vkTQhaX+RaQAUUflad9sbJL0q6ZGIOLPC/5+SNFXjbABqUil028NqR74nIvatdExETEua7hxf5h0GAH6SKq+6W9ILkuYi4pnyIwGoW5Xn6Nsl7ZR0h+33Ov/9vvBcAGrU9aF7RPxbUv1vLgewargyDkiA0IEECB1IgNCBBAgdSMAlds8cxAtmhobKbIh78ODBIutK0rZt24qsOzo6WmTds2fPFlm3pI0bNxZb+8yZZReYXrRWq6WIWPZTMs7oQAKEDiRA6EAChA4kQOhAAoQOJEDoQAKEDiRA6EAChA4kQOhAAoQOJEDoQAKEDiRA6EAChA4kQOhAAoQOJEDoQAKEDiRA6EAChA4kwHbPHe1Ph65fifv3B7Ozs0XWnZiYKLJuqfu4pA0bNhRbu8T21wsLC2q1Wmz3DGRE6EAChA4kQOhAAoQOJEDoQAKEDiRQOXTba23P2n6j5EAA6tfLGX2XpLlSgwAop1Lotsck3SVpd9lxAJRQ9Yz+rKTHJbXKjQKglK6h275b0lcRcbDLcVO2Z2zP1DYdgFpUOaNvl3SP7U8lvSLpDtsvXXhQRExHxGRETNY8I4CL1DX0iHgyIsYiYouk+yW9HREPFJ8MQG34OTqQwFAvB0fEu5LeLTIJgGI4owMJEDqQAKEDCRA6kAChAwkU2wW2xI6fJXdULbVD6fDwcJF1JWlxcbHIuq+99lqRde+7774i60pSq1Xm6uxNmzYVWVeS5ufna1+z1WopItgFFsiI0IEECB1IgNCBBAgdSIDQgQQIHUiA0IEECB1IgNCBBAgdSIDQgQQIHUiA0IEECB1IgNCBBAgdSIDQgQQIHUiA0IEECB1IgF1gO0rtAjuIM69ZU+bf/w8//LDIupJ0zTXXFFm35C6+CwsLRdZlF1ggKUIHEiB0IAFCBxIgdCABQgcSIHQggUqh277M9l7bH9ies31r6cEA1Geo4nHPSXozIv5g+xJJ6wrOBKBmXUO3vVHS7ZL+JEkRcU7SubJjAahTlYfu45JOS3rR9qzt3bbXF54LQI2qhD4k6UZJz0fEhKRvJT1x4UG2p2zP2J6peUYAF6lK6CclnYyI/Z2v96od/nkiYjoiJiNiss4BAVy8rqFHxJeSTti+rnPTDklHi04FoFZVX3V/SNKezivuH0t6sNxIAOpWKfSIeE8SD8mBAcWVcUAChA4kQOhAAoQOJEDoQAKEDiRQbLvn2hctrNQWxyW3ey5lEGc+ceJEkXWvvfbaIutKZbbrPnv2rFqtFts9AxkROpAAoQMJEDqQAKEDCRA6kAChAwkQOpAAoQMJEDqQAKEDCRA6kAChAwkQOpAAoQMJEDqQAKEDCRA6kAChAwkQOpAAoQMJDNQusKV2apXK7Xxacubvv/++yLrDw8NF1l1YWCiybknHjh0rtvbWrVtrXzMiFBHsAgtkROhAAoQOJEDoQAKEDiRA6EAChA4kUCl024/aPmL7fdsv2x4tPRiA+nQN3fZmSQ9LmoyIGyStlXR/6cEA1KfqQ/chSZfaHpK0TtIX5UYCULeuoUfE55KelnRc0ilJX0fEW6UHA1CfKg/dL5d0r6RxSVdJWm/7gRWOm7I9Y3um/jEBXIwqD93vlPRJRJyOiAVJ+yTdduFBETEdEZMRMVn3kAAuTpXQj0u6xfY625a0Q9Jc2bEA1KnKc/T9kvZKOiTpcOf3TBeeC0CNhqocFBFPSXqq8CwACuHKOCABQgcSIHQgAUIHEiB0IAFCBxKo9OO1ftFqtYqt3b4WqH6ltpGWpKGhMn99i4uLRdYtaWRkpMi6N998c5F1JenAgQO1r7lz584Vb+eMDiRA6EAChA4kQOhAAoQOJEDoQAKEDiRA6EAChA4kQOhAAoQOJEDoQAKEDiRA6EAChA4kQOhAAoQOJEDoQAKEDiRA6EAChA4k4BK7lNo+Lemziof/UtJ/ah+inEGbV2Lm1dAv8/4qIjZdeGOR0HtheyYiJhsdogeDNq/EzKuh3+floTuQAKEDCfRD6NNND9CjQZtXYubV0NfzNv4cHUB5/XBGB1AYoQMJEDqQAKEDCRA6kMD/AJXmsXtHOu7eAAAAAElFTkSuQmCC\n", 1006 | "text/plain": [ 1007 | "
" 1008 | ] 1009 | }, 1010 | "metadata": { 1011 | "needs_background": "light" 1012 | }, 1013 | "output_type": "display_data" 1014 | } 1015 | ], 1016 | "source": [ 1017 | "# Let's look ath the image of the confusion matrix\n", 1018 | "plt.matshow(conf_mx, cmap=plt.cm.gray)\n", 1019 | "plt.show()" 1020 | ] 1021 | }, 1022 | { 1023 | "cell_type": "markdown", 1024 | "metadata": {}, 1025 | "source": [ 1026 | "The 5s look slightly darker than the other digits, which could mean that there are\n", 1027 | "fewer images of 5s in the dataset or that the classifier does not perform as\n", 1028 | "well on 5s as on other digits. In fact, you can verify that both are the case.\n", 1029 | "\n", 1030 | "Let’s focus the plot on the errors. First, you need to divide each value in\n", 1031 | "the confusion matrix by the number of images in the corresponding class\n", 1032 | "so that you can compare error rates instead of absolute numbers of errors:" 1033 | ] 1034 | }, 1035 | { 1036 | "cell_type": "code", 1037 | "execution_count": 60, 1038 | "metadata": {}, 1039 | "outputs": [ 1040 | { 1041 | "data": { 1042 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPoAAAECCAYAAADXWsr9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAALyUlEQVR4nO3dzYvd9RXH8c9n7iSTp5rGWJQ8UIOUSihUy6BWoYvYRZ9IRStYUGw32fTBlkJpu/EfELGLUhhss1HaRSpSi7QV2i66MDTGQGPSYrA2iU1sfJgYSjKTmTldzA3EjHp/o7+T372e9wuEzHA9HpJ5+7v35jffcUQIwIfbWNcLAMhH6EABhA4UQOhAAYQOFEDoQAGdhW77C7b/afuI7R91tUdTtrfa/rPtQ7ZfsP1A1zs1Ybtn+3nbv+t6lyZsf9T2Htv/sH3Y9me73mkQ29/vf00ctP0r26u63ulSnYRuuyfpZ5K+KGm7pK/b3t7FLsswJ+kHEbFd0i2SvjUCO0vSA5IOd73EMvxU0u8j4npJn9aQ7257s6TvSpqMiE9J6km6p9utlurqin6TpCMR8VJEzEr6taSvdrRLIxFxIiL29399RotfgJu73eq92d4i6cuSHu16lyZsr5f0OUm/kKSImI2I6U6XamZc0mrb45LWSPpPx/ss0VXomyUdu+jj4xryaC5m+1pJN0ra2/Eqgzwi6YeSFjreo6ltkk5J2t1/ufGo7bVdL/VeIuIVSQ9JOirphKTTEfHHbrdaijfjlsn2Okm/kfS9iHir633eje2vSPpvRDzX9S7LMC7pM5J+HhE3SvqfpKF+/8b2Bi0+G90maZOktbbv7XarpboK/RVJWy/6eEv/c0PN9gotRv54RDzR9T4D3CZpp+2XtfjSaIftx7pdaaDjko5HxIVnSnu0GP4w+7ykf0XEqYg4L+kJSbd2vNMSXYX+N0mfsL3N9kotvnnx2452acS2tfja8XBEPNz1PoNExI8jYktEXKvF398/RcTQXWkuFhEnJR2z/cn+p26XdKjDlZo4KukW22v6XyO3awjfQBzv4j8aEXO2vy3pD1p8l/KXEfFCF7ssw22S7pP0d9sH+p/7SUQ83d1KH0rfkfR4/wLwkqRvdrzPe4qIvbb3SNqvxb+ZeV7SVLdbLWW+TRX48OPNOKAAQgcKIHSgAEIHCiB0oIDOQ7e9q+sdlmPU9pXY+XIY9n07D13SUP8GvYNR21di58thqPcdhtABJEu5Ycb2yN2Fs3j3YvtG8YakXq/X9QpaWFjQ2Fjz69D8/HzKHhs3bmz0uHPnzmnVquWdN/H666+/n5UGioglX8yd3AI7jJb7h9TUzMxMylwp739O69evT5mbFaMknTlzJmXuzp07U+ZK0u7du9NmX4qn7kABhA4UQOhAAYQOFEDoQAGNQh+1M9gBvN3A0Ef0DHYAF2lyRR+5M9gBvF2T0Ef6DHYALd4Z1//unaG+sR+oqknojc5gj4gp9U+/HMV73YEPsyZP3UfuDHYAbzfwij6iZ7ADuEij1+j9H1LADyoARhR3xgEFEDpQAKEDBRA6UAChAwVwZlzfihUrul5h2c6dO5cyd2FhIWXuW2+9lTJXyjs/b3p6OmWulHMI57udy8cVHSiA0IECCB0ogNCBAggdKIDQgQIIHSiA0IECCB0ogNCBAggdKIDQgQIIHSiA0IECCB0ogNCBAggdKIDQgQIIHSiA0IECCB0ogNCBAlKOex4bG9Pq1aszRqfJOor4+uuvT5krSWfPnk2Z+9prr6XM3bJlS8pcKe/P784770yZK0lPPfVU2uxLcUUHCiB0oABCBwogdKAAQgcKIHSgAEIHChgYuu2ttv9s+5DtF2w/cDkWA9CeJjfMzEn6QUTst/0RSc/ZfiYiDiXvBqAlA6/oEXEiIvb3f31G0mFJm7MXA9CeZb1Gt32tpBsl7U3ZBkCKxve6214n6TeSvhcRS24str1L0q7+r1tbEMAH1yh02yu0GPnjEfHEOz0mIqYkTUlSr9eL1jYE8IE1edfdkn4h6XBEPJy/EoC2NXmNfpuk+yTtsH2g/8+XkvcC0KKBT90j4q+SeNENjDDujAMKIHSgAEIHCiB0oABCBwpIOQU2IrSwsND63IyZF9xwww0pcw8cOJAyN9Pdd9+dMvfpp59OmStJK1asSJk7MTGRMleSNm3a1PrMkydPvuPnuaIDBRA6UAChAwUQOlAAoQMFEDpQAKEDBRA6UAChAwUQOlAAoQMFEDpQAKEDBRA6UAChAwUQOlAAoQMFEDpQAKEDBRA6UAChAwUQOlBAynHPkjQ3N9f6zMWf4Jzj6NGjKXN7vV7KXEman59Pmfvkk0+mzF2zZk3KXEk6f/58ytzp6emUuZK0ffv21me+275c0YECCB0ogNCBAggdKIDQgQIIHSiA0IECGoduu2f7edu/y1wIQPuWc0V/QNLhrEUA5GkUuu0tkr4s6dHcdQBkaHpFf0TSDyUt5K0CIMvA0G1/RdJ/I+K5AY/bZXuf7X0R0dqCAD64Jlf02yTttP2ypF9L2mH7sUsfFBFTETEZEZOZ33wCYPkGhh4RP46ILRFxraR7JP0pIu5N3wxAa/h7dKCAZX0/ekT8RdJfUjYBkIYrOlAAoQMFEDpQAKEDBRA6UIAz7mLr9XqRceJn1kmfkjQxMZEyd8eOHSlzJWnv3r0pc0+cOJEy97rrrkuZK0nHjh1LmTszM5MyV5K2bt3a+sxXX31Vs7OzS+5Y44oOFEDoQAGEDhRA6EABhA4UQOhAAYQOFEDoQAGEDhRA6EABhA4UQOhAAYQOFEDoQAGEDhRA6EABhA4UQOhAAYQOFEDoQAGEDhSQdgrs2rVrW5+beSLnunXrUua++eabKXMl6corr0yZe/XVV6fMPXToUMpcScr6Ud0333xzylxJevbZZ1PmRgSnwAIVETpQAKEDBRA6UAChAwUQOlAAoQMFNArd9kdt77H9D9uHbX82ezEA7Rlv+LifSvp9RHzN9kpJ7f9MZABpBoZue72kz0n6hiRFxKyk2dy1ALSpyVP3bZJOSdpt+3nbj9pu//5WAGmahD4u6TOSfh4RN0r6n6QfXfog27ts77O9L+P+eQDvX5PQj0s6HhF7+x/v0WL4bxMRUxExGRGTWd9gAOD9GRh6RJyUdMz2J/uful1S3rchAWhd03fdvyPp8f477i9J+mbeSgDa1ij0iDggaTJ3FQBZuDMOKIDQgQIIHSiA0IECCB0ogNCBAlKOex4bG4uJiYnW587NzbU+84INGzakzD179mzKXElatWpVytzp6emUufPz8ylzJSnrtus77rgjZa4krVy5svWZzzzzjN544w2OewYqInSgAEIHCiB0oABCBwogdKAAQgcKIHSgAEIHCiB0oABCBwogdKAAQgcKIHSgAEIHCiB0oABCBwogdKAAQgcKIHSgAEIHCkg7BTbjhMurrrqq9ZkXzM7OpszNOl1Wkl588cWUuTfddFPK3IMHD6bMlaSZmZmUuZknD2/cuLH1mdPT05qbm+MUWKAiQgcKIHSgAEIHCiB0oABCBwogdKCARqHb/r7tF2wftP0r2zk/xhNAioGh294s6buSJiPiU5J6ku7JXgxAe5o+dR+XtNr2uKQ1kv6TtxKAtg0MPSJekfSQpKOSTkg6HRF/zF4MQHuaPHXfIOmrkrZJ2iRpre173+Fxu2zvs70v4/55AO9fk6fun5f0r4g4FRHnJT0h6dZLHxQRUxExGRGT9pJ76gF0qEnoRyXdYnuNFwu+XdLh3LUAtKnJa/S9kvZI2i/p7/1/Zyp5LwAtGm/yoIh4UNKDybsASMKdcUABhA4UQOhAAYQOFEDoQAGEDhSQctxzr9eLVava/07W+fn51mdekHE8tSRdccUVKXMl6ZprrkmZe+TIkZS5p0+fTpkrSTt37kyZe9ddd6XMlaT7778/ZW5EcNwzUBGhAwUQOlAAoQMFEDpQAKEDBRA6UAChAwUQOlAAoQMFEDpQAKEDBRA6UAChAwUQOlAAoQMFEDpQAKEDBRA6UAChAwUQOlBAyimwtk9J+nfDh18l6bXWl8gzavtK7Hw5DMu+H4+Ij136yZTQl8P2voiY7HSJZRi1fSV2vhyGfV+eugMFEDpQwDCEPtX1Ass0avtK7Hw5DPW+nb9GB5BvGK7oAJIROlAAoQMFEDpQAKEDBfwfaoXCaAh0MeEAAAAASUVORK5CYII=\n", 1043 | "text/plain": [ 1044 | "
" 1045 | ] 1046 | }, 1047 | "metadata": { 1048 | "needs_background": "light" 1049 | }, 1050 | "output_type": "display_data" 1051 | } 1052 | ], 1053 | "source": [ 1054 | "row_sums = conf_mx.sum(axis=1, keepdims=True)\n", 1055 | "norm_conf_mx = conf_mx / row_sums\n", 1056 | "\n", 1057 | "np.fill_diagonal(norm_conf_mx, 0)\n", 1058 | "plt.matshow(norm_conf_mx, cmap=plt.cm.gray)\n", 1059 | "plt.show()" 1060 | ] 1061 | }, 1062 | { 1063 | "cell_type": "markdown", 1064 | "metadata": {}, 1065 | "source": [ 1066 | "Analyzing individual errors can also be a good way to gain insights on\n", 1067 | "what your classifier is doing and why it is failing, but it is more difficult\n", 1068 | "and time-consuming. For example, let's ploit 3s and 5s:" 1069 | ] 1070 | }, 1071 | { 1072 | "cell_type": "code", 1073 | "execution_count": 78, 1074 | "metadata": {}, 1075 | "outputs": [], 1076 | "source": [ 1077 | "def plot_digits(instances, images_per_row=10, **options):\n", 1078 | " \"\"\"\n", 1079 | " Plots digit with a grid form - correctly & incorrecly guessed\n", 1080 | " \"\"\"\n", 1081 | " size = 28\n", 1082 | " images_per_row = min(len(instances), images_per_row)\n", 1083 | " images = [instance.reshape(size,size) for instance in instances.values]\n", 1084 | " n_rows = (len(instances) - 1) // images_per_row + 1\n", 1085 | " row_images = []\n", 1086 | " n_empty = n_rows * images_per_row - len(instances)\n", 1087 | " images.append(np.zeros((size, size * n_empty)))\n", 1088 | " for row in range(n_rows):\n", 1089 | " rimages = images[row * images_per_row : (row + 1) * images_per_row]\n", 1090 | " row_images.append(np.concatenate(rimages, axis=1))\n", 1091 | " image = np.concatenate(row_images, axis=0)\n", 1092 | " plt.imshow(image, cmap = mpl.cm.binary, **options)\n", 1093 | " plt.axis(\"off\")" 1094 | ] 1095 | }, 1096 | { 1097 | "cell_type": "code", 1098 | "execution_count": 79, 1099 | "metadata": {}, 1100 | "outputs": [ 1101 | { 1102 | "data": { 1103 | "image/png": "\n", 1104 | "text/plain": [ 1105 | "
" 1106 | ] 1107 | }, 1108 | "metadata": { 1109 | "needs_background": "light" 1110 | }, 1111 | "output_type": "display_data" 1112 | } 1113 | ], 1114 | "source": [ 1115 | "cl_a, cl_b = 3, 5\n", 1116 | "X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]\n", 1117 | "X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]\n", 1118 | "X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]\n", 1119 | "X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]\n", 1120 | "\n", 1121 | "plt.figure(figsize=(8,8))\n", 1122 | "plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)\n", 1123 | "plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)\n", 1124 | "plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)\n", 1125 | "plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)\n", 1126 | "plt.show()" 1127 | ] 1128 | }, 1129 | { 1130 | "cell_type": "markdown", 1131 | "metadata": {}, 1132 | "source": [ 1133 | "Most misclassified images seem like obvious errors to us, and it’s hard to understand why the classifier made the mistakes it did. 3 The reason is that we used a simple `SGDClassifier`, which is a linear model. All it does is assign a weight per class to each pixel, and when it sees a new image it just sums up the weighted pixelintensities to get a score for each class. So since 3s and 5s differ only by a few pixels, this model will easily confuse them.\n", 1134 | "\n", 1135 | "The main difference between 3s and 5s is the position of the small line that joins the top line to the bottom arc. If you draw a 3 with the junction\n", 1136 | "slightly shifted to the left, the classifier might classify it as a 5, and vice versa. In other words, this classifier is quite sensitive to image shifting and rotation. So one way to reduce the 3/5 confusion would be to preprocess the images to ensure that they are well centered and not too rotated. This will probably help reduce other errors as well.\n", 1137 | "\n", 1138 | "## Multilabel Classification\n", 1139 | "\n", 1140 | "Say the classifier has been trained to recognize three faces, Alice, Bob, and Charlie. Then when the classifier is shown a picture of Alice and Charlie, it should output [1, 0, 1] (meaning “Alice yes, Bob no, Charlie yes”). Such a classification system that outputs multiple binary tags is called _a multilabel classification_ system." 1141 | ] 1142 | }, 1143 | { 1144 | "cell_type": "code", 1145 | "execution_count": 81, 1146 | "metadata": {}, 1147 | "outputs": [ 1148 | { 1149 | "data": { 1150 | "text/plain": [ 1151 | "KNeighborsClassifier()" 1152 | ] 1153 | }, 1154 | "execution_count": 81, 1155 | "metadata": {}, 1156 | "output_type": "execute_result" 1157 | } 1158 | ], 1159 | "source": [ 1160 | "from sklearn.neighbors import KNeighborsClassifier\n", 1161 | "\n", 1162 | "y_train_large = (y_train >= 7)\n", 1163 | "y_train_odd = (y_train % 2 == 1)\n", 1164 | "y_multilabel = np.c_[y_train_large, y_train_odd]\n", 1165 | "\n", 1166 | "knn_clf = KNeighborsClassifier()\n", 1167 | "knn_clf.fit(X_train, y_multilabel)" 1168 | ] 1169 | }, 1170 | { 1171 | "cell_type": "markdown", 1172 | "metadata": {}, 1173 | "source": [ 1174 | "This code creates a y_multilabel array containing two target labels for each digit image: the first indicates whether or not the digit is large (7, 8,or 9), and the second indicates whether or not it is odd." 1175 | ] 1176 | }, 1177 | { 1178 | "cell_type": "code", 1179 | "execution_count": 83, 1180 | "metadata": {}, 1181 | "outputs": [ 1182 | { 1183 | "data": { 1184 | "text/plain": [ 1185 | "array([[False, True]])" 1186 | ] 1187 | }, 1188 | "execution_count": 83, 1189 | "metadata": {}, 1190 | "output_type": "execute_result" 1191 | } 1192 | ], 1193 | "source": [ 1194 | "# Prediction with Multilabel classifier\n", 1195 | "knn_clf.predict([some_digit])" 1196 | ] 1197 | }, 1198 | { 1199 | "cell_type": "markdown", 1200 | "metadata": {}, 1201 | "source": [ 1202 | "And it gets it right! The digit 5 is indeed not large ( False ) and odd ( True ).\n", 1203 | "\n", 1204 | "One approach to evaluate multilable classifier is to measure $F_1$ score for each individual label, then compare the average score." 1205 | ] 1206 | }, 1207 | { 1208 | "cell_type": "code", 1209 | "execution_count": 85, 1210 | "metadata": {}, 1211 | "outputs": [ 1212 | { 1213 | "data": { 1214 | "text/plain": [ 1215 | "0.976410265560605" 1216 | ] 1217 | }, 1218 | "execution_count": 85, 1219 | "metadata": {}, 1220 | "output_type": "execute_result" 1221 | } 1222 | ], 1223 | "source": [ 1224 | "y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3)\n", 1225 | "f1_score(y_multilabel, y_train_knn_pred, average='macro')" 1226 | ] 1227 | }, 1228 | { 1229 | "cell_type": "markdown", 1230 | "metadata": {}, 1231 | "source": [ 1232 | "If for example we have many more pictures of Alice then of Bob or Charlie, you may want to give more weight to the classifier's score on picture of Alice. Simple option is to give each label a weight equal to its support. This can be done by setting `average=\"weighted\"`.\n", 1233 | "\n", 1234 | "## Multioutput Classification\n", 1235 | "\n", 1236 | "Multiclass or multioutput classification can have more than two possible values.\n", 1237 | "\n", 1238 | "Let's first build a system that removes noise from images. The classifier's output is multilabel (one label per pixel) and each label can have multiple values (pixel intensity ranges from 0 to 255)." 1239 | ] 1240 | }, 1241 | { 1242 | "cell_type": "code", 1243 | "execution_count": 94, 1244 | "metadata": {}, 1245 | "outputs": [], 1246 | "source": [ 1247 | "noise = np.random.randint(0, 100, (len(X_train), 784))\n", 1248 | "X_train_mod = X_train + noise\n", 1249 | "noise = np.random.randint(0, 100, (len(X_test), 784))\n", 1250 | "X_test_mod = X_test + noise\n", 1251 | "y_train_mod = X_train\n", 1252 | "y_test_mod = X_test" 1253 | ] 1254 | }, 1255 | { 1256 | "cell_type": "code", 1257 | "execution_count": 98, 1258 | "metadata": {}, 1259 | "outputs": [ 1260 | { 1261 | "data": { 1262 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAACmCAYAAAB5qlzZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQ8UlEQVR4nO3dS2wV9BLH8SmFCgoFCgVa+qClSFveT0HKQ4Iohq26YK+JiSYYdaeJcWOi0cSokejWHcaYGFHRQAQUAQEpD6GUR2l5FAqFFmgB6V3e3Du/OTnV8ufR72c5nVNOT4/jSec/88/p6ekxAEAaA+72EwCA/oSiCwAJUXQBICGKLgAkRNEFgIQougCQ0MBMX2xpaXHnycaPHy9zGxsbXez48eMyd/bs2S7W2dkpc8eNG+diZ86ckblHjx51serqapl7+vRpF5szZ47Mra+vd7EbN2642IgRI+TjL1686GIVFRVZ/1uTJ0+WudevX3exUaNGydxbt25l9bzMzCZOnOhiXV1dMle9DhcuXJC5ra2tLrZ48eIcmXzncVYSd5p8b/NJFwASougCQEIUXQBIKOPfdNXfU/fv3y9zi4uLs/5H8/LyXKygoEDmdnR0uFhTU5PMvXnzposNHjxY5k6YMMHF1N95zcza29td7NFHH3Wxc+fOycerv73m5Og/ZU6ZMsXFzp49K3PHjBnjYkOHDpW5atx77969Mnf48OFZPd7MbOBA/xZqbm6WudHfvIH+hE+6AJAQRRcAEqLoAkBCFF0ASIiiCwAJZTy9cODAAReLOvxqCmnq1Kky9/fff3cxdaLBzKy8vNzF1HSVmZ6WU5NyZmYPP/ywi6lTCmZmlZWVWT2HaHIsmtBSGhoaXGzRokUyVz0H9Tsz0ydRohMJ6jW7evWqzFW/n7Fjx8rcaOoQ6E/4pAsACVF0ASAhii4AJETRBYCEMjbS1FhtSUmJzFXNoqiJVVNT42Kq0WOmm1sLFiyQuVu2bHGxJ598Uua2tLS4WNTMO3nypIv1Zj2lGmWOmoHd3d0u1tbWJnMHDRqUVSzT91DU63Dt2jWZe+nSJRdTY89mcUMQ6E/4pAsACVF0ASAhii4AJETRBYCEKLoAkFDG0wvq4kK14NpMX+qouv5mZrm5uS4WXTapOuF///23zFVLvbdt2yZz1SmBkSNHylw13jtggP//VTSCqy5vnDVrlsxdsmSJi6mxaTO9QDxaJq9+b0OGDJG56nWIFp6r0evDhw/LXPWaAf0N/xUAQEIUXQBIiKILAAlRdAEgoYyNNNWoiXbOqrHaqOGVn5/vYlEjTY3QRrfKFhUVudju3buzfg6lpaUyV/3M6tbd+fPny8f/+eefLrZv3z6ZqyxevFjGVeNO7Qk202O80biuEo1pq9csGhU/f/68i0W7d4EHFZ90ASAhii4AJETRBYCEKLoAkBBFFwASynh6YdiwYS72yCOPyFy1zDpafK3GatVJCTN9SmD06NEyV50ymDBhgsxVC9pVzEwvaFcLy9XYtJk+KRGNAatbd3/99VeZu3DhQheLTow89NBDMq6o31t0KkItR49GxdXvEuhv+KQLAAlRdAEgIYouACRE0QWAhDI20lTDqrm5WeaqcdBozFQ1aqKbg3/77TcXixpL06ZNc7H6+nqZu3LlShdbtWqVzFVNL/UzXL58WT5eOX78uIyrnyHaS6xEr6NqBkZNNzV2rG5aNjMbPHiwi9XV1clc9fpEjdn+YP369S72+eefy1y1J1m99mZma9ascbFojLuqqirTU8QdwCddAEiIogsACVF0ASAhii4AJETRBYCEcnp6ejJ93X0xWr6tRoajm2l37NjhYhMnTpS56pSBuvXXzGzQoEEuFt1Mq/69aGx5+vTpLqZGhtVrEGltbZVxdaoh6vCr8WJ1+sHMbPXq1S6mlr6b6eXoixYtkrm3b9+WcUX9zJMmTcrJ+hv0rYxv/BQqKipc7MSJE3fk31LvFTOz2traO/Lv3QnRJQNvvPGGi82dO/dOP51syPc2n3QBICGKLgAkRNEFgIQougCQUMYxYNX4GDJkiMxVTaxbt27J3Jwc//dltZfVzOz111/P+jmMHDnSxb7++muZq/bsfvnllzJXNf7y8vJcrKGhQT5e3Xgb7bdVDY/o9mO1wzj6edXvJ2qOzZgxw8WiXbhdXV0uFo0i9+b24f7giy++cDF1c7SZbngdPHhQ5u7Zs8fFNm/eLHO3b9/uYmVlZS7W1NQkH98b6j0Y7cZWt4Or52qmG2z3SCNN4pMuACRE0QWAhCi6AJAQRRcAEqLoAkBCGceAGxsb3Rej0dGdO3e6WHR6obCw0MVUJ95M30JbWVkpc9UYb3d3t8xVy9ij56u6wcuWLXOxn376ST5enZSIRo7VbcBLliyRuU8//bSLRd3gl19+2cWeeuopmVteXu5i0YJ2NaKcm5src1taWlystLS0344BpxT996Xe26rzr/777i11YkfdDG5mVl1d7WLRbduffPKJi7300ku9fHZ3BGPAAHC3UXQBICGKLgAkRNEFgIQyNtJ6xBe3bdsmc9V+2ui2UjU6GjXoVAMn2umrbjYdOFBPOqtR1fHjx8tcNb6omkKqYWamx4M7Ojpkrhpx/vHHH2WuGvmdN2+ezF27dq2LRbt31Q3K0WujmiPRjl11+/CIESNopPVjX331lYw/++yzLha9Xzdt2uRiBQUF/+6J9Q0aaQBwt1F0ASAhii4AJETRBYCEKLoAkFDGJeYbN250sbq6Opl75MgRF1O36JrpZeNRh16NCS5YsEDmqptU29vbZa46aRCdilBjuGqBdHQi4bHHHnMxdeOumVlnZ6eLbdiwQeaePXvWxXrzmkeLrdXC8miJ+S+//OJiy5cvl7nq9xMtaMeDR12KEI3rqlNVb731lsy9R04qZI1PugCQEEUXABKi6AJAQhRdAEgoYyNNNax2794tc9Ufs6OR4ccff9zFopHho0ePutiNGzdk7rBhw1wsunX3/PnzLhbdjquaRcOHD3exaGfpzZs3XUzt8zXTt6BeuXJF5qqxZxUzMxswwP//Ndq9q55vdBusakhGY8DRTl70D2rvrWqumekGa7R7937DJ10ASIiiCwAJUXQBICGKLgAklLGRpppmarLJzKykpMTFogskDx065GI1NTUyV01oRU0odUliRO3/Vc0mM305ptp7O2XKFPl41Vj69ttvZe5HH33kYqpBaGb23nvvuVh+fr7MVY1KtSfYTO8Kji7tVK9DNNnXm98P7l9bt26V8XfffTfr7/HNN9+42NSpU//xc7qX8EkXABKi6AJAQhRdAEiIogsACVF0ASChjKcXli1b5mLbt2+XuXv27HGx3NxcmavGT9VYrZneGVtYWChzz50752Lq5IGZ2a5du1wsuklXUT/Dli1bZK66BTcae1Y7Z1944QWZq8Yir169KnPV66tOHpiZFRcXu5jal2xmVlFR4WLRz3bt2jUZx4Plu+++k3E1vr9ixQqZu3Dhwj59TvcSPukCQEIUXQBIiKILAAlRdAEgoYyNtFOnTrlYtJ9WNVSi8VU1bhuN4Ko/vk+aNEnmqh25DQ0NMlc1ltTIsZkeX1V7b5cuXSofr5pba9eulbllZWUuVl1dLXPVOPWsWbNkrhrtzcvLk7nd3d1ZPS8z3ag8duyYzFV7eqOfDfeH69evu9j3338vc1XtePvtt2VuNKL+IOCTLgAkRNEFgIQougCQEEUXABKi6AJAQhlPL6iufXSjqzoloG79NTM7cOCAi6nRUzOzjo4OF4vGTFWHPuqOq+5oe3u7zFUnM1TXNlrevW7dOheLlo2vWrXKxaLl6EVFRS7W2Ngoc6dPn+5iBw8elLlqYbk6yWJmdvjwYReLfrauri4Zx/1LLdJXKwHM9Hs7qhEPMj7pAkBCFF0ASIiiCwAJUXQBIKGMjTTVaKmrq5O5anQ02u06YsQIF4tuGVZNJLU3N/r31E28ZnrH7bhx42Tu0KFDXUw1Djdu3Cgfv2HDBhcrLy+Xuc8995yLzZw5U+ZeunTJxaL9warhFTW2VJNQvQZmusm4Y8cOmatuYMb9Ibq9+p133nGxaDf2m2++2afP6X7FJ10ASIiiCwAJUXQBICGKLgAkRNEFgIQynl5Qt8VG46Cqu93a2ipz1XJzNRpsZnbhwgUXq6mpkblq1LW+vl7mqu+hlmybme3evdvFxowZ42J79+6Vj1djy2PHjpW56kSCWlZupm8D/uuvv2SuOtUwfvx4mZuTk+Ni6kSDWTzyq0QnVHBvaWtrc7FXXnlF5qqR8WeeeUbmPsg3/PYGn3QBICGKLgAkRNEFgIQougCQUE5PT0/4xSNHjrgvqv22ZvoP6vPmzZO5alw2uuFXjQxHDS/1s0S3F589e9bFoqaQGvl98cUXXUw1oMzMSktLXezTTz+VuaqRFjWg1L7jEydOyFw1rhs9X/V91a3MZnp3avR9VQN1/vz5OvnOi9/4/Ygah1+wYIGL7dq1Sz6+qqrKxaLbgPvhGLh8b/NJFwASougCQEIUXQBIiKILAAlRdAEgoYxjwGrUdfTo0TJXdUG3bdsmc9X4qeram+lx2eg2YDWuG3VMm5ubXWzu3LkyV40il5SUuFg0Iv3aa6+5mFr6bmZWW1vrYtGJBPUcoqXt6pbg2bNny1w1MvzDDz/I3IKCAhcrKyuTuYWFhTKOu0e9L6KTCsoHH3zgYv3wlEKv8EkXABKi6AJAQhRdAEiIogsACWVspKnbdXNzc7POjRo1agS3s7Mz01P5H1GDbsmSJS6mRnjNdJPw6NGjMvfVV191sYsXL7rYmjVr5ONVY0m9BhE1lmtm1tLS4mKqsRXF9+3bJ3PV72LUqFEyVzXdojFtFY/GtNG3Tp48KeMrV67M6vHvv/++jK9evfofP6f+ik+6AJAQRRcAEqLoAkBCFF0ASIiiCwAJZTy9oMZaoyXmamH58ePHZa66WXbAAF3/1amIvLw8mbtz504Xi8aW1Xjx+vXrZa5apK669s8//7x8vDq9UFlZKXM3b97sYuXl5Vk/r9OnT2edO2jQIJmrllhHy+7PnDnjYpcvX846N1p0j761bt06GY9ONfy/pUuXyni0sB4xPukCQEIUXQBIiKILAAlRdAEgoYyNNLX3NhpfPXbsmItFjRp1869qgpnpvbNdXV0yt66uzsWiBtCWLVtc7LPPPpO5aoRWNQPVTmEz3VBUO3rNzAYO9L+SqBnY1NSU9XNQY8vRyPCmTZtcLBrtXbFihYu1tbXJ3GnTpsk4+pZ6b3/88cd34ZlA4ZMuACRE0QWAhCi6AJAQRRcAEqLoAkBCGU8vFBcXu5g60WBm9scff7hYRUWFzFUd9tLSUpmrFmqrEV4zfbJCjZ6a6RtPb9y4IXPVWKsaf5w8ebJ8fG+WmKsTBdFob1FRkYsNHz5c5qpTHNEpA3WCQt08bGZ25coVF2tvb5e5hw8fdrEZM2bIXPxzW7dudbFofF+pqqpysaFDh/6r54T/4pMuACRE0QWAhCi6AJAQRRcAEsrYSFPjoOPGjZO51dXVLrZ3716Zq/bhjhw5UuaqfZ9PPPGEzFVNM7VH1kyP8UZ7YNXY8Ycffuhi0W7S27dvu5gam47Mnz9fxlUzMBoZVs2tqNGpGimXLl2SuaopOmfOHJmrxpZxd82cOdPFfv75ZxeLRsbRe3zSBYCEKLoAkBBFFwASougCQEIUXQBIKCda8m1mdurUKffFaEm2uh333LlzMrewsNDFolMR+/fvd7Ha2lqZq0aG1eipmT6pEI0iq1FX1aFvaWmRj1fj1GrRtJnZlClTXCxaBt/Q0OBialm5mdny5ctd7NChQzI3Pz/fxdTv10zf4hzdAq1OQKxcufJuXScbv/GBviHf23zSBYCEKLoAkBBFFwASougCQEIZG2kAgL7FJ10ASIiiCwAJUXQBICGKLgAkRNEFgIQougCQ0H8A3dtNUwSfDK8AAAAASUVORK5CYII=\n", 1263 | "text/plain": [ 1264 | "
" 1265 | ] 1266 | }, 1267 | "metadata": { 1268 | "needs_background": "light" 1269 | }, 1270 | "output_type": "display_data" 1271 | } 1272 | ], 1273 | "source": [ 1274 | "def plot_digit(data):\n", 1275 | " image = data.reshape(28, 28)\n", 1276 | " plt.imshow(image, cmap = mpl.cm.binary,\n", 1277 | " interpolation=\"nearest\")\n", 1278 | " plt.axis(\"off\")\n", 1279 | " \n", 1280 | "# Let's take a look at the example\n", 1281 | "some_index = 0\n", 1282 | "plt.subplot(121); plot_digit(X_test_mod.iloc[some_index].values)\n", 1283 | "plt.subplot(122); plot_digit(y_test_mod.iloc[some_index].values)\n", 1284 | "plt.show()" 1285 | ] 1286 | }, 1287 | { 1288 | "cell_type": "code", 1289 | "execution_count": 101, 1290 | "metadata": {}, 1291 | "outputs": [ 1292 | { 1293 | "data": { 1294 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAFO0lEQVR4nO3dsUvUfxzH8TtzaG27QW0pCAKhqTYdHSIapMnVoqV/QVysPyBy9H9Q0EEImoKGtkDczEHbCpxEvN/2g8B7f/X0utedj8fom+/dh+jpB/zwuWt3u90WkGdi2AsALiZOCCVOCCVOCCVOCDXZMPenXBi89kU/tHNCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCKHFCqMlhL2BQ2u12z9nKykr57IMHD8r59PR0OZ+bmyvncBl2TgglTgglTgglTgglTgglTgglTgjV7na71bwcJnv37l3P2cePH8tnG/5Nru3t27c9Z0+fPi2f7XQ613rv4+Pjcv7ly5ees4ODg/LZ79+/l/MPHz6U8zdv3pTzMXbhobydE0KJE0KJE0KJE0KJE0KJE0KJE0KN7Tln5cWLF+X88ePH5fzHjx/lfGtrq5xXd00HfcZavfeg37/pvT9//txzNj8/f8OrieKcE0aJOCGUOCGUOCGUOCGUOCGUOCHU2H5ubWVzc3PYSxhJZ2dn5fz169flfGNjo5yfnJxceU3jzM4JocQJocQJocQJocQJocQJocQJoW7lOSf9mZys/7s0nVM23efkb3ZOCCVOCCVOCCVOCCVOCCVOCHUrPxqTwbhz5045f/jwYTnf29u7yeWMEh+NCaNEnBBKnBBKnBBKnBBKnBBKnBDKlTEu7du3b+X8/Py8nC8uLt7kcsaenRNCiRNCiRNCiRNCiRNCiRNCiRNCuc/JpS0tLZXzpnPQ3d3dcn7//v0rr2lMuM8Jo0ScEEqcEEqcEEqcEEqcEEqcEMp9Tv4yMdH/7+u1tbVyfovPMfti54RQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQ7nPeMqenp+W8Oos8Pj4unz06OirnnU6nnN9i7nPCKBEnhBInhBInhBInhBInhHJlbMw0HZXcvXu3nFdHa+vr6+Wzjkpulp0TQokTQokTQokTQokTQokTQokTQjnnHDP7+/vlvOGKYOvevXs9Z8+fP+9rTfTHzgmhxAmhxAmhxAmhxAmhxAmhxAmhnHOOmJ8/f5bz2dnZa73++/fve86mpqau9dpcjZ0TQokTQokTQokTQokTQokTQokTQjnnDHN+fl7O19bWynnTfc1Hjx6V81evXpVz/h07J4QSJ4QSJ4QSJ4QSJ4QSJ4RqN/zpvf67PDfu8PCwnM/MzFzr9Xd2dsr5wsLCtV6fvrQv+qGdE0KJE0KJE0KJE0KJE0KJE0KJE0I55wxTfQVfq9Vq/f79u5w/e/asnH/9+vWqS2LwnHPCKBEnhBInhBInhBInhBInhBInhPLRmEPw69evnrM/f/6Uz7bbFx6J/e/Tp099rYk8dk4IJU4IJU4IJU4IJU4IJU4IJU4I5ZxzCDqdTs9Z0znm6upqOX/y5ElfayKPnRNCiRNCiRNCiRNCiRNCiRNCiRNC+dzaIZiY6P07sen7Mbe3t296OQyfz62FUSJOCCVOCCVOCCVOCCVOCOXK2BAsLy/3nL18+fLfLYRodk4IJU4IJU4IJU4IJU4IJU4IJU4I5coYDJ8rYzBKxAmhxAmhxAmhxAmhxAmhxAmhmu5z1t9HBwyMnRNCiRNCiRNCiRNCiRNCiRNC/QfJN7COt705zQAAAABJRU5ErkJggg==\n", 1295 | "text/plain": [ 1296 | "
" 1297 | ] 1298 | }, 1299 | "metadata": { 1300 | "needs_background": "light" 1301 | }, 1302 | "output_type": "display_data" 1303 | } 1304 | ], 1305 | "source": [ 1306 | "# Let's run the classifier\n", 1307 | "knn_clf.fit(X_train_mod, y_train_mod)\n", 1308 | "clean_digit = knn_clf.predict([X_test_mod.iloc[some_index].values])\n", 1309 | "plot_digit(clean_digit)" 1310 | ] 1311 | }, 1312 | { 1313 | "cell_type": "markdown", 1314 | "metadata": {}, 1315 | "source": [ 1316 | "Looks close enough to the target! This concludes our tour of classification." 1317 | ] 1318 | }, 1319 | { 1320 | "cell_type": "code", 1321 | "execution_count": null, 1322 | "metadata": {}, 1323 | "outputs": [], 1324 | "source": [] 1325 | } 1326 | ], 1327 | "metadata": { 1328 | "kernelspec": { 1329 | "display_name": "handson", 1330 | "language": "python", 1331 | "name": "handson" 1332 | }, 1333 | "language_info": { 1334 | "codemirror_mode": { 1335 | "name": "ipython", 1336 | "version": 3 1337 | }, 1338 | "file_extension": ".py", 1339 | "mimetype": "text/x-python", 1340 | "name": "python", 1341 | "nbconvert_exporter": "python", 1342 | "pygments_lexer": "ipython3", 1343 | "version": "3.8.5" 1344 | } 1345 | }, 1346 | "nbformat": 4, 1347 | "nbformat_minor": 4 1348 | } 1349 | --------------------------------------------------------------------------------