├── .gitignore ├── .readthedocs.yml ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── Quickstart.rst ├── _static │ ├── examples │ │ ├── plot_calibration_curve.png │ │ ├── plot_confusion_matrix.png │ │ ├── plot_cumulative_gain.png │ │ ├── plot_elbow_curve.png │ │ ├── plot_feature_importances.png │ │ ├── plot_ks_statistic.png │ │ ├── plot_learning_curve.png │ │ ├── plot_lift_curve.png │ │ ├── plot_pca_2d_projection.png │ │ ├── plot_pca_component_variance.png │ │ ├── plot_precision_recall_curve.png │ │ ├── plot_roc_curve.png │ │ └── plot_silhouette.png │ ├── quickstart_plot_confusion_matrix.png │ ├── quickstart_plot_confusion_matrix2.png │ ├── quickstart_plot_precision_recall_curve.png │ └── readme_collage.jpg ├── apidocs.rst ├── cluster.rst ├── conf.py ├── decomposition.rst ├── estimators.rst ├── functionsapidocs.rst ├── index.rst ├── make.bat └── metrics.rst ├── environment.yml ├── examples ├── jupyter_notebooks │ ├── plot_confusion_matrix.ipynb │ ├── plot_cumulative_gain.ipynb │ ├── plot_elbow_curve.ipynb │ ├── plot_feature_importance.ipynb │ ├── plot_ks_statistic.ipynb │ ├── plot_learning_curve.ipynb │ ├── plot_lift_curve.ipynb │ ├── plot_pca_2d_projection.ipynb │ ├── plot_pca_component_variance.ipynb │ ├── plot_precision_recall_curve.ipynb │ ├── plot_roc_curve.ipynb │ └── plot_silhouette.ipynb ├── p_r_curves.png ├── plot_calibration_curve.py ├── plot_confusion_matrix.py ├── plot_cumulative_gain.py ├── plot_elbow_curve.py ├── plot_feature_importances.py ├── plot_ks_statistic.py ├── plot_learning_curve.py ├── plot_lift_curve.py ├── plot_pca_2d_projection.py ├── plot_pca_component_variance.py ├── plot_precision_recall.py ├── plot_roc.py ├── plot_silhouette.py └── roc_curves.png ├── requirements.txt ├── scikitplot ├── __init__.py ├── classifiers.py ├── cluster.py ├── clustering.py ├── decomposition.py ├── estimators.py ├── helpers.py ├── metrics.py ├── plotters.py └── tests │ ├── __init__.py │ ├── test_classifiers.py │ ├── test_cluster.py │ ├── test_clustering.py │ ├── test_decomposition.py │ ├── test_estimators.py │ ├── test_metrics.py │ └── test_plotters.py ├── setup.py └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | conda: 2 | file: environment.yml 3 | 4 | requirements_file: 5 | requirements.txt 6 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "2.7" 4 | - "3.5" 5 | - "3.6" 6 | # command to install dependencies 7 | before_script: # configure a headless display to test plot generation 8 | - "export DISPLAY=:99.0" 9 | - "sh -e /etc/init.d/xvfb start" 10 | - sleep 3 # give xvfb some time to start 11 | install: 12 | - pip install --upgrade pip setuptools wheel 13 | - pip install --only-binary=numpy,scipy numpy scipy 14 | - pip install -r requirements.txt 15 | # command to run tests 16 | script: py.test -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Steps for contributing 2 | 3 | Fixing a bug you found in Scikit-plot? Suggesting a feature? Adding your own plotting function? Listed here are some guidelines to keep in mind when contributing. 4 | 5 | 1. **Open an issue** along with detailed explanation. For bug reports, include the code to reproduce the bug. For feature requests, explain why you think the feature could be useful. 6 | 7 | 2. **Fork the repository**. If you're contributing code, clone the forked repository into your local machine. 8 | 9 | 3. **Run the tests** to make sure they pass on your machine. Simply run `pytest` at the root folder and make sure all tests pass. 10 | 11 | 4. **Create a new branch**. Please do not commit directly to the master branch. Create your own branch and place your additions there. 12 | 13 | 5. **Write your code**. Please follow PEP8 coding standards. Also, if you're adding a function, you must [write a docstring using the Google format](http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) detailing the API of your function. Take a look at the docstrings of the other Scikit-plot functions to get an idea of what the docstring of yours should look like. 14 | 15 | 6. **Write/modify the corresponding unit tests**. After adding in your code and the corresponding unit tests, run `pytest` again to make sure they pass. 16 | 17 | 7. **Submit a pull request**. After submitting a PR, if all tests pass, your code will be reviewed and merged promptly. 18 | 19 | Thank you for taking the time to make Scikit-plot better! -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2018] [Reiichiro Nakano] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Welcome to Scikit-plot 2 | 3 | [![PyPI version](https://badge.fury.io/py/scikit-plot.svg)](https://badge.fury.io/py/scikit-plot) 4 | [![license](https://img.shields.io/github/license/mashape/apistatus.svg)]() 5 | [![Build Status](https://travis-ci.org/reiinakano/scikit-plot.svg?branch=master)](https://travis-ci.org/reiinakano/scikit-plot) 6 | [![PyPI](https://img.shields.io/pypi/pyversions/scikit-plot.svg)]() 7 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.293191.svg)](https://doi.org/10.5281/zenodo.293191) 8 | 9 | ### Single line functions for detailed visualizations 10 | 11 | ### The quickest and easiest way to go from analysis... 12 | 13 | ![roc_curves](docs/_static/readme_collage.jpg) 14 | 15 | ### ...to this. 16 | 17 | Scikit-plot is the result of an unartistic data scientist's dreadful realization that *visualization is one of the most crucial components in the data science process, not just a mere afterthought*. 18 | 19 | Gaining insights is simply a lot easier when you're looking at a colored heatmap of a confusion matrix complete with class labels rather than a single-line dump of numbers enclosed in brackets. Besides, if you ever need to present your results to someone (virtually any time anybody hires you to do data science), you show them visualizations, not a bunch of numbers in Excel. 20 | 21 | That said, there are a number of visualizations that frequently pop up in machine learning. Scikit-plot is a humble attempt to provide aesthetically-challenged programmers (such as myself) the opportunity to generate quick and beautiful graphs and plots with as little boilerplate as possible. 22 | 23 | ## Okay then, prove it. Show us an example. 24 | 25 | Say we use Naive Bayes in multi-class classification and decide we want to visualize the results of a common classification metric, the Area under the Receiver Operating Characteristic curve. Since the ROC is only valid in binary classification, we want to show the respective ROC of each class if it were the positive class. As an added bonus, let's show the micro-averaged and macro-averaged curve in the plot as well. 26 | 27 | Let's use scikit-plot with the sample digits dataset from scikit-learn. 28 | 29 | ```python 30 | # The usual train-test split mumbo-jumbo 31 | from sklearn.datasets import load_digits 32 | from sklearn.model_selection import train_test_split 33 | from sklearn.naive_bayes import GaussianNB 34 | 35 | X, y = load_digits(return_X_y=True) 36 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33) 37 | nb = GaussianNB() 38 | nb.fit(X_train, y_train) 39 | predicted_probas = nb.predict_proba(X_test) 40 | 41 | # The magic happens here 42 | import matplotlib.pyplot as plt 43 | import scikitplot as skplt 44 | skplt.metrics.plot_roc(y_test, predicted_probas) 45 | plt.show() 46 | ``` 47 | ![roc_curves](examples/roc_curves.png) 48 | 49 | Pretty. 50 | 51 | And... That's it. Encaptured in that small example is the entire philosophy of Scikit-plot: **single line functions for detailed visualization**. You simply browse the plots available in the documentation, and call the function with the necessary arguments. Scikit-plot tries to stay out of your way as much as possible. No unnecessary bells and whistles. And when you *do* need the bells and whistles, each function offers a myriad of parameters for customizing various elements in your plots. 52 | 53 | Finally, compare and [view the non-scikit-plot way of plotting the multi-class ROC curve](http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html). Which one would you rather do? 54 | 55 | ## Maximum flexibility. Compatibility with non-scikit-learn objects. 56 | 57 | Although Scikit-plot is loosely based around the scikit-learn interface, you don't actually need Scikit-learn objects to use the available functions. As long as you provide the functions what they're asking for, they'll happily draw the plots for you. 58 | 59 | Here's a quick example to generate the precision-recall curves of a Keras classifier on a sample dataset. 60 | 61 | ```python 62 | # Import what's needed for the Functions API 63 | import matplotlib.pyplot as plt 64 | import scikitplot as skplt 65 | 66 | # This is a Keras classifier. We'll generate probabilities on the test set. 67 | keras_clf.fit(X_train, y_train, batch_size=64, nb_epoch=10, verbose=2) 68 | probas = keras_clf.predict_proba(X_test, batch_size=64) 69 | 70 | # Now plot. 71 | skplt.metrics.plot_precision_recall_curve(y_test, probas) 72 | plt.show() 73 | ``` 74 | ![p_r_curves](examples/p_r_curves.png) 75 | 76 | You can see clearly here that `skplt.metrics.plot_precision_recall_curve` needs only the ground truth y-values and the predicted probabilities to generate the plot. This lets you use *anything* you want as the classifier, from Keras NNs to NLTK Naive Bayes to that groundbreaking classifier algorithm you just wrote. 77 | 78 | The possibilities are endless. 79 | 80 | ## Installation 81 | 82 | Installation is simple! First, make sure you have the dependencies [Scikit-learn](http://scikit-learn.org) and [Matplotlib](http://matplotlib.org/) installed. 83 | 84 | Then just run: 85 | ```bash 86 | pip install scikit-plot 87 | ``` 88 | 89 | Or if you want the latest development version, clone this repo and run 90 | ```bash 91 | python setup.py install 92 | ``` 93 | at the root folder. 94 | 95 | If using conda, you can install Scikit-plot by running: 96 | ```bash 97 | conda install -c conda-forge scikit-plot 98 | ``` 99 | 100 | ## Documentation and Examples 101 | 102 | Explore the full features of Scikit-plot. 103 | 104 | You can find detailed documentation [here](http://scikit-plot.readthedocs.io). 105 | 106 | Examples are found in the [examples folder of this repo](examples/). 107 | 108 | ## Contributing to Scikit-plot 109 | 110 | Reporting a bug? Suggesting a feature? Want to add your own plot to the library? Visit our [contributor guidelines](CONTRIBUTING.md). 111 | 112 | ## Citing Scikit-plot 113 | 114 | Are you using Scikit-plot in an academic paper? You should be! Reviewers love eye candy. 115 | 116 | If so, please consider citing Scikit-plot with DOI [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.293191.svg)](https://doi.org/10.5281/zenodo.293191) 117 | 118 | #### APA 119 | 120 | > Reiichiro Nakano. (2018). reiinakano/scikit-plot: 0.3.7 [Data set]. Zenodo. http://doi.org/10.5281/zenodo.293191 121 | 122 | #### IEEE 123 | 124 | > [1]Reiichiro Nakano, “reiinakano/scikit-plot: 0.3.7”. Zenodo, 19-Feb-2017. 125 | 126 | #### ACM 127 | 128 | > [1]Reiichiro Nakano 2018. reiinakano/scikit-plot: 0.3.7. Zenodo. 129 | 130 | Happy plotting! 131 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = Scikit-plot 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/Quickstart.rst: -------------------------------------------------------------------------------- 1 | .. Quickstart file describing a quick plot with scikit-plot 2 | 3 | 4 | First steps with Scikit-plot 5 | ============================ 6 | 7 | Eager to use Scikit-plot? Let's get started! This section of the documentation will teach you the basic philosophy behind Scikit-plot by running you through a quick example. 8 | 9 | Installation 10 | ------------ 11 | 12 | Before anything else, make sure you've installed the latest version of Scikit-plot. Scikit-plot is on PyPi, so simply run:: 13 | 14 | $ pip install scikit-plot 15 | 16 | to install the latest version. 17 | 18 | Alternatively, you can clone the `source repository `_ and run:: 19 | 20 | $ python setup.py install 21 | 22 | at the root folder. 23 | 24 | Scikit-plot depends on `Scikit-learn `_ and `Matplotlib `_ to do its magic, so make sure you have them installed as well. 25 | 26 | Your First Plot 27 | --------------- 28 | 29 | For our quick example, let's show how well a Random Forest can classify the digits dataset bundled with Scikit-learn. A popular way to evaluate a classifier's performance is by viewing its confusion matrix. 30 | 31 | Before we begin plotting, we'll need to import the following for Scikit-plot:: 32 | 33 | >>> import matplotlib.pyplot as plt 34 | 35 | :mod:`matplotlib.pyplot` is used by Matplotlib to make plotting work like it does in MATLAB and deals with things like axes, figures, and subplots. But don't worry. Unless you're an advanced user, you won't need to understand any of that while using Scikit-plot. All you need to remember is that we use the :func:`matplotlib.pyplot.show` function to show any plots generated by Scikit-plot. 36 | 37 | Let's begin by generating our sample digits dataset:: 38 | 39 | >>> from sklearn.datasets import load_digits 40 | >>> X, y = load_digits(return_X_y=True) 41 | 42 | Here, ``X`` and ``y`` contain the features and labels of our classification dataset, respectively. 43 | 44 | We'll proceed by creating an instance of a RandomForestClassifier object from Scikit-learn with some initial parameters:: 45 | 46 | >>> from sklearn.ensemble import RandomForestClassifier 47 | >>> random_forest_clf = RandomForestClassifier(n_estimators=5, max_depth=5, random_state=1) 48 | 49 | Let's use :func:`sklearn.model_selection.cross_val_predict` to generate predicted labels on our dataset:: 50 | 51 | >>> from sklearn.model_selection import cross_val_predict 52 | >>> predictions = cross_val_predict(random_forest_clf, X, y) 53 | 54 | For those not familiar with what :func:`cross_val_predict` does, it generates cross-validated estimates for each sample point in our dataset. Comparing the cross-validated estimates with the true labels, we'll be able to get evaluation metrics such as accuracy, precision, recall, and in our case, the confusion matrix. 55 | 56 | To plot and show our confusion matrix, we'll use the function :func:`~scikitplot.metrics.plot_confusion_matrix`, passing it both the true labels and predicted labels. We'll also set the optional argument ``normalize=True`` so the values displayed in our confusion matrix plot will be from the range [0, 1]. Finally, to show our plot, we'll call ``plt.show()``. 57 | 58 | >>> import scikitplot as skplt 59 | >>> skplt.metrics.plot_confusion_matrix(y, predictions, normalize=True) 60 | 61 | >>> plt.show() 62 | 63 | .. image:: _static/quickstart_plot_confusion_matrix.png 64 | :align: center 65 | :alt: Confusion matrix 66 | 67 | And that's it! A quick glance of our confusion matrix shows that our classifier isn't doing so well with identifying the digits 1, 8, and 9. Hmm. Perhaps a bit more tweaking of our Random Forest's hyperparameters is in order. 68 | 69 | One more example 70 | ---------------- 71 | 72 | Finally, let's show an example wherein we *don't* use Scikit-learn. 73 | 74 | Here's a quick example to generate the precision-recall curves of a Keras classifier on a sample dataset. 75 | 76 | >>> # Import what's needed for the Functions API 77 | >>> import matplotlib.pyplot as plt 78 | >>> import scikitplot as skplt 79 | >>> # This is a Keras classifier. We'll generate probabilities on the test set. 80 | >>> keras_clf.fit(X_train, y_train, batch_size=64, nb_epoch=10, verbose=2) 81 | >>> probas = keras_clf.predict_proba(X_test, batch_size=64) 82 | >>> # Now plot. 83 | >>> skplt.metrics.plot_precision_recall_curve(y_test, probas) 84 | 85 | >>> plt.show() 86 | 87 | .. image:: _static/quickstart_plot_precision_recall_curve.png 88 | :align: center 89 | :alt: Precision Recall Curves 90 | 91 | And again, that's it! As in the example above, all we needed to do was pass the ground truth labels and predicted probabilities to :func:`~scikitplot.metrics.plot_precision_recall_curve` to generate the precision-recall curves. This means you can use literally any classifier you want to generate the precision-recall curves, from Keras classifiers to NLTK Naive Bayes to XGBoost, as long as you pass in the predicted probabilities in the correct format. 92 | 93 | Now what? 94 | --------- 95 | 96 | The recommended way to start using Scikit-plot is to just go through the documentation for the various modules and choose which plots you think would be useful for your work. 97 | 98 | Happy plotting! 99 | -------------------------------------------------------------------------------- /docs/_static/examples/plot_calibration_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/examples/plot_calibration_curve.png -------------------------------------------------------------------------------- /docs/_static/examples/plot_confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/examples/plot_confusion_matrix.png -------------------------------------------------------------------------------- /docs/_static/examples/plot_cumulative_gain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/examples/plot_cumulative_gain.png -------------------------------------------------------------------------------- /docs/_static/examples/plot_elbow_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/examples/plot_elbow_curve.png -------------------------------------------------------------------------------- /docs/_static/examples/plot_feature_importances.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/examples/plot_feature_importances.png -------------------------------------------------------------------------------- /docs/_static/examples/plot_ks_statistic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/examples/plot_ks_statistic.png -------------------------------------------------------------------------------- /docs/_static/examples/plot_learning_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/examples/plot_learning_curve.png -------------------------------------------------------------------------------- /docs/_static/examples/plot_lift_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/examples/plot_lift_curve.png -------------------------------------------------------------------------------- /docs/_static/examples/plot_pca_2d_projection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/examples/plot_pca_2d_projection.png -------------------------------------------------------------------------------- /docs/_static/examples/plot_pca_component_variance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/examples/plot_pca_component_variance.png -------------------------------------------------------------------------------- /docs/_static/examples/plot_precision_recall_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/examples/plot_precision_recall_curve.png -------------------------------------------------------------------------------- /docs/_static/examples/plot_roc_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/examples/plot_roc_curve.png -------------------------------------------------------------------------------- /docs/_static/examples/plot_silhouette.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/examples/plot_silhouette.png -------------------------------------------------------------------------------- /docs/_static/quickstart_plot_confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/quickstart_plot_confusion_matrix.png -------------------------------------------------------------------------------- /docs/_static/quickstart_plot_confusion_matrix2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/quickstart_plot_confusion_matrix2.png -------------------------------------------------------------------------------- /docs/_static/quickstart_plot_precision_recall_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/quickstart_plot_precision_recall_curve.png -------------------------------------------------------------------------------- /docs/_static/readme_collage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/docs/_static/readme_collage.jpg -------------------------------------------------------------------------------- /docs/apidocs.rst: -------------------------------------------------------------------------------- 1 | .. apidocs file containing the API Documentation 2 | .. _factoryapidocs: 3 | 4 | Factory API Reference 5 | ===================== 6 | 7 | This document contains the plotting methods that are embedded into scikit-learn objects by the factory functions :func:`~scikitplot.clustering_factory` and :func:`~scikitplot.classifier_factory`. 8 | 9 | .. admonition:: Important Note 10 | 11 | If you want to use stand-alone functions and not bother with the factory functions, view the :ref:`functionsapidocs` instead. 12 | 13 | Classifier Plots 14 | ---------------- 15 | 16 | .. autofunction:: scikitplot.classifier_factory 17 | 18 | .. automodule:: scikitplot.classifiers 19 | :members: plot_learning_curve, plot_confusion_matrix_with_cv, plot_roc_curve_with_cv, plot_ks_statistic_with_cv, plot_precision_recall_curve_with_cv, plot_feature_importances 20 | 21 | Clustering Plots 22 | ---------------- 23 | 24 | .. autofunction:: scikitplot.clustering_factory 25 | 26 | .. automodule:: scikitplot.clustering 27 | :members: plot_silhouette, plot_elbow_curve -------------------------------------------------------------------------------- /docs/cluster.rst: -------------------------------------------------------------------------------- 1 | .. apidocs file containing the API Documentation 2 | .. _clusterdocs: 3 | 4 | Clusterer Module (API Reference) 5 | ================================ 6 | 7 | .. automodule:: scikitplot.cluster 8 | :members: plot_elbow_curve -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Scikit-plot documentation build configuration file, created by 4 | # sphinx-quickstart on Sun Feb 12 17:56:21 2017. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | # If extensions (or modules to document with autodoc) are in another directory, 16 | # add these directories to sys.path here. If the directory is relative to the 17 | # documentation root, use os.path.abspath to make it absolute, like shown here. 18 | # 19 | import os 20 | import sys 21 | # sys.path.insert(0, os.path.abspath('.')) 22 | sys.path.insert(0, os.path.abspath('../')) 23 | 24 | # -- General configuration ------------------------------------------------ 25 | 26 | # If your documentation needs a minimal Sphinx version, state it here. 27 | # 28 | # needs_sphinx = '1.0' 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'sphinx.ext.autodoc', 35 | 'sphinx.ext.doctest', 36 | 'sphinx.ext.napoleon' 37 | ] 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ['_templates'] 41 | 42 | # The suffix(es) of source filenames. 43 | # You can specify multiple suffix as a list of string: 44 | # 45 | # source_suffix = ['.rst', '.md'] 46 | source_suffix = '.rst' 47 | 48 | # The master toctree document. 49 | master_doc = 'index' 50 | 51 | # General information about the project. 52 | project = u'Scikit-plot' 53 | copyright = u'2017, Reiichiro S. Nakano' 54 | author = u'Reiichiro S. Nakano' 55 | 56 | # The version info for the project you're documenting, acts as replacement for 57 | # |version| and |release|, also used in various other places throughout the 58 | # built documents. 59 | # 60 | # The short X.Y version. 61 | version = u'' 62 | # The full version, including alpha/beta/rc tags. 63 | release = u'' 64 | 65 | # The language for content autogenerated by Sphinx. Refer to documentation 66 | # for a list of supported languages. 67 | # 68 | # This is also used if you do content translation via gettext catalogs. 69 | # Usually you set "language" from the command line for these cases. 70 | language = None 71 | 72 | # List of patterns, relative to source directory, that match files and 73 | # directories to ignore when looking for source files. 74 | # This patterns also effect to html_static_path and html_extra_path 75 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 76 | 77 | # The name of the Pygments (syntax highlighting) style to use. 78 | pygments_style = 'sphinx' 79 | 80 | # If true, `todo` and `todoList` produce output, else they produce nothing. 81 | todo_include_todos = False 82 | 83 | 84 | # -- Options for HTML output ---------------------------------------------- 85 | 86 | # The theme to use for HTML and HTML Help pages. See the documentation for 87 | # a list of builtin themes. 88 | # 89 | html_theme = 'sphinx_rtd_theme' 90 | 91 | # Theme options are theme-specific and customize the look and feel of a theme 92 | # further. For a list of options available for each theme, see the 93 | # documentation. 94 | # 95 | # html_theme_options = {} 96 | 97 | # Add any paths that contain custom static files (such as style sheets) here, 98 | # relative to this directory. They are copied after the builtin static files, 99 | # so a file named "default.css" will overwrite the builtin "default.css". 100 | html_static_path = ['_static'] 101 | 102 | 103 | # -- Options for HTMLHelp output ------------------------------------------ 104 | 105 | # Output file base name for HTML help builder. 106 | htmlhelp_basename = 'Scikit-plotdoc' 107 | 108 | 109 | # -- Options for LaTeX output --------------------------------------------- 110 | 111 | latex_elements = { 112 | # The paper size ('letterpaper' or 'a4paper'). 113 | # 114 | # 'papersize': 'letterpaper', 115 | 116 | # The font size ('10pt', '11pt' or '12pt'). 117 | # 118 | # 'pointsize': '10pt', 119 | 120 | # Additional stuff for the LaTeX preamble. 121 | # 122 | # 'preamble': '', 123 | 124 | # Latex figure (float) alignment 125 | # 126 | # 'figure_align': 'htbp', 127 | } 128 | 129 | # Grouping the document tree into LaTeX files. List of tuples 130 | # (source start file, target name, title, 131 | # author, documentclass [howto, manual, or own class]). 132 | latex_documents = [ 133 | (master_doc, 'Scikit-plot.tex', u'Scikit-plot Documentation', 134 | u'Reiichiro S. Nakano', 'manual'), 135 | ] 136 | 137 | 138 | # -- Options for manual page output --------------------------------------- 139 | 140 | # One entry per manual page. List of tuples 141 | # (source start file, name, description, authors, manual section). 142 | man_pages = [ 143 | (master_doc, 'scikit-plot', u'Scikit-plot Documentation', 144 | [author], 1) 145 | ] 146 | 147 | 148 | # -- Options for Texinfo output ------------------------------------------- 149 | 150 | # Grouping the document tree into Texinfo files. List of tuples 151 | # (source start file, target name, title, author, 152 | # dir menu entry, description, category) 153 | texinfo_documents = [ 154 | (master_doc, 'Scikit-plot', u'Scikit-plot Documentation', 155 | author, 'Scikit-plot', 'One line description of project.', 156 | 'Miscellaneous'), 157 | ] 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /docs/decomposition.rst: -------------------------------------------------------------------------------- 1 | .. apidocs file containing the API Documentation 2 | .. _decompositiondocs: 3 | 4 | Decomposition Module (API Reference) 5 | ==================================== 6 | 7 | .. automodule:: scikitplot.decomposition 8 | :members: plot_pca_component_variance, plot_pca_2d_projection -------------------------------------------------------------------------------- /docs/estimators.rst: -------------------------------------------------------------------------------- 1 | .. apidocs file containing the API Documentation 2 | .. _estimatorssdocs: 3 | 4 | Estimators Module (API Reference) 5 | ================================= 6 | 7 | .. automodule:: scikitplot.estimators 8 | :members: plot_learning_curve, plot_feature_importances -------------------------------------------------------------------------------- /docs/functionsapidocs.rst: -------------------------------------------------------------------------------- 1 | .. apidocs file containing the API Documentation 2 | .. _functionsapidocs: 3 | 4 | Functions API Reference 5 | ======================= 6 | 7 | This document contains the stand-alone plotting functions for maximum flexibility. If you want to use factory functions :func:`~scikitplot.clustering_factory` and :func:`~scikitplot.classifier_factory`, use the :ref:`factoryapidocs` instead. 8 | 9 | .. automodule:: scikitplot.plotters 10 | :members: plot_learning_curve, plot_confusion_matrix, plot_roc_curve, plot_ks_statistic, plot_precision_recall_curve, plot_feature_importances, plot_silhouette, plot_elbow_curve, plot_pca_component_variance, plot_pca_2d_projection 11 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Scikit-plot documentation master file, created by 2 | sphinx-quickstart on Sun Feb 12 17:56:21 2017. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Scikit-plot's documentation! 7 | ======================================= 8 | 9 | The quickest and easiest way to go from analysis... 10 | --------------------------------------------------- 11 | 12 | .. image:: _static/readme_collage.jpg 13 | :align: center 14 | :alt: All plots 15 | 16 | ...to this. 17 | ----------- 18 | 19 | Scikit-plot is the result of an unartistic data scientist's dreadful realization that *visualization is one of the most crucial components in the data science process, not just a mere afterthought*. 20 | 21 | Gaining insights is simply a lot easier when you're looking at a colored heatmap of a confusion matrix complete with class labels rather than a single-line dump of numbers enclosed in brackets. Besides, if you ever need to present your results to someone (virtually any time anybody hires you to do data science), you show them visualizations, not a bunch of numbers in Excel. 22 | 23 | That said, there are a number of visualizations that frequently pop up in machine learning. Scikit-plot is a humble attempt to provide aesthetically-challenged programmers (such as myself) the opportunity to generate quick and beautiful graphs and plots with as little boilerplate as possible. 24 | 25 | .. toctree:: 26 | :maxdepth: 2 27 | :name: mastertoc 28 | 29 | First Steps with Scikit-plot 30 | Metrics Module 31 | Estimators Module 32 | Clusterer Module 33 | Decomposition Module 34 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=Scikit-plot 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/metrics.rst: -------------------------------------------------------------------------------- 1 | .. apidocs file containing the API Documentation 2 | .. _metricsdocs: 3 | 4 | Metrics Module (API Reference) 5 | ============================== 6 | 7 | .. automodule:: scikitplot.metrics 8 | :members: plot_confusion_matrix, plot_roc, plot_ks_statistic, plot_precision_recall, plot_silhouette, plot_calibration_curve, plot_cumulative_gain, plot_lift_curve -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: scikit-plot-env 2 | 3 | dependencies: 4 | - python=3 5 | - numpy 6 | - scipy 7 | - scikit-learn 8 | - matplotlib 9 | - joblib 10 | -------------------------------------------------------------------------------- /examples/jupyter_notebooks/plot_feature_importance.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### An example showing the plot_feature_importances method used by a scikit-learn classifier\n", 8 | "\n", 9 | "In this example, we'll be plotting the feature importances in a `RandomForestClassifier` for the Iris dataset. In order for this to work, we need to first create an instance of our classifier then fit it to our data." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Populating the interactive namespace from numpy and matplotlib\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "from sklearn.ensemble import RandomForestClassifier\n", 27 | "from sklearn.datasets import load_iris as load_data\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "\n", 30 | "# Import scikit-plot\n", 31 | "import scikitplot as skplt\n", 32 | "\n", 33 | "%pylab inline\n", 34 | "pylab.rcParams['figure.figsize'] = (12, 12)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", 46 | " max_depth=None, max_features='auto', max_leaf_nodes=None,\n", 47 | " min_impurity_split=1e-07, min_samples_leaf=1,\n", 48 | " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", 49 | " n_estimators=10, n_jobs=1, oob_score=False, random_state=1,\n", 50 | " verbose=0, warm_start=False)" 51 | ] 52 | }, 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "output_type": "execute_result" 56 | } 57 | ], 58 | "source": [ 59 | "# Load data\n", 60 | "X, y = load_data(return_X_y=True)\n", 61 | "\n", 62 | "# Create classifier instance and fit\n", 63 | "classifier = RandomForestClassifier(random_state=1)\n", 64 | "classifier.fit(X,y)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "data": { 74 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsYAAAK7CAYAAADx1EmqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3X+0rXdd2Pn3x1wQkAraXIeSEJOxESYiWnvFqmhvV6UF\nrA2OMIC1LKzKZEbKdM3QkeVaoyx12tKOs1oXaCa4spiKy1T5ZSrRoA4BFFoTFCIBQ2PAJtFK+Kko\nApHv/HF26OZ6k3OSnHNPcvN6rXVW9vNj7+d79n7uvu/zvc/OmbVWAABwf/c5hz0AAAC4NxDGAACQ\nMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDFwGpiZ983Mx2fmY1tfj7yHj3l8Zm7erzHu8Zgvn5kf\nOZXHvCMz86KZecVhjwPgVBLGwOniW9ZaD936+v3DHMzMHDnM498T9+WxA9wTwhg4rc3M35iZt8zM\nR2bmHTNzfGvbd87Mu2fmj2fmxpn5HzfrP6/6xeqR2zPQJ87onjirvJm5/r6Zubb6k5k5srnfq2bm\n1pl578w8f4/jPndm1maMN83Mh2fmopn56pm5dvP9vGRr/+fMzK/PzEtm5qMz8zsz87e3tj9yZi6f\nmQ/NzA0z8z1b2140M6+cmVfMzB9VF1XfXz1j872/486er+3nYmb+t5l5/8z8wcx859b2B8/Mj87M\n723G92sz8+DdXiOAU8msAHDampmzqtdV/7D6pepvV6+amcestW6t3l/9verG6hurX5yZq9davzkz\nT65esdY6e+vx9nLYZ1XfXH2g+nT176uf36w/u/qVmbl+rXXlHr+Nr6nO34zv8s338U3VA6rfmpmf\nW2u9cWvfV1ZnVv999eqZOW+t9aHqsuqd1SOrx1S/PDO/u9b6/zb3vbB6evXs6nM3j/FX11rfsTWW\nO3y+NtsfUT2sOqt6YvXKmXntWuvD1f9VfVn1ddV/2Yz103t4jQBOGTPGwOnitZsZx4/MzGs3676j\numKtdcVa69NrrV+urqmeUrXWet1a63fXjjdWr6++4R6O48fWWjettT5efXV1dK31Q2utT661bqxe\nVj3zLjzeD6+1/myt9frqT6qfWWu9f611S/Xm6q9t7fv+6l+vtT611vp31fXVN8/Mo6qvr75v81hv\nr36ynQi+3VvXWq/dPE8fP9lA9vB8far6oc3xr6g+Vj16Zj6n+kfV/7LWumWt9edrrbestT7RLq8R\nwKlkxhg4XTx1rfUrJ6z74urpM/MtW+seUL2hajMr/IPVl7YzUfCQ6rfv4ThuOuH4j5yZj2ytO6Od\noN2rP9y6/fGTLD90a/mWtdbaWv69dmaIH1l9aK31xydsO3YH4z6pPTxfH1xr3ba1/Keb8Z1ZPaj6\n3ZM87J2+RgCnkjAGTmc3VT+11vqeEzfMzOdWr2pn1vTn11qf2sw03369xDrxPu3M2D5ka/kRJ9ln\n+343Ve9da51/dwZ/N5w1M7MVx+e0c/nF71dfODN/aSuOz6lu2brvid/vZy3v4fm6Mx+o/qz6kuod\nJ2y7w9cI4FRzKQVwOntF9S0z83dn5oyZedDmQ2JnVw9s51raW6vbNrOhf2frvn9Y/eWZedjWurdX\nT5mZL5yZR1T/ZJfj/0b1x5sP5D14M4bHzsxX79t3+Nm+qHr+zDxgZp5e/XftXKZwU/WW6p9vnoPH\nVd/VzvNzR/6wOndzGUTt/nzdobXWp6tLq/978yHAM2bmazexfWevEcApJYyB09YmCC9s5/+wcGs7\ns5P/tPqczczp86ufrT5cfXs7s6u33/d3qp+pbtxct/zI6qfamfF8XzvX1/67XY7/5+18WO0rq/e2\nM3P6k+18QO0g/Md2Pqj3ger/rJ621vrgZtuzqnPbmT1+TfWDJ7n0ZNvPbf77wZn5zd2erz14QTuX\nXVxdfah6cTuvwx2+RnfhsQH2xXz25WgA3BfNzHOq715rPeGwxwJwX+UncgAASBgDAEDlUgoAAKjM\nGAMAQHWI/x/jM888c5177rmHdXgAAO4n3va2t31grXV0t/0OLYzPPffcrrnmmsM6PAAA9xMz83t7\n2c+lFAAAkDAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWM\nAQCg2mMYz8yTZub6mblhZl54B/scn5m3z8x1M/PG/R0mAAAcrCO77TAzZ1QvrZ5Y3VxdPTOXr7Xe\ntbXPw6sfr5601vrPM/NFBzVgAAA4CHuZMX58dcNa68a11iery6oLT9jn26tXr7X+c9Va6/37O0wA\nADhYewnjs6qbtpZv3qzb9qXVF8zMVTPztpl59skeaGaeOzPXzMw1t956690bMQAAHID9+vDdkeqv\nV99c/d3q/5iZLz1xp7XWJWutY2utY0ePHt2nQwMAwD236zXG1S3Vo7aWz96s23Zz9cG11p9UfzIz\nb6q+onrPvowSAAAO2F5mjK+uzp+Z82bmgdUzq8tP2OfnqyfMzJGZeUj1NdW793eoAABwcHadMV5r\n3TYzz6uurM6oLl1rXTczF222X7zWevfM/FJ1bfXp6ifXWu88yIEDAMB+mrXWoRz42LFj65prrjmU\nYwMAcP8xM29bax3bbT+/+Q4AABLGAABQCWMAAKiEMQAAVMIYAAAqYQwAAJUwBgCAShgD93LHjx/v\n+PHjhz0MAO4HhDEAACSMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCg\nEsYAAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCgEsbcixw/frzjx48f9jAAgPsp\nYQwAAAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEAAFTCGAAA\nKmEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoYAwBAJYwB\nAKASxgAAUAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEAAFTC\nGAAAKmEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoYAwBA\nJYwBAKASxgAAUAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEA\nAFTCGAAAqj2G8cw8aWaun5kbZuaFJ9l+fGY+OjNv33z9wP4PFQAADs6R3XaYmTOql1ZPrG6urp6Z\ny9da7zph1zevtf7eAYwRAAAO3F5mjB9f3bDWunGt9cnqsurCgx0WAACcWnsJ47Oqm7aWb96sO9HX\nzcy1M/OLM/Nl+zI6AAA4RXa9lGKPfrM6Z631sZl5SvXa6vwTd5qZ51bPrTrnnHP26dAAAHDP7WXG\n+JbqUVvLZ2/WfcZa64/WWh/b3L6iesDMnHniA621LllrHVtrHTt69Og9GDYAAOyvvYTx1dX5M3Pe\nzDywemZ1+fYOM/OImZnN7cdvHveD+z1YAAA4KLteSrHWum1mnlddWZ1RXbrWum5mLtpsv7h6WvU/\nzcxt1cerZ6611gGOGwAA9tWerjHeXB5xxQnrLt66/ZLqJfs7NAAAOHX85jsAAEgYAwBAJYwBAKAS\nxgAAUAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEAAFTCGAAA\nKmEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoYAwBAJYwB\nAKASxgAAUAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEAAFTC\nGAAAKmEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoYAwBA\nJYwBAKASxgAAUAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEA\nAFTCGAAAKmEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoY\nAwBAJYwBAKASxgAAUAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACo\nhDEAAFTCGAAAKmEMAADVHsN4Zp40M9fPzA0z88I72e+rZ+a2mXna/g0RAAAO3q5hPDNnVC+tnlxd\nUD1rZi64g/1eXL1+vwcJAAAHbS8zxo+vblhr3bjW+mR1WXXhSfb7x9Wrqvfv4/gAAOCU2EsYn1Xd\ntLV882bdZ8zMWdW3Vj9xZw80M8+dmWtm5ppbb731ro4VAAAOzH59+O5fV9+31vr0ne201rpkrXVs\nrXXs6NGj+3RoAAC4547sYZ9bqkdtLZ+9WbftWHXZzFSdWT1lZm5ba712X0YJAAAHbC9hfHV1/syc\n104QP7P69u0d1lrn3X57Zl5e/YIoBgDgvmTXMF5r3TYzz6uurM6oLl1rXTczF222X3zAYwQAgAO3\nlxnj1lpXVFecsO6kQbzWes49HxYAAJxafvMdAAAkjAEAoBLGAABQCWMAAKiEMQAAVMIYAAAqYQwA\nAJUwBgCAShgDcB90/Pjxjh8/ftjDAE4zwhgAABLGAABQCWMAAKiEMQAAVMIYAAAqYQwAAJUwBgCA\nShgDAEAljAEAoBLGAABQCWMAAKiEMQAAVMIYAAAqYQwAAJUwBgCAShgDAEAljAEAoBLGAABQ1ZHD\nHgD3wMxhj+BgnI7f11qHPQIAYBdmjAEAIGEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEA\nAFTCGAAAKmEMAACVXwkNp5fT8ddp3+50/N78qnCAexUzxgAAkDAGAIBKGAMAQCWMAQCgEsYAAFAJ\nYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAA\nlTAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCgEsYA\nAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACph\nDAAAlTAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAAKo9hvHMPGlmrp+ZG2bmhSfZfuHM\nXDszb5+Za2bmCfs/VAAAODhHdtthZs6oXlo9sbq5unpmLl9rvWtrt1+tLl9rrZl5XPWz1WMOYsAA\nAHAQ9jJj/PjqhrXWjWutT1aXVRdu77DW+thaa20WP69aAQDAfchewvis6qat5Zs36z7LzHzrzPxO\n9brqH53sgWbmuZtLLa659dZb7854AQDgQOzbh+/WWq9Zaz2memr1w3ewzyVrrWNrrWNHjx7dr0MD\nAMA9tpcwvqV61Nby2Zt1J7XWelP1387MmfdwbAAAcMrsJYyvrs6fmfNm5oHVM6vLt3eYmb86M7O5\n/VXV51Yf3O/BAgDAQdn1/0qx1rptZp5XXVmdUV261rpuZi7abL+4+rbq2TPzqerj1TO2PowHAAD3\neruGcdVa64rqihPWXbx1+8XVi/d3aAAAcOr4zXcAAJAwBgCAShgDAEAljAEAoBLGAABQCWMAAKiE\nMQAAVMIYAAAqYQwAAJUwBgCAShgDAEAljAEAoBLGAABQCWMAAKiEMQAAVMIYAAAqYQwAAJUwBgCA\nShgDAEAljAEAoBLGAABQCWMAAKiEMQAAVMIYAAAqYQwAAJUwBgCAShgDAEAljAEAoBLGAABQCWMA\nAKiEMQAAVMIYAAAqYQwAAJUwBgCAShgDAEAljAEAoBLGAABQCWMAAKiEMQAAVMIYAAAqYQwAAJUw\nBgCAShgDAEAljAEAoBLGAABQCWMAAKiEMQAAVMIYAAAqYQwAAJUwBgCAShgDAEAljAEAoBLGAABQ\nCWMAAKiEMQAAVMIYAAAqYQwAAJUwBgCAShgDAEAljAEAoBLGAABQCWMAAKiEMQAAVMIYAACqOnLY\nAwDggM0c9ggOzun4va112COA+y0zxgAAkDAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgA\nACphDAAAlTAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWM\nAQCgEsYAAFAJYwAAqPYYxjPzpJm5fmZumJkXnmT7P5iZa2fmt2fmLTPzFfs/VAAAODi7hvHMnFG9\ntHpydUH1rJm54ITd3lv9zbXWl1c/XF2y3wMFAICDtJcZ48dXN6y1blxrfbK6rLpwe4e11lvWWh/e\nLP6H6uz9HSYAABysvYTxWdVNW8s3b9bdke+qfvGeDAoAAE61I/v5YDPzt9oJ4yfcwfbnVs+tOuec\nc/bz0AAAcI/sZcb4lupRW8tnb9Z9lpl5XPWT1YVrrQ+e7IHWWpestY6ttY4dPXr07owXAAAOxF7C\n+Orq/Jk5b2YeWD2zunx7h5k5p3p19Q/XWu/Z/2ECAMDB2vVSirXWbTPzvOrK6ozq0rXWdTNz0Wb7\nxdUPVH+5+vGZqbptrXXs4IYNAAD7a0/XGK+1rqiuOGHdxVu3v7v67v0dGgAAnDp+8x0AACSMAQCg\nEsYAAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgA\nACphDAAAlTAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWM\nAQCgEsYAAFAJYwAAqIQxAABUwhgAAKo6ctgDgNtdddgDAADu18wYAwBAwhgAACphDAAAlTAGAIBK\nGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCgEsYAAFAJYwAA\nqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAAlTAG\nAE4Tx48f7/jx44c9DO7DhDEAACSMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMA\nQCWMAQCgEsYAAFDVkcMeAMCdueqwBwDA/YYZYwAASBgDAEAljAEAoBLGAABQCWMAAKiEMQAAVMIY\nAAAqYQwAAJUwBgCAShgDAEAljAEAoBLGAABQCWMAAKiEMQAAVMIYAAAqYQwAAJUwBgCAao9hPDNP\nmpnrZ+aGmXnhSbY/ZmbeOjOfmJkX7P8wAQDgYB3ZbYeZOaN6afXE6ubq6pm5fK31rq3dPlQ9v3rq\ngYwSAAAO2F5mjB9f3bDWunGt9cnqsurC7R3WWu9fa11dfeoAxggAAAduL2F8VnXT1vLNm3V32cw8\nd2aumZlrbr311rvzEAAAcCBO6Yfv1lqXrLWOrbWOHT169FQeGgAA7tRewviW6lFby2dv1gEAwGlj\nL2F8dXX+zJw3Mw+snlldfrDDAgCAU2vX/yvFWuu2mXledWV1RnXpWuu6mblos/3imXlEdU31+dWn\nZ+afVBestf7oAMcOAAD7ZtcwrlprXVFdccK6i7du/5d2LrEAAID7JL/5DgAAEsYAAFAJYwAAqIQx\nAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBK\nGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAAVR057AEAwF111WEPADgtmTEGAICEMQAA\nVMIYAAAqYQwAAJUwBgCAShgDAEAljAEAoBLGAABQCWMAAKiEMQAAVMIYAAAqYQwAAJUwBgCAShgD\nAEAljAEAoBLGAABQCWMAAKjqyGEPAAA4BDOHPYKDc7p9b2sd9gjuN8wYAwBAwhgAACphDAAAlTAG\nAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCgEsYAAFAJ\nYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAA\nlTAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCgEsYA\nAFAJYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCg2mMYz8yTZub6mblhZl54ku0zMz+2\n2X7tzHzV/g8VAAAOzq5hPDNnVC+tnlxdUD1rZi44YbcnV+dvvp5b/cQ+jxMAAA7UXmaMH1/dsNa6\nca31yeqy6sIT9rmw+rdrx3+oHj4zf2WfxwoAAAfmyB72Oau6aWv55upr9rDPWdUfbO80M89tZ0a5\nc845566OlROtddgj4N7GOcHJOC84mdPxvDh+fOe/V111mKPgPuyUfvhurXXJWuvYWuvY0aNHT+Wh\nAQDgTu0ljG+pHrW1fPZm3V3dBwAA7rX2EsZXV+fPzHkz88DqmdXlJ+xzefXszf+d4m9UH11r/cGJ\nDwQAAPdWu15jvNa6bWaeV11ZnVFduta6bmYu2my/uLqiekp1Q/Wn1Xce3JABAGD/7eXDd621rmgn\nfrfXXbx1e1Xfu79DAwCAU8dvvgMAgIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCgEsYAAFAJ\nYwAAqIQxAABUwhgAACphDAAAlTAGAIBKGAMAQCWMAQCgEsYAAFAJYwAAqIQxAABUwhgAACphDAAA\nlTAGAIBKGAMAQFVHDnsAAAD74aqrrjrsIXAfZ8YYAAASxgAAUAljAACohDEAAFTCGAAAKmEMAACV\nMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoYAwBAJYwBAKASxgAA\nUAljAACohDEAAFTCGAAAKmEMAACVMAYAgEoYAwBAJYwBAKASxgAAUAljAACohDEAAFTCGAAAqpq1\n1uEceObW6vcO5eDcm51ZfeCwB8G9jvOCk3FecDLOC07mi9daR3fb6dDCGE5mZq5Zax077HFw7+K8\n4GScF5yM84J7wqUUAACQMAYAgEoYc+9zyWEPgHsl5wUn47zgZJwX3G2uMQYAgMwYAwBAJYwBAKAS\nxhyimTk+M79wN+73yJl55R1su2pmjm1uf//W+nNn5p13f7QctDs6H+7uebKH4z11Zi7YWv7MucPB\nm5nnzMwj97Dfy2fmaXtdvw/j8r5xL3BPz4893O+imXn2SdZ/5jWfma+cmadsbXvRzLzgrh6L+xZh\nzH3OWuv311p7eSP8/t134X7sqdUFu+7FQXlOtWv4HALvG/cOz+kAz4+11sVrrX+7y25fWT1ll304\nzQhj7tDMfN7MvG5m3jEz75yZZ2zW//WZeePMvG1mrpyZv7JZf9XM/JuZeftm/8dv1j9+Zt46M781\nM2+ZmUfvctzXzczjNrd/a2Z+YHP7h2bme074if7BM3PZzLx7Zl5TPXiz/l9UD96M5ac3D33GzLxs\nZq6bmdfPzIMP4Gk7bR3W+XCSMVw6M7+xuf+Fm/XPmZlXz8wvzcx/mpl/uXWf75qZ92zu87KZecnM\nfF3196t/tRnfl2x2f/pmv/fMzDfs01N32tv8mfydmfnpzZ/FV87MQzbb/sL5sZnhO1b99Ob5f/DM\n/MDMXL05Vy6ZmbkLx7+zc/DFJ76mM/OQmfnZmXnXzLxmZv7jzBzzvnEwTvX5MTNfNDNv29z+iplZ\nM3POZvl3N6//Z2Z/N2N4x8y8o/rezboHVj9UPWMzhmdsHv6CzXl148w8/6CeMw7RWsuXr5N+Vd9W\nvWxr+WHVA6q3VEc3655RXbq5fdXt+1ffWL1zc/vzqyOb299UvWpz+3j1Cyc57gvbeXN6WHV1deVm\n/RuqR1fnbj32/7p1/MdVt1XHNssf23rMczfbvnKz/LPVdxz2c3xf+jrE8+Ez66t/dvvrVj28ek/1\nee3MLt24GdOD2vl1849qZ8bpfdUXbsb65uolm/u/vHra1nGuqn50c/sp1a8c9nN+X/na/Pla1ddv\nli+tXrCH8+PY1mN84dbtn6q+5WSv09Y+L6+etodj/IXXdDO2/2dz+7HeN07L8+O6dt5rntfO3yP/\noPri6q2b7S+qXrC5fW31jZvb/6r/+l71nNvfL7bu85bqc9v5tdMfrB5w2M+vr/39OhLcsd+ufnRm\nXtxOmLx5Zh7bzl8kv7z5gf2M6g+27vMzVWutN83M58/Mw6u/VP2/M3N+O2+OD9jluG+unl+9t3pd\n9cTN7MJ5a63rZ+bcrX2/sfqxzTGvnZlr7+Rx37vWevvm9tvaebNm7w7rfNj2d6q/P//1Or8HVeds\nbv/qWuujVTPzrnb+EjyzeuNa60Ob9T9XfemdPP6rN/91ftx1N621fn1z+xXt/Bn+pe78/Nj2t2bm\nf68e0s4PMtdV/34Px330Lsc42Wv6hOrfVK213ul945Q41efHW6qvb+fviH9WPamadv5++YzNe9LD\n11pv2qz6qerJd/K4r1trfaL6xMy8v/pvqpvvZH/uY4Qxd2it9Z6Z+ap2Zlp+ZGZ+tXpNdd1a62vv\n6G4nWf7h6g1rrW/dRO1Vuxz66nb+Ge3G6pfbiZvvaecvpXviE1u3/7zNZRfszSGeD9um+ra11vWf\ntXLma/qLr+/deX+7/THu7v3vz072Wk93fn5UNTMPqn68nRnCm2bmRe380LMXux3jnr6m3jf2x6k+\nP95UfUM7PyD/fPV9m2O+7q4P/bPsx/sM92KuMeYOzc4ngv90rfWKdv556auq66ujM/O1m30eMDNf\ntnW32687fUL10c0M3sOqWzbbn7Pbcddan6xuqp5evbWdn/Bf0M4b3YneVH375piPbedyitt9ambu\nymwkd+KwzocTXFn949uvL5yZv7bL/ldXf3NmvmBmjrRzOcjt/rid2Wv2xzm3nwft/Jn8te78/Nh+\n/m+PnA9FIGDZAAABwUlEQVTMzEPbuURir3Y7B0/m16v/YbP/BdWXb23zvnEwTvX58ebqO6r/tNb6\ndPWhdn6o/7XtndZaH6k+snmPqp1LLm7nPeJ+SBhzZ768+o2ZeXv1g9WPbKL1adWLNx9UeHv1dVv3\n+bOZ+a3q4uq7Nuv+ZfXPN+v3+tP1m6v3r7U+vrl9dif8E9jGT1QPnZl3t/NBie1Z5Uuqa7c+RMM9\nc5jnw+1+uJ1LL66dmes2y3dorXVLO/+M+hvtxND7qo9uNl9W/dPZ+RDfl5z8EbgLrq++d/Nn8Quq\nn9jl/Hh5dfHmfPpE9bLqne388HP1Xg+6h3PwZH68nSB7V/Uj7fyz/O3nhfeNg3FKz4+11vvamZG+\nfULl16qPrLU+fJLdv7N66eZY2x/qe0M7H7bb/vAdpzm/Epp9MzNXtfNhhmsOeywcvnvL+TAzD11r\nfWwzY/yadj7c85rDHNPpZnNJzC+stR57yEPZk5k5o50PTf3Z5oeiX6kevQk19tl97fzg/s21McDp\n7kUz803t/HPs66vXHvJ4OHwPqd6wuWRiqv9ZFANlxhgAACrXGAMAQCWMAQCgEsYAAFAJYwAAqIQx\nAABU9f8DEwy8J7uC8AcAAAAASUVORK5CYII=\n", 75 | "text/plain": [ 76 | "" 77 | ] 78 | }, 79 | "metadata": {}, 80 | "output_type": "display_data" 81 | } 82 | ], 83 | "source": [ 84 | "skplt.estimators.plot_feature_importances(classifier, feature_names=['petal length', 'petal width',\n", 85 | " 'sepal length', 'sepal width'])\n", 86 | "plt.show()" 87 | ] 88 | } 89 | ], 90 | "metadata": { 91 | "kernelspec": { 92 | "display_name": "Python 3", 93 | "language": "python", 94 | "name": "python3" 95 | }, 96 | "language_info": { 97 | "codemirror_mode": { 98 | "name": "ipython", 99 | "version": 3 100 | }, 101 | "file_extension": ".py", 102 | "mimetype": "text/x-python", 103 | "name": "python", 104 | "nbconvert_exporter": "python", 105 | "pygments_lexer": "ipython3", 106 | "version": "3.6.1" 107 | } 108 | }, 109 | "nbformat": 4, 110 | "nbformat_minor": 1 111 | } 112 | -------------------------------------------------------------------------------- /examples/jupyter_notebooks/plot_lift_curve.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### An example showing the plot_lift_curve method \n", 8 | "\n", 9 | "In this example, we'll be plotting a lift curve to describe the classification performance of a `LogisticRegression` classifier using the breast cncer dataset from scikit-learn. Here, we'll be using the `scikitplot.metrics.plot_lift_curve` method." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 3, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%matplotlib inline\n", 19 | "\"\"\"\n", 20 | "An example showing the plot_lift_curve method used\n", 21 | "by a scikit-learn classifier\n", 22 | "\"\"\"\n", 23 | "from __future__ import absolute_import\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "from sklearn.linear_model import LogisticRegression\n", 26 | "from sklearn.datasets import load_breast_cancer as load_data\n", 27 | "# Import scikit-plot\n", 28 | "import scikitplot as skplt" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 4, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEWCAYAAAB1xKBvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAIABJREFUeJzt3Xd8FVX++P/XO42QQofQCUiTIiUgoK4GUUF0ZV39KljQtSBrWdtaVv24urrruq66a1kVwbX8FESxy8q6CmKhBaR3aSb0FtLbff/+mOGShJSbkJtJct/Px+M+cufMmZn3uYH7zsyZOUdUFWOMMaYyYV4HYIwxpn6whGGMMSYgljCMMcYExBKGMcaYgFjCMMYYExBLGMYYYwJiCcOYYkTkFyKyodhyLxFZLiIZIvI7L2MzxmuWMExIEpFtInJO6XJV/VZVexUruheYq6rxqvqciLwuIo9Xsm8Rkd+JyGoRyRKRVBF5T0T613Q7jKlNljCMqVgXYE0Vt/kncDvwO6AF0BP4CLigqgcXkYiqbmNMsFjCMKYYEUkWkVT3/dfASOAFEckUkUnAlcC97vKnZWzfA7gFmKCqX6tqnqpmq+rbqvpXt848Ebmh2DbXish3xZZVRG4RkU3AJhF5SUT+Xuo4H4vIXe779iIyS0T2ichWu3RmgsUShjHlUNWzgW+BW1U1TlWnAG8Df3OXf1nGZqOAVFVdfIKH/xUwDOgDTAcuFxEBEJHmwHnADBEJAz4FVgAd3OPfISKjT/D4xhzHEoYxNaslsKsG9vOEqh5U1RycpKXAL9x1lwILVHUnMBRorap/UtV8Vd0CvAqMr4EYjCnBro8aU7MOAO1qYD8/H32jqioiM4AJwHzgCuD/c1d3AdqLyOFi24bjJBljapSdYRhTNZUN7/wV0FFEhlRQJwuIKbbcNoDjTAcuFZEuOJeqZrnlPwNbVbVZsVe8qo6tJE5jqswShgllkSISXewVyBn3HqBbeStVdRPwL2C624Ee5e57vIjc71ZbDvxaRGJEpDtwfWUHVdUfgf3AVGCOqh49o1gMZIjIfSLSWETCRaSfiAwNoC3GVIklDBPKZgM5xV6PBLDNNKCPiBwWkY/KqfM74AXgReAw8BNwMU7nNMCzQD5O8nkDpyM9EO8A57g/AVDVIuBCYCCwlWNJpWmA+zQmYGITKBljjAmEnWEYY4wJiCUMY4wxAbGEYYwxJiBBew5DRDoBbwIJOLcITlHVf5aqcw/OUAtHYzkZ5yGkgyKyDcgAioBCVa3oNkVjjDFBFrRObxFpB7RT1WUiEg8sBX6lqmvLqf9L4E53OAbchDFEVfcHesxWrVppYmJiteLNysoiNja2WtvWV9bmhi/U2gvW5qpaunTpflVtHUjdoJ1hqOou3CESVDVDRNbhjHVTZsLAeYp1+okcMzExkZSUlGptO2/ePJKTk0/k8PWOtbnhC7X2grW5qkRke8B1a+O2WhFJxBnSoJ+qHiljfQyQCnRX1YNu2VbgEM7lrFfcgd/K2vckYBJAQkJC0owZM6oVY2ZmJnFxcdXatr6yNjd8odZesDZX1ciRI5cGfMlfVYP6AuJwLkf9uoI6lwOflirr4P5sgzMS55mVHSspKUmra+7cudXetr6yNjd8odZeVWtzVQEpGuD3eVDvkhKRSJwxb95W1Q8qqDqeUpejVDXN/bkX+BA4NVhxGmOMqVzQEoY7dv80YJ2qPlNBvabAWcDHxcpi3Y5yRCQWZ+z/1cGK1RhjTOWCObz56cDVwCoRWe6WPQB0BlDVl92yi4H/qmpWsW0TgA/d+WIigHdU9YsgxmqMMaYSwbxL6jtAAqj3OvB6qbItwICgBGaMMaZaQn4Cpay8Qlb8fJi1B4qI2hzwIx8NQl1rc/tmjUlsFVr3zxtTn4R8wkg9lMMVUxc5C0sWeRuMF+pYm5+9fAAXD+rodRjGmDKEfMIwdcud765gZWo6w7u1ZHTfsiaiM8Z4JeQTRkxUOCO6teTw4UM0a9bc63BqVV1pc0ZeAavTjj3P+e/vt/Hv77cx67enkdTF+/iMMY6QTxidWsQwfdJw99H64V6HU6vqUptvm/4jn67YWaLskpd+IKlLc6ZOHELz2CiPIjPGHGXDm5s64S8X9+PJS/ozqnebEuVLtx/i3Ge/Yf7GfUdHADDGeMQShqkT4qMjuXxoZ16dOIRzTi6ZNPZn5jPxtcW8vWiHR9EZY8AShqljwsKEqdcMZfEDo4gML/kYz1sLAh5U0xgTBJYwTJ3Upkk0MyaN4IJT2vnLNuzJYOOeDA+jMia0WcIwdVZSl+a8eMVgxvY/dnvtZ6U6xo0xtccShqnzfnlKe//7577ezPYDWRXUNsYEiyUMU+eN7N2G2Khw//K976/0MBpjQpclDFPnRUeGc9nQTv7lxdsOsjs918OIjAlNljBMvfB/F/ShSbTznKkqnPPMN/h89lyGMbXJEoapF8LChIcu6ONfzswrZNyL3/OP/21k8167c8qY2mAJw9QbY4vdYguwKi2df/xvE+c8M5/cgiKPojImdFjCMPVGXKMI5txxZpnrhv3lK7tEZUyQWcIw9UqvtvG8fFUSvxzQvkR5ek4BC7cc8CgqY0KDJQxT74zp15bnJwxizaOjS5Q/OWcD976/gplLfvYoMmMaNksYpt6KbRTBZ7ed4V9e8fNhZqakcu+slXy3qe5MPWtMQ2EJw9Rrfds3oX+HpseVv73IBio0pqYFLWGISCcRmSsia0VkjYjcXkadZBFJF5Hl7uvhYuvGiMgGEdksIvcHK05Tv4kIL1wxiGtPS+Sq4Z395V+u3cP+zDwPIzOm4QnmjHuFwN2qukxE4oGlIvKlqq4tVe9bVb2weIGIhAMvAucCqcASEfmkjG2NoUvLWB65qC8A63ZlsHT7IQp9yuhn57PkwXMIC5NK9mCMCUTQzjBUdZeqLnPfZwDrgA4Bbn4qsFlVt6hqPjADGBecSE1DMr7YECIHsvI5/cmv2WRDohtTI6Q2pr0UkURgPtBPVY8UK08GZuGcRewEfq+qa0TkUmCMqt7g1rsaGKaqt5ax70nAJICEhISkGTNmVCvGzMxM4uLiqrVtfdUQ25xXqNz0v+wSZU0bCX/7RWMaRUiDbHNFQq29YG2uqpEjRy5V1SGB1A3mJSkARCQOJyncUTxZuJYBXVQ1U0TGAh8BPaqyf1WdAkwBGDJkiCYnJ1crznnz5lHdbeurhtrmL/tncO6z8/3L6XnKoSYnMf7Uzg22zeUJtfaCtTmYgnqXlIhE4iSLt1X1g9LrVfWIqma672cDkSLSCkgDOhWr2tEtM6ZSPRLimfXb00qUvbVwO7VxNm1MQxbMu6QEmAasU9VnyqnT1q2HiJzqxnMAWAL0EJGuIhIFjAc+CVaspuFJ6tKc5Q+fS6MI55/4mp1H6P7gf/g5w+dxZMbUX8E8wzgduBo4u9hts2NFZLKITHbrXAqsFpEVwHPAeHUUArcCc3A6y2eq6pogxmoaoGYxUSWGECnyKVNX5dmZhjHVFLQ+DFX9DqjwfkZVfQF4oZx1s4HZQQjNhJBJZ3bjox/TKHQHJtx+xMeyHYdI6tLC48iMqX/sSW/ToPVMiOfru5Np3zTaX/bwx2tsZFtjqsEShmnwOreM4dVrjt01uGbnEW6b8aNdmjKmiixhmJDQt31ThiY29y9/vnIX976/0sOIjKl/LGGYkPHkJaeUWH5/WSo/7cv0KBpj6h9LGCZkdGsdx5/PaOxfVoVp3231MCJj6hdLGCakdIgLY8ak4f7lWUtTbVRbYwJkCcOEnGFdW/jn0Mgr9PHWAps7w5hAWMIwIUdEuPHMbv7ltxZuJ7egyMOIjKkfLGGYkDS2X1s6NHP6Mw5m5TNrWarHERlT91nCMCEpIjyM687o6l+e+u1We5jPmEpYwjAh6/KhnYiPdkbH2bo/i/+t2+NxRMbUbZYwTMiKaxTBFcOOzQP+0jc/2dPfxlTAEoYJaded3pWocOe/wY87DrPgpwMeR2RM3WUJw4S0hCbRXDqko3/5hbmbPYzGmLrNEoYJeb896yTCw5yR+H/46QDLdhzyOCJj6iZLGCbkdWoRw7hiEy29+LWdZRhTFksYxgA3jzwJcaf7+mr9XtbuPOJtQMbUQZYwjAG6t4lnTN+2/uUX59lZhjGlWcIwxnXLyO7+97NX7bKhz40pxRKGMa5+HZoysldrwBn6/F9zf/I4ImPqlqAlDBHpJCJzRWStiKwRkdvLqHOliKwUkVUi8oOIDCi2bptbvlxEUoIVpzHF3Xr2sbOMj5ansW1/lofRGFO3BPMMoxC4W1X7AMOBW0SkT6k6W4GzVLU/8BgwpdT6kao6UFWHYEwtSOrSgtO7twSgyKf886tNHkdkTN0RtIShqrtUdZn7PgNYB3QoVecHVT160/tCoCPGeOyuc3v633+0PI3NezM8jMaYukNqY+wcEUkE5gP9VLXM+xVF5PdAb1W9wV3eChwCFHhFVUuffRzdbhIwCSAhISFpxowZ1YoxMzOTuLi4am1bX1mby/d0Si6r9jtzZJzaNpybB0YHO7SgsN9xaDiRNo8cOXJpwFdxVDWoLyAOWAr8uoI6I3HOQFoWK+vg/mwDrADOrOxYSUlJWl1z586t9rb1lbW5fMt3HNIu933mf63dmR7cwILEfseh4UTaDKRogN/nQb1LSkQigVnA26r6QTl1TgGmAuNU1T/ym6qmuT/3Ah8CpwYzVmOKG9CpGeecnOBffvbLjR5GY0zdEMy7pASYBqxT1WfKqdMZ+AC4WlU3FiuPFZH4o++B84DVwYrVmLIU78v479o9XDl1oU2yZEJaMM8wTgeuBs52b41dLiJjRWSyiEx26zwMtAT+Ver22QTgOxFZASwGPlfVL4IYqzHH6dO+CWP7H3v6+/vNB/jwxzQPIzLGWxHB2rGqfgdIJXVuAG4oo3wLMOD4LYypXXed24vZq3b7l5/5ciMXnNKO6MhwD6Myxhv2pLcxFejeJo4v7zzTv5x2OIc3F2zzLB5jvGQJw5hK9EiI57Fxff3LL3y9mcPZ+R5GZIw3LGEYE4Dxp3ama6tYAI7kFvKveTbOlAk9ljCMCUBkeBj3jO7lX379h22kHc7xMCJjap8lDGMCdH6/tgzs1AyA/EIfT/93g8cRGVO7LGEYEyAR4Q/n9/Yvf/hjms3MZ0KKJQxjqmBYt5aM6t0GcObMeOyztUeHsjGmwbOEYUwV3X9+b8LDnEeMFmw5wJw1uyvZwpiGwRKGMVXUIyGeq4d38S8//vk6cguKPIzImNphCcOYarjznJ40j4kEIPVQDtO+2+pxRMYEnyUMY6qhaUwkd5137DbbF+duZnd6rocRGRN8ljCMqaYJQzvRu208ANn5Rfzti/UeR2RMcFnCMKaaIsLDePjCY9PUf/BjGst2HKpgC2PqN0sYxpyA07q3YkzfY0Og//HjNRTZnBmmgbKEYcwJemDsyURFOP+VVqWl886i7R5HZExwWMIw5gR1bhnDLcnd/ct/m7OBvRnWAW4aHksYxtSAm87q5h/NNiO3kL98vs7jiIypeZYwjKkB0ZHh/KnYnBkfLd/JD5v3exiRMTXPEoYxNeQXPVpz4Snt/MsPfbyavEJ7Atw0HJYwjKlB/3dhH+IaRQCwZV8Wr87f4nFExtScoCUMEekkInNFZK2IrBGR28uoIyLynIhsFpGVIjK42LprRGST+7omWHEaU5MSmkRz93k9/cvPf72ZrfuzPIzImJoTzDOMQuBuVe0DDAduEZE+peqcD/RwX5OAlwBEpAXwR2AYcCrwRxFpHsRYjakxVw/vQr8OTQDIK/Rx/6yV+OzZDNMABC1hqOouVV3mvs8A1gEdSlUbB7ypjoVAMxFpB4wGvlTVg6p6CPgSGBOsWI2pSRHhYfz116f4h0BftPUg05fs8DgqY05crfRhiEgiMAhYVGpVB+DnYsupbll55cbUC/06NOWmM7v5l5+YvZ5d6TYHuKnfIoJ9ABGJA2YBd6hqjc9nKSKTcC5nkZCQwLx586q1n8zMzGpvW19Zm4NrYKTSNkbYna1k5hUyeeo33DG4ESJSK8cH+x2Hitpqc1AThohE4iSLt1X1gzKqpAGdii13dMvSgORS5fPKOoaqTgGmAAwZMkSTk5PLqlapefPmUd1t6ytrc/A1P+kg/+/lBQCs2FfEkeY9GTew9k6W7XccGmqrzcG8S0qAacA6VX2mnGqfABPdu6WGA+mquguYA5wnIs3dzu7z3DJj6pWhiS2YOOLY7HyPfrqW/Zl5HkZkTPUFsw/jdOBq4GwRWe6+xorIZBGZ7NaZDWwBNgOvAjcDqOpB4DFgifv6k1tmTL1z75jedGjWGICDWfk8+OEqVO2uKVP/BO2SlKp+B1R4sVad/zW3lLPuNeC1IIRmTK2KaxTBE7/uz8TXFgMwZ80eZi1L49Kkjh5HZkzV2JPextSCM3u25urhxS5NfbKG1EPZHkZkTNVZwjCmlvxhbG8SW8YAkJFXyD3v2QN9pn6xhGFMLYmJiuDpywbiPs/Hgi0H+PcP2zyNyZiqsIRhTC1K6tKc3yaf5F9+8ov1bNqT4WFExgTOEoYxtez2UT05uZ0z1lR+oY+7Zq4gv9DncVTGVM4ShjG1LCoijH9cPpCo8GPzgD/93w0eR2VM5SxhGOOBXm3juXdML//yK/O38M3GfR5GZEzlLGEY45HrTu/KWT1b+5fvnrmcvRm5HkZkTMUCShgi8mQgZcaYwIWFCU9fNoDW8Y0A2J+Zz13vrrBbbU2dFegZxrlllJ1fk4EYE4paxTXi2csGcnQA2+827+cVm9bV1FEVJgwR+a2IrAJ6u1OoHn1tBVbWTojGNGxn9GjF5LOO3Wr79H838OOOQx5GZEzZKjvDWAj8EvjY/Xn0laSqVwU5NmNCxl3n9mRQ52YAFPqUW9/5kcPZ+R5HZUxJlSWMqaq6DWijqtuLvWzkWGNqUGR4GM+NH0R8tDMeaNrhHO6aaf0Zpm6pLGGEicgDQE8Ruav0qzYCNCZUdGoRw1OXDvAvf71+Ly9985OHERlTUmUJYzxQhDMMenwZL2NMDRrTry2Tis0F/tScDVw1dRHpOQUeRmWMo8L5MFR1A/CkiKxU1f/UUkzGhLR7Rvdi+Y7DLN7mXPn9bvN+Jk5bxLs3jSA6Mtzj6Ewoq+wuqaMd233skpQxtSMyPIznrxhEq7hG/rIVqek8+OFqm6nPeKqyS1Kx7s84jr8cFRfEuIwJaQlNopl+47ASZbOWpfLv77d5E5AxVH5J6hX356Ol14nIHcEKyhgDPRLi2frEWO55fyXvL00F4M+z19G7bTyndW/lcXQmFJ3IWFJ2ScqYIBMRHv9VPwZ0cp7RKPIpt7yzjJ8P2vSupvadSMKQGovCGFOu6MhwXrkqyT/m1KHsAm58M4Xs/EKPIzOh5kQSRoW9byLymojsFZHV5ay/R0SWu6/VIlIkIi3cddtEZJW7LuUEYjSmQWjbNJqXrxpMZLjzd9r63Rnc+e5ye7DP1KrK7pLKEJEjZbwygPaV7Pt1YEx5K1X1KVUdqKoDgT8A35R6gnyku35IgG0xpkFL6tKCx8b18y/PWbOHJ79Y72FEJtRUmDBUNV5Vm5TxilfVyjrM5wOBDiEyAZgeYF1jQtb4Uztz/Rld/cuvzN/C9MU7PIzIhBIJ5n3dIpIIfKaq/SqoEwOkAt2PnmG4o+Eewrns9YqqTqlg+0nAJICEhISkGTNmVCvWzMxM4uJC605ha3P95FPl+R/z+HFvEQBhAnclRdOv1fEP9TWE9laVtblqRo4cuTTQKzl1IWFcDlylqr8sVtZBVdNEpA3wJXCbe8ZSoSFDhmhKSvW6PObNm0dycnK1tq2vrM31V1ZeIZe9soA1O48AEN8ogg9uPo0eCSVH7Gko7a0Ka3PViEjACaMuTNE6nlKXo1Q1zf25F/gQONWDuIyps2IbRTDtmqG0bRINQEZeIb95fYlN8WqCytOEISJNgbNw5ts4WhYrIvFH3wPnAWXeaWVMKGvbNJpp1w4hJsq5FJV6KIdrX1tCRq4NVGiCI2gJQ0SmAwuAXiKSKiLXi8hkEZlcrNrFwH9VNatYWQLwnYisABYDn6vqF8GK05j6rG/7prxwxSDCw5zbbdfuOsJNby0lr7DI48hMQ1ThnU4nQlUnBFDndZzbb4uXbQEGlFXfGHO8s3sn8MTF/bl3ljNr8g8/HeCumSt4fvwgjyMzDU3QEoYxpvZcNrQT+zLzeGrOBgA+X7mLVrFRJDexB/tMzakLnd7GmBpwc/JJXDOii3/5jQXb+WyL9WeYmmMJw5gGQkR4+Jd9uaB/O3/ZrE0F9mCfqTGWMIxpQMLDhGcuH8BpJ7X0lz3w4So++jHNw6hMQ2EJw5gGplFEOK9cnUS/Dk0AUIW731vBF6t3exyZqe8sYRjTAMVHR/LmdcPoEOfcblvkU26bvoy5G/Z6HJmpzyxhGNNAtYiN4p6h0XRt5cy0XFCkTH5rKT/8tN/jyEx9ZQnDmAasWaMw3r5hGB2aNQYgr9DHDW+ksHR7oANJG3OMJQxjGrj2zRoz/cbhJDRxZuzLzi/imteWsHT7IY8jM/WNJQxjQkDnljG8fcNwWsZGAZCZV8jEaYtYss3ONEzg7EnvvAzY/gMtDqyCjXleR1OrrM1VIGHQbiDEta75oGpJ9zZxvHPjcK54dSEHsvLJyi/imtcW8+9rhzKsW8vKd2BCniWMwz/DO5dxCsAqr4OpXdbmaujzK4hrA+pz7ldFy/jpc2e8L29dedu566FkWaN4GPIbaNqp2D7c/Wix5aPrwiKgZXcIO35CpV5t45kxaTgTXl3E/sw8svOLuPbfS5h27RBOO6nVCXwwJhRYwjCmKtZ+5M1xV82sWv3oZjD2KaLyIsHng7BjV597JDhJ44pXF7I3I4+cgiKue30JUycO5YweljRM+SxhRMVCj/M4cOAALVuG1mm5tTlAGbtgdz07Fcs9DB/cyGkAm3vD5W9D044Q6Uy41L1NHO/eNIIJUxay+0guuQU+rn9jCf+6cjCjTk7wNHRTd1nCaN4FrnyPVSE4raO1uQp2r4Zt3wHq9GcgIM5Dcc5Pd7nEurJ+hpUqo4J1AmkpsO5TyM8+tv8S9cOKlYdB5l7IKdWRvW89vJDkbDP8t3DenyEsjK6tYnn3puFMmLKQnem55BX6uOmtpTx92QDGDexQ9c/INHiWMIwJRNt+zqu2nXwhnPNI1bbZ/D/4+nE48BPkHSm2QmHhv5zXgCugywi69LuUd28awZVTF7HjYDaFPuWOd5eTnlPAxBGJNdcO0yDYbbXGNDTdz4FJ8+APP7Oq3wMQ1/b4OivegU9ugzcvolPjfN6fPIJeCfGA04/+8MdreP6rTajafBrmGEsYxjRgB1oNg99vgIcPwpDrj6+QugT+OYA2y59n5oRODOrczL/q6S838vjn6/D5LGkYhyUMY0JBWDhc+AzcvhJGPQw9Rh9bl3sYvn6cpi8PZGaXTzire3P/qmnfbeXeWSspKPJ5ELSpayxhGBNKmneBX9wNV86Ecf9yO+mPiVzyMv+O+hsX94n3l72/NJXr30ghM6+wtqM1dUzQEoaIvCYie0VkdTnrk0UkXUSWu6+Hi60bIyIbRGSziNwfrBiNCWmDroRrZ8Pga6BxC39x2Ja5PLvll8xt8VduDf+Q08NWsXzjNi57eQF7juR6GLDxWjDPMF4HxlRS51tVHei+/gQgIuHAi8D5QB9ggoj0CWKcxoSuLiPgoufgnp9gxK0lVnXNXsnvI9/j7agnWNzoZk7a8wUXv/g9G/dkeBSs8VrQEoaqzgeqM7LZqcBmVd2iqvnADGBcjQZnjCkpLAzOexzG/LXM1dFSwPNRL/BszgP8/qX3+GGzzakRiiSYt82JSCLwmaoedwO7iCQDs4BUYCfwe1VdIyKXAmNU9Qa33tXAMFW9tfQ+3PWTgEkACQkJSTNmzKhWrJmZmcTFxVVr2/rK2tzwVae94iuk+aHltN63gJjsVJoeWV9ifaGG8XzRrynoNY5TOzWpyXBrRKj9juHE2jxy5MilqjokkLpePri3DOiiqpkiMhb4COhR1Z2o6hRgCsCQIUO0uk8uzwvBp56tzQ1f9dt7zrG3OYfgg0mw6b8ARIiPOyPeZ92mxXzT9HkmXXgWYWFSI/HWhFD7HUPttdmzu6RU9YiqZrrvZwORItIKSAM6Fava0S0zxnihcXO48j2Y+DFFMW38xSeH7WDysnHM+edkMjOPVLAD01B4ljBEpK2IM5iOiJzqxnIAWAL0EJGuIhIFjAc+8SpOY4yrWzLhtywkf/AN+Ip9dZyfPoPdz57Fzm0bvIvN1Ipg3lY7HVgA9BKRVBG5XkQmi8hkt8qlwGoRWQE8B4xXRyFwKzAHWAfMVNU1wYrTGFMFsS2JuuhpdOInZEc09Rd3L9pC49fPYd3C/3gYnAm2oPVhqOqESta/ALxQzrrZwOxgxGWMOXHh3X5BzP2bWTvjAbpveo0oKaI5R4j7z5Wk/PwHhlz6+2Oj8ZoGw570NsZUT0QUfa76O1vGTucAztlGpBQxZM3jLH7+anJzczwO0NQ0SxjGmBPSe9ho8q/7ik3hJ/nLTj34KVv/nszOn3/yMDJT0yxhGGNOWLvOPehw1zekNDl2O+7JheuJnpbMqvkeTWtrapwlDGNMjYiJjSfpjvdY0uNOitTpv2jBEfp+dS2LXr8PX1GRxxGaE2UJwxhTYyQsjKFXPsKm899hP87cGmGiDNv2Mmv/Ppoj+3d7HKE5EZYwjDE1rvfwsTD5W1ZHneIv65ezhOwXTmddylceRmZOhCUMY0xQtGrbmV73fMX37Sb6y9qyn5M+/X98/+Yf7RJVPWQJwxgTNJGRUZx+0/MsO/1l0okFIEqKOH3LP1j9t3PZv3uHxxGaqrCEYYwJusHnTiDnN3PZFNHTX3ZK3lLCXj6d1XNnehiZqQpLGMaYWtG2Sy+63vsti9pfg6/YXVT9vrmRJS/dSF5ulscRmsp4Oby5MSbERERFM2zSc6z+dhQJX91Oaw4BMHTPTLb8bTF66TRO6lP51AwFBQWkpqaSm3v8lLFNmzZl3bp1NR57XRZIm6Ojo+na48k1AAAZdElEQVTYsSORkZHVPo4lDGNMrev3i3Ec6DWUZf++jsE5CwDo5ttG3rtj+KH7bQyb8CDhEeV/PaWmphIfH09iYiJSasyqjIwM4uPjgxp/XVNZm1WVAwcOkJqaSteuXat9HLskZYzxRMs27Rl0z2wWnfwAeer81dtICjjtp2fY8ORZ7Nxa/l/Mubm5tGzZ8rhkYcomIrRs2bLMM7KqsIRhjPGMhIUx7PL72DvhC34qNhZVn4LVNHv9LJbO+jvq85W9rSWLKqmJz8sShjHGc516D6HzfT+wsNMNFKrztRQjeSSteozVfzuHXTaIYZ1gCcMYUydERkUz/Pqn2TzuI7ZJR395/9ylxE49gwWznsNXVPbZhhd2797N+PHjOemkk0hKSmLs2LFs3LiRbdu20a9fv6AcMy8vj8svv5zu3bszbNgwtm3bFpTjlMcShjGmTuk9+Cza3LOI79tc4b/9tolkM2LV/7HmybNJ27LW4widTuSLL76Y5ORkfvrpJ5YuXcoTTzzBnj17gnrcadOm0bx5czZv3sydd97JfffdF9TjlWZ3SRlj6pyYmDhOv/kl1i/6FU2++B3t1Rm0sH/+j+S8cRb7fvkJqj5Ewki8//OgxbHtrxeUWT537lwiIyOZPHmyv2zAgAHONsX+6t+2bRtXX301WVnOMyYvvPACp512Grt27eLyyy/nyJEjFBYW8tJLL3Haaadx/fXXk5KSgohw3XXXceedd5Y47scff8wjjzwCwKWXXsqtt96KqtZgiytmCcMYU2f1Hjaa3H5LWPzWfSTtmk64KI0ln0YFR8jbtR5p1tmTuFavXk1SUlKl9dq0acOXX35JdHQ0mzZtYsKECaSkpPDOO+8wevRoHnzwQYqKisjOzmb58uWkpaWxevVqAA4fPnzc/tLS0ujUqRMAERERNG3alAMHDtCoUaOabWA5LGEYY+q06NgmnDr5JTb9OJ7wT39HN982p5w89NAmb4OrREFBAbfeeivLly8nPDycjRs3AjB06FCuu+46CgoK+NWvfsXAgQPp1q0bW7Zs4bbbbuOCCy7gvPPO8zj64wUtYYjIa8CFwF5VPa4HSESuBO4DBMgAfquqK9x129yyIqBQVSt/9NMY06D1GHQW+X0W8/07j9Acp29DBLb9rj35RFAY14GYJi1q5cG9vn378v7771da79lnnyUhIYEVK1bg8/mIjo4G4Mwzz2T+/Pl8/vnnXHvttdx1111MnDiRFStWMGfOHF5++WVmzpzJa6+9VmJ/HTp04Oeff6Zjx44UFhaSnp5Oy5YtyczMDEo7Swtmp/frwJgK1m8FzlLV/sBjwJRS60eq6kBLFsaYo6IaNeL03zxBUWwCOdL4WDmFxGRuJ3v3RnyFBUGP4+yzzyYvL48pU459ba1cuZJvv/22RL309HTatWtHWFgYb731FkXukO7bt28nISGBG2+8kRtuuIFly5axf/9+fD4fl1xyCY8//jjLli077rgXXXQRb7zxBgDvv/8+Z599dq0+jxK0MwxVnS8iiRWs/6HY4kKgY3l1jTGmuIjIKKLb9iLr8F4a5ewmAud22xhfFr7sbLKKsmncoj1hYcH5m1hE+PDDD7njjjt48skniY6OJjExkX/84x8l6t18881ccsklvPnmm4wZM4bYWGeI93nz5vHUU08RGRlJXFwcb775JmlpafzmN7/B5z6o+MQTTxx33Ouvv56rr76a7t2706JFC2bMmBGU9pWnrvRhXA/8p9iyAv8VEQVeUdXSZx/GmBAnIsQ2T6AwrjlZB1OJLUoHnClhY/P3kb/7MIXxHYiJbx6U47dv356ZM8semv1ox3WPHj1YuXKlv/zJJ58E4JprruGaa645bruyziqKi46O5r333qtuyCdMgnlLlnuG8VlZfRjF6owE/gWcoaoH3LIOqpomIm2AL4HbVHV+OdtPAiYBJCQkJFU342ZmZhIXF1etbesra3PD11Db27RpU7p3716irCg/h+i8fUSTX6I8kxgKG7ciPCKqNkOsVUVFRYSHh1dab/PmzaSnp5coGzly5NJAL/17eoYhIqcAU4HzjyYLAFVNc3/uFZEPgVOBMhOGe/YxBWDIkCGanJxcrVjmzZtHdbetr6zNDV9Dbe+6devK6NiOx+drxeF9acQXHiRcnEs7cWTjy/6Z7KjmNG7RgfDwunJhpeYE2tEfHR3NoEGDqn0cz570FpHOwAfA1aq6sVh5rIjEH30PnAes9iZKY0x9EhYWRnhMM7RNb7LCmxwrFyWu4CC6Zy1Zh/bU6sNuDUkwb6udDiQDrUQkFfgjEAmgqi8DDwMtgX+5vfxHb59NAD50yyKAd1T1i2DFaYxpeCIiGxGRcBK5WUfgSBrR6gzrHUERETk7ycvZjy++A43jm3kcaf0SzLukJlSy/gbghjLKtwADghWXMSZ0RMc2QWPiyUrfT1T2biIpBKAR+ZCxlazMWCKadaBR41iPI60fbPBBY0yDJiLENmtNWEIfMiNbUaTHnluI1SyiDm4ka88WCvLzPIyyfrCEYYwJCeHh4cS17oSv9ckl+jdEILYonfB968jat4OiAB/882J48/nz5zN48GAiIiICetK8plnCMMaElMioRsQmnEResx5kS4y/PEyU2IIDsHctmQfS8LlPZZfFq+HNO3fuzOuvv84VV1wR1OOUp+HdX2aMCS2PNC2xGOgoUuWN7xoOHH1yJfN3G4hp1ua4J8a9Gt48MTERIGhPsFfGEoYxxpQjLncX+bv3UdC4DTFNWyPuF7VXw5t7zRKGMcZUIIpConJ2kp+zj4KYNsQ0bRXwtja8uTHG1CWPlBzqoiaGN/cVFZFzeDeN8vb7BzaMooCo7DTys/fSrUt7T4Y395p1ehtjTClh4eHEtuyAJPQlM6o1hcW+KqMo4IIh3cjNPMxzz/7dP7psbQxv7jVLGMYYU47w8AjiWnWENn3JjGxNoTpfmSLCR1P/zndf/5fu3bpwcu9e3H///bRt27bE9jfffDNvvPEGAwYMYP369SWGNx8wYACDBg3i3Xff5fbbbyctLY3k5GQGDhzIVVddVebw5kuWLKFjx46899573HTTTfTt2zf4H0IxdknKGGMqERERQVzrjhQWJpB5aDeNCw7Svm1rZr7ypL9OAeHkN4qlcbM2QRvefOjQoaSmptZEk6rFzjCMMSZAERGRxLXuBG36kBXVqsSlqkiKiM3bg2/PGjL3/0xhQfBn/qttdoZhjDFVFB4RSWyrThQVtSPz8F4a5e0nEqd/IgIfcfn7Kdp7gMzIZkQ1bUtUo2iPI64ZljCMMaaawsMjiGvZHp8vgczD+4jK3UeUO8BhuChxhYfw7T9EVng84U0SiI45sbu3vGYJwxhjTlBYWDhxLdri87Uh68gBIrL3OiPiAmECsb4MOJxBTnpjNLY1jeNb4E7hUK9YwjDGmBoSFhZGbLPWaNNW5GQcgqy9NNYc//rGmgOZO8jP3EV+dCsaN21Vr2YArD+RGmNMPSEiNG7SApq0IDc7g6Ije4gpyuDoSUUUBUTl7qIwdw+Zkc2JatKmXvRz2F1SxhhTDeHh4QwcOJABAwYwePBgfvjhhzLrRcfEE9u2OwWtTiYzogVFeuxrNwIfcQUHiNy/jqzdm8jJOOyfPvbaa6/1P01+ww03sHbt2uA3qhJ2hmGMMdXQuHFjli9fDsCcOXP4wx/+wDfffFNu/ahG0US16UJRYQcy0/cRlbff30EuArG+TMjIJC8jisLoFqj6/NtOnTo1uI0JkJ1hGGPqPRHxv5o0aVJiecqUKf56U6ZMKbGu9Ku6jhw5QvPmzQHIzMxk1KhRDB48mP79+/Pxxx8DkJWVxQUXXMDgpCSGn3UuH8xfQ3ZcZ75ftYWzLrmBpDFXMPqKmzm4J43Y3N2Qc5jc9L3k5WSRnJxMSkoKAHFxcTz44IMMGDCA4cOH++fg2LdvH5dccglDhw5l6NChfP/999VuT3nsDMMYY6ohJyeHgQMHkpuby65du/j6668BiI6O5sMPP6RJkybs37+f4cOHc9FFF/HFF1/Qvn17Pv/8c8AZZyqycQz3PPo0M9+dRbNo5dNZM3nwyRd57ZlHECC6KItGhzbiy88mJ+MQPl8RWVlZDB8+nD//+c/ce++9vPrqq9x+++3cfvvt3HnnnZxxxhns2LGD0aNHs27duhptsyUMY0y9d/S6P1Q8Wu2kSZOYNGlSjRyz+CWpBQsWMHHiRFavXo2q8sADDzB//nzCwsJIS0tjz5499O/fn7vvvpv77ruPCy+8kF/84hesXr2a1atXc+FFvwKgsLCQhNYtySOyxLHC8NE4dy++3auJiorivHPOBiApKYkvv/wSgP/9738l+jmOHDlCZmYmcXFx1JSgJgwReQ24ENirqsdNcivOOeA/gbFANnCtqi5z110DPORWfVxV3whmrMYYU10jRoxg//797Nu3j9mzZ7Nv3z6WLl1KZGQkiYmJ5Obm0rNnT5YtW8bs2bN56KGHGDVqFBdffDF9+/ZlwYIFJfanqhRFxpEXFk2xXEgEPiIjwml0aCM5hxtTkJtFgTsEic/nY+HChf4h1IMh2H0YrwNjKlh/PtDDfU0CXgIQkRbAH4FhwKnAH0WkeVAjNcaYalq/fj1FRUW0bNmS9PR02rRpQ2RkJHPnzmX79u0A7Ny5k5iYGK666iruueceli1bRq9evdi3b58/YRQUFLBmzRpEhPDIKBo1a0dh65MpkkgKCS9xzMaaQ3TeATTnEJq1j1Gjzub555/3rz969lOTgpowVHU+cLCCKuOAN9WxEGgmIu2A0cCXqnpQVQ8BX1Jx4jlhI0eOLLcjrLqdZklJSeXWK35avHTp0gr3uXTpUn/dSZMmlVuv9JSRFe0zVNt09PfckNpU0e+prH/X9b1NIsL27dtJSUnxv4rPo52bm1tiXenX0fm1wZlzu7x6pW9jLb0+JyeHnj170rNnTy655BLeeOMNwsPDGT16NN988w3du3fnmWeeITExkZUrVzJr1iz69+/PwIEDefTRR3nooYfYvHkzjz76KLfccgs9e/akV69evPPOO6SkpJCZmQlAZFQ0hEWwcV8+KTuL8Cmk7CwiZWcRPx3ycTDbRxNfOv/8v1tIWbKEU045hT59+vDyyy9T07zuw+gA/FxsOdUtK6/8OCIyCefshISEBObNm1fjQW7YsMG/3w0bNlRYt/jxMzIyyq23c+fOgPeZkpLi39fOnTvLrZeRkRFw+zds2ED79u2ZN29eg2pTQ/w9WZuOHaM8BQUF/vVHJzQqT3Z2tr9OQQUjyhYVFVV4zEWLFvnfJyQk0KxZMzIyMoiNjS1zprz27dszYsQIevXq5S/btm0bvXr1KpFEj3riiSdo27YtGRkZfPDBB/4zlfnz5/vrjBo1ilGjRgHQuFUnpk6bVmIfpePPzc09se9IVQ3qC0gEVpez7jPgjGLLXwFDgN8DDxUr/z/g95UdKykpSatr7ty51d62vrI2N3wNtb1r164td92RI0dqMRJv+Xw+zTpyUDN2btDcnKxK65f1uQEpGuD3udfPYaQBnYotd3TLyis3xhjjEhFi4pujce1oFB0T9ON5nTA+ASaKYziQrqq7gDnAeSLSXJzO7vPcMmOMAUreSmsqVxOfV7Bvq50OJAOtRCQV586nSABVfRmYjXNL7Wac22p/4647KCKPAUvcXf1JVSvqPDfGhJDo6GgOHDhAy5Ytj+scN8dTVQ4cOHDCt9wGNWGo6oRK1itwSznrXgOO7zkyxoS8jh07kpqayr59+45bl5ubG9RnEeqiQNocHR1Nx44dT+g4Xt8lZYwxVRYZGUnXrl3LXDdv3jwGDRpUyxF5q7ba7HUfhjHGmHrCEoYxxpiAWMIwxhgTEGlIt6aJyD5gezU3bwXsr8Fw6gNrc8MXau0Fa3NVdVHV1oFUbFAJ40SISIqqDvE6jtpkbW74Qq29YG0OJrskZYwxJiCWMIwxxgTEEsYxxw8X2fBZmxu+UGsvWJuDxvowjDHGBMTOMIwxxgTEEoYxxpiAhFzCEJExIrJBRDaLyP1lrG8kIu+66xeJSGLtR1lzAmjvXSKyVkRWishXItLFizhrUmVtLlbvEhFREan3t2AG0mYRucz9Xa8RkXdqO8aaFsC/7c4iMldEfnT/fY/1Is6aIiKvicheEVldznoRkefcz2OliAyu8SACnWmpIbyAcOAnoBsQBawA+pSqczPwsvt+PPCu13EHub0jgRj3/W/rc3sDbbNbLx6YDywEhngddy38nnsAPwLN3eU2XsddC22eAvzWfd8H2OZ13CfY5jOBwZQ/g+lY4D+AAMOBRTUdQ6idYZwKbFbVLaqaD8wAxpWqMw54w33/PjBK6u+A+5W2V1Xnqmq2u7gQZ3bD+iyQ3zHAY8CTQG5tBhckgbT5RuBFVT0EoKp7aznGmhZImxVo4r5vCpQ/KXk9oKrzgYrmBRoHvKmOhUAzEWlXkzGEWsLoAPxcbDnVLSuzjqoWAulAy1qJruYF0t7irsf5C6U+q7TN7ql6J1X9vDYDC6JAfs89gZ4i8r2ILBSRMbUWXXAE0uZHgKvcydtmA7fVTmieqer/9yqz+TAMACJyFTAEOMvrWIJJRMKAZ4BrPQ6ltkXgXJZKxjmLnC8i/VX1sKdRBdcE4HVVfVpERgBviUg/VfV5HVh9FWpnGGlAp2LLHd2yMuuISATOqeyBWomu5gXSXkTkHOBB4CJVzaul2IKlsjbHA/2AeSKyDeda7yf1vOM7kN9zKvCJqhao6lZgI04Cqa8CafP1wEwAVV0AROMM0tdQBfT//USEWsJYAvQQka4iEoXTqf1JqTqfANe47y8Fvla3R6keqrS9IjIIeAUnWdT369pQSZtVNV1VW6lqoqom4vTbXKSqKd6EWyMC+Xf9Ec7ZBSLSCucS1ZbaDLKGBdLmHcAoABE5GSdhHD+na8PxCTDRvVtqOJCuqrtq8gAhdUlKVQtF5FZgDs5dFq+p6hoR+ROQoqqfANNwTl0343Qwjfcu4hMTYHufAuKA99y+/R2qepFnQZ+gANvcoATY5jnAeSKyFigC7lHV+nrmHGib7wZeFZE7cTrAr63Hf/whItNxkn4rt1/mj0AkgKq+jNNPMxbYDGQDv6nxGOrx52eMMaYWhdolKWOMMdVkCcMYY0xALGEYY4wJiCUMY4wxAbGEYYwxJiCWMEydIyJFIrJcRFaLyHsiEuNRHHd4dWz3+E+5I8s+5WEMieWNjmpCjyUMUxflqOpAVe0H5AOTA91QRMJrMI47AM8SBjAJOEVV7/EwBmP8LGGYuu5boDs4412JyGL37OOVo8lBRDJF5GkRWQGMEJGhIvKDiKxw68eLSLj7F/sSd66Am9xtk0Vknoi8LyLrReRt90nZ3wHtgbkiMtet+5KIpLh/9T96NEARGetuu9Sdj+AztzzWncNgsTsnw3Gj5rrHeso9m1olIpe75Z/gPFC59GhZsW3Ocj+D5e5+40UkTpz5TJa5+xnn1k10Y3tdRDa67TtHnEEIN4nIqW69R0TkLRFZ4JbfWEasZX6GJoR4Pca7vexV+gVkuj8jgI9x5uk4GfgUiHTX/QuY6L5X4DL3fRTOkBdD3eUm7n4mAQ+5ZY2AFKArzpOz6Tjj7oQBC4Az3HrbgFbF4mrh/gwH5gGn4Aw38TPQ1V03HfjMff8X4Cr3fTOc8ZtiS7X1EuBLd58JOMNZtCv+OZTx+XwKnO6+j3PbFwE0ccta4TztK0AiUAj0d9u3FHjNXTcO+Mjd5hGcOSUau9v/jJMwE3HnXyjvM/T634u9au9lZximLmosIstxvpB24AzXMgpIApa460bhTJ4DzlAXs9z3vYBdqroEQFWPqDNM/Xk44+wsBxbhDFl/dPC9xaqaqs4opstxviTLcpmILMOZiKgvzqQ8vYEt6gzoB07COOo84H73mPNwkkvnUvs8A5iuqkWqugf4BhhayefzPfCMexbUzG2fAH8RkZXA/3CGtU5w629V1VVu+9YAX6mqAqtKtfVjVc1R1f3AXJw5J4qr6DM0ISCkxpIy9UaOqg4sXiDOQFdvqOofyqifq6pFlexTgNtUdU6p/SYDxUfoLaKM/xci0hX4Pc6ZyyEReR0nAVR2zEtUdUMl9apEVf8qIp/jjBv0vYiMxhl1tzWQpKoF4ozEezS+4u3zFVv2UbKtpccJKr1c5mdoQoedYZj64ivgUhFpAyAiLaTs+cc3AO1EZKhbL16cYernAL8VkUi3vKeIxFZyzAyc4dDBubSVBaSLSAJwfrHjdZNjc78X72+YA9zmJrujIwOX9i1wuds/0BpnGs7FFQUlIie5ZwxP4oza2htnGP69brIYCVRnbvZxIhItIi1xLtUtKbW+Op+haUDsDMPUC6q6VkQeAv4rziRIBcAtwPZS9fLdTuLnRaQxkAOcA0zFufyyzP0C3wf8qpLDTgG+EJGdqjpSRH4E1uNc3//ePV6OiNzs1sui5JfsY8A/gJVuzFuBC0sd40NgBE7/gQL3quruSuK6w00KRy8x/QcnsX0qIqtwLuWtr2QfZVmJcymqFfCYqu4slgihep+haUBstFpjTpCIxKlqpvsl+iKwSVWf9TquqhCRR3A62f/udSym7rJLUsacuBvdjuA1OJeGXvE4HmOCws4wjDHGBMTOMIwxxgTEEoYxxpiAWMIwxhgTEEsYxhhjAmIJwxhjTED+f50BnQ594eFGAAAAAElFTkSuQmCC\n", 39 | "text/plain": [ 40 | "" 41 | ] 42 | }, 43 | "metadata": {}, 44 | "output_type": "display_data" 45 | } 46 | ], 47 | "source": [ 48 | "X, y = load_data(return_X_y=True)\n", 49 | "lr = LogisticRegression()\n", 50 | "lr.fit(X, y)\n", 51 | "probas = lr.predict_proba(X)\n", 52 | "skplt.metrics.plot_lift_curve(y_true=y, y_probas=probas)\n", 53 | "plt.show()" 54 | ] 55 | } 56 | ], 57 | "metadata": { 58 | "kernelspec": { 59 | "display_name": "Python 3", 60 | "language": "python", 61 | "name": "python3" 62 | }, 63 | "language_info": { 64 | "codemirror_mode": { 65 | "name": "ipython", 66 | "version": 3 67 | }, 68 | "file_extension": ".py", 69 | "mimetype": "text/x-python", 70 | "name": "python", 71 | "nbconvert_exporter": "python", 72 | "pygments_lexer": "ipython3", 73 | "version": "3.5.2" 74 | } 75 | }, 76 | "nbformat": 4, 77 | "nbformat_minor": 2 78 | } 79 | -------------------------------------------------------------------------------- /examples/p_r_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/examples/p_r_curves.png -------------------------------------------------------------------------------- /examples/plot_calibration_curve.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example showing the plot_calibration_curve method 3 | used by a scikit-learn classifier 4 | """ 5 | from sklearn.ensemble import RandomForestClassifier 6 | from sklearn.naive_bayes import GaussianNB 7 | from sklearn.linear_model import LogisticRegression 8 | from sklearn.svm import LinearSVC 9 | from sklearn.datasets import make_classification 10 | import matplotlib.pyplot as plt 11 | import scikitplot as skplt 12 | 13 | X, y = make_classification(n_samples=100000, n_features=20, 14 | n_informative=2, n_redundant=2, 15 | random_state=20) 16 | 17 | X_train, y_train, X_test, y_test = X[:1000], y[:1000], X[1000:], y[1000:] 18 | 19 | rf_probas = RandomForestClassifier().fit(X_train, y_train).predict_proba(X_test) 20 | lr_probas = LogisticRegression().fit(X_train, y_train).predict_proba(X_test) 21 | nb_probas = GaussianNB().fit(X_train, y_train).predict_proba(X_test) 22 | sv_scores = LinearSVC().fit(X_train, y_train).decision_function(X_test) 23 | 24 | probas_list = [rf_probas, lr_probas, nb_probas, sv_scores] 25 | clf_names=['Random Forest', 26 | 'Logistic Regression', 27 | 'Gaussian Naive Bayes', 28 | 'Support Vector Machine'] 29 | 30 | skplt.metrics.plot_calibration_curve(y_test, 31 | probas_list=probas_list, 32 | clf_names=clf_names, 33 | n_bins=10) 34 | plt.show() 35 | -------------------------------------------------------------------------------- /examples/plot_confusion_matrix.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example showing the plot_confusion_matrix method 3 | used by a scikit-learn classifier 4 | """ 5 | from sklearn.ensemble import RandomForestClassifier 6 | from sklearn.datasets import load_digits as load_data 7 | import matplotlib.pyplot as plt 8 | import scikitplot as skplt 9 | 10 | X, y = load_data(return_X_y=True) 11 | rf = RandomForestClassifier() 12 | rf.fit(X, y) 13 | preds = rf.predict(X) 14 | skplt.metrics.plot_confusion_matrix(y_true=y, y_pred=preds) 15 | plt.show() 16 | -------------------------------------------------------------------------------- /examples/plot_cumulative_gain.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example showing the plot_cumulative_gain method used 3 | by a scikit-learn classifier 4 | """ 5 | from __future__ import absolute_import 6 | import matplotlib.pyplot as plt 7 | from sklearn.linear_model import LogisticRegression 8 | from sklearn.datasets import load_breast_cancer as load_data 9 | import scikitplot as skplt 10 | 11 | 12 | X, y = load_data(return_X_y=True) 13 | lr = LogisticRegression() 14 | lr.fit(X, y) 15 | probas = lr.predict_proba(X) 16 | skplt.metrics.plot_cumulative_gain(y_true=y, y_probas=probas) 17 | plt.show() 18 | -------------------------------------------------------------------------------- /examples/plot_elbow_curve.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example showing the plot_silhouette 3 | method used by a scikit-learn clusterer 4 | """ 5 | from __future__ import absolute_import 6 | import matplotlib.pyplot as plt 7 | import scikitplot as skplt 8 | from sklearn.cluster import KMeans 9 | from sklearn.datasets import load_iris as load_data 10 | 11 | 12 | X, y = load_data(return_X_y=True) 13 | kmeans = KMeans(random_state=1) 14 | skplt.cluster.plot_elbow_curve(kmeans, X, cluster_ranges=range(1, 11)) 15 | plt.show() 16 | -------------------------------------------------------------------------------- /examples/plot_feature_importances.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example showing the plot_feature_importances 3 | method used by a scikit-learn classifier 4 | """ 5 | from sklearn.ensemble import RandomForestClassifier 6 | from sklearn.datasets import load_iris as load_data 7 | import matplotlib.pyplot as plt 8 | import scikitplot as skplt 9 | 10 | X, y = load_data(return_X_y=True) 11 | rf = RandomForestClassifier() 12 | rf.fit(X, y) 13 | skplt.estimators.plot_feature_importances(rf, 14 | feature_names=['petal length', 15 | 'petal width', 16 | 'sepal length', 17 | 'sepal width']) 18 | plt.show() 19 | -------------------------------------------------------------------------------- /examples/plot_ks_statistic.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example showing the plot_ks_statistic method used 3 | by a scikit-learn classifier 4 | """ 5 | from __future__ import absolute_import 6 | import matplotlib.pyplot as plt 7 | from sklearn.linear_model import LogisticRegression 8 | from sklearn.datasets import load_breast_cancer as load_data 9 | import scikitplot as skplt 10 | 11 | 12 | X, y = load_data(return_X_y=True) 13 | lr = LogisticRegression() 14 | lr.fit(X, y) 15 | probas = lr.predict_proba(X) 16 | skplt.metrics.plot_ks_statistic(y_true=y, y_probas=probas) 17 | plt.show() 18 | -------------------------------------------------------------------------------- /examples/plot_learning_curve.py: -------------------------------------------------------------------------------- 1 | """An example showing the plot_learning_curve method used by a scikit-learn classifier""" 2 | from __future__ import absolute_import 3 | import matplotlib.pyplot as plt 4 | from sklearn.ensemble import RandomForestClassifier 5 | from sklearn.datasets import load_breast_cancer as load_data 6 | import scikitplot as skplt 7 | 8 | 9 | X, y = load_data(return_X_y=True) 10 | rf = RandomForestClassifier() 11 | skplt.estimators.plot_learning_curve(rf, X, y) 12 | plt.show() 13 | -------------------------------------------------------------------------------- /examples/plot_lift_curve.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example showing the plot_lift_curve method used 3 | by a scikit-learn classifier 4 | """ 5 | from __future__ import absolute_import 6 | import matplotlib.pyplot as plt 7 | from sklearn.linear_model import LogisticRegression 8 | from sklearn.datasets import load_breast_cancer as load_data 9 | import scikitplot as skplt 10 | 11 | 12 | X, y = load_data(return_X_y=True) 13 | lr = LogisticRegression() 14 | lr.fit(X, y) 15 | probas = lr.predict_proba(X) 16 | skplt.metrics.plot_lift_curve(y_true=y, y_probas=probas) 17 | plt.show() 18 | -------------------------------------------------------------------------------- /examples/plot_pca_2d_projection.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example showing the plot_pca_2d_projection 3 | method used by a scikit-learn PCA object 4 | """ 5 | from sklearn.decomposition import PCA 6 | from sklearn.datasets import load_digits as load_data 7 | import scikitplot as skplt 8 | import matplotlib.pyplot as plt 9 | 10 | X, y = load_data(return_X_y=True) 11 | pca = PCA(random_state=1) 12 | pca.fit(X) 13 | skplt.decomposition.plot_pca_2d_projection(pca, X, y) 14 | plt.show() 15 | -------------------------------------------------------------------------------- /examples/plot_pca_component_variance.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example showing the plot_pca_component_variance 3 | method used by a scikit-learn PCA object 4 | """ 5 | from sklearn.decomposition import PCA 6 | from sklearn.datasets import load_digits as load_data 7 | import scikitplot as skplt 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | X, y = load_data(return_X_y=True) 12 | pca = PCA(random_state=1) 13 | pca.fit(X) 14 | skplt.decomposition.plot_pca_component_variance(pca) 15 | plt.show() 16 | -------------------------------------------------------------------------------- /examples/plot_precision_recall.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example showing the plot_precision_recall method 3 | used by a scikit-learn classifier 4 | """ 5 | from __future__ import absolute_import 6 | import matplotlib.pyplot as plt 7 | from sklearn.naive_bayes import GaussianNB 8 | from sklearn.datasets import load_digits as load_data 9 | import scikitplot as skplt 10 | 11 | 12 | X, y = load_data(return_X_y=True) 13 | nb = GaussianNB() 14 | nb.fit(X, y) 15 | probas = nb.predict_proba(X) 16 | skplt.metrics.plot_precision_recall(y_true=y, y_probas=probas) 17 | plt.show() 18 | -------------------------------------------------------------------------------- /examples/plot_roc.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example showing the plot_roc_curve method 3 | used by a scikit-learn classifier 4 | """ 5 | from __future__ import absolute_import 6 | import matplotlib.pyplot as plt 7 | from sklearn.naive_bayes import GaussianNB 8 | from sklearn.datasets import load_digits as load_data 9 | import scikitplot as skplt 10 | 11 | 12 | X, y = load_data(return_X_y=True) 13 | nb = GaussianNB() 14 | nb.fit(X, y) 15 | probas = nb.predict_proba(X) 16 | skplt.metrics.plot_roc(y_true=y, y_probas=probas) 17 | plt.show() 18 | -------------------------------------------------------------------------------- /examples/plot_silhouette.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example showing the plot_silhouette method 3 | used by a scikit-learn clusterer 4 | """ 5 | from __future__ import absolute_import 6 | import matplotlib.pyplot as plt 7 | from sklearn.cluster import KMeans 8 | from sklearn.datasets import load_iris as load_data 9 | import scikitplot as skplt 10 | 11 | X, y = load_data(return_X_y=True) 12 | kmeans = KMeans(n_clusters=4, random_state=1) 13 | cluster_labels = kmeans.fit_predict(X) 14 | skplt.metrics.plot_silhouette(X, cluster_labels) 15 | plt.show() 16 | -------------------------------------------------------------------------------- /examples/roc_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/examples/roc_curves.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=1.4.0 2 | scikit-learn>=0.18 3 | scipy>=0.9 4 | joblib>=0.10 5 | -------------------------------------------------------------------------------- /scikitplot/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | from . import metrics, cluster, decomposition, estimators 3 | __version__ = '0.3.7' 4 | 5 | 6 | from scikitplot.classifiers import classifier_factory 7 | from scikitplot.clustering import clustering_factory 8 | -------------------------------------------------------------------------------- /scikitplot/classifiers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, \ 2 | unicode_literals 3 | import six 4 | import warnings 5 | import types 6 | 7 | import numpy as np 8 | 9 | from sklearn.model_selection import StratifiedKFold 10 | from sklearn.base import clone 11 | from sklearn.utils import deprecated 12 | 13 | from scikitplot import plotters 14 | from scikitplot.plotters import plot_feature_importances 15 | from scikitplot.plotters import plot_learning_curve 16 | 17 | 18 | @deprecated('This will be removed in v0.4.0. The Factory ' 19 | 'API has been deprecated. Please migrate ' 20 | 'existing code into the various new modules ' 21 | 'of the Functions API. Please note that the ' 22 | 'interface of those functions will likely be ' 23 | 'different from that of the Factory API.') 24 | def classifier_factory(clf): 25 | """Embeds scikit-plot instance methods in an sklearn classifier. 26 | 27 | Args: 28 | clf: Scikit-learn classifier instance 29 | 30 | Returns: 31 | The same scikit-learn classifier instance passed in **clf** 32 | with embedded scikit-plot instance methods. 33 | 34 | Raises: 35 | ValueError: If **clf** does not contain the instance methods 36 | necessary for scikit-plot instance methods. 37 | """ 38 | required_methods = ['fit', 'score', 'predict'] 39 | 40 | for method in required_methods: 41 | if not hasattr(clf, method): 42 | raise TypeError('"{}" is not in clf. Did you pass a ' 43 | 'classifier instance?'.format(method)) 44 | 45 | optional_methods = ['predict_proba'] 46 | 47 | for method in optional_methods: 48 | if not hasattr(clf, method): 49 | warnings.warn('{} not in clf. Some plots may ' 50 | 'not be possible to generate.'.format(method)) 51 | 52 | additional_methods = { 53 | 'plot_learning_curve': plot_learning_curve, 54 | 'plot_confusion_matrix': plot_confusion_matrix_with_cv, 55 | 'plot_roc_curve': plot_roc_curve_with_cv, 56 | 'plot_ks_statistic': plot_ks_statistic_with_cv, 57 | 'plot_precision_recall_curve': plot_precision_recall_curve_with_cv, 58 | 'plot_feature_importances': plot_feature_importances 59 | } 60 | 61 | for key, fn in six.iteritems(additional_methods): 62 | if hasattr(clf, key): 63 | warnings.warn('"{}" method already in clf. ' 64 | 'Overriding anyway. This may ' 65 | 'result in unintended behavior.'.format(key)) 66 | setattr(clf, key, types.MethodType(fn, clf)) 67 | return clf 68 | 69 | 70 | def plot_confusion_matrix_with_cv(clf, X, y, labels=None, true_labels=None, 71 | pred_labels=None, title=None, 72 | normalize=False, hide_zeros=False, 73 | x_tick_rotation=0, do_cv=True, cv=None, 74 | shuffle=True, random_state=None, ax=None, 75 | figsize=None, cmap='Blues', 76 | title_fontsize="large", 77 | text_fontsize="medium"): 78 | """Generates the confusion matrix for a given classifier and dataset. 79 | 80 | Args: 81 | clf: Classifier instance that implements ``fit`` and ``predict`` 82 | methods. 83 | 84 | X (array-like, shape (n_samples, n_features)): 85 | Training vector, where n_samples is the number of samples and 86 | n_features is the number of features. 87 | 88 | y (array-like, shape (n_samples) or (n_samples, n_features)): 89 | Target relative to X for classification. 90 | 91 | labels (array-like, shape (n_classes), optional): List of labels to 92 | index the matrix. This may be used to reorder or select a subset of 93 | labels. If none is given, those that appear at least once in ``y`` 94 | are used in sorted order. 95 | (new in v0.2.5) 96 | 97 | true_labels (array-like, optional): The true labels to display. 98 | If none is given, then all of the labels are used. 99 | 100 | pred_labels (array-like, optional): The predicted labels to display. 101 | If none is given, then all of the labels are used. 102 | 103 | title (string, optional): Title of the generated plot. Defaults to 104 | "Confusion Matrix" if normalize` is True. Else, defaults to 105 | "Normalized Confusion Matrix. 106 | 107 | normalize (bool, optional): If True, normalizes the confusion matrix 108 | before plotting. Defaults to False. 109 | 110 | hide_zeros (bool, optional): If True, does not plot cells containing a 111 | value of zero. Defaults to False. 112 | 113 | x_tick_rotation (int, optional): Rotates x-axis tick labels by the 114 | specified angle. This is useful in cases where there are numerous 115 | categories and the labels overlap each other. 116 | 117 | do_cv (bool, optional): If True, the classifier is cross-validated on 118 | the dataset using the cross-validation strategy in `cv` to generate 119 | the confusion matrix. If False, the confusion matrix is generated 120 | without training or cross-validating the classifier. This assumes 121 | that the classifier has already been called with its `fit` method 122 | beforehand. 123 | 124 | cv (int, cross-validation generator, iterable, optional): Determines 125 | the cross-validation strategy to be used for splitting. 126 | 127 | Possible inputs for cv are: 128 | - None, to use the default 3-fold cross-validation, 129 | - integer, to specify the number of folds. 130 | - An object to be used as a cross-validation generator. 131 | - An iterable yielding train/test splits. 132 | 133 | For integer/None inputs, if ``y`` is binary or multiclass, 134 | :class:`StratifiedKFold` used. If the estimator is not a classifier 135 | or if ``y`` is neither binary nor multiclass, :class:`KFold` is 136 | used. 137 | 138 | shuffle (bool, optional): Used when do_cv is set to True. Determines 139 | whether to shuffle the training data before splitting using 140 | cross-validation. Default set to True. 141 | 142 | random_state (int :class:`RandomState`): Pseudo-random number generator 143 | state used for random sampling. 144 | 145 | ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to 146 | plot the learning curve. If None, the plot is drawn on a new set of 147 | axes. 148 | 149 | figsize (2-tuple, optional): Tuple denoting figure size of the plot 150 | e.g. (6, 6). Defaults to ``None``. 151 | 152 | cmap (string or :class:`matplotlib.colors.Colormap` instance, optional): 153 | Colormap used for plotting the projection. View Matplotlib Colormap 154 | documentation for available options. 155 | https://matplotlib.org/users/colormaps.html 156 | 157 | title_fontsize (string or int, optional): Matplotlib-style fontsizes. 158 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 159 | "large". 160 | 161 | text_fontsize (string or int, optional): Matplotlib-style fontsizes. 162 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 163 | "medium". 164 | 165 | 166 | Returns: 167 | ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was 168 | drawn. 169 | 170 | Example: 171 | >>> rf = classifier_factory(RandomForestClassifier()) 172 | >>> rf.plot_confusion_matrix(X, y, normalize=True) 173 | 174 | >>> plt.show() 175 | 176 | .. image:: _static/examples/plot_confusion_matrix.png 177 | :align: center 178 | :alt: Confusion matrix 179 | """ 180 | y = np.array(y) 181 | 182 | if not do_cv: 183 | y_pred = clf.predict(X) 184 | y_true = y 185 | 186 | else: 187 | if cv is None: 188 | cv = StratifiedKFold(shuffle=shuffle, random_state=random_state) 189 | elif isinstance(cv, int): 190 | cv = StratifiedKFold(n_splits=cv, shuffle=shuffle, 191 | random_state=random_state) 192 | else: 193 | pass 194 | 195 | clf_clone = clone(clf) 196 | 197 | preds_list = [] 198 | trues_list = [] 199 | for train_index, test_index in cv.split(X, y): 200 | X_train, X_test = X[train_index], X[test_index] 201 | y_train, y_test = y[train_index], y[test_index] 202 | clf_clone.fit(X_train, y_train) 203 | preds = clf_clone.predict(X_test) 204 | preds_list.append(preds) 205 | trues_list.append(y_test) 206 | y_pred = np.concatenate(preds_list) 207 | y_true = np.concatenate(trues_list) 208 | 209 | ax = plotters.plot_confusion_matrix(y_true=y_true, y_pred=y_pred, 210 | labels=labels, true_labels=true_labels, 211 | pred_labels=pred_labels, 212 | title=title, normalize=normalize, 213 | hide_zeros=hide_zeros, 214 | x_tick_rotation=x_tick_rotation, ax=ax, 215 | figsize=figsize, cmap=cmap, 216 | title_fontsize=title_fontsize, 217 | text_fontsize=text_fontsize) 218 | 219 | return ax 220 | 221 | 222 | def plot_roc_curve_with_cv(clf, X, y, title='ROC Curves', do_cv=True, 223 | cv=None, shuffle=True, random_state=None, 224 | curves=('micro', 'macro', 'each_class'), 225 | ax=None, figsize=None, cmap='nipy_spectral', 226 | title_fontsize="large", text_fontsize="medium"): 227 | """Generates the ROC curves for a given classifier and dataset. 228 | 229 | Args: 230 | clf: Classifier instance that implements ``fit`` and ``predict`` 231 | methods. 232 | 233 | X (array-like, shape (n_samples, n_features)): 234 | Training vector, where n_samples is the number of samples and 235 | n_features is the number of features. 236 | 237 | y (array-like, shape (n_samples) or (n_samples, n_features)): 238 | Target relative to X for classification. 239 | 240 | title (string, optional): Title of the generated plot. Defaults to 241 | "ROC Curves". 242 | 243 | do_cv (bool, optional): If True, the classifier is cross-validated on 244 | the dataset using the cross-validation strategy in `cv` to generate 245 | the confusion matrix. If False, the confusion matrix is generated 246 | without training or cross-validating the classifier. This assumes 247 | that the classifier has already been called with its `fit` method 248 | beforehand. 249 | 250 | cv (int, cross-validation generator, iterable, optional): Determines 251 | the cross-validation strategy to be used for splitting. 252 | 253 | Possible inputs for cv are: 254 | - None, to use the default 3-fold cross-validation, 255 | - integer, to specify the number of folds. 256 | - An object to be used as a cross-validation generator. 257 | - An iterable yielding train/test splits. 258 | 259 | For integer/None inputs, if ``y`` is binary or multiclass, 260 | :class:`StratifiedKFold` used. If the estimator is not a classifier 261 | or if ``y`` is neither binary nor multiclass, :class:`KFold` is 262 | used. 263 | 264 | shuffle (bool, optional): Used when do_cv is set to True. Determines 265 | whether to shuffle the training data before splitting using 266 | cross-validation. Default set to True. 267 | 268 | random_state (int :class:`RandomState`): Pseudo-random number generator 269 | state used for random sampling. 270 | 271 | curves (array-like): A listing of which curves should be plotted on the 272 | resulting plot. Defaults to `("micro", "macro", "each_class")` 273 | i.e. "micro" for micro-averaged curve, "macro" for macro-averaged 274 | curve 275 | 276 | ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to 277 | plot the learning curve. If None, the plot is drawn on a new set of 278 | axes. 279 | 280 | figsize (2-tuple, optional): Tuple denoting figure size of the plot 281 | e.g. (6, 6). Defaults to ``None``. 282 | 283 | cmap (string or :class:`matplotlib.colors.Colormap` instance, optional): 284 | Colormap used for plotting the projection. View Matplotlib Colormap 285 | documentation for available options. 286 | https://matplotlib.org/users/colormaps.html 287 | 288 | title_fontsize (string or int, optional): Matplotlib-style fontsizes. 289 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 290 | "large". 291 | 292 | text_fontsize (string or int, optional): Matplotlib-style fontsizes. 293 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 294 | "medium". 295 | 296 | Returns: 297 | ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was 298 | drawn. 299 | 300 | Example: 301 | >>> nb = classifier_factory(GaussianNB()) 302 | >>> nb.plot_roc_curve(X, y, random_state=1) 303 | 304 | >>> plt.show() 305 | 306 | .. image:: _static/examples/plot_roc_curve.png 307 | :align: center 308 | :alt: ROC Curves 309 | """ 310 | y = np.array(y) 311 | 312 | if not hasattr(clf, 'predict_proba'): 313 | raise TypeError('"predict_proba" method not in classifier. ' 314 | 'Cannot calculate ROC Curve.') 315 | 316 | if not do_cv: 317 | probas = clf.predict_proba(X) 318 | y_true = y 319 | 320 | else: 321 | if cv is None: 322 | cv = StratifiedKFold(shuffle=shuffle, random_state=random_state) 323 | elif isinstance(cv, int): 324 | cv = StratifiedKFold(n_splits=cv, shuffle=shuffle, 325 | random_state=random_state) 326 | else: 327 | pass 328 | 329 | clf_clone = clone(clf) 330 | 331 | preds_list = [] 332 | trues_list = [] 333 | for train_index, test_index in cv.split(X, y): 334 | X_train, X_test = X[train_index], X[test_index] 335 | y_train, y_test = y[train_index], y[test_index] 336 | clf_clone.fit(X_train, y_train) 337 | preds = clf_clone.predict_proba(X_test) 338 | preds_list.append(preds) 339 | trues_list.append(y_test) 340 | probas = np.concatenate(preds_list, axis=0) 341 | y_true = np.concatenate(trues_list) 342 | 343 | # Compute ROC curve and ROC area for each class 344 | ax = plotters.plot_roc_curve(y_true=y_true, y_probas=probas, title=title, 345 | curves=curves, ax=ax, figsize=figsize, 346 | cmap=cmap, title_fontsize=title_fontsize, 347 | text_fontsize=text_fontsize) 348 | 349 | return ax 350 | 351 | 352 | def plot_ks_statistic_with_cv(clf, X, y, title='KS Statistic Plot', 353 | do_cv=True, cv=None, shuffle=True, 354 | random_state=None, ax=None, figsize=None, 355 | title_fontsize="large", text_fontsize="medium"): 356 | """Generates the KS Statistic plot for a given classifier and dataset. 357 | 358 | Args: 359 | clf: Classifier instance that implements "fit" and "predict_proba" 360 | methods. 361 | 362 | X (array-like, shape (n_samples, n_features)): 363 | Training vector, where n_samples is the number of samples and 364 | n_features is the number of features. 365 | 366 | y (array-like, shape (n_samples) or (n_samples, n_features)): 367 | Target relative to X for classification. 368 | 369 | title (string, optional): Title of the generated plot. Defaults to 370 | "KS Statistic Plot". 371 | 372 | do_cv (bool, optional): If True, the classifier is cross-validated on 373 | the dataset using the cross-validation strategy in `cv` to generate 374 | the confusion matrix. If False, the confusion matrix is generated 375 | without training or cross-validating the classifier. This assumes 376 | that the classifier has already been called with its `fit` method 377 | beforehand. 378 | 379 | cv (int, cross-validation generator, iterable, optional): Determines 380 | the cross-validation strategy to be used for splitting. 381 | 382 | Possible inputs for cv are: 383 | - None, to use the default 3-fold cross-validation, 384 | - integer, to specify the number of folds. 385 | - An object to be used as a cross-validation generator. 386 | - An iterable yielding train/test splits. 387 | 388 | For integer/None inputs, if ``y`` is binary or multiclass, 389 | :class:`StratifiedKFold` used. If the estimator is not a classifier 390 | or if ``y`` is neither binary nor multiclass, :class:`KFold` is 391 | used. 392 | 393 | shuffle (bool, optional): Used when do_cv is set to True. Determines 394 | whether to shuffle the training data before splitting using 395 | cross-validation. Default set to True. 396 | 397 | random_state (int :class:`RandomState`): Pseudo-random number generator 398 | state used for random sampling. 399 | 400 | ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to 401 | plot the learning curve. If None, the plot is drawn on a new set of 402 | axes. 403 | 404 | figsize (2-tuple, optional): Tuple denoting figure size of the plot 405 | e.g. (6, 6). Defaults to ``None``. 406 | 407 | title_fontsize (string or int, optional): Matplotlib-style fontsizes. 408 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 409 | "large". 410 | 411 | text_fontsize (string or int, optional): Matplotlib-style fontsizes. 412 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 413 | "medium". 414 | 415 | Returns: 416 | ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was 417 | drawn. 418 | 419 | Example: 420 | >>> lr = classifier_factory(LogisticRegression()) 421 | >>> lr.plot_ks_statistic(X, y, random_state=1) 422 | 423 | >>> plt.show() 424 | 425 | .. image:: _static/examples/plot_ks_statistic.png 426 | :align: center 427 | :alt: KS Statistic 428 | """ 429 | y = np.array(y) 430 | 431 | if not hasattr(clf, 'predict_proba'): 432 | raise TypeError('"predict_proba" method not in classifier. ' 433 | 'Cannot calculate ROC Curve.') 434 | 435 | if not do_cv: 436 | probas = clf.predict_proba(X) 437 | y_true = y 438 | 439 | else: 440 | if cv is None: 441 | cv = StratifiedKFold(shuffle=shuffle, random_state=random_state) 442 | elif isinstance(cv, int): 443 | cv = StratifiedKFold(n_splits=cv, shuffle=shuffle, 444 | random_state=random_state) 445 | else: 446 | pass 447 | 448 | clf_clone = clone(clf) 449 | 450 | preds_list = [] 451 | trues_list = [] 452 | for train_index, test_index in cv.split(X, y): 453 | X_train, X_test = X[train_index], X[test_index] 454 | y_train, y_test = y[train_index], y[test_index] 455 | clf_clone.fit(X_train, y_train) 456 | preds = clf_clone.predict_proba(X_test) 457 | preds_list.append(preds) 458 | trues_list.append(y_test) 459 | probas = np.concatenate(preds_list, axis=0) 460 | y_true = np.concatenate(trues_list) 461 | 462 | ax = plotters.plot_ks_statistic(y_true, probas, title=title, 463 | ax=ax, figsize=figsize, 464 | title_fontsize=title_fontsize, 465 | text_fontsize=text_fontsize) 466 | 467 | return ax 468 | 469 | 470 | def plot_precision_recall_curve_with_cv(clf, X, y, 471 | title='Precision-Recall Curve', 472 | do_cv=True, cv=None, shuffle=True, 473 | random_state=None, 474 | curves=('micro', 'each_class'), 475 | ax=None, figsize=None, 476 | cmap='nipy_spectral', 477 | title_fontsize="large", 478 | text_fontsize="medium"): 479 | """Generates the Precision-Recall curve for a given classifier and dataset. 480 | 481 | Args: 482 | clf: Classifier instance that implements "fit" and "predict_proba" 483 | methods. 484 | 485 | X (array-like, shape (n_samples, n_features)): 486 | Training vector, where n_samples is the number of samples and 487 | n_features is the number of features. 488 | 489 | y (array-like, shape (n_samples) or (n_samples, n_features)): 490 | Target relative to X for classification. 491 | 492 | title (string, optional): Title of the generated plot. Defaults to 493 | "Precision-Recall Curve". 494 | 495 | do_cv (bool, optional): If True, the classifier is cross-validated on 496 | the dataset using the cross-validation strategy in `cv` to generate 497 | the confusion matrix. If False, the confusion matrix is generated 498 | without training or cross-validating the classifier. This assumes 499 | that the classifier has already been called with its `fit` method 500 | beforehand. 501 | 502 | cv (int, cross-validation generator, iterable, optional): Determines 503 | the cross-validation strategy to be used for splitting. 504 | 505 | Possible inputs for cv are: 506 | - None, to use the default 3-fold cross-validation, 507 | - integer, to specify the number of folds. 508 | - An object to be used as a cross-validation generator. 509 | - An iterable yielding train/test splits. 510 | 511 | For integer/None inputs, if ``y`` is binary or multiclass, 512 | :class:`StratifiedKFold` used. If the estimator is not a classifier 513 | or if ``y`` is neither binary nor multiclass, :class:`KFold` is 514 | used. 515 | 516 | shuffle (bool, optional): Used when do_cv is set to True. Determines 517 | whether to shuffle the training data before splitting using 518 | cross-validation. Default set to True. 519 | 520 | random_state (int :class:`RandomState`): Pseudo-random number generator 521 | state used for random sampling. 522 | 523 | curves (array-like): A listing of which curves should be plotted on the 524 | resulting plot. Defaults to `("micro", "each_class")` 525 | i.e. "micro" for micro-averaged curve 526 | 527 | ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to 528 | plot the learning curve. If None, the plot is drawn on a new set of 529 | axes. 530 | 531 | figsize (2-tuple, optional): Tuple denoting figure size of the plot 532 | e.g. (6, 6). Defaults to ``None``. 533 | 534 | cmap (string or :class:`matplotlib.colors.Colormap` instance, optional): 535 | Colormap used for plotting the projection. View Matplotlib Colormap 536 | documentation for available options. 537 | https://matplotlib.org/users/colormaps.html 538 | 539 | title_fontsize (string or int, optional): Matplotlib-style fontsizes. 540 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 541 | "large". 542 | 543 | text_fontsize (string or int, optional): Matplotlib-style fontsizes. 544 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 545 | "medium". 546 | 547 | Returns: 548 | ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was 549 | drawn. 550 | 551 | Example: 552 | >>> nb = classifier_factory(GaussianNB()) 553 | >>> nb.plot_precision_recall_curve(X, y, random_state=1) 554 | 555 | >>> plt.show() 556 | 557 | .. image:: _static/examples/plot_precision_recall_curve.png 558 | :align: center 559 | :alt: Precision Recall Curve 560 | """ 561 | y = np.array(y) 562 | 563 | if not hasattr(clf, 'predict_proba'): 564 | raise TypeError('"predict_proba" method not in classifier. ' 565 | 'Cannot calculate Precision-Recall Curve.') 566 | 567 | if not do_cv: 568 | probas = clf.predict_proba(X) 569 | y_true = y 570 | 571 | else: 572 | if cv is None: 573 | cv = StratifiedKFold(shuffle=shuffle, random_state=random_state) 574 | elif isinstance(cv, int): 575 | cv = StratifiedKFold(n_splits=cv, shuffle=shuffle, 576 | random_state=random_state) 577 | else: 578 | pass 579 | 580 | clf_clone = clone(clf) 581 | 582 | preds_list = [] 583 | trues_list = [] 584 | for train_index, test_index in cv.split(X, y): 585 | X_train, X_test = X[train_index], X[test_index] 586 | y_train, y_test = y[train_index], y[test_index] 587 | clf_clone.fit(X_train, y_train) 588 | preds = clf_clone.predict_proba(X_test) 589 | preds_list.append(preds) 590 | trues_list.append(y_test) 591 | probas = np.concatenate(preds_list, axis=0) 592 | y_true = np.concatenate(trues_list) 593 | 594 | # Compute Precision-Recall curve and area for each class 595 | ax = plotters.plot_precision_recall_curve(y_true, probas, title=title, 596 | curves=curves, ax=ax, 597 | figsize=figsize, cmap=cmap, 598 | title_fontsize=title_fontsize, 599 | text_fontsize=text_fontsize) 600 | return ax 601 | -------------------------------------------------------------------------------- /scikitplot/cluster.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`scikitplot.cluster` module includes plots built specifically for 3 | scikit-learn clusterer instances e.g. KMeans. You can use your own clusterers, 4 | but these plots assume specific properties shared by scikit-learn estimators. 5 | The specific requirements are documented per function. 6 | """ 7 | from __future__ import absolute_import, division, print_function, \ 8 | unicode_literals 9 | 10 | import time 11 | 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | 15 | from sklearn.base import clone 16 | from joblib import Parallel, delayed 17 | 18 | 19 | def plot_elbow_curve(clf, X, title='Elbow Plot', cluster_ranges=None, n_jobs=1, 20 | show_cluster_time=True, ax=None, figsize=None, 21 | title_fontsize="large", text_fontsize="medium"): 22 | """Plots elbow curve of different values of K for KMeans clustering. 23 | 24 | Args: 25 | clf: Clusterer instance that implements ``fit``,``fit_predict``, and 26 | ``score`` methods, and an ``n_clusters`` hyperparameter. 27 | e.g. :class:`sklearn.cluster.KMeans` instance 28 | 29 | X (array-like, shape (n_samples, n_features)): 30 | Data to cluster, where n_samples is the number of samples and 31 | n_features is the number of features. 32 | 33 | title (string, optional): Title of the generated plot. Defaults to 34 | "Elbow Plot" 35 | 36 | cluster_ranges (None or :obj:`list` of int, optional): List of 37 | n_clusters for which to plot the explained variances. Defaults to 38 | ``range(1, 12, 2)``. 39 | 40 | n_jobs (int, optional): Number of jobs to run in parallel. Defaults to 41 | 1. 42 | 43 | show_cluster_time (bool, optional): Include plot of time it took to 44 | cluster for a particular K. 45 | 46 | ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to 47 | plot the curve. If None, the plot is drawn on a new set of axes. 48 | 49 | figsize (2-tuple, optional): Tuple denoting figure size of the plot 50 | e.g. (6, 6). Defaults to ``None``. 51 | 52 | title_fontsize (string or int, optional): Matplotlib-style fontsizes. 53 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 54 | "large". 55 | 56 | text_fontsize (string or int, optional): Matplotlib-style fontsizes. 57 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 58 | "medium". 59 | 60 | Returns: 61 | ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was 62 | drawn. 63 | 64 | Example: 65 | >>> import scikitplot as skplt 66 | >>> kmeans = KMeans(random_state=1) 67 | >>> skplt.cluster.plot_elbow_curve(kmeans, cluster_ranges=range(1, 30)) 68 | 69 | >>> plt.show() 70 | 71 | .. image:: _static/examples/plot_elbow_curve.png 72 | :align: center 73 | :alt: Elbow Curve 74 | """ 75 | if cluster_ranges is None: 76 | cluster_ranges = range(1, 12, 2) 77 | else: 78 | cluster_ranges = sorted(cluster_ranges) 79 | 80 | if not hasattr(clf, 'n_clusters'): 81 | raise TypeError('"n_clusters" attribute not in classifier. ' 82 | 'Cannot plot elbow method.') 83 | 84 | tuples = Parallel(n_jobs=n_jobs)(delayed(_clone_and_score_clusterer) 85 | (clf, X, i) for i in cluster_ranges) 86 | clfs, times = zip(*tuples) 87 | 88 | if ax is None: 89 | fig, ax = plt.subplots(1, 1, figsize=figsize) 90 | 91 | ax.set_title(title, fontsize=title_fontsize) 92 | ax.plot(cluster_ranges, np.absolute(clfs), 'b*-') 93 | ax.grid(True) 94 | ax.set_xlabel('Number of clusters', fontsize=text_fontsize) 95 | ax.set_ylabel('Sum of Squared Errors', fontsize=text_fontsize) 96 | ax.tick_params(labelsize=text_fontsize) 97 | 98 | if show_cluster_time: 99 | ax2_color = 'green' 100 | ax2 = ax.twinx() 101 | ax2.plot(cluster_ranges, times, ':', alpha=0.75, color=ax2_color) 102 | ax2.set_ylabel('Clustering duration (seconds)', 103 | color=ax2_color, alpha=0.75, 104 | fontsize=text_fontsize) 105 | ax2.tick_params(colors=ax2_color, labelsize=text_fontsize) 106 | 107 | return ax 108 | 109 | 110 | def _clone_and_score_clusterer(clf, X, n_clusters): 111 | """Clones and scores clusterer instance. 112 | 113 | Args: 114 | clf: Clusterer instance that implements ``fit``,``fit_predict``, and 115 | ``score`` methods, and an ``n_clusters`` hyperparameter. 116 | e.g. :class:`sklearn.cluster.KMeans` instance 117 | 118 | X (array-like, shape (n_samples, n_features)): 119 | Data to cluster, where n_samples is the number of samples and 120 | n_features is the number of features. 121 | 122 | n_clusters (int): Number of clusters 123 | 124 | Returns: 125 | score: Score of clusters 126 | 127 | time: Number of seconds it took to fit cluster 128 | """ 129 | start = time.time() 130 | clf = clone(clf) 131 | setattr(clf, 'n_clusters', n_clusters) 132 | return clf.fit(X).score(X), time.time() - start 133 | -------------------------------------------------------------------------------- /scikitplot/clustering.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, \ 2 | unicode_literals 3 | import six 4 | import warnings 5 | import types 6 | 7 | from sklearn.utils import deprecated 8 | 9 | from scikitplot.plotters import plot_silhouette, plot_elbow_curve 10 | 11 | 12 | @deprecated('This will be removed in v0.4.0. The Factory ' 13 | 'API has been deprecated. Please migrate ' 14 | 'existing code into the various new modules ' 15 | 'of the Functions API. Please note that the ' 16 | 'interface of those functions will likely be ' 17 | 'different from that of the Factory API.') 18 | def clustering_factory(clf): 19 | """Embeds scikit-plot plotting methods in an sklearn clusterer instance. 20 | 21 | Args: 22 | clf: Scikit-learn clusterer instance 23 | 24 | Returns: 25 | The same scikit-learn clusterer instance passed in **clf** with 26 | embedded scikit-plot instance methods. 27 | 28 | Raises: 29 | ValueError: If **clf** does not contain the instance methods necessary 30 | for scikit-plot instance methods. 31 | """ 32 | required_methods = ['fit', 'fit_predict'] 33 | 34 | for method in required_methods: 35 | if not hasattr(clf, method): 36 | raise TypeError('"{}" is not in clf. Did you ' 37 | 'pass a clusterer instance?'.format(method)) 38 | 39 | additional_methods = { 40 | 'plot_silhouette': plot_silhouette, 41 | 'plot_elbow_curve': plot_elbow_curve 42 | } 43 | 44 | for key, fn in six.iteritems(additional_methods): 45 | if hasattr(clf, key): 46 | warnings.warn('"{}" method already in clf. ' 47 | 'Overriding anyway. This may ' 48 | 'result in unintended behavior.'.format(key)) 49 | setattr(clf, key, types.MethodType(fn, clf)) 50 | return clf 51 | -------------------------------------------------------------------------------- /scikitplot/decomposition.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`scikitplot.decomposition` module includes plots built specifically 3 | for scikit-learn estimators that are used for dimensionality reduction 4 | e.g. PCA. You can use your own estimators, but these plots assume specific 5 | properties shared by scikit-learn estimators. The specific requirements are 6 | documented per function. 7 | """ 8 | from __future__ import absolute_import, division, print_function, \ 9 | unicode_literals 10 | 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | 14 | 15 | def plot_pca_component_variance(clf, title='PCA Component Explained Variances', 16 | target_explained_variance=0.75, ax=None, 17 | figsize=None, title_fontsize="large", 18 | text_fontsize="medium"): 19 | """Plots PCA components' explained variance ratios. (new in v0.2.2) 20 | 21 | Args: 22 | clf: PCA instance that has the ``explained_variance_ratio_`` attribute. 23 | 24 | title (string, optional): Title of the generated plot. Defaults to 25 | "PCA Component Explained Variances" 26 | 27 | target_explained_variance (float, optional): Looks for the minimum 28 | number of principal components that satisfies this value and 29 | emphasizes it on the plot. Defaults to 0.75 30 | 31 | ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to 32 | plot the curve. If None, the plot is drawn on a new set of axes. 33 | 34 | figsize (2-tuple, optional): Tuple denoting figure size of the plot 35 | e.g. (6, 6). Defaults to ``None``. 36 | 37 | title_fontsize (string or int, optional): Matplotlib-style fontsizes. 38 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 39 | "large". 40 | 41 | text_fontsize (string or int, optional): Matplotlib-style fontsizes. 42 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 43 | "medium". 44 | 45 | Returns: 46 | ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was 47 | drawn. 48 | 49 | Example: 50 | >>> import scikitplot as skplt 51 | >>> pca = PCA(random_state=1) 52 | >>> pca.fit(X) 53 | >>> skplt.decomposition.plot_pca_component_variance(pca) 54 | 55 | >>> plt.show() 56 | 57 | .. image:: _static/examples/plot_pca_component_variance.png 58 | :align: center 59 | :alt: PCA Component variances 60 | """ 61 | if not hasattr(clf, 'explained_variance_ratio_'): 62 | raise TypeError('"clf" does not have explained_variance_ratio_ ' 63 | 'attribute. Has the PCA been fitted?') 64 | 65 | if ax is None: 66 | fig, ax = plt.subplots(1, 1, figsize=figsize) 67 | 68 | ax.set_title(title, fontsize=title_fontsize) 69 | 70 | cumulative_sum_ratios = np.cumsum(clf.explained_variance_ratio_) 71 | 72 | # Magic code for figuring out closest value to target_explained_variance 73 | idx = np.searchsorted(cumulative_sum_ratios, target_explained_variance) 74 | 75 | ax.plot(range(len(clf.explained_variance_ratio_) + 1), 76 | np.concatenate(([0], np.cumsum(clf.explained_variance_ratio_))), 77 | '*-') 78 | ax.grid(True) 79 | ax.set_xlabel('First n principal components', fontsize=text_fontsize) 80 | ax.set_ylabel('Explained variance ratio of first n components', 81 | fontsize=text_fontsize) 82 | ax.set_ylim([-0.02, 1.02]) 83 | if idx < len(cumulative_sum_ratios): 84 | ax.plot(idx+1, cumulative_sum_ratios[idx], 'ro', 85 | label='{0:0.3f} Explained variance ratio for ' 86 | 'first {1} components'.format(cumulative_sum_ratios[idx], 87 | idx+1), 88 | markersize=4, markeredgewidth=4) 89 | ax.axhline(cumulative_sum_ratios[idx], 90 | linestyle=':', lw=3, color='black') 91 | ax.tick_params(labelsize=text_fontsize) 92 | ax.legend(loc="best", fontsize=text_fontsize) 93 | 94 | return ax 95 | 96 | 97 | def plot_pca_2d_projection(clf, X, y, title='PCA 2-D Projection', 98 | biplot=False, feature_labels=None, 99 | ax=None, figsize=None, cmap='Spectral', 100 | title_fontsize="large", text_fontsize="medium"): 101 | """Plots the 2-dimensional projection of PCA on a given dataset. 102 | 103 | Args: 104 | clf: Fitted PCA instance that can ``transform`` given data set into 2 105 | dimensions. 106 | 107 | X (array-like, shape (n_samples, n_features)): 108 | Feature set to project, where n_samples is the number of samples 109 | and n_features is the number of features. 110 | 111 | y (array-like, shape (n_samples) or (n_samples, n_features)): 112 | Target relative to X for labeling. 113 | 114 | title (string, optional): Title of the generated plot. Defaults to 115 | "PCA 2-D Projection" 116 | 117 | biplot (bool, optional): If True, the function will generate and plot 118 | biplots. If false, the biplots are not generated. 119 | 120 | feature_labels (array-like, shape (n_classes), optional): List of labels 121 | that represent each feature of X. Its index position must also be 122 | relative to the features. If ``None`` is given, then labels will be 123 | automatically generated for each feature. 124 | e.g. "variable1", "variable2", "variable3" ... 125 | 126 | ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to 127 | plot the curve. If None, the plot is drawn on a new set of axes. 128 | 129 | figsize (2-tuple, optional): Tuple denoting figure size of the plot 130 | e.g. (6, 6). Defaults to ``None``. 131 | 132 | cmap (string or :class:`matplotlib.colors.Colormap` instance, optional): 133 | Colormap used for plotting the projection. View Matplotlib Colormap 134 | documentation for available options. 135 | https://matplotlib.org/users/colormaps.html 136 | 137 | title_fontsize (string or int, optional): Matplotlib-style fontsizes. 138 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 139 | "large". 140 | 141 | text_fontsize (string or int, optional): Matplotlib-style fontsizes. 142 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 143 | "medium". 144 | 145 | Returns: 146 | ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was 147 | drawn. 148 | 149 | Example: 150 | >>> import scikitplot as skplt 151 | >>> pca = PCA(random_state=1) 152 | >>> pca.fit(X) 153 | >>> skplt.decomposition.plot_pca_2d_projection(pca, X, y) 154 | 155 | >>> plt.show() 156 | 157 | .. image:: _static/examples/plot_pca_2d_projection.png 158 | :align: center 159 | :alt: PCA 2D Projection 160 | """ 161 | transformed_X = clf.transform(X) 162 | if ax is None: 163 | fig, ax = plt.subplots(1, 1, figsize=figsize) 164 | 165 | ax.set_title(title, fontsize=title_fontsize) 166 | classes = np.unique(np.array(y)) 167 | 168 | colors = plt.cm.get_cmap(cmap)(np.linspace(0, 1, len(classes))) 169 | 170 | for label, color in zip(classes, colors): 171 | ax.scatter(transformed_X[y == label, 0], transformed_X[y == label, 1], 172 | alpha=0.8, lw=2, label=label, color=color) 173 | 174 | if biplot: 175 | xs = transformed_X[:, 0] 176 | ys = transformed_X[:, 1] 177 | vectors = np.transpose(clf.components_[:2, :]) 178 | vectors_scaled = vectors * [xs.max(), ys.max()] 179 | for i in range(vectors.shape[0]): 180 | ax.annotate("", xy=(vectors_scaled[i, 0], vectors_scaled[i, 1]), 181 | xycoords='data', xytext=(0, 0), textcoords='data', 182 | arrowprops={'arrowstyle': '-|>', 'ec': 'r'}) 183 | 184 | ax.text(vectors_scaled[i, 0] * 1.05, vectors_scaled[i, 1] * 1.05, 185 | feature_labels[i] if feature_labels else "Variable" + str(i), 186 | color='b', fontsize=text_fontsize) 187 | 188 | ax.legend(loc='best', shadow=False, scatterpoints=1, 189 | fontsize=text_fontsize) 190 | ax.set_xlabel('First Principal Component', fontsize=text_fontsize) 191 | ax.set_ylabel('Second Principal Component', fontsize=text_fontsize) 192 | ax.tick_params(labelsize=text_fontsize) 193 | 194 | return ax 195 | -------------------------------------------------------------------------------- /scikitplot/estimators.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`scikitplot.estimators` module includes plots built specifically for 3 | scikit-learn estimator (classifier/regressor) instances e.g. Random Forest. 4 | You can use your own estimators, but these plots assume specific properties 5 | shared by scikit-learn estimators. The specific requirements are documented per 6 | function. 7 | """ 8 | from __future__ import absolute_import, division, print_function, \ 9 | unicode_literals 10 | 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | 14 | from sklearn.model_selection import learning_curve 15 | 16 | 17 | def plot_feature_importances(clf, title='Feature Importance', 18 | feature_names=None, max_num_features=20, 19 | order='descending', x_tick_rotation=0, ax=None, 20 | figsize=None, title_fontsize="large", 21 | text_fontsize="medium"): 22 | """Generates a plot of a classifier's feature importances. 23 | 24 | Args: 25 | clf: Classifier instance that has a ``feature_importances_`` attribute, 26 | e.g. :class:`sklearn.ensemble.RandomForestClassifier` or 27 | :class:`xgboost.XGBClassifier`. 28 | 29 | title (string, optional): Title of the generated plot. Defaults to 30 | "Feature importances". 31 | 32 | feature_names (None, :obj:`list` of string, optional): Determines the 33 | feature names used to plot the feature importances. If None, 34 | feature names will be numbered. 35 | 36 | max_num_features (int): Determines the maximum number of features to 37 | plot. Defaults to 20. 38 | 39 | order ('ascending', 'descending', or None, optional): Determines the 40 | order in which the feature importances are plotted. Defaults to 41 | 'descending'. 42 | 43 | x_tick_rotation (int, optional): Rotates x-axis tick labels by the 44 | specified angle. This is useful in cases where there are numerous 45 | categories and the labels overlap each other. 46 | 47 | ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to 48 | plot the curve. If None, the plot is drawn on a new set of axes. 49 | 50 | figsize (2-tuple, optional): Tuple denoting figure size of the plot 51 | e.g. (6, 6). Defaults to ``None``. 52 | 53 | title_fontsize (string or int, optional): Matplotlib-style fontsizes. 54 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 55 | "large". 56 | 57 | text_fontsize (string or int, optional): Matplotlib-style fontsizes. 58 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 59 | "medium". 60 | 61 | Returns: 62 | ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was 63 | drawn. 64 | 65 | Example: 66 | >>> import scikitplot as skplt 67 | >>> rf = RandomForestClassifier() 68 | >>> rf.fit(X, y) 69 | >>> skplt.estimators.plot_feature_importances( 70 | ... rf, feature_names=['petal length', 'petal width', 71 | ... 'sepal length', 'sepal width']) 72 | 73 | >>> plt.show() 74 | 75 | .. image:: _static/examples/plot_feature_importances.png 76 | :align: center 77 | :alt: Feature Importances 78 | """ 79 | if not hasattr(clf, 'feature_importances_'): 80 | raise TypeError('"feature_importances_" attribute not in classifier. ' 81 | 'Cannot plot feature importances.') 82 | 83 | importances = clf.feature_importances_ 84 | 85 | if hasattr(clf, 'estimators_')\ 86 | and isinstance(clf.estimators_, list)\ 87 | and hasattr(clf.estimators_[0], 'feature_importances_'): 88 | std = np.std([tree.feature_importances_ for tree in clf.estimators_], 89 | axis=0) 90 | 91 | else: 92 | std = None 93 | 94 | if order == 'descending': 95 | indices = np.argsort(importances)[::-1] 96 | 97 | elif order == 'ascending': 98 | indices = np.argsort(importances) 99 | 100 | elif order is None: 101 | indices = np.array(range(len(importances))) 102 | 103 | else: 104 | raise ValueError('Invalid argument {} for "order"'.format(order)) 105 | 106 | if ax is None: 107 | fig, ax = plt.subplots(1, 1, figsize=figsize) 108 | 109 | if feature_names is None: 110 | feature_names = indices 111 | else: 112 | feature_names = np.array(feature_names)[indices] 113 | 114 | max_num_features = min(max_num_features, len(importances)) 115 | 116 | ax.set_title(title, fontsize=title_fontsize) 117 | 118 | if std is not None: 119 | ax.bar(range(max_num_features), 120 | importances[indices][:max_num_features], color='r', 121 | yerr=std[indices][:max_num_features], align='center') 122 | else: 123 | ax.bar(range(max_num_features), 124 | importances[indices][:max_num_features], 125 | color='r', align='center') 126 | 127 | ax.set_xticks(range(max_num_features)) 128 | ax.set_xticklabels(feature_names[:max_num_features], 129 | rotation=x_tick_rotation) 130 | ax.set_xlim([-1, max_num_features]) 131 | ax.tick_params(labelsize=text_fontsize) 132 | return ax 133 | 134 | 135 | def plot_learning_curve(clf, X, y, title='Learning Curve', cv=None, 136 | shuffle=False, random_state=None, 137 | train_sizes=None, n_jobs=1, scoring=None, 138 | ax=None, figsize=None, title_fontsize="large", 139 | text_fontsize="medium"): 140 | """Generates a plot of the train and test learning curves for a classifier. 141 | 142 | Args: 143 | clf: Classifier instance that implements ``fit`` and ``predict`` 144 | methods. 145 | 146 | X (array-like, shape (n_samples, n_features)): 147 | Training vector, where n_samples is the number of samples and 148 | n_features is the number of features. 149 | 150 | y (array-like, shape (n_samples) or (n_samples, n_features)): 151 | Target relative to X for classification or regression; 152 | None for unsupervised learning. 153 | 154 | title (string, optional): Title of the generated plot. Defaults to 155 | "Learning Curve" 156 | 157 | cv (int, cross-validation generator, iterable, optional): Determines 158 | the cross-validation strategy to be used for splitting. 159 | 160 | Possible inputs for cv are: 161 | - None, to use the default 3-fold cross-validation, 162 | - integer, to specify the number of folds. 163 | - An object to be used as a cross-validation generator. 164 | - An iterable yielding train/test splits. 165 | 166 | For integer/None inputs, if ``y`` is binary or multiclass, 167 | :class:`StratifiedKFold` used. If the estimator is not a classifier 168 | or if ``y`` is neither binary nor multiclass, :class:`KFold` is 169 | used. 170 | 171 | shuffle (bool, optional): Used when do_cv is set to True. Determines 172 | whether to shuffle the training data before splitting using 173 | cross-validation. Default set to True. 174 | 175 | random_state (int :class:`RandomState`): Pseudo-random number generator 176 | state used for random sampling. 177 | 178 | train_sizes (iterable, optional): Determines the training sizes used to 179 | plot the learning curve. If None, ``np.linspace(.1, 1.0, 5)`` is 180 | used. 181 | 182 | n_jobs (int, optional): Number of jobs to run in parallel. Defaults to 183 | 1. 184 | 185 | scoring (string, callable or None, optional): default: None 186 | A string (see scikit-learn model evaluation documentation) or a 187 | scorerbcallable object / function with signature 188 | scorer(estimator, X, y). 189 | 190 | ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to 191 | plot the curve. If None, the plot is drawn on a new set of axes. 192 | 193 | figsize (2-tuple, optional): Tuple denoting figure size of the plot 194 | e.g. (6, 6). Defaults to ``None``. 195 | 196 | title_fontsize (string or int, optional): Matplotlib-style fontsizes. 197 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 198 | "large". 199 | 200 | text_fontsize (string or int, optional): Matplotlib-style fontsizes. 201 | Use e.g. "small", "medium", "large" or integer-values. Defaults to 202 | "medium". 203 | 204 | Returns: 205 | ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was 206 | drawn. 207 | 208 | Example: 209 | >>> import scikitplot as skplt 210 | >>> rf = RandomForestClassifier() 211 | >>> skplt.estimators.plot_learning_curve(rf, X, y) 212 | 213 | >>> plt.show() 214 | 215 | .. image:: _static/examples/plot_learning_curve.png 216 | :align: center 217 | :alt: Learning Curve 218 | """ 219 | if ax is None: 220 | fig, ax = plt.subplots(1, 1, figsize=figsize) 221 | 222 | if train_sizes is None: 223 | train_sizes = np.linspace(.1, 1.0, 5) 224 | 225 | ax.set_title(title, fontsize=title_fontsize) 226 | ax.set_xlabel("Training examples", fontsize=text_fontsize) 227 | ax.set_ylabel("Score", fontsize=text_fontsize) 228 | train_sizes, train_scores, test_scores = learning_curve( 229 | clf, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes, 230 | scoring=scoring, shuffle=shuffle, random_state=random_state) 231 | train_scores_mean = np.mean(train_scores, axis=1) 232 | train_scores_std = np.std(train_scores, axis=1) 233 | test_scores_mean = np.mean(test_scores, axis=1) 234 | test_scores_std = np.std(test_scores, axis=1) 235 | ax.grid() 236 | ax.fill_between(train_sizes, train_scores_mean - train_scores_std, 237 | train_scores_mean + train_scores_std, alpha=0.1, color="r") 238 | ax.fill_between(train_sizes, test_scores_mean - test_scores_std, 239 | test_scores_mean + test_scores_std, alpha=0.1, color="g") 240 | ax.plot(train_sizes, train_scores_mean, 'o-', color="r", 241 | label="Training score") 242 | ax.plot(train_sizes, test_scores_mean, 'o-', color="g", 243 | label="Cross-validation score") 244 | ax.tick_params(labelsize=text_fontsize) 245 | ax.legend(loc="best", fontsize=text_fontsize) 246 | 247 | return ax 248 | -------------------------------------------------------------------------------- /scikitplot/helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, \ 2 | unicode_literals 3 | import numpy as np 4 | from sklearn.preprocessing import LabelEncoder 5 | 6 | 7 | def binary_ks_curve(y_true, y_probas): 8 | """This function generates the points necessary to calculate the KS 9 | Statistic curve. 10 | 11 | Args: 12 | y_true (array-like, shape (n_samples)): True labels of the data. 13 | 14 | y_probas (array-like, shape (n_samples)): Probability predictions of 15 | the positive class. 16 | 17 | Returns: 18 | thresholds (numpy.ndarray): An array containing the X-axis values for 19 | plotting the KS Statistic plot. 20 | 21 | pct1 (numpy.ndarray): An array containing the Y-axis values for one 22 | curve of the KS Statistic plot. 23 | 24 | pct2 (numpy.ndarray): An array containing the Y-axis values for one 25 | curve of the KS Statistic plot. 26 | 27 | ks_statistic (float): The KS Statistic, or the maximum vertical 28 | distance between the two curves. 29 | 30 | max_distance_at (float): The X-axis value at which the maximum vertical 31 | distance between the two curves is seen. 32 | 33 | classes (np.ndarray, shape (2)): An array containing the labels of the 34 | two classes making up `y_true`. 35 | 36 | Raises: 37 | ValueError: If `y_true` is not composed of 2 classes. The KS Statistic 38 | is only relevant in binary classification. 39 | """ 40 | y_true, y_probas = np.asarray(y_true), np.asarray(y_probas) 41 | lb = LabelEncoder() 42 | encoded_labels = lb.fit_transform(y_true) 43 | if len(lb.classes_) != 2: 44 | raise ValueError('Cannot calculate KS statistic for data with ' 45 | '{} category/ies'.format(len(lb.classes_))) 46 | idx = encoded_labels == 0 47 | data1 = np.sort(y_probas[idx]) 48 | data2 = np.sort(y_probas[np.logical_not(idx)]) 49 | 50 | ctr1, ctr2 = 0, 0 51 | thresholds, pct1, pct2 = [], [], [] 52 | while ctr1 < len(data1) or ctr2 < len(data2): 53 | 54 | # Check if data1 has no more elements 55 | if ctr1 >= len(data1): 56 | current = data2[ctr2] 57 | while ctr2 < len(data2) and current == data2[ctr2]: 58 | ctr2 += 1 59 | 60 | # Check if data2 has no more elements 61 | elif ctr2 >= len(data2): 62 | current = data1[ctr1] 63 | while ctr1 < len(data1) and current == data1[ctr1]: 64 | ctr1 += 1 65 | 66 | else: 67 | if data1[ctr1] > data2[ctr2]: 68 | current = data2[ctr2] 69 | while ctr2 < len(data2) and current == data2[ctr2]: 70 | ctr2 += 1 71 | 72 | elif data1[ctr1] < data2[ctr2]: 73 | current = data1[ctr1] 74 | while ctr1 < len(data1) and current == data1[ctr1]: 75 | ctr1 += 1 76 | 77 | else: 78 | current = data2[ctr2] 79 | while ctr2 < len(data2) and current == data2[ctr2]: 80 | ctr2 += 1 81 | while ctr1 < len(data1) and current == data1[ctr1]: 82 | ctr1 += 1 83 | 84 | thresholds.append(current) 85 | pct1.append(ctr1) 86 | pct2.append(ctr2) 87 | 88 | thresholds = np.asarray(thresholds) 89 | pct1 = np.asarray(pct1) / float(len(data1)) 90 | pct2 = np.asarray(pct2) / float(len(data2)) 91 | 92 | if thresholds[0] != 0: 93 | thresholds = np.insert(thresholds, 0, [0.0]) 94 | pct1 = np.insert(pct1, 0, [0.0]) 95 | pct2 = np.insert(pct2, 0, [0.0]) 96 | if thresholds[-1] != 1: 97 | thresholds = np.append(thresholds, [1.0]) 98 | pct1 = np.append(pct1, [1.0]) 99 | pct2 = np.append(pct2, [1.0]) 100 | 101 | differences = pct1 - pct2 102 | ks_statistic, max_distance_at = (np.max(differences), 103 | thresholds[np.argmax(differences)]) 104 | 105 | return thresholds, pct1, pct2, ks_statistic, max_distance_at, lb.classes_ 106 | 107 | 108 | def validate_labels(known_classes, passed_labels, argument_name): 109 | """Validates the labels passed into the true_labels or pred_labels 110 | arguments in the plot_confusion_matrix function. 111 | 112 | Raises a ValueError exception if any of the passed labels are not in the 113 | set of known classes or if there are duplicate labels. Otherwise returns 114 | None. 115 | 116 | Args: 117 | known_classes (array-like): 118 | The classes that are known to appear in the data. 119 | passed_labels (array-like): 120 | The labels that were passed in through the argument. 121 | argument_name (str): 122 | The name of the argument being validated. 123 | 124 | Example: 125 | >>> known_classes = ["A", "B", "C"] 126 | >>> passed_labels = ["A", "B"] 127 | >>> validate_labels(known_classes, passed_labels, "true_labels") 128 | """ 129 | known_classes = np.array(known_classes) 130 | passed_labels = np.array(passed_labels) 131 | 132 | unique_labels, unique_indexes = np.unique(passed_labels, return_index=True) 133 | 134 | if len(passed_labels) != len(unique_labels): 135 | indexes = np.arange(0, len(passed_labels)) 136 | duplicate_indexes = indexes[~np.in1d(indexes, unique_indexes)] 137 | duplicate_labels = [str(x) for x in passed_labels[duplicate_indexes]] 138 | 139 | msg = "The following duplicate labels were passed into {0}: {1}" \ 140 | .format(argument_name, ", ".join(duplicate_labels)) 141 | raise ValueError(msg) 142 | 143 | passed_labels_absent = ~np.in1d(passed_labels, known_classes) 144 | 145 | if np.any(passed_labels_absent): 146 | absent_labels = [str(x) for x in passed_labels[passed_labels_absent]] 147 | 148 | msg = ("The following labels " 149 | "were passed into {0}, " 150 | "but were not found in " 151 | "labels: {1}").format(argument_name, ", ".join(absent_labels)) 152 | raise ValueError(msg) 153 | 154 | return 155 | 156 | 157 | def cumulative_gain_curve(y_true, y_score, pos_label=None): 158 | """This function generates the points necessary to plot the Cumulative Gain 159 | 160 | Note: This implementation is restricted to the binary classification task. 161 | 162 | Args: 163 | y_true (array-like, shape (n_samples)): True labels of the data. 164 | 165 | y_score (array-like, shape (n_samples)): Target scores, can either be 166 | probability estimates of the positive class, confidence values, or 167 | non-thresholded measure of decisions (as returned by 168 | decision_function on some classifiers). 169 | 170 | pos_label (int or str, default=None): Label considered as positive and 171 | others are considered negative 172 | 173 | Returns: 174 | percentages (numpy.ndarray): An array containing the X-axis values for 175 | plotting the Cumulative Gains chart. 176 | 177 | gains (numpy.ndarray): An array containing the Y-axis values for one 178 | curve of the Cumulative Gains chart. 179 | 180 | Raises: 181 | ValueError: If `y_true` is not composed of 2 classes. The Cumulative 182 | Gain Chart is only relevant in binary classification. 183 | """ 184 | y_true, y_score = np.asarray(y_true), np.asarray(y_score) 185 | 186 | # ensure binary classification if pos_label is not specified 187 | classes = np.unique(y_true) 188 | if (pos_label is None and 189 | not (np.array_equal(classes, [0, 1]) or 190 | np.array_equal(classes, [-1, 1]) or 191 | np.array_equal(classes, [0]) or 192 | np.array_equal(classes, [-1]) or 193 | np.array_equal(classes, [1]))): 194 | raise ValueError("Data is not binary and pos_label is not specified") 195 | elif pos_label is None: 196 | pos_label = 1. 197 | 198 | # make y_true a boolean vector 199 | y_true = (y_true == pos_label) 200 | 201 | sorted_indices = np.argsort(y_score)[::-1] 202 | y_true = y_true[sorted_indices] 203 | gains = np.cumsum(y_true) 204 | 205 | percentages = np.arange(start=1, stop=len(y_true) + 1) 206 | 207 | gains = gains / float(np.sum(y_true)) 208 | percentages = percentages / float(len(y_true)) 209 | 210 | gains = np.insert(gains, 0, [0]) 211 | percentages = np.insert(percentages, 0, [0]) 212 | 213 | return percentages, gains 214 | -------------------------------------------------------------------------------- /scikitplot/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reiinakano/scikit-plot/2dd3e6a76df77edcbd724c4db25575f70abb57cb/scikitplot/tests/__init__.py -------------------------------------------------------------------------------- /scikitplot/tests/test_classifiers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import unittest 3 | import scikitplot 4 | import warnings 5 | from sklearn.datasets import load_iris as load_data 6 | from sklearn.datasets import load_breast_cancer 7 | from sklearn.linear_model import LogisticRegression 8 | from sklearn.ensemble import RandomForestClassifier 9 | from sklearn.exceptions import NotFittedError 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import scikitplot.plotters as skplt 13 | 14 | 15 | def convert_labels_into_string(y_true): 16 | return ["A" if x==0 else x for x in y_true] 17 | 18 | 19 | class TestClassifierFactory(unittest.TestCase): 20 | 21 | def setUp(self): 22 | class Classifier: 23 | def __init__(self): 24 | pass 25 | 26 | def fit(self): 27 | pass 28 | 29 | def predict(self): 30 | pass 31 | 32 | def score(self): 33 | pass 34 | 35 | def predict_proba(self): 36 | pass 37 | 38 | class PartialClassifier: 39 | def __init__(self): 40 | pass 41 | 42 | def fit(self): 43 | pass 44 | 45 | def predict(self): 46 | pass 47 | 48 | def score(self): 49 | pass 50 | 51 | class NotClassifier: 52 | def __init__(self): 53 | pass 54 | 55 | self.Classifier = Classifier 56 | self.PartialClassifier = PartialClassifier 57 | self.NotClassifier = NotClassifier 58 | 59 | def test_instance_validation(self): 60 | 61 | clf = self.Classifier() 62 | scikitplot.classifier_factory(clf) 63 | 64 | not_clf = self.NotClassifier() 65 | self.assertRaises(TypeError, scikitplot.classifier_factory, not_clf) 66 | 67 | partial_clf = self.PartialClassifier() 68 | with warnings.catch_warnings(record=True) as w: 69 | warnings.simplefilter('always') 70 | scikitplot.classifier_factory(partial_clf) 71 | assert len(w) == 2 72 | assert issubclass(w[-1].category, UserWarning) 73 | assert " not in clf. Some plots may not be possible to generate." in str(w[-1].message) 74 | 75 | def test_method_insertion(self): 76 | 77 | clf = self.Classifier() 78 | scikitplot.classifier_factory(clf) 79 | assert hasattr(clf, 'plot_learning_curve') 80 | assert hasattr(clf, 'plot_confusion_matrix') 81 | assert hasattr(clf, 'plot_roc_curve') 82 | assert hasattr(clf, 'plot_ks_statistic') 83 | assert hasattr(clf, 'plot_precision_recall_curve') 84 | assert hasattr(clf, 'plot_feature_importances') 85 | 86 | with warnings.catch_warnings(record=True) as w: 87 | warnings.simplefilter('always') 88 | scikitplot.classifier_factory(clf) 89 | 90 | assert len(w) == 7 91 | for warning in w[1:]: 92 | assert issubclass(warning.category, UserWarning) 93 | assert ' method already in clf. ' \ 94 | 'Overriding anyway. This may ' \ 95 | 'result in unintended behavior.' in str(warning.message) 96 | 97 | 98 | class TestPlotLearningCurve(unittest.TestCase): 99 | 100 | def setUp(self): 101 | np.random.seed(0) 102 | self.X, self.y = load_data(return_X_y=True) 103 | p = np.random.permutation(len(self.X)) 104 | self.X, self.y = self.X[p], self.y[p] 105 | 106 | def tearDown(self): 107 | plt.close("all") 108 | 109 | def test_string_classes(self): 110 | np.random.seed(0) 111 | clf = LogisticRegression() 112 | scikitplot.classifier_factory(clf) 113 | ax = clf.plot_learning_curve(self.X, convert_labels_into_string(self.y)) 114 | 115 | def test_cv(self): 116 | np.random.seed(0) 117 | clf = LogisticRegression() 118 | scikitplot.classifier_factory(clf) 119 | ax = clf.plot_learning_curve(self.X, self.y) 120 | ax = clf.plot_learning_curve(self.X, self.y, cv=5) 121 | 122 | def test_train_sizes(self): 123 | np.random.seed(0) 124 | clf = LogisticRegression() 125 | scikitplot.classifier_factory(clf) 126 | ax = clf.plot_learning_curve(self.X, self.y, train_sizes=np.linspace(0.1, 1.0, 8)) 127 | 128 | def test_n_jobs(self): 129 | np.random.seed(0) 130 | clf = LogisticRegression() 131 | scikitplot.classifier_factory(clf) 132 | ax = clf.plot_learning_curve(self.X, self.y, n_jobs=-1) 133 | 134 | def test_ax(self): 135 | np.random.seed(0) 136 | clf = LogisticRegression() 137 | scikitplot.classifier_factory(clf) 138 | fig, ax = plt.subplots(1, 1) 139 | out_ax = clf.plot_learning_curve(self.X, self.y) 140 | assert ax is not out_ax 141 | out_ax = clf.plot_learning_curve(self.X, self.y, ax=ax) 142 | assert ax is out_ax 143 | 144 | 145 | class TestPlotConfusionMatrix(unittest.TestCase): 146 | def setUp(self): 147 | np.random.seed(0) 148 | self.X, self.y = load_data(return_X_y=True) 149 | p = np.random.permutation(len(self.X)) 150 | self.X, self.y = self.X[p], self.y[p] 151 | 152 | def tearDown(self): 153 | plt.close("all") 154 | 155 | def test_string_classes(self): 156 | np.random.seed(0) 157 | clf = LogisticRegression() 158 | scikitplot.classifier_factory(clf) 159 | ax = clf.plot_confusion_matrix(self.X, convert_labels_into_string(self.y)) 160 | 161 | def test_cv(self): 162 | np.random.seed(0) 163 | clf = LogisticRegression() 164 | scikitplot.classifier_factory(clf) 165 | ax = clf.plot_confusion_matrix(self.X, self.y) 166 | ax = clf.plot_confusion_matrix(self.X, self.y, cv=5) 167 | 168 | def test_normalize(self): 169 | np.random.seed(0) 170 | clf = LogisticRegression() 171 | scikitplot.classifier_factory(clf) 172 | ax = clf.plot_confusion_matrix(self.X, self.y, normalize=True) 173 | ax = clf.plot_confusion_matrix(self.X, self.y, normalize=False) 174 | 175 | def test_labels(self): 176 | np.random.seed(0) 177 | clf = LogisticRegression() 178 | scikitplot.classifier_factory(clf) 179 | ax = clf.plot_confusion_matrix(self.X, self.y, labels=[0, 1, 2]) 180 | 181 | def test_true_pred_labels(self): 182 | np.random.seed(0) 183 | clf = LogisticRegression() 184 | scikitplot.classifier_factory(clf) 185 | 186 | true_labels = [0, 1] 187 | pred_labels = [0, 2] 188 | 189 | ax = clf.plot_confusion_matrix(self.X, self.y, true_labels=true_labels, 190 | pred_labels=pred_labels) 191 | 192 | def test_cmap(self): 193 | np.random.seed(0) 194 | clf = LogisticRegression() 195 | scikitplot.classifier_factory(clf) 196 | ax = clf.plot_confusion_matrix(self.X, self.y, cmap='nipy_spectral') 197 | ax = clf.plot_confusion_matrix(self.X, self.y, cmap=plt.cm.nipy_spectral) 198 | 199 | def test_do_cv(self): 200 | np.random.seed(0) 201 | clf = LogisticRegression() 202 | scikitplot.classifier_factory(clf) 203 | ax = clf.plot_confusion_matrix(self.X, self.y) 204 | self.assertRaises(NotFittedError, clf.plot_confusion_matrix, self.X, self.y, do_cv=False) 205 | 206 | def test_shuffle(self): 207 | np.random.seed(0) 208 | clf = LogisticRegression() 209 | scikitplot.classifier_factory(clf) 210 | ax = clf.plot_confusion_matrix(self.X, self.y, shuffle=True) 211 | ax = clf.plot_confusion_matrix(self.X, self.y, shuffle=False) 212 | 213 | def test_ax(self): 214 | np.random.seed(0) 215 | clf = LogisticRegression() 216 | scikitplot.classifier_factory(clf) 217 | fig, ax = plt.subplots(1, 1) 218 | out_ax = clf.plot_confusion_matrix(self.X, self.y) 219 | assert ax is not out_ax 220 | out_ax = clf.plot_confusion_matrix(self.X, self.y, ax=ax) 221 | assert ax is out_ax 222 | 223 | def test_array_like(self): 224 | ax = skplt.plot_confusion_matrix([0, 1], [1, 0]) 225 | 226 | 227 | class TestPlotROCCurve(unittest.TestCase): 228 | def setUp(self): 229 | np.random.seed(0) 230 | self.X, self.y = load_data(return_X_y=True) 231 | p = np.random.permutation(len(self.X)) 232 | self.X, self.y = self.X[p], self.y[p] 233 | 234 | def tearDown(self): 235 | plt.close("all") 236 | 237 | def test_string_classes(self): 238 | np.random.seed(0) 239 | clf = LogisticRegression() 240 | scikitplot.classifier_factory(clf) 241 | ax = clf.plot_roc_curve(self.X, convert_labels_into_string(self.y)) 242 | 243 | def test_predict_proba(self): 244 | np.random.seed(0) 245 | 246 | class DummyClassifier: 247 | def __init__(self): 248 | pass 249 | 250 | def fit(self): 251 | pass 252 | 253 | def predict(self): 254 | pass 255 | 256 | def score(self): 257 | pass 258 | 259 | clf = DummyClassifier() 260 | scikitplot.classifier_factory(clf) 261 | self.assertRaises(TypeError, clf.plot_roc_curve, self.X, self.y) 262 | 263 | def test_do_cv(self): 264 | np.random.seed(0) 265 | clf = LogisticRegression() 266 | scikitplot.classifier_factory(clf) 267 | ax = clf.plot_roc_curve(self.X, self.y) 268 | self.assertRaises(AttributeError, clf.plot_roc_curve, self.X, self.y, 269 | do_cv=False) 270 | 271 | def test_ax(self): 272 | np.random.seed(0) 273 | clf = LogisticRegression() 274 | scikitplot.classifier_factory(clf) 275 | fig, ax = plt.subplots(1, 1) 276 | out_ax = clf.plot_roc_curve(self.X, self.y) 277 | assert ax is not out_ax 278 | out_ax = clf.plot_roc_curve(self.X, self.y, ax=ax) 279 | assert ax is out_ax 280 | 281 | def test_cmap(self): 282 | np.random.seed(0) 283 | clf = LogisticRegression() 284 | scikitplot.classifier_factory(clf) 285 | ax = clf.plot_roc_curve(self.X, self.y, cmap='nipy_spectral') 286 | ax = clf.plot_roc_curve(self.X, self.y, cmap=plt.cm.nipy_spectral) 287 | 288 | def test_curve_diffs(self): 289 | np.random.seed(0) 290 | clf = LogisticRegression() 291 | scikitplot.classifier_factory(clf) 292 | ax_macro = clf.plot_roc_curve(self.X, self.y, curves='macro') 293 | ax_micro = clf.plot_roc_curve(self.X, self.y, curves='micro') 294 | ax_class = clf.plot_roc_curve(self.X, self.y, curves='each_class') 295 | self.assertNotEqual(ax_macro, ax_micro, ax_class) 296 | 297 | def test_invalid_curve_arg(self): 298 | np.random.seed(0) 299 | clf = LogisticRegression() 300 | scikitplot.classifier_factory(clf) 301 | self.assertRaises(ValueError, clf.plot_roc_curve, self.X, self.y, 302 | curves='zzz') 303 | 304 | def test_array_like(self): 305 | ax = skplt.plot_roc_curve([0, 1], [[0.8, 0.2], [0.2, 0.8]]) 306 | 307 | 308 | class TestPlotKSStatistic(unittest.TestCase): 309 | def setUp(self): 310 | np.random.seed(0) 311 | self.X, self.y = load_breast_cancer(return_X_y=True) 312 | p = np.random.permutation(len(self.X)) 313 | self.X, self.y = self.X[p], self.y[p] 314 | 315 | def tearDown(self): 316 | plt.close("all") 317 | 318 | def test_string_classes(self): 319 | np.random.seed(0) 320 | clf = LogisticRegression() 321 | scikitplot.classifier_factory(clf) 322 | ax = clf.plot_ks_statistic(self.X, convert_labels_into_string(self.y)) 323 | 324 | def test_predict_proba(self): 325 | np.random.seed(0) 326 | 327 | class DummyClassifier: 328 | def __init__(self): 329 | pass 330 | 331 | def fit(self): 332 | pass 333 | 334 | def predict(self): 335 | pass 336 | 337 | def score(self): 338 | pass 339 | 340 | clf = DummyClassifier() 341 | scikitplot.classifier_factory(clf) 342 | self.assertRaises(TypeError, clf.plot_ks_statistic, self.X, self.y) 343 | 344 | def test_two_classes(self): 345 | clf = LogisticRegression() 346 | scikitplot.classifier_factory(clf) 347 | X, y = load_data(return_X_y=True) 348 | self.assertRaises(ValueError, clf.plot_ks_statistic, X, y) 349 | 350 | def test_do_cv(self): 351 | np.random.seed(0) 352 | clf = LogisticRegression() 353 | scikitplot.classifier_factory(clf) 354 | ax = clf.plot_ks_statistic(self.X, self.y) 355 | self.assertRaises(AttributeError, clf.plot_ks_statistic, self.X, self.y, 356 | do_cv=False) 357 | 358 | def test_ax(self): 359 | np.random.seed(0) 360 | clf = LogisticRegression() 361 | scikitplot.classifier_factory(clf) 362 | fig, ax = plt.subplots(1, 1) 363 | out_ax = clf.plot_ks_statistic(self.X, self.y) 364 | assert ax is not out_ax 365 | out_ax = clf.plot_ks_statistic(self.X, self.y, ax=ax) 366 | assert ax is out_ax 367 | 368 | def test_array_like(self): 369 | ax = skplt.plot_ks_statistic([0, 1], [[0.8, 0.2], [0.2, 0.8]]) 370 | 371 | 372 | class TestPlotPrecisionRecall(unittest.TestCase): 373 | def setUp(self): 374 | np.random.seed(0) 375 | self.X, self.y = load_data(return_X_y=True) 376 | p = np.random.permutation(len(self.X)) 377 | self.X, self.y = self.X[p], self.y[p] 378 | 379 | def tearDown(self): 380 | plt.close("all") 381 | 382 | def test_string_classes(self): 383 | np.random.seed(0) 384 | clf = LogisticRegression() 385 | scikitplot.classifier_factory(clf) 386 | ax = clf.plot_precision_recall_curve(self.X, convert_labels_into_string(self.y)) 387 | 388 | def test_predict_proba(self): 389 | np.random.seed(0) 390 | 391 | class DummyClassifier: 392 | def __init__(self): 393 | pass 394 | 395 | def fit(self): 396 | pass 397 | 398 | def predict(self): 399 | pass 400 | 401 | def score(self): 402 | pass 403 | 404 | clf = DummyClassifier() 405 | scikitplot.classifier_factory(clf) 406 | self.assertRaises(TypeError, clf.plot_precision_recall_curve, self.X, self.y) 407 | 408 | def test_do_cv(self): 409 | np.random.seed(0) 410 | clf = LogisticRegression() 411 | scikitplot.classifier_factory(clf) 412 | ax = clf.plot_precision_recall_curve(self.X, self.y) 413 | self.assertRaises(AttributeError, clf.plot_precision_recall_curve, self.X, self.y, 414 | do_cv=False) 415 | 416 | def test_ax(self): 417 | np.random.seed(0) 418 | clf = LogisticRegression() 419 | scikitplot.classifier_factory(clf) 420 | fig, ax = plt.subplots(1, 1) 421 | out_ax = clf.plot_precision_recall_curve(self.X, self.y) 422 | assert ax is not out_ax 423 | out_ax = clf.plot_precision_recall_curve(self.X, self.y, ax=ax) 424 | assert ax is out_ax 425 | 426 | def test_curve_diffs(self): 427 | np.random.seed(0) 428 | clf = LogisticRegression() 429 | scikitplot.classifier_factory(clf) 430 | ax_micro = clf.plot_precision_recall_curve(self.X, self.y, curves='micro') 431 | ax_class = clf.plot_precision_recall_curve(self.X, self.y, curves='each_class') 432 | self.assertNotEqual(ax_micro, ax_class) 433 | 434 | def test_cmap(self): 435 | np.random.seed(0) 436 | clf = LogisticRegression() 437 | scikitplot.classifier_factory(clf) 438 | ax = clf.plot_precision_recall_curve(self.X, self.y, cmap='nipy_spectral') 439 | ax = clf.plot_precision_recall_curve(self.X, self.y, cmap=plt.cm.nipy_spectral) 440 | 441 | def test_invalid_curve_arg(self): 442 | np.random.seed(0) 443 | clf = LogisticRegression() 444 | scikitplot.classifier_factory(clf) 445 | self.assertRaises(ValueError, clf.plot_precision_recall_curve, self.X, self.y, 446 | curves='zzz') 447 | 448 | def test_array_like(self): 449 | ax = skplt.plot_precision_recall_curve([0, 1], [[0.8, 0.2], [0.2, 0.8]]) 450 | 451 | 452 | class TestFeatureImportances(unittest.TestCase): 453 | def setUp(self): 454 | np.random.seed(0) 455 | self.X, self.y = load_data(return_X_y=True) 456 | p = np.random.permutation(len(self.X)) 457 | self.X, self.y = self.X[p], self.y[p] 458 | 459 | def tearDown(self): 460 | plt.close("all") 461 | 462 | def test_string_classes(self): 463 | np.random.seed(0) 464 | clf = RandomForestClassifier() 465 | scikitplot.classifier_factory(clf) 466 | clf.fit(self.X, convert_labels_into_string(self.y)) 467 | ax = clf.plot_feature_importances() 468 | 469 | def test_feature_importances_in_clf(self): 470 | np.random.seed(0) 471 | clf = LogisticRegression() 472 | scikitplot.classifier_factory(clf) 473 | clf.fit(self.X, self.y) 474 | self.assertRaises(TypeError, clf.plot_feature_importances) 475 | 476 | def test_feature_names(self): 477 | np.random.seed(0) 478 | clf = RandomForestClassifier() 479 | scikitplot.classifier_factory(clf) 480 | clf.fit(self.X, self.y) 481 | ax = clf.plot_feature_importances(feature_names=["a", "b", "c", "d"]) 482 | 483 | def test_max_num_features(self): 484 | np.random.seed(0) 485 | clf = RandomForestClassifier() 486 | scikitplot.classifier_factory(clf) 487 | clf.fit(self.X, self.y) 488 | ax = clf.plot_feature_importances(max_num_features=2) 489 | ax = clf.plot_feature_importances(max_num_features=4) 490 | ax = clf.plot_feature_importances(max_num_features=6) 491 | 492 | def test_order(self): 493 | np.random.seed(0) 494 | clf = RandomForestClassifier() 495 | scikitplot.classifier_factory(clf) 496 | clf.fit(self.X, self.y) 497 | ax = clf.plot_feature_importances(order='ascending') 498 | ax = clf.plot_feature_importances(order='descending') 499 | ax = clf.plot_feature_importances(order=None) 500 | 501 | def test_ax(self): 502 | np.random.seed(0) 503 | clf = RandomForestClassifier() 504 | scikitplot.classifier_factory(clf) 505 | clf.fit(self.X, self.y) 506 | fig, ax = plt.subplots(1, 1) 507 | out_ax = clf.plot_feature_importances() 508 | assert ax is not out_ax 509 | out_ax = clf.plot_feature_importances(ax=ax) 510 | assert ax is out_ax 511 | 512 | 513 | if __name__ == '__main__': 514 | unittest.main() 515 | -------------------------------------------------------------------------------- /scikitplot/tests/test_cluster.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import unittest 3 | import numpy as np 4 | from sklearn.datasets import load_iris as load_data 5 | from sklearn.cluster import KMeans 6 | import matplotlib.pyplot as plt 7 | 8 | from scikitplot.cluster import plot_elbow_curve 9 | 10 | 11 | class TestPlotElbow(unittest.TestCase): 12 | def setUp(self): 13 | np.random.seed(0) 14 | self.X, self.y = load_data(return_X_y=True) 15 | p = np.random.permutation(len(self.X)) 16 | self.X, self.y = self.X[p], self.y[p] 17 | 18 | def tearDown(self): 19 | plt.close("all") 20 | 21 | def test_n_clusters_in_clf(self): 22 | np.random.seed(0) 23 | 24 | class DummyClusterer: 25 | def __init__(self): 26 | pass 27 | 28 | def fit(self): 29 | pass 30 | 31 | def fit_predict(self): 32 | pass 33 | 34 | clf = DummyClusterer() 35 | self.assertRaises(TypeError, plot_elbow_curve, clf, self.X) 36 | 37 | def test_cluster_ranges(self): 38 | np.random.seed(0) 39 | clf = KMeans() 40 | plot_elbow_curve(clf, self.X, cluster_ranges=range(1, 10)) 41 | 42 | def test_ax(self): 43 | np.random.seed(0) 44 | clf = KMeans() 45 | fig, ax = plt.subplots(1, 1) 46 | out_ax = plot_elbow_curve(clf, self.X) 47 | assert ax is not out_ax 48 | out_ax = plot_elbow_curve(clf, self.X, ax=ax) 49 | assert ax is out_ax 50 | 51 | def test_n_jobs(self): 52 | np.random.seed(0) 53 | clf = KMeans() 54 | plot_elbow_curve(clf, self.X, n_jobs=2) 55 | 56 | def test_show_cluster_time(self): 57 | np.random.seed(0) 58 | clf = KMeans() 59 | plot_elbow_curve(clf, self.X, show_cluster_time=False) 60 | -------------------------------------------------------------------------------- /scikitplot/tests/test_clustering.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import unittest 3 | import scikitplot 4 | import warnings 5 | import numpy as np 6 | from sklearn.datasets import load_iris as load_data 7 | from sklearn.cluster import KMeans 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | class TestClassifierFactory(unittest.TestCase): 12 | 13 | def setUp(self): 14 | class Clusterer: 15 | def __init__(self): 16 | pass 17 | 18 | def fit(self): 19 | pass 20 | 21 | def fit_predict(self): 22 | pass 23 | 24 | class NotClusterer: 25 | def __init__(self): 26 | pass 27 | 28 | self.Clusterer = Clusterer 29 | self.NotClusterer = NotClusterer 30 | 31 | def test_instance_validation(self): 32 | 33 | clf = self.Clusterer() 34 | scikitplot.clustering_factory(clf) 35 | 36 | not_clf = self.NotClusterer() 37 | self.assertRaises(TypeError, scikitplot.clustering_factory, not_clf) 38 | 39 | def test_method_insertion(self): 40 | 41 | clf = self.Clusterer() 42 | scikitplot.clustering_factory(clf) 43 | assert hasattr(clf, 'plot_silhouette') 44 | assert hasattr(clf, 'plot_elbow_curve') 45 | 46 | with warnings.catch_warnings(record=True) as w: 47 | warnings.simplefilter('always') 48 | scikitplot.clustering_factory(clf) 49 | 50 | assert len(w) >= 2 51 | for warning in w[1:]: 52 | assert issubclass(warning.category, UserWarning) 53 | assert ' method already in clf. ' \ 54 | 'Overriding anyway. This may ' \ 55 | 'result in unintended behavior.' in str(warning.message) 56 | 57 | 58 | class TestPlotSilhouette(unittest.TestCase): 59 | def setUp(self): 60 | np.random.seed(0) 61 | self.X, self.y = load_data(return_X_y=True) 62 | p = np.random.permutation(len(self.X)) 63 | self.X, self.y = self.X[p], self.y[p] 64 | 65 | def tearDown(self): 66 | plt.close("all") 67 | 68 | def test_copy(self): 69 | np.random.seed(0) 70 | clf = KMeans() 71 | scikitplot.clustering_factory(clf) 72 | ax = clf.plot_silhouette(self.X) 73 | assert not hasattr(clf, "cluster_centers_") 74 | ax = clf.plot_silhouette(self.X, copy=False) 75 | assert hasattr(clf, "cluster_centers_") 76 | 77 | def test_cmap(self): 78 | np.random.seed(0) 79 | clf = KMeans() 80 | scikitplot.clustering_factory(clf) 81 | ax = clf.plot_silhouette(self.X, cmap='Spectral') 82 | ax = clf.plot_silhouette(self.X, cmap=plt.cm.Spectral) 83 | 84 | def test_ax(self): 85 | np.random.seed(0) 86 | clf = KMeans() 87 | scikitplot.clustering_factory(clf) 88 | fig, ax = plt.subplots(1, 1) 89 | out_ax = clf.plot_silhouette(self.X) 90 | assert ax is not out_ax 91 | out_ax = clf.plot_silhouette(self.X, ax=ax) 92 | assert ax is out_ax 93 | 94 | 95 | class TestPlotElbow(unittest.TestCase): 96 | def setUp(self): 97 | np.random.seed(0) 98 | self.X, self.y = load_data(return_X_y=True) 99 | p = np.random.permutation(len(self.X)) 100 | self.X, self.y = self.X[p], self.y[p] 101 | 102 | def tearDown(self): 103 | plt.close("all") 104 | 105 | def test_n_clusters_in_clf(self): 106 | np.random.seed(0) 107 | 108 | class DummyClusterer: 109 | def __init__(self): 110 | pass 111 | 112 | def fit(self): 113 | pass 114 | 115 | def fit_predict(self): 116 | pass 117 | 118 | clf = DummyClusterer() 119 | scikitplot.clustering_factory(clf) 120 | self.assertRaises(TypeError, clf.plot_elbow_curve, self.X) 121 | 122 | def test_cluster_ranges(self): 123 | np.random.seed(0) 124 | clf = KMeans() 125 | scikitplot.clustering_factory(clf) 126 | ax = clf.plot_elbow_curve(self.X, cluster_ranges=range(1, 10)) 127 | ax = clf.plot_elbow_curve(self.X) 128 | 129 | def test_ax(self): 130 | np.random.seed(0) 131 | clf = KMeans() 132 | scikitplot.clustering_factory(clf) 133 | fig, ax = plt.subplots(1, 1) 134 | out_ax = clf.plot_elbow_curve(self.X) 135 | assert ax is not out_ax 136 | out_ax = clf.plot_elbow_curve(self.X, ax=ax) 137 | assert ax is out_ax 138 | 139 | if __name__ == '__main__': 140 | unittest.main() 141 | -------------------------------------------------------------------------------- /scikitplot/tests/test_decomposition.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import unittest 3 | 4 | from sklearn.datasets import load_iris as load_data 5 | from sklearn.decomposition import PCA 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from scikitplot.decomposition import plot_pca_component_variance 11 | from scikitplot.decomposition import plot_pca_2d_projection 12 | 13 | 14 | class TestPlotPCAComponentVariance(unittest.TestCase): 15 | 16 | def setUp(self): 17 | np.random.seed(0) 18 | self.X, self.y = load_data(return_X_y=True) 19 | p = np.random.permutation(len(self.X)) 20 | self.X, self.y = self.X[p], self.y[p] 21 | 22 | def tearDown(self): 23 | plt.close("all") 24 | 25 | def test_target_explained_variance(self): 26 | np.random.seed(0) 27 | clf = PCA() 28 | clf.fit(self.X) 29 | plot_pca_component_variance(clf, target_explained_variance=0) 30 | plot_pca_component_variance(clf, target_explained_variance=0.5) 31 | plot_pca_component_variance(clf, target_explained_variance=1) 32 | plot_pca_component_variance(clf, target_explained_variance=1.5) 33 | 34 | def test_fitted(self): 35 | np.random.seed(0) 36 | clf = PCA() 37 | self.assertRaises(TypeError, plot_pca_component_variance, clf) 38 | 39 | def test_ax(self): 40 | np.random.seed(0) 41 | clf = PCA() 42 | clf.fit(self.X) 43 | fig, ax = plt.subplots(1, 1) 44 | out_ax = plot_pca_component_variance(clf) 45 | assert ax is not out_ax 46 | out_ax = plot_pca_component_variance(clf, ax=ax) 47 | assert ax is out_ax 48 | 49 | 50 | class TestPlotPCA2DProjection(unittest.TestCase): 51 | 52 | def setUp(self): 53 | np.random.seed(0) 54 | self.X, self.y = load_data(return_X_y=True) 55 | p = np.random.permutation(len(self.X)) 56 | self.X, self.y = self.X[p], self.y[p] 57 | 58 | def tearDown(self): 59 | plt.close("all") 60 | 61 | def test_ax(self): 62 | np.random.seed(0) 63 | clf = PCA() 64 | clf.fit(self.X) 65 | fig, ax = plt.subplots(1, 1) 66 | out_ax = plot_pca_2d_projection(clf, self.X, self.y) 67 | assert ax is not out_ax 68 | out_ax = plot_pca_2d_projection(clf, self.X, self.y, ax=ax) 69 | assert ax is out_ax 70 | 71 | def test_cmap(self): 72 | np.random.seed(0) 73 | clf = PCA() 74 | clf.fit(self.X) 75 | plot_pca_2d_projection(clf, self.X, self.y, cmap='Spectral') 76 | plot_pca_2d_projection(clf, self.X, self.y, cmap=plt.cm.Spectral) 77 | 78 | def test_biplot(self): 79 | np.random.seed(0) 80 | clf = PCA() 81 | clf.fit(self.X) 82 | ax = plot_pca_2d_projection(clf, self.X, self.y, biplot=True, 83 | feature_labels=load_data().feature_names) 84 | -------------------------------------------------------------------------------- /scikitplot/tests/test_estimators.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import unittest 3 | from sklearn.datasets import load_iris as load_data 4 | from sklearn.linear_model import LogisticRegression 5 | from sklearn.ensemble import RandomForestClassifier 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | from scikitplot.estimators import plot_feature_importances 10 | from scikitplot.estimators import plot_learning_curve 11 | 12 | 13 | def convert_labels_into_string(y_true): 14 | return ["A" if x == 0 else x for x in y_true] 15 | 16 | 17 | class TestFeatureImportances(unittest.TestCase): 18 | def setUp(self): 19 | np.random.seed(0) 20 | self.X, self.y = load_data(return_X_y=True) 21 | p = np.random.permutation(len(self.X)) 22 | self.X, self.y = self.X[p], self.y[p] 23 | 24 | def tearDown(self): 25 | plt.close("all") 26 | 27 | def test_string_classes(self): 28 | np.random.seed(0) 29 | clf = RandomForestClassifier() 30 | clf.fit(self.X, convert_labels_into_string(self.y)) 31 | plot_feature_importances(clf) 32 | 33 | def test_feature_importances_in_clf(self): 34 | np.random.seed(0) 35 | clf = LogisticRegression() 36 | clf.fit(self.X, self.y) 37 | self.assertRaises(TypeError, plot_feature_importances, clf) 38 | 39 | def test_feature_names(self): 40 | np.random.seed(0) 41 | clf = RandomForestClassifier() 42 | clf.fit(self.X, self.y) 43 | plot_feature_importances(clf, feature_names=["a", "b", "c", "d"]) 44 | 45 | def test_max_num_features(self): 46 | np.random.seed(0) 47 | clf = RandomForestClassifier() 48 | clf.fit(self.X, self.y) 49 | plot_feature_importances(clf, max_num_features=2) 50 | plot_feature_importances(clf, max_num_features=4) 51 | plot_feature_importances(clf, max_num_features=6) 52 | 53 | def test_order(self): 54 | np.random.seed(0) 55 | clf = RandomForestClassifier() 56 | clf.fit(self.X, self.y) 57 | plot_feature_importances(clf, order='ascending') 58 | plot_feature_importances(clf, order='descending') 59 | plot_feature_importances(clf, order=None) 60 | 61 | def test_ax(self): 62 | np.random.seed(0) 63 | clf = RandomForestClassifier() 64 | clf.fit(self.X, self.y) 65 | fig, ax = plt.subplots(1, 1) 66 | out_ax = plot_feature_importances(clf) 67 | assert ax is not out_ax 68 | out_ax = plot_feature_importances(clf, ax=ax) 69 | assert ax is out_ax 70 | 71 | 72 | class TestPlotLearningCurve(unittest.TestCase): 73 | 74 | def setUp(self): 75 | np.random.seed(0) 76 | self.X, self.y = load_data(return_X_y=True) 77 | p = np.random.permutation(len(self.X)) 78 | self.X, self.y = self.X[p], self.y[p] 79 | 80 | def tearDown(self): 81 | plt.close("all") 82 | 83 | def test_string_classes(self): 84 | np.random.seed(0) 85 | clf = LogisticRegression() 86 | plot_learning_curve(clf, self.X, convert_labels_into_string(self.y)) 87 | 88 | def test_cv(self): 89 | np.random.seed(0) 90 | clf = LogisticRegression() 91 | plot_learning_curve(clf, self.X, self.y) 92 | plot_learning_curve(clf, self.X, self.y, cv=5) 93 | 94 | def test_train_sizes(self): 95 | np.random.seed(0) 96 | clf = LogisticRegression() 97 | plot_learning_curve(clf, self.X, self.y, 98 | train_sizes=np.linspace(0.1, 1.0, 8)) 99 | 100 | def test_n_jobs(self): 101 | np.random.seed(0) 102 | clf = LogisticRegression() 103 | plot_learning_curve(clf, self.X, self.y, n_jobs=-1) 104 | 105 | def test_random_state_and_shuffle(self): 106 | np.random.seed(0) 107 | clf = LogisticRegression() 108 | plot_learning_curve(clf, self.X, self.y, random_state=1, shuffle=True) 109 | 110 | def test_ax(self): 111 | np.random.seed(0) 112 | clf = LogisticRegression() 113 | fig, ax = plt.subplots(1, 1) 114 | out_ax = plot_learning_curve(clf, self.X, self.y) 115 | assert ax is not out_ax 116 | out_ax = plot_learning_curve(clf, self.X, self.y, ax=ax) 117 | assert ax is out_ax 118 | -------------------------------------------------------------------------------- /scikitplot/tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import unittest 3 | 4 | from sklearn.datasets import load_iris as load_data 5 | from sklearn.datasets import load_breast_cancer 6 | from sklearn.linear_model import LogisticRegression 7 | from sklearn.svm import LinearSVC 8 | from sklearn.ensemble import RandomForestClassifier 9 | from sklearn.cluster import KMeans 10 | 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | 14 | from scikitplot.metrics import plot_confusion_matrix 15 | from scikitplot.metrics import plot_roc_curve 16 | from scikitplot.metrics import plot_roc 17 | from scikitplot.metrics import plot_ks_statistic 18 | from scikitplot.metrics import plot_precision_recall_curve 19 | from scikitplot.metrics import plot_precision_recall 20 | from scikitplot.metrics import plot_silhouette 21 | from scikitplot.metrics import plot_calibration_curve 22 | from scikitplot.metrics import plot_cumulative_gain 23 | from scikitplot.metrics import plot_lift_curve 24 | 25 | 26 | def convert_labels_into_string(y_true): 27 | return ["A" if x == 0 else x for x in y_true] 28 | 29 | 30 | class TestPlotConfusionMatrix(unittest.TestCase): 31 | def setUp(self): 32 | np.random.seed(0) 33 | self.X, self.y = load_data(return_X_y=True) 34 | p = np.random.permutation(len(self.X)) 35 | self.X, self.y = self.X[p], self.y[p] 36 | 37 | def tearDown(self): 38 | plt.close("all") 39 | 40 | def test_string_classes(self): 41 | np.random.seed(0) 42 | clf = LogisticRegression() 43 | clf.fit(self.X, convert_labels_into_string(self.y)) 44 | preds = clf.predict(self.X) 45 | plot_confusion_matrix(convert_labels_into_string(self.y), preds) 46 | 47 | def test_normalize(self): 48 | np.random.seed(0) 49 | clf = LogisticRegression() 50 | clf.fit(self.X, self.y) 51 | preds = clf.predict(self.X) 52 | plot_confusion_matrix(self.y, preds, normalize=True) 53 | plot_confusion_matrix(self.y, preds, normalize=False) 54 | 55 | def test_labels(self): 56 | np.random.seed(0) 57 | clf = LogisticRegression() 58 | clf.fit(self.X, self.y) 59 | preds = clf.predict(self.X) 60 | plot_confusion_matrix(self.y, preds, labels=[0, 1, 2]) 61 | 62 | def test_hide_counts(self): 63 | np.random.seed(0) 64 | clf = LogisticRegression() 65 | clf.fit(self.X, self.y) 66 | preds = clf.predict(self.X) 67 | plot_confusion_matrix(self.y, preds, hide_counts=True) 68 | 69 | def test_true_pred_labels(self): 70 | np.random.seed(0) 71 | clf = LogisticRegression() 72 | clf.fit(self.X, self.y) 73 | preds = clf.predict(self.X) 74 | 75 | true_labels = [0, 1] 76 | pred_labels = [0, 2] 77 | 78 | plot_confusion_matrix(self.y, preds, 79 | true_labels=true_labels, 80 | pred_labels=pred_labels) 81 | 82 | def test_cmap(self): 83 | np.random.seed(0) 84 | clf = LogisticRegression() 85 | clf.fit(self.X, self.y) 86 | preds = clf.predict(self.X) 87 | plot_confusion_matrix(self.y, preds, cmap='nipy_spectral') 88 | plot_confusion_matrix(self.y, preds, cmap=plt.cm.nipy_spectral) 89 | 90 | def test_ax(self): 91 | np.random.seed(0) 92 | clf = LogisticRegression() 93 | clf.fit(self.X, self.y) 94 | preds = clf.predict(self.X) 95 | fig, ax = plt.subplots(1, 1) 96 | out_ax = plot_confusion_matrix(self.y, preds) 97 | assert ax is not out_ax 98 | out_ax = plot_confusion_matrix(self.y, preds, ax=ax) 99 | assert ax is out_ax 100 | 101 | def test_array_like(self): 102 | plot_confusion_matrix([0, 'a'], ['a', 0]) 103 | plot_confusion_matrix([0, 1], [1, 0]) 104 | plot_confusion_matrix(['b', 'a'], ['a', 'b']) 105 | 106 | 107 | class TestPlotROCCurve(unittest.TestCase): 108 | def setUp(self): 109 | np.random.seed(0) 110 | self.X, self.y = load_data(return_X_y=True) 111 | p = np.random.permutation(len(self.X)) 112 | self.X, self.y = self.X[p], self.y[p] 113 | 114 | def tearDown(self): 115 | plt.close("all") 116 | 117 | def test_string_classes(self): 118 | np.random.seed(0) 119 | clf = LogisticRegression() 120 | clf.fit(self.X, convert_labels_into_string(self.y)) 121 | probas = clf.predict_proba(self.X) 122 | plot_roc_curve(convert_labels_into_string(self.y), probas) 123 | 124 | def test_ax(self): 125 | np.random.seed(0) 126 | clf = LogisticRegression() 127 | clf.fit(self.X, self.y) 128 | probas = clf.predict_proba(self.X) 129 | fig, ax = plt.subplots(1, 1) 130 | out_ax = plot_roc_curve(self.y, probas) 131 | assert ax is not out_ax 132 | out_ax = plot_roc_curve(self.y, probas, ax=ax) 133 | assert ax is out_ax 134 | 135 | def test_cmap(self): 136 | np.random.seed(0) 137 | clf = LogisticRegression() 138 | clf.fit(self.X, self.y) 139 | probas = clf.predict_proba(self.X) 140 | plot_roc_curve(self.y, probas, cmap='nipy_spectral') 141 | plot_roc_curve(self.y, probas, cmap=plt.cm.nipy_spectral) 142 | 143 | def test_curve_diffs(self): 144 | np.random.seed(0) 145 | clf = LogisticRegression() 146 | clf.fit(self.X, self.y) 147 | probas = clf.predict_proba(self.X) 148 | ax_macro = plot_roc_curve(self.y, probas, curves='macro') 149 | ax_micro = plot_roc_curve(self.y, probas, curves='micro') 150 | ax_class = plot_roc_curve(self.y, probas, curves='each_class') 151 | self.assertNotEqual(ax_macro, ax_micro, ax_class) 152 | 153 | def test_invalid_curve_arg(self): 154 | np.random.seed(0) 155 | clf = LogisticRegression() 156 | clf.fit(self.X, self.y) 157 | probas = clf.predict_proba(self.X) 158 | self.assertRaises(ValueError, plot_roc_curve, self.y, probas, 159 | curves='zzz') 160 | 161 | def test_array_like(self): 162 | plot_roc_curve([0, 'a'], [[0.8, 0.2], [0.2, 0.8]]) 163 | plot_roc_curve([0, 1], [[0.8, 0.2], [0.2, 0.8]]) 164 | plot_roc_curve(['b', 'a'], [[0.8, 0.2], [0.2, 0.8]]) 165 | 166 | 167 | class TestPlotROC(unittest.TestCase): 168 | def setUp(self): 169 | np.random.seed(0) 170 | self.X, self.y = load_data(return_X_y=True) 171 | p = np.random.permutation(len(self.X)) 172 | self.X, self.y = self.X[p], self.y[p] 173 | 174 | def tearDown(self): 175 | plt.close("all") 176 | 177 | def test_string_classes(self): 178 | np.random.seed(0) 179 | clf = LogisticRegression() 180 | clf.fit(self.X, convert_labels_into_string(self.y)) 181 | probas = clf.predict_proba(self.X) 182 | plot_roc(convert_labels_into_string(self.y), probas) 183 | 184 | def test_ax(self): 185 | np.random.seed(0) 186 | clf = LogisticRegression() 187 | clf.fit(self.X, self.y) 188 | probas = clf.predict_proba(self.X) 189 | fig, ax = plt.subplots(1, 1) 190 | out_ax = plot_roc(self.y, probas) 191 | assert ax is not out_ax 192 | out_ax = plot_roc(self.y, probas, ax=ax) 193 | assert ax is out_ax 194 | 195 | def test_cmap(self): 196 | np.random.seed(0) 197 | clf = LogisticRegression() 198 | clf.fit(self.X, self.y) 199 | probas = clf.predict_proba(self.X) 200 | plot_roc(self.y, probas, cmap='nipy_spectral') 201 | plot_roc(self.y, probas, cmap=plt.cm.nipy_spectral) 202 | 203 | def test_plot_micro(self): 204 | np.random.seed(0) 205 | clf = LogisticRegression() 206 | clf.fit(self.X, self.y) 207 | probas = clf.predict_proba(self.X) 208 | plot_roc(self.y, probas, plot_micro=False) 209 | plot_roc(self.y, probas, plot_micro=True) 210 | 211 | def test_plot_macro(self): 212 | np.random.seed(0) 213 | clf = LogisticRegression() 214 | clf.fit(self.X, self.y) 215 | probas = clf.predict_proba(self.X) 216 | plot_roc(self.y, probas, plot_macro=False) 217 | plot_roc(self.y, probas, plot_macro=True) 218 | 219 | def test_classes_to_plot(self): 220 | np.random.seed(0) 221 | clf = LogisticRegression() 222 | clf.fit(self.X, self.y) 223 | probas = clf.predict_proba(self.X) 224 | plot_roc(self.y, probas, classes_to_plot=[0, 1]) 225 | plot_roc(self.y, probas, classes_to_plot=np.array([0, 1])) 226 | 227 | def test_array_like(self): 228 | plot_roc([0, 'a'], [[0.8, 0.2], [0.2, 0.8]]) 229 | plot_roc([0, 1], [[0.8, 0.2], [0.2, 0.8]]) 230 | plot_roc(['b', 'a'], [[0.8, 0.2], [0.2, 0.8]]) 231 | 232 | 233 | class TestPlotKSStatistic(unittest.TestCase): 234 | def setUp(self): 235 | np.random.seed(0) 236 | self.X, self.y = load_breast_cancer(return_X_y=True) 237 | p = np.random.permutation(len(self.X)) 238 | self.X, self.y = self.X[p], self.y[p] 239 | 240 | def tearDown(self): 241 | plt.close("all") 242 | 243 | def test_string_classes(self): 244 | np.random.seed(0) 245 | clf = LogisticRegression() 246 | clf.fit(self.X, convert_labels_into_string(self.y)) 247 | probas = clf.predict_proba(self.X) 248 | plot_ks_statistic(convert_labels_into_string(self.y), probas) 249 | 250 | def test_two_classes(self): 251 | np.random.seed(0) 252 | # Test this one on Iris (3 classes) 253 | X, y = load_data(return_X_y=True) 254 | clf = LogisticRegression() 255 | clf.fit(X, y) 256 | probas = clf.predict_proba(X) 257 | self.assertRaises(ValueError, plot_ks_statistic, y, probas) 258 | 259 | def test_ax(self): 260 | np.random.seed(0) 261 | clf = LogisticRegression() 262 | clf.fit(self.X, self.y) 263 | probas = clf.predict_proba(self.X) 264 | fig, ax = plt.subplots(1, 1) 265 | out_ax = plot_ks_statistic(self.y, probas) 266 | assert ax is not out_ax 267 | out_ax = plot_ks_statistic(self.y, probas, ax=ax) 268 | assert ax is out_ax 269 | 270 | def test_array_like(self): 271 | plot_ks_statistic([0, 1], [[0.8, 0.2], [0.2, 0.8]]) 272 | plot_ks_statistic([0, 'a'], [[0.8, 0.2], [0.2, 0.8]]) 273 | plot_ks_statistic(['b', 'a'], [[0.8, 0.2], [0.2, 0.8]]) 274 | 275 | 276 | class TestPlotPrecisionRecallCurve(unittest.TestCase): 277 | def setUp(self): 278 | np.random.seed(0) 279 | self.X, self.y = load_data(return_X_y=True) 280 | p = np.random.permutation(len(self.X)) 281 | self.X, self.y = self.X[p], self.y[p] 282 | 283 | def tearDown(self): 284 | plt.close("all") 285 | 286 | def test_string_classes(self): 287 | np.random.seed(0) 288 | clf = LogisticRegression() 289 | clf.fit(self.X, convert_labels_into_string(self.y)) 290 | probas = clf.predict_proba(self.X) 291 | plot_precision_recall_curve(convert_labels_into_string(self.y), probas) 292 | 293 | def test_ax(self): 294 | np.random.seed(0) 295 | clf = LogisticRegression() 296 | clf.fit(self.X, self.y) 297 | probas = clf.predict_proba(self.X) 298 | fig, ax = plt.subplots(1, 1) 299 | out_ax = plot_precision_recall_curve(self.y, probas) 300 | assert ax is not out_ax 301 | out_ax = plot_precision_recall_curve(self.y, probas, ax=ax) 302 | assert ax is out_ax 303 | 304 | def test_curve_diffs(self): 305 | np.random.seed(0) 306 | clf = LogisticRegression() 307 | clf.fit(self.X, self.y) 308 | probas = clf.predict_proba(self.X) 309 | ax_micro = plot_precision_recall_curve(self.y, probas, curves='micro') 310 | ax_class = plot_precision_recall_curve(self.y, probas, 311 | curves='each_class') 312 | self.assertNotEqual(ax_micro, ax_class) 313 | 314 | def test_cmap(self): 315 | np.random.seed(0) 316 | clf = LogisticRegression() 317 | clf.fit(self.X, self.y) 318 | probas = clf.predict_proba(self.X) 319 | plot_precision_recall_curve(self.y, probas, cmap='nipy_spectral') 320 | plot_precision_recall_curve(self.y, probas, cmap=plt.cm.nipy_spectral) 321 | 322 | def test_invalid_curve_arg(self): 323 | np.random.seed(0) 324 | clf = LogisticRegression() 325 | clf.fit(self.X, self.y) 326 | probas = clf.predict_proba(self.X) 327 | self.assertRaises(ValueError, plot_precision_recall_curve, self.y, 328 | probas, curves='zzz') 329 | 330 | def test_array_like(self): 331 | plot_precision_recall_curve([0, 1], [[0.8, 0.2], [0.2, 0.8]]) 332 | plot_precision_recall_curve([0, 'a'], [[0.8, 0.2], [0.2, 0.8]]) 333 | plot_precision_recall_curve(['b', 'a'], [[0.8, 0.2], [0.2, 0.8]]) 334 | 335 | 336 | class TestPlotPrecisionRecall(unittest.TestCase): 337 | def setUp(self): 338 | np.random.seed(0) 339 | self.X, self.y = load_data(return_X_y=True) 340 | p = np.random.permutation(len(self.X)) 341 | self.X, self.y = self.X[p], self.y[p] 342 | 343 | def tearDown(self): 344 | plt.close("all") 345 | 346 | def test_string_classes(self): 347 | np.random.seed(0) 348 | clf = LogisticRegression() 349 | clf.fit(self.X, convert_labels_into_string(self.y)) 350 | probas = clf.predict_proba(self.X) 351 | plot_precision_recall(convert_labels_into_string(self.y), probas) 352 | 353 | def test_ax(self): 354 | np.random.seed(0) 355 | clf = LogisticRegression() 356 | clf.fit(self.X, self.y) 357 | probas = clf.predict_proba(self.X) 358 | fig, ax = plt.subplots(1, 1) 359 | out_ax = plot_precision_recall(self.y, probas) 360 | assert ax is not out_ax 361 | out_ax = plot_precision_recall(self.y, probas, ax=ax) 362 | assert ax is out_ax 363 | 364 | def test_plot_micro(self): 365 | np.random.seed(0) 366 | clf = LogisticRegression() 367 | clf.fit(self.X, self.y) 368 | probas = clf.predict_proba(self.X) 369 | plot_precision_recall(self.y, probas, plot_micro=True) 370 | plot_precision_recall(self.y, probas, plot_micro=False) 371 | 372 | def test_cmap(self): 373 | np.random.seed(0) 374 | clf = LogisticRegression() 375 | clf.fit(self.X, self.y) 376 | probas = clf.predict_proba(self.X) 377 | plot_precision_recall(self.y, probas, cmap='nipy_spectral') 378 | plot_precision_recall(self.y, probas, cmap=plt.cm.nipy_spectral) 379 | 380 | def test_classes_to_plot(self): 381 | np.random.seed(0) 382 | clf = LogisticRegression() 383 | clf.fit(self.X, self.y) 384 | probas = clf.predict_proba(self.X) 385 | plot_precision_recall(self.y, probas, classes_to_plot=[0, 1]) 386 | plot_precision_recall(self.y, probas, classes_to_plot=np.array([0, 1])) 387 | 388 | def test_array_like(self): 389 | plot_precision_recall([0, 1], [[0.8, 0.2], [0.2, 0.8]]) 390 | plot_precision_recall([0, 'a'], [[0.8, 0.2], [0.2, 0.8]]) 391 | plot_precision_recall(['b', 'a'], [[0.8, 0.2], [0.2, 0.8]]) 392 | 393 | 394 | class TestPlotSilhouette(unittest.TestCase): 395 | def setUp(self): 396 | np.random.seed(0) 397 | self.X, self.y = load_data(return_X_y=True) 398 | p = np.random.permutation(len(self.X)) 399 | self.X, self.y = self.X[p], self.y[p] 400 | 401 | def tearDown(self): 402 | plt.close("all") 403 | 404 | def test_plot_silhouette(self): 405 | np.random.seed(0) 406 | clf = KMeans() 407 | cluster_labels = clf.fit_predict(self.X) 408 | plot_silhouette(self.X, cluster_labels) 409 | 410 | def test_string_classes(self): 411 | np.random.seed(0) 412 | clf = KMeans() 413 | cluster_labels = clf.fit_predict(self.X) 414 | plot_silhouette(self.X, convert_labels_into_string(cluster_labels)) 415 | 416 | def test_cmap(self): 417 | np.random.seed(0) 418 | clf = KMeans() 419 | cluster_labels = clf.fit_predict(self.X) 420 | plot_silhouette(self.X, cluster_labels, cmap='Spectral') 421 | plot_silhouette(self.X, cluster_labels, cmap=plt.cm.Spectral) 422 | 423 | def test_ax(self): 424 | np.random.seed(0) 425 | clf = KMeans() 426 | cluster_labels = clf.fit_predict(self.X) 427 | plot_silhouette(self.X, cluster_labels) 428 | fig, ax = plt.subplots(1, 1) 429 | out_ax = plot_silhouette(self.X, cluster_labels) 430 | assert ax is not out_ax 431 | out_ax = plot_silhouette(self.X, cluster_labels, ax=ax) 432 | assert ax is out_ax 433 | 434 | def test_array_like(self): 435 | plot_silhouette(self.X.tolist(), self.y.tolist()) 436 | plot_silhouette(self.X.tolist(), convert_labels_into_string(self.y)) 437 | 438 | 439 | class TestPlotCalibrationCurve(unittest.TestCase): 440 | def setUp(self): 441 | np.random.seed(0) 442 | self.X, self.y = load_breast_cancer(return_X_y=True) 443 | p = np.random.permutation(len(self.X)) 444 | self.X, self.y = self.X[p], self.y[p] 445 | self.lr = LogisticRegression() 446 | self.rf = RandomForestClassifier(random_state=8) 447 | self.svc = LinearSVC() 448 | self.lr_probas = self.lr.fit(self.X, self.y).predict_proba(self.X) 449 | self.rf_probas = self.rf.fit(self.X, self.y).predict_proba(self.X) 450 | self.svc_scores = self.svc.fit(self.X, self.y).\ 451 | decision_function(self.X) 452 | 453 | def tearDown(self): 454 | plt.close("all") 455 | 456 | def test_decision_function(self): 457 | plot_calibration_curve(self.y, [self.lr_probas, 458 | self.rf_probas, 459 | self.svc_scores]) 460 | 461 | def test_plot_calibration(self): 462 | plot_calibration_curve(self.y, [self.lr_probas, self.rf_probas]) 463 | 464 | def test_string_classes(self): 465 | plot_calibration_curve(convert_labels_into_string(self.y), 466 | [self.lr_probas, self.rf_probas]) 467 | 468 | def test_cmap(self): 469 | plot_calibration_curve(convert_labels_into_string(self.y), 470 | [self.lr_probas, self.rf_probas], 471 | cmap='Spectral') 472 | plot_calibration_curve(convert_labels_into_string(self.y), 473 | [self.lr_probas, self.rf_probas], 474 | cmap=plt.cm.Spectral) 475 | 476 | def test_ax(self): 477 | plot_calibration_curve(self.y, [self.lr_probas, self.rf_probas]) 478 | fig, ax = plt.subplots(1, 1) 479 | out_ax = plot_calibration_curve(self.y, 480 | [self.lr_probas, self.rf_probas]) 481 | assert ax is not out_ax 482 | out_ax = plot_calibration_curve(self.y, 483 | [self.lr_probas, self.rf_probas], 484 | ax=ax) 485 | assert ax is out_ax 486 | 487 | def test_array_like(self): 488 | plot_calibration_curve(self.y, [self.lr_probas.tolist(), 489 | self.rf_probas.tolist()]) 490 | plot_calibration_curve(convert_labels_into_string(self.y), 491 | [self.lr_probas.tolist(), 492 | self.rf_probas.tolist()]) 493 | 494 | def test_invalid_probas_list(self): 495 | self.assertRaises(ValueError, plot_calibration_curve, 496 | self.y, 'notalist') 497 | 498 | def test_not_binary(self): 499 | wrong_y = self.y.copy() 500 | wrong_y[-1] = 3 501 | self.assertRaises(ValueError, plot_calibration_curve, 502 | wrong_y, [self.lr_probas, self.rf_probas]) 503 | 504 | def test_wrong_clf_names(self): 505 | self.assertRaises(ValueError, plot_calibration_curve, 506 | self.y, [self.lr_probas, self.rf_probas], 507 | ['One']) 508 | 509 | def test_wrong_probas_shape(self): 510 | self.assertRaises(ValueError, plot_calibration_curve, 511 | self.y, [self.lr_probas.reshape(-1), 512 | self.rf_probas]) 513 | self.assertRaises(ValueError, plot_calibration_curve, 514 | self.y, [np.random.randn(1, 5)]) 515 | 516 | 517 | class TestPlotCumulativeGain(unittest.TestCase): 518 | def setUp(self): 519 | np.random.seed(0) 520 | self.X, self.y = load_breast_cancer(return_X_y=True) 521 | p = np.random.permutation(len(self.X)) 522 | self.X, self.y = self.X[p], self.y[p] 523 | 524 | def tearDown(self): 525 | plt.close("all") 526 | 527 | def test_string_classes(self): 528 | np.random.seed(0) 529 | clf = LogisticRegression() 530 | clf.fit(self.X, convert_labels_into_string(self.y)) 531 | probas = clf.predict_proba(self.X) 532 | plot_cumulative_gain(convert_labels_into_string(self.y), probas) 533 | 534 | def test_two_classes(self): 535 | np.random.seed(0) 536 | # Test this one on Iris (3 classes) 537 | X, y = load_data(return_X_y=True) 538 | clf = LogisticRegression() 539 | clf.fit(X, y) 540 | probas = clf.predict_proba(X) 541 | self.assertRaises(ValueError, plot_cumulative_gain, y, probas) 542 | 543 | def test_ax(self): 544 | np.random.seed(0) 545 | clf = LogisticRegression() 546 | clf.fit(self.X, self.y) 547 | probas = clf.predict_proba(self.X) 548 | fig, ax = plt.subplots(1, 1) 549 | out_ax = plot_cumulative_gain(self.y, probas) 550 | assert ax is not out_ax 551 | out_ax = plot_cumulative_gain(self.y, probas, ax=ax) 552 | assert ax is out_ax 553 | 554 | def test_array_like(self): 555 | plot_cumulative_gain([0, 1], [[0.8, 0.2], [0.2, 0.8]]) 556 | plot_cumulative_gain([0, 'a'], [[0.8, 0.2], [0.2, 0.8]]) 557 | plot_cumulative_gain(['b', 'a'], [[0.8, 0.2], [0.2, 0.8]]) 558 | 559 | 560 | class TestPlotLiftCurve(unittest.TestCase): 561 | def setUp(self): 562 | np.random.seed(0) 563 | self.X, self.y = load_breast_cancer(return_X_y=True) 564 | p = np.random.permutation(len(self.X)) 565 | self.X, self.y = self.X[p], self.y[p] 566 | 567 | def tearDown(self): 568 | plt.close("all") 569 | 570 | def test_string_classes(self): 571 | np.random.seed(0) 572 | clf = LogisticRegression() 573 | clf.fit(self.X, convert_labels_into_string(self.y)) 574 | probas = clf.predict_proba(self.X) 575 | plot_lift_curve(convert_labels_into_string(self.y), probas) 576 | 577 | def test_two_classes(self): 578 | np.random.seed(0) 579 | # Test this one on Iris (3 classes) 580 | X, y = load_data(return_X_y=True) 581 | clf = LogisticRegression() 582 | clf.fit(X, y) 583 | probas = clf.predict_proba(X) 584 | self.assertRaises(ValueError, plot_lift_curve, y, probas) 585 | 586 | def test_ax(self): 587 | np.random.seed(0) 588 | clf = LogisticRegression() 589 | clf.fit(self.X, self.y) 590 | probas = clf.predict_proba(self.X) 591 | fig, ax = plt.subplots(1, 1) 592 | out_ax = plot_lift_curve(self.y, probas) 593 | assert ax is not out_ax 594 | out_ax = plot_lift_curve(self.y, probas, ax=ax) 595 | assert ax is out_ax 596 | 597 | def test_array_like(self): 598 | plot_lift_curve([0, 1], [[0.8, 0.2], [0.2, 0.8]]) 599 | plot_lift_curve([0, 'a'], [[0.8, 0.2], [0.2, 0.8]]) 600 | plot_lift_curve(['b', 'a'], [[0.8, 0.2], [0.2, 0.8]]) 601 | -------------------------------------------------------------------------------- /scikitplot/tests/test_plotters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import unittest 3 | import scikitplot.plotters as skplt 4 | from sklearn.datasets import load_iris as load_data 5 | from sklearn.decomposition import PCA 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | class TestPlotPCAComponentVariance(unittest.TestCase): 11 | 12 | def setUp(self): 13 | np.random.seed(0) 14 | self.X, self.y = load_data(return_X_y=True) 15 | p = np.random.permutation(len(self.X)) 16 | self.X, self.y = self.X[p], self.y[p] 17 | 18 | def tearDown(self): 19 | plt.close("all") 20 | 21 | def test_target_explained_variance(self): 22 | np.random.seed(0) 23 | clf = PCA() 24 | clf.fit(self.X) 25 | ax = skplt.plot_pca_component_variance(clf, target_explained_variance=0) 26 | ax = skplt.plot_pca_component_variance(clf, target_explained_variance=0.5) 27 | ax = skplt.plot_pca_component_variance(clf, target_explained_variance=1) 28 | ax = skplt.plot_pca_component_variance(clf, target_explained_variance=1.5) 29 | 30 | def test_fitted(self): 31 | np.random.seed(0) 32 | clf = PCA() 33 | self.assertRaises(TypeError, skplt.plot_pca_component_variance, clf) 34 | 35 | def test_ax(self): 36 | np.random.seed(0) 37 | clf = PCA() 38 | clf.fit(self.X) 39 | fig, ax = plt.subplots(1, 1) 40 | out_ax = skplt.plot_pca_component_variance(clf) 41 | assert ax is not out_ax 42 | out_ax = skplt.plot_pca_component_variance(clf, ax=ax) 43 | assert ax is out_ax 44 | 45 | 46 | class TestPlotPCA2DProjection(unittest.TestCase): 47 | 48 | def setUp(self): 49 | np.random.seed(0) 50 | self.X, self.y = load_data(return_X_y=True) 51 | p = np.random.permutation(len(self.X)) 52 | self.X, self.y = self.X[p], self.y[p] 53 | 54 | def tearDown(self): 55 | plt.close("all") 56 | 57 | def test_ax(self): 58 | np.random.seed(0) 59 | clf = PCA() 60 | clf.fit(self.X) 61 | fig, ax = plt.subplots(1, 1) 62 | out_ax = skplt.plot_pca_2d_projection(clf, self.X, self.y) 63 | assert ax is not out_ax 64 | out_ax =skplt.plot_pca_2d_projection(clf, self.X, self.y, ax=ax) 65 | assert ax is out_ax 66 | 67 | def test_cmap(self): 68 | np.random.seed(0) 69 | clf = PCA() 70 | clf.fit(self.X) 71 | fig, ax = plt.subplots(1, 1) 72 | ax = skplt.plot_pca_2d_projection(clf, self.X, self.y, cmap='Spectral') 73 | ax = skplt.plot_pca_2d_projection(clf, self.X, self.y, cmap=plt.cm.Spectral) 74 | 75 | 76 | class TestValidateLabels(unittest.TestCase): 77 | 78 | def test_valid_equal(self): 79 | known_labels = ["A", "B", "C"] 80 | passed_labels = ["A", "B", "C"] 81 | arg_name = "true_labels" 82 | 83 | actual = skplt.validate_labels(known_labels, passed_labels, arg_name) 84 | self.assertEqual(actual, None) 85 | 86 | def test_valid_subset(self): 87 | known_labels = ["A", "B", "C"] 88 | passed_labels = ["A", "B"] 89 | arg_name = "true_labels" 90 | 91 | actual = skplt.validate_labels(known_labels, passed_labels, arg_name) 92 | self.assertEqual(actual, None) 93 | 94 | def test_invalid_one_duplicate(self): 95 | known_labels = ["A", "B", "C"] 96 | passed_labels = ["A", "B", "B"] 97 | arg_name = "true_labels" 98 | 99 | with self.assertRaises(ValueError) as context: 100 | skplt.validate_labels(known_labels, passed_labels, arg_name) 101 | 102 | msg = "The following duplicate labels were passed into true_labels: B" 103 | self.assertEqual(msg, str(context.exception)) 104 | 105 | def test_invalid_two_duplicates(self): 106 | known_labels = ["A", "B", "C"] 107 | passed_labels = ["A", "B", "A", "B"] 108 | arg_name = "true_labels" 109 | 110 | with self.assertRaises(ValueError) as context: 111 | skplt.validate_labels(known_labels, passed_labels, arg_name) 112 | 113 | msg = "The following duplicate labels were passed into true_labels: A, B" 114 | self.assertEqual(msg, str(context.exception)) 115 | 116 | def test_invalid_one_missing(self): 117 | known_labels = ["A", "B", "C"] 118 | passed_labels = ["A", "B", "D"] 119 | arg_name = "true_labels" 120 | 121 | with self.assertRaises(ValueError) as context: 122 | skplt.validate_labels(known_labels, passed_labels, arg_name) 123 | 124 | msg = "The following labels were passed into true_labels, but were not found in labels: D" 125 | self.assertEqual(msg, str(context.exception)) 126 | 127 | def test_invalid_two_missing(self): 128 | known_labels = ["A", "B", "C"] 129 | passed_labels = ["A", "E", "B", "D"] 130 | arg_name = "true_labels" 131 | 132 | with self.assertRaises(ValueError) as context: 133 | skplt.validate_labels(known_labels, passed_labels, arg_name) 134 | 135 | msg = "The following labels were passed into true_labels, but were not found in labels: E, D" 136 | self.assertEqual(msg, str(context.exception)) 137 | 138 | def test_numerical_labels(self): 139 | known_labels = [0, 1, 2] 140 | passed_labels = [0, 2] 141 | arg_name = "true_labels" 142 | 143 | actual = skplt.validate_labels(known_labels, passed_labels, arg_name) 144 | self.assertEqual(actual, None) 145 | 146 | def test_invalid_duplicate_numerical_labels(self): 147 | known_labels = [0, 1, 2] 148 | passed_labels = [0, 2, 2] 149 | arg_name = "true_labels" 150 | 151 | with self.assertRaises(ValueError) as context: 152 | skplt.validate_labels(known_labels, passed_labels, arg_name) 153 | 154 | msg = "The following duplicate labels were passed into true_labels: 2" 155 | self.assertEqual(msg, str(context.exception)) 156 | 157 | def test_invalid_missing_numerical_labels(self): 158 | known_labels = [0, 1, 2] 159 | passed_labels = [0, 2, 3] 160 | arg_name = "true_labels" 161 | 162 | with self.assertRaises(ValueError) as context: 163 | skplt.validate_labels(known_labels, passed_labels, arg_name) 164 | 165 | msg = "The following labels were passed into true_labels, but were not found in labels: 3" 166 | self.assertEqual(msg, str(context.exception)) 167 | 168 | 169 | if __name__ == '__main__': 170 | unittest.main() 171 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from setuptools import setup, find_packages 3 | from setuptools.command.test import test as TestCommand 4 | import io 5 | import codecs 6 | import os 7 | import sys 8 | 9 | here = os.path.abspath(os.path.dirname(__file__)) 10 | 11 | 12 | def read(*filenames, **kwargs): 13 | encoding = kwargs.get('encoding', 'utf-8') 14 | sep = kwargs.get('sep', '\n') 15 | buf = [] 16 | for filename in filenames: 17 | with io.open(filename, encoding=encoding) as f: 18 | buf.append(f.read()) 19 | return sep.join(buf) 20 | 21 | long_description = read('README.md') 22 | # long_description = '' 23 | 24 | 25 | class PyTest(TestCommand): 26 | def finalize_options(self): 27 | TestCommand.finalize_options(self) 28 | self.test_args = [] 29 | self.test_suite = True 30 | 31 | def run_tests(self): 32 | import pytest 33 | errcode = pytest.main(self.test_args) 34 | sys.exit(errcode) 35 | 36 | setup( 37 | name='scikit-plot', 38 | version='0.3.7', 39 | url='https://github.com/reiinakano/scikit-plot', 40 | license='MIT License', 41 | author='Reiichiro Nakano', 42 | tests_require=['pytest'], 43 | install_requires=[ 44 | 'matplotlib>=1.4.0', 45 | 'scikit-learn>=0.18', 46 | 'scipy>=0.9', 47 | 'joblib>=0.10' 48 | ], 49 | cmdclass={'test': PyTest}, 50 | author_email='reiichiro.s.nakano@gmail.com', 51 | description='An intuitive library to add plotting functionality to scikit-learn objects.', 52 | long_description=long_description, 53 | packages=['scikitplot'], 54 | include_package_data=True, 55 | platforms='any', 56 | test_suite='scikitplot.tests.test_scikitplot', 57 | classifiers = [ 58 | 'Programming Language :: Python', 59 | 'Programming Language :: Python :: 2', 60 | 'Programming Language :: Python :: 2.7', 61 | 'Programming Language :: Python :: 3', 62 | 'Programming Language :: Python :: 3.5', 63 | 'Programming Language :: Python :: 3.6', 64 | 'Natural Language :: English', 65 | 'Intended Audience :: Developers', 66 | 'Intended Audience :: Science/Research', 67 | 'License :: OSI Approved :: MIT License', 68 | 'Operating System :: OS Independent', 69 | 'Topic :: Scientific/Engineering :: Visualization', 70 | ], 71 | extras_require={ 72 | 'testing': ['pytest'], 73 | } 74 | ) 75 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox (https://tox.readthedocs.io/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | envlist = py27, py34 8 | 9 | [testenv] 10 | commands = nose2 -v 11 | deps = 12 | 13 | --------------------------------------------------------------------------------