├── .Rbuildignore ├── COPYING ├── DESCRIPTION ├── LICENSE ├── NAMESPACE ├── R ├── boxplot.R ├── calibration.R ├── classifier_plots.R ├── density.R ├── example.R ├── example_gen.R ├── individual_plots.R ├── lift.R ├── metrics.R ├── notation_key.R ├── positives.R ├── precision.R ├── roc.R └── style.R ├── README.md ├── boris-git.toml ├── boris.toml ├── data └── example_predictions.rda ├── img ├── notation.png └── notation.svg ├── inst └── img │ └── notation.png ├── man ├── accuracy_plot.Rd ├── calculate_auc.Rd ├── calibration_plot.Rd ├── classifierplots.Rd ├── classifierplots_folder.Rd ├── density_plot.Rd ├── example_predictions.Rd ├── figures │ └── example.png ├── lift_plot.Rd ├── notation_key_plot.Rd ├── positives_plot.Rd ├── precision_plot.Rd ├── propensity_plot.Rd ├── recall_plot.Rd ├── roc_plot.Rd └── sigmoid.Rd └── master.toml /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^README\.md$ 2 | ^R/example_gen\.R$ 3 | ^.*\.Rproj$ 4 | ^\.Rproj\.user$ 5 | ^cran-comments\.md$ # 6 | ^NEWS\.md$ 7 | ^bin$ 8 | ^classifierplots.*.tar.gz$ 9 | ^example.png$ 10 | ^example_output$ 11 | ^classifierplots\.Rcheck$ 12 | ^.*\.toml$ 13 | ^LOCALNOTES\.md$ 14 | ^img$ 15 | ^example2$ 16 | ^exampledata\.psv$ 17 | ^COPYING$ 18 | ^CRAN-RELEASE$ 19 | -------------------------------------------------------------------------------- /COPYING: -------------------------------------------------------------------------------- 1 | Copyright 2017, Ambiata, All Rights Reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | 1. Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of 15 | its contributors may be used to endorse or promote products derived 16 | from this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: classifierplots 2 | Title: Generates a Visualization of Classifier Performance as a Grid of Diagnostic Plots 3 | Version: 1.4.0 4 | Authors@R: c( 5 | person("Aaron", "Defazio", email = "aaron.defazio@gmail.com", role = c("aut", "cre")), 6 | person("Huw", "Campbell", email = "huw.campbell@ambiata.com", role = c("aut"))) 7 | Description: 8 | Generates a visualization of binary classifier performance as a grid of 9 | diagnostic plots with just one function call. Includes ROC curves, 10 | prediction density, accuracy, precision, recall and calibration plots, all using 11 | ggplot2 for easy modification. 12 | Debug your binary classifiers faster and easier! 13 | Depends: 14 | R (>= 3.1), 15 | ggplot2 (>= 2.2), 16 | data.table (>= 1.10), 17 | Imports: 18 | Rcpp (>= 0.12), 19 | grid, 20 | ROCR, 21 | caret, 22 | gridExtra (>= 2.2), 23 | stats, 24 | utils, 25 | png, 26 | Suggests: 27 | testthat, 28 | License: BSD 3-clause License + file LICENSE 29 | Encoding: UTF-8 30 | BugReports: https://github.com/adefazio/classifierplots/issues 31 | URL: https://github.com/adefazio/classifierplots 32 | LazyData: true 33 | RoxygenNote: 5.0.1 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2017 2 | COPYRIGHT HOLDER: Ambiata 3 | ORGANIZATION: Ambiata 4 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | export(accuracy_plot) 4 | export(calibration_plot) 5 | export(classifierplots) 6 | export(classifierplots_folder) 7 | export(density_plot) 8 | export(lift_plot) 9 | export(notation_key_plot) 10 | export(positives_plot) 11 | export(precision_plot) 12 | export(propensity_plot) 13 | export(recall_plot) 14 | export(roc_plot) 15 | export(sigmoid) 16 | import(data.table) 17 | import(ggplot2) 18 | importFrom(ROCR,performance) 19 | importFrom(ROCR,prediction) 20 | importFrom(caret,createResample) 21 | importFrom(ggplot2,ggproto) 22 | importFrom(grDevices,dev.new) 23 | importFrom(grDevices,dev.off) 24 | importFrom(grDevices,pdf) 25 | importFrom(grDevices,x11) 26 | importFrom(grid,gpar) 27 | importFrom(grid,grid.draw) 28 | importFrom(grid,grobName) 29 | importFrom(grid,grobTree) 30 | importFrom(grid,rasterGrob) 31 | importFrom(grid,textGrob) 32 | importFrom(gridExtra,arrangeGrob) 33 | importFrom(gridExtra,grid.arrange) 34 | importFrom(png,readPNG) 35 | importFrom(stats,dnorm) 36 | importFrom(stats,qbeta) 37 | importFrom(stats,quantile) 38 | importFrom(stats,sd) 39 | importFrom(utils,sessionInfo) 40 | -------------------------------------------------------------------------------- /R/boxplot.R: -------------------------------------------------------------------------------- 1 | 2 | "%||%" <- function(a, b) { 3 | if (!is.null(a)) a else b 4 | } 5 | 6 | #' @importFrom grid grobName 7 | ggname <- function(prefix, grob) { 8 | grob$name <- grobName(grob, prefix) 9 | grob 10 | } 11 | 12 | geom_ambiboxplot <- function(mapping = NULL, data = NULL, 13 | stat = "boxplot", position = "dodge", 14 | ..., 15 | width=0.3, 16 | outlier.colour = NULL, 17 | outlier.color = NULL, 18 | outlier.shape = 19, 19 | outlier.size = 0.8, 20 | outlier.stroke = 0.5, 21 | na.rm = FALSE, 22 | show.legend = NA, 23 | inherit.aes = TRUE) { 24 | layer( 25 | data = data, 26 | mapping = mapping, 27 | stat = stat, 28 | geom = GeomAmbiBoxplot, 29 | position = position, 30 | show.legend = show.legend, 31 | inherit.aes = inherit.aes, 32 | params = list( 33 | outlier.colour = outlier.color %||% outlier.colour, 34 | outlier.shape = outlier.shape, 35 | outlier.size = outlier.size, 36 | outlier.stroke = outlier.stroke, 37 | na.rm = na.rm, 38 | width = width, 39 | ... 40 | ) 41 | ) 42 | } 43 | 44 | #' @importFrom ggplot2 ggproto 45 | #' @importFrom grid grobTree 46 | GeomAmbiBoxplot <- ggproto("GeomAmbiBoxplot", Geom, 47 | setup_data = function(data, params) { 48 | # Set width to a default if it is not part of data. 49 | data$width <- data$width %||% 50 | params$width %||% (resolution(data$x, FALSE) * 0.9) 51 | 52 | if (!is.null(data$outliers)) { 53 | suppressWarnings({ 54 | out_min <- vapply(data$outliers, min, numeric(1)) 55 | out_max <- vapply(data$outliers, max, numeric(1)) 56 | }) 57 | 58 | data$ymin_final <- pmin(out_min, data$ymin) 59 | data$ymax_final <- pmax(out_max, data$ymax) 60 | } 61 | 62 | data$xmin <- data$x - data$width / 2.0 63 | data$xmax <- data$x + data$width / 2.0 64 | data$barmin <- data$x - data$width / 3.0 65 | data$barmax <- data$x + data$width / 3.0 66 | 67 | #print("setup_data end") 68 | #browser() 69 | data 70 | }, 71 | draw_group = function(self, data, panel_scales, coord, fatten = 2, 72 | outlier.colour = NULL, outlier.shape = 19, 73 | outlier.size = 0.8, outlier.stroke = 0.5) { 74 | #print("draw group called") 75 | #browser() 76 | common <- data.frame( 77 | colour = alpha(data$fill, data$alpha), 78 | size = data$size, 79 | fill = alpha(data$fill, data$alpha), 80 | group = data$group, 81 | alpha = 1.0, 82 | stringsAsFactors = FALSE 83 | ) 84 | 85 | whiskers <- data.frame( 86 | x = data$x, 87 | xend = data$x, 88 | y = c(data$upper, data$lower), 89 | yend = c(data$ymax, data$ymin), 90 | alpha = common$alpha, 91 | colour = common$fill, 92 | common 93 | ) 94 | 95 | box <- data.frame( 96 | xmin = data$xmin, 97 | xmax = data$xmax, 98 | ymin = data$lower, 99 | ymax = data$upper, 100 | alpha = data$alpha, 101 | common 102 | ) 103 | 104 | crossbar <- data.frame( 105 | x = data$barmin, 106 | xend = data$barmax, 107 | y = data$middle, 108 | yend = data$middle, 109 | colour = data$colour, 110 | size = data$size, 111 | group = data$group, 112 | alpha = 1.0 113 | ) 114 | 115 | if (!is.null(data$outliers) && length(data$outliers[[1]] >= 1)) { 116 | outliers <- data.frame( 117 | y = data$outliers[[1]], 118 | x = data$x[1], 119 | colour = outlier.colour %||% data$fill[1], 120 | shape = outlier.shape %||% data$shape[1], 121 | size = outlier.size %||% data$size[1], 122 | stroke = outlier.stroke %||% data$stroke[1], 123 | fill = outlier.colour %||% data$fill[1], 124 | alpha = 1.0, 125 | stringsAsFactors = FALSE 126 | ) 127 | outliers_grob <- GeomPoint$draw_panel(outliers, panel_scales, coord) 128 | } else { 129 | outliers_grob <- NULL 130 | } 131 | 132 | ggname("geom_ambiboxplot", grobTree( 133 | outliers_grob, 134 | GeomSegment$draw_panel(whiskers, panel_scales, coord), 135 | GeomRect$draw_panel(box, panel_scales, coord), 136 | GeomSegment$draw_panel(crossbar, panel_scales, coord) 137 | )) 138 | }, 139 | 140 | draw_key = draw_key_boxplot, 141 | 142 | default_aes = aes(weight = 1, colour = "#FAFAFA", fill = "#51a7f9", size = 0.5, 143 | alpha = 1, shape = 19, width=0.3), 144 | 145 | required_aes = c("x", "lower", "upper", "middle", "ymin", "ymax") 146 | ) 147 | -------------------------------------------------------------------------------- /R/calibration.R: -------------------------------------------------------------------------------- 1 | #' @title calibration_plot 2 | #' @description Returns a ggplot2 plot object containing a smoothed propensity @@ prediction level plot 3 | #' @param test.y List of know labels on the test set 4 | #' @param pred.prob List of probability predictions on the test set 5 | #' @export 6 | calibration_plot <- function(test.y, pred.prob) { 7 | nbuckets = 10 8 | bucket_array <- seq(0.0, 1.0, by=0.1) 9 | positive_in_band <- function(bucket) { 10 | in_bucket_indicator <- pred.prob >= bucket_array[bucket] & pred.prob < bucket_array[bucket+1] 11 | bucket_size <- sum(in_bucket_indicator) 12 | positive <- sum(test.y[in_bucket_indicator] == 1) 13 | return(qbeta(c(llb=0.025, lb=0.25, y=0.5, ub=0.75, uub=0.965), 0.5+positive, 0.5+bucket_size-positive)) 14 | } 15 | tbl <- data.table(bucket = 1:nbuckets, percentage = 5+bucket_array[1:nbuckets]*100, 16 | blb=bucket_array[1:nbuckets], bub=bucket_array[(1:nbuckets) + 1]) 17 | tbl <- cbind(tbl, 100*t(sapply(tbl$bucket, positive_in_band))) 18 | 19 | ggplot(tbl, aes(x=percentage, y=y)) + 20 | geom_ribbon(aes(ymin=llb, ymax=uub), fill=green_str, alpha=0.2) + 21 | geom_ribbon(aes(ymin=lb, ymax=ub), fill=green_str, alpha=0.4) + 22 | geom_abline(slope=1.0, intercept=0, linetype="dotted") + 23 | scale_x_continuous(name="Predicted probability (%)", limits=c(0,100.0), breaks=seq(5, 95.0, 10.0)) + 24 | scale_y_continuous(name="Smoothed true probability (%)", limits=c(0,100.0)) + 25 | ggtitle("Calibration") 26 | } 27 | 28 | calibration_rolling_window <- function(test.y, pred.prob, granularity=0.02) { 29 | check_classifier_input_and_init(test.y, pred.prob) 30 | step_array <- seq(0.0, 1.0, by=granularity) 31 | thesh_steps <- round(quantile(pred.prob, step_array), digits=4) 32 | pred.order <- order(pred.prob, decreasing=T) 33 | 34 | # We choose the window size based on the amount of data, heuristically. 35 | if(length(test.y) > 2000) { 36 | window_split <- 20.0 37 | } else { 38 | window_split <- 10.0 39 | } 40 | 41 | window_radius <- abs(thesh_steps[2] - thesh_steps[length(thesh_steps)-1])/window_split 42 | #print(paste("Window radius: ", window_radius)) 43 | 44 | propensity_tbl_perc <- data.table( 45 | part=1:length(step_array), percentage=100*step_array, 46 | threshold=thesh_steps, step_array=step_array) 47 | propensity_tbl_perc[, propensity := 48 | propensity_at_prediction_level(test.y, pred.prob, threshold, window_radius), by=c("threshold")] 49 | 50 | 51 | 52 | 53 | extreme_percentile <- quantile(pred.prob, 0.975) 54 | upper_prop <- propensity_at_prediction_level(test.y, pred.prob, extreme_percentile, window_radius) 55 | range_max <- max(extreme_percentile, upper_prop) 56 | 57 | return(ggplot(propensity_tbl_perc, aes(x=100*threshold, y=100.0*propensity)) + 58 | geom_line(color=green_str, size=1.5) + classifier_theme + classifier_colours + 59 | geom_abline(slope=1.0, intercept=0, linetype="dotted") + 60 | scale_x_continuous(name="Predicted probability (%)", limits=c(0,100.0*range_max)) + 61 | scale_y_continuous(name="True probability (%)", limits=c(0,100.0*range_max)) + 62 | ggtitle("Calibration")) 63 | } 64 | -------------------------------------------------------------------------------- /R/classifier_plots.R: -------------------------------------------------------------------------------- 1 | 2 | #' The main functions you want are \code{\link{classifierplots}} or \code{\link{classifierplots_folder}}. 3 | #' @docType package 4 | #' @name classifierplots 5 | NULL 6 | 7 | #' \figure{example.png} 8 | #' @title classifierplots 9 | #' @description Produce a suit of classifier diagnostic plots 10 | #' @param test.y List of know labels on the test set 11 | #' @param pred.prob List of probability predictions on the test set 12 | #' @import data.table ggplot2 13 | #' @importFrom gridExtra grid.arrange 14 | #' @importFrom ROCR performance 15 | #' @importFrom ROCR prediction 16 | #' @export 17 | #' @examples 18 | #' \dontrun{ 19 | #' classifierplots(example_predictions$test.y, example_predictions$pred.prob) 20 | #' } 21 | classifierplots <- function(test.y, pred.prob) { 22 | produce_classifier_plots(test.y, pred.prob, show=T) 23 | } 24 | 25 | #' @title classifierplots_folder 26 | #' @description Produce a suit of classifier diagnostic plots, saving to disk. 27 | #' @param test.y List of know labels on the test set 28 | #' @param pred.prob List of probability predictions on the test set 29 | #' @param folder Directory to save plots into 30 | #' @param height height of separately saved plots 31 | #' @param width width of separately saved plots 32 | #' @export 33 | classifierplots_folder <- function(test.y, pred.prob, folder, height=5, width=5) { 34 | produce_classifier_plots(test.y, pred.prob, folder=folder, height=height, width=width, show=F) 35 | } 36 | 37 | #' @importFrom grid grid.draw 38 | #' @importFrom gridExtra arrangeGrob 39 | #' @importFrom grDevices dev.new 40 | #' @importFrom grDevices dev.off 41 | #' @importFrom grDevices pdf 42 | #' @importFrom grDevices x11 43 | #' @importFrom utils sessionInfo 44 | produce_classifier_plots <- function( 45 | test.y, pred.prob, 46 | folder=NULL, 47 | height=5, width=5, show=F) { 48 | check_classifier_input_and_init(test.y, pred.prob) 49 | n <- length(test.y) 50 | 51 | if(!is.null(folder)) { 52 | dir.create(folder, showWarnings = F) 53 | } 54 | 55 | # Subsample data if it is huge 56 | if(n > 500000) { 57 | sel.ind <- sample.int(n, 500000) 58 | 59 | test.y <- test.y[sel.ind] 60 | pred.prob <- pred.prob[sel.ind] 61 | 62 | print("Data was subsampled to 500k points for the purpose of plotting") 63 | } 64 | 65 | if(!is.null(folder)) { 66 | auc <- calculate_auc(test.y, pred.prob) 67 | write(auc, file=paste0(folder, "/auc.txt")) 68 | } 69 | 70 | saveplot <- function(plt, plt.name, width=7, height=7) { 71 | if(!is.null(folder)) { 72 | full.plt.name <- paste0(folder, "/", plt.name) 73 | ggsave(plot=plt, filename=full.plt.name, width = width, height = height) 74 | print(paste("Saved plot:", full.plt.name)) 75 | } 76 | invisible() 77 | } 78 | 79 | roc.plt <- roc_plot(test.y, pred.prob) 80 | saveplot(roc.plt, "ROC.pdf") 81 | 82 | positives.plt <- positives_plot(test.y, pred.prob) 83 | saveplot(positives.plt, "positives.pdf") 84 | 85 | cal.plt <- calibration_plot(test.y, pred.prob) 86 | saveplot(cal.plt, "calibration.pdf") 87 | 88 | notat.plt <- notation_key_plot() 89 | 90 | dens.plt <- density_plot(test.y, pred.prob) 91 | saveplot(dens.plt, "density.pdf") 92 | 93 | acc.plt.perc <- accuracy_plot(test.y, pred.prob) 94 | saveplot(acc.plt.perc, "accuracy.pdf") 95 | 96 | prec.plt.perc <- precision_plot(test.y, pred.prob) 97 | saveplot(prec.plt.perc, "precision.pdf") 98 | 99 | lift.plt <- lift_plot(test.y, pred.prob) 100 | saveplot(lift.plt, "lift.pdf") 101 | 102 | recall.plt <- recall_plot(test.y, pred.prob) 103 | saveplot(recall.plt, "recall.pdf") 104 | 105 | 106 | if(!is.null(folder)) { 107 | print("Saving all plots in one file ...") 108 | all.plt.full.name <- paste0(folder, "/ALL.pdf") 109 | g <- gridExtra::arrangeGrob(roc.plt, positives.plt, cal.plt, notat.plt, 110 | dens.plt, acc.plt.perc, prec.plt.perc, recall.plt, ncol=4) 111 | 112 | ggsave(filename=all.plt.full.name, g, width=25, height=13) 113 | print(paste("Saved plot:", all.plt.full.name)) 114 | } 115 | 116 | if(show) { 117 | dev.new(width=25, height=13, dpi=55) 118 | #return(gridExtra::grid.arrange(roc.plt, positives.plt, cal.plt, notat.plt, 119 | # dens.plt, acc.plt.perc, prec.plt.perc, recall.plt, ncol=4)) 120 | g <- gridExtra::arrangeGrob(roc.plt, positives.plt, cal.plt, notat.plt, 121 | dens.plt, acc.plt.perc, prec.plt.perc, recall.plt, ncol=4) 122 | grid::grid.draw(g) 123 | return(g) 124 | } else { 125 | invisible() 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /R/density.R: -------------------------------------------------------------------------------- 1 | #' @title density_plot 2 | #' @description Returns a ggplot2 plot object containing a score density plot. 3 | #' @param test.y List of know labels on the test set 4 | #' @param pred.prob List of probability predictions on the test set 5 | #' @importFrom grid textGrob 6 | #' @importFrom grid gpar 7 | #' @importFrom stats dnorm 8 | #' @importFrom stats qbeta 9 | #' @importFrom stats quantile 10 | #' @importFrom stats sd 11 | #' @export 12 | density_plot <- function(test.y, pred.prob) { 13 | check_classifier_input_and_init(test.y, pred.prob) 14 | print("Generating score density plot") 15 | ground.truth <- factor(test.y) 16 | density_tbl <- data.table(Prediction=pred.prob, `Ground Truth`=ground.truth) 17 | 18 | mp <- max(quantile(pred.prob[ground.truth == 1], 0.95), 19 | quantile(pred.prob[ground.truth != 1], 0.95)) 20 | 21 | if(mp < 0.4) { 22 | limits <- c(mp*1.1, 0.0) 23 | } else { 24 | limits <- c(1.0, 0.0) 25 | } 26 | 27 | annotation <- paste0("Test set size: ", ifelse(length(test.y)==500000, ">= 500,000", length(test.y)), 28 | "\nNegative cases: ", format(100*sum(test.y != 1)/length(test.y), digits=3), 29 | "%\nPositive cases: ", format(100*sum(test.y == 1)/length(test.y), digits=3), "%") 30 | 31 | plt <- ggplot(density_tbl) + 32 | geom_density(aes(x=Prediction, fill=`Ground Truth`), alpha=0.4, color="#00000000", size=1.5) + 33 | scale_x_reverse(name="Probability threshold", limits=limits) + 34 | scale_y_continuous(name="Density", expand=c(0,0)) + 35 | ggtitle("Prediction density") + 36 | annotation_custom(grob=grid::textGrob(annotation, x=0.05, y=0.87, just=c("left", "top"), 37 | gp = grid::gpar(col=fontgrey_str))) + 38 | legend_theme + classifier_theme + 39 | #theme(text=element_text(size=16, color="#444444")) + 40 | classifier_colours 41 | return(plt) 42 | } 43 | -------------------------------------------------------------------------------- /R/example.R: -------------------------------------------------------------------------------- 1 | 2 | #' Generated using the gen_example included in the github source 3 | #' @name example_predictions 4 | #' @docType data 5 | #' @keywords data 6 | NULL 7 | -------------------------------------------------------------------------------- /R/example_gen.R: -------------------------------------------------------------------------------- 1 | 2 | # Just run during development to produce the example data 3 | gen_example <- function() { 4 | if (!requireNamespace("LiblineaR", quietly = TRUE)) { 5 | stop("LiblineaR needed for this function to work. Please install it.", 6 | call. = FALSE) 7 | } 8 | if (!requireNamespace("SVMMaj", quietly = TRUE)) { 9 | stop("SVMMaj needed for this function to work. Please install it.", 10 | call. = FALSE) 11 | } 12 | 13 | X <- as.data.table(AusCredit$X) 14 | #y <- AusCredit$y 15 | y <- factor(AusCredit$y, labels=c(0, 1), levels=c("Rejected", "Accepted")) 16 | 17 | # Liblinear requires a matrix datatype. We all remove the bias with 0+. 18 | # If y was part of the data table, we would use "y ~ 0 + ." instead 19 | X.mm <- model.matrix(~ 0 + ., data=X) 20 | 21 | # Here I pull out just train/test sets. Be sure to pull out a val set 22 | # as well if your tuning hyperparameters. 23 | smpl_frac <- 0.5 24 | #seed(42) 25 | #train.ind <- sample.int(nrow(X), smpl_frac*nrow(X)) 26 | train.ind <- c(1:345) 27 | train.mm <- X.mm[train.ind,] 28 | test.mm <- X.mm[-train.ind,] 29 | train.data <- X[train.ind,] 30 | test.data <- X[-train.ind,] 31 | train.y <- y[train.ind] 32 | test.y <- y[-train.ind] 33 | 34 | # Defaults are pretty reasonable. 35 | fit.ll <- LiblineaR(data=train.mm, target=train.y, type=0, cost=1, epsilon=0.0001, verbose=T) 36 | 37 | pred.ll <- predict(fit.ll, test.mm, proba=T) 38 | pred.prob <- pred.ll$probabilities[,"1"] 39 | 40 | test.y <- as.numeric(test.y) - 1 41 | values <- train.mm[,"X2"] 42 | 43 | #classifierplots(test.y, pred.prob) 44 | example_predictions <- list(test.y=test.y, pred.prob=pred.prob) 45 | devtools::use_data(example_predictions) 46 | } 47 | -------------------------------------------------------------------------------- /R/individual_plots.R: -------------------------------------------------------------------------------- 1 | 2 | check_predictions <- function(pred.prob) { 3 | 4 | if(max(pred.prob) > 1) { 5 | stop(paste("Pred.prob not in [0,1]. Max:", max(pred.prob), 6 | ". You can use the sigmoid(x) function in this package to map to [0,1].")) 7 | } 8 | 9 | if(min(pred.prob) < 0) { 10 | stop(paste("Pred.prob not in [0,1]. Min:", min(pred.prob), 11 | ". You can use the sigmoid(x) function in this package to map to [0,1].")) 12 | } 13 | } 14 | 15 | check_classifier_input_and_init <- function(test.y, pred.prob) { 16 | 17 | if(length(test.y) != length(pred.prob)) { 18 | stop(paste("Length of test.y:", length(test.y), "did not match pred.prob:", length(pred.prob))) 19 | } 20 | yvals <- unique(test.y) 21 | if(length(yvals) != 2) { 22 | stop(paste("test.y had more than 2 unique values:", length(yvals))) 23 | } 24 | if(sum(yvals == 1.0) != 1) { 25 | stop(paste("This code expects test.y to be numerical, with the positive class indicated by '1'. There was no 1 in test.y!")) 26 | } 27 | 28 | check_predictions(pred.prob) 29 | } 30 | 31 | #' @title sigmoid 32 | #' @description Logistic sigmoid function, that maps any real number to the [0,1] interval. Supports vectors of numeric. 33 | #' @param x data 34 | #' @export 35 | sigmoid <- function(x) { 1.0/(1.0+exp(-x)) } 36 | 37 | #' @title propensity_plot 38 | #' @description Returns a ggplot2 plot object containing an propensity @@ percentile plot 39 | #' @param test.y List of know labels on the test set 40 | #' @param pred.prob List of probability predictions on the test set 41 | #' @param granularity Default 0.02, probability step between points in plot. 42 | #' @export 43 | propensity_plot <- function(test.y, pred.prob, granularity=0.02) { 44 | check_classifier_input_and_init(test.y, pred.prob) 45 | step_array <- seq(0.0, 1.0, by=granularity) 46 | thesh_steps <- round(quantile(pred.prob, step_array), digits=4) 47 | pred.order <- order(pred.prob, decreasing=T) 48 | 49 | propensity_tbl_perc <- data.table( 50 | part=1:length(step_array), percentage=100 - 100*step_array, 51 | threshold=thesh_steps, step_array=step_array) 52 | propensity_tbl_perc[, propensity := 53 | propensity_at_threshold(test.y, pred.prob, part, pred.order, thesh_steps), by=c("part")] 54 | 55 | return(ggplot(propensity_tbl_perc, aes(x=percentage, y=100.0*propensity)) + 56 | geom_line(color=green_str, size=1.5) + classifier_theme + classifier_colours + 57 | scale_x_continuous(name="Instance decile (non-cumulative %)", breaks=seq(0.0, 100.0, 10.0)) + 58 | scale_y_continuous(name="Smoothed positive (%)") + 59 | ggtitle("Positive rate (rolling window)")) 60 | } 61 | 62 | #' @title accuracy_plot 63 | #' @description Returns a ggplot2 plot object containing an accuracy @@ percentile plot 64 | #' @param test.y List of know labels on the test set 65 | #' @param pred.prob List of probability predictions on the test set 66 | #' @param granularity Default 0.02, probability step between points in plot. 67 | #' @param show_numbers Show values as numbers above the plot line 68 | #' @export 69 | accuracy_plot <- function(test.y, pred.prob, granularity=0.02, show_numbers=T) { 70 | check_classifier_input_and_init(test.y, pred.prob) 71 | step_array <- seq(0.0, 1.0, by=granularity) 72 | thesh_steps <- round(quantile(pred.prob, step_array), digits=4) 73 | accuracy_tbl_perc <- data.table(percentage=100 - 100*step_array, threshold=thesh_steps) 74 | accuracy_tbl_perc[, accuracy := sapply(threshold, function(x) accuracy_at_threshold(x, test.y, pred.prob))] 75 | accuracy_tbl_perc[, accuracy_lb := sapply(threshold, function(x) accuracy_at_threshold_p(0.025, x, test.y, pred.prob))] 76 | accuracy_tbl_perc[, accuracy_ub := sapply(threshold, function(x) accuracy_at_threshold_p(0.975, x, test.y, pred.prob))] 77 | 78 | if(show_numbers) { 79 | deciles <- seq(0, 100, 10) 80 | accuracy_tbl_perc[percentage %in% deciles, dec_lbl := paste0(format(100*accuracy, digits=2), "%")] 81 | numbers <- geom_text(aes(x=percentage, y=102*accuracy, label=dec_lbl), 82 | hjust=0.3, vjust=-1.0, size=4, color=I(blue_str)) 83 | } else { 84 | numbers <- NULL 85 | } 86 | 87 | return(ggplot(accuracy_tbl_perc, aes(x=percentage, y=100.0*accuracy)) + 88 | geom_ribbon(aes(ymin=100.0*accuracy_lb, ymax=100.0*accuracy_ub), fill=green_str, alpha=0.2) + 89 | geom_line(color=green_str, size=1.5) + classifier_theme + classifier_colours + 90 | scale_x_continuous(name="k% (thresholded to positive class)", breaks=seq(0.0, 100.0, 10.0)) + 91 | scale_y_continuous(name="Accuracy (%)", limits=c(0,100), breaks=seq(0.0, 100.0, 10.0)) + 92 | numbers + 93 | ggtitle("Accuracy @ k")) 94 | } 95 | 96 | #' @title recall_plot 97 | #' @description Returns a ggplot2 plot object containing an sensitivity @@ percentile plot 98 | #' @param test.y List of know labels on the test set 99 | #' @param pred.prob List of probability predictions on the test set 100 | #' @param granularity Default 0.02, probability step between points in plot. 101 | #' @param show_numbers Show numbers at deciles T/F default T. 102 | #' @export 103 | recall_plot <- function(test.y, pred.prob, granularity=0.02, show_numbers=T) { 104 | check_classifier_input_and_init(test.y, pred.prob) 105 | step_array <- seq(0.0, 1.0, by=granularity) 106 | thesh_steps <- round(quantile(pred.prob, step_array), digits=4) 107 | tbl <- data.table(percentage=100 - 100*step_array, threshold=thesh_steps) 108 | tbl[, sensitivity := sapply(threshold, function(x) sensitivity_at_threshold(x, test.y, pred.prob))] 109 | tbl[, sensitivity_lb := sapply(threshold, function(x) sensitivity_at_threshold_p(0.025, x, test.y, pred.prob))] 110 | tbl[, sensitivity_ub := sapply(threshold, function(x) sensitivity_at_threshold_p(0.975, x, test.y, pred.prob))] 111 | 112 | if(show_numbers) { 113 | deciles <- seq(10, 100, 10) 114 | tbl[percentage %in% deciles, dec_lbl := paste0(format(100*sensitivity, digits=2), "%")] 115 | numbers <- geom_text(aes(x=percentage, y=100*sensitivity+2*sensitivity, label=dec_lbl), 116 | hjust=0.3, vjust=3.0, size=4, color=I(blue_str)) 117 | } else { 118 | numbers <- NULL 119 | } 120 | 121 | return(ggplot(tbl, aes(x=percentage, y=100.0*sensitivity)) + 122 | geom_ribbon(aes(ymin=100.0*sensitivity_lb, ymax=100.0*sensitivity_ub), fill=green_str, alpha=0.2) + 123 | geom_line(color=green_str, size=1.5) + classifier_theme + classifier_colours + 124 | scale_x_continuous(name="k% (thresholded to positive class)", breaks=seq(0.0, 100.0, 10.0), limits=c(0,100), expand=c(0, 0.3)) + 125 | scale_y_continuous(name="Recall (%)", breaks=seq(0.0, 100.0, 10.0), limits=c(0,100), expand=c(0, 0.3)) + 126 | numbers + 127 | ggtitle("Recall @ k")) 128 | } 129 | 130 | # Variables used in data.table expressions have to be defined here 131 | utils::globalVariables(c( 132 | "Prediction", "Ground Truth", "accuracy", "threshold", 133 | "precision", "sensitivity", "percentage", "fpr", "tpr", 134 | "propensity", "positive_perc", "bucket", "dec_lbl", "part", 135 | "ymin", "ymax", "sensitivity_lb", "sensitivity_ub")) 136 | -------------------------------------------------------------------------------- /R/lift.R: -------------------------------------------------------------------------------- 1 | #' @title lift_plot 2 | #' @description Returns a ggplot2 plot object containing an precision @@ percentile plot 3 | #' @param test.y List of know labels on the test set 4 | #' @param pred.prob List of probability predictions on the test set 5 | #' @param granularity Default 0.02, probability step between points in plot. 6 | #' @param show_numbers Show numbers at deciles T/F default T. 7 | #' @export 8 | lift_plot <- function(test.y, pred.prob, granularity=0.02, show_numbers=T) { 9 | check_classifier_input_and_init(test.y, pred.prob) 10 | 11 | step_array <- seq(0.0, 1.0, by=granularity) 12 | thesh_steps <- round(quantile(pred.prob, step_array), digits=4) 13 | tbl <- data.table(percentage=100 - 100*step_array, threshold=thesh_steps) 14 | tbl[, precision := sapply(threshold, function(x) precision_at_threshold(x, test.y, pred.prob))] 15 | tbl[, precision_lb := sapply(threshold, function(x) precision_at_threshold_p(0.025, x, test.y, pred.prob))] 16 | tbl[, precision_ub := sapply(threshold, function(x) precision_at_threshold_p(0.975, x, test.y, pred.prob))] 17 | 18 | baseline_rate <- tbl[percentage == 100, precision] 19 | 20 | tbl[, lift :=(precision-baseline_rate)/baseline_rate] 21 | tbl[, lift_lb :=(precision_lb-baseline_rate)/baseline_rate] 22 | tbl[, lift_ub :=(precision_ub-baseline_rate)/baseline_rate] 23 | 24 | if(show_numbers) { 25 | deciles <- seq(10, 100, 10) 26 | tbl[percentage %in% deciles, dec_lbl := paste0(format(100*lift, digits=2), "%")] 27 | numbers <- geom_text(aes(x=percentage, y=100*lift+2*lift, label=dec_lbl), 28 | hjust=0.3, vjust=-1.0, size=4, color=I(blue_str)) 29 | } else { 30 | numbers <- NULL 31 | } 32 | 33 | return(ggplot(tbl, aes(x=percentage, y=100*lift)) + 34 | geom_ribbon(aes(ymin=100.0*lift_lb, ymax=100.0*lift_ub), fill=green_str, alpha=0.2) + 35 | geom_abline(slope=0.0, intercept=1, linetype="dotted") + 36 | geom_line(color=green_str, size=1.5) + classifier_theme + classifier_colours + 37 | scale_x_continuous(name="k% (thresholded to positive class)", breaks=seq(0.0, 100.0, 10.0)) + 38 | scale_y_continuous(name="Relative lift (%)") + 39 | numbers + 40 | ggtitle("Lift")) 41 | } 42 | -------------------------------------------------------------------------------- /R/metrics.R: -------------------------------------------------------------------------------- 1 | 2 | ######################################### 3 | 4 | propensity_at_threshold <- function(test.y, prob.y, part, pred.order, part_quantiles) { 5 | window_each_side <- 4 6 | #part_quantiles is an increasing sequence of quantiles 7 | part_lb <- max(1, part-window_each_side) 8 | part_ub <- min(part+1+window_each_side, length(part_quantiles)) 9 | in_part_indicator <- (prob.y < part_quantiles[part_ub] & 10 | prob.y >= part_quantiles[part_lb]) 11 | 12 | part_orders <- pred.order[in_part_indicator] 13 | 14 | # Use a gaussian weighting function, scaled to fit the data window's order statistics. 15 | part_weights <- dnorm(part_orders, mean=mean(part_orders), sd=sd(part_orders)) 16 | part_weights <- part_weights/sum(part_weights) 17 | 18 | # Standard weighted proportion equation 19 | rate_prop <- t(part_weights) %*% (test.y[in_part_indicator] == 1) 20 | #browser() 21 | return(rate_prop) 22 | } 23 | 24 | propensity_at_prediction_level <- function(test.y, prob.y, pred.level, window_radius) { 25 | #part_quantiles is an increasing sequence of quantiles 26 | in_part_indicator <- (prob.y < pred.level+window_radius & 27 | prob.y >= pred.level-window_radius) 28 | 29 | prob.sub <- prob.y[in_part_indicator] 30 | 31 | # Use a gaussian weighting function, scaled to fit the data window's order statistics. 32 | pweights <- dnorm(prob.sub, mean=mean(prob.sub), sd=sd(prob.sub)) 33 | pweights <- pweights/sum(pweights) 34 | 35 | # Standard weighted proportion equation 36 | rate_prop <- t(pweights) %*% (test.y[in_part_indicator] == 1) 37 | #browser() 38 | return(rate_prop) 39 | } 40 | 41 | # accuracy correct/n 42 | accuracy_at_threshold_p <- function(p, threshold, test.y, prob.y) { 43 | test.y.bin <- test.y == 1 44 | pred.y.bin <- prob.y >= threshold 45 | correct = sum(pred.y.bin == test.y.bin) 46 | #return(correct/length(test.y)) 47 | return(qbeta(p, correct, length(test.y)-correct)) 48 | } 49 | 50 | accuracy_at_threshold <- function(threshold, test.y, prob.y) { 51 | return(accuracy_at_threshold_p(0.5, threshold, test.y, prob.y)) 52 | } 53 | 54 | 55 | # precision TP/(TP+FP) 56 | precision_at_threshold_p <- function(p, threshold, test.y, prob.y) { 57 | test.y.bin <- test.y == 1 58 | pred.y.bin <- prob.y >= threshold 59 | true_positives <- sum(pred.y.bin & test.y.bin) 60 | false_positives <- sum(pred.y.bin & (!test.y.bin)) 61 | return(qbeta(p, true_positives, false_positives)) 62 | } 63 | 64 | precision_at_threshold <- function(threshold, test.y, prob.y) { 65 | return(precision_at_threshold_p(0.5, threshold, test.y, prob.y)) 66 | } 67 | 68 | 69 | #tmp[,`:=`(tmax = qbeta(0.05, cs * .I, .I - cs * .I), tmin = qbeta(0.95, cs * .I, .I - cs * .I )) ] 70 | 71 | # Posterior quantiles of sensitivity (TP/P) 72 | sensitivity_at_threshold_p <- function(p, threshold, test.y, prob.y) { 73 | test.y.bin <- test.y == 1 74 | pred.y.bin <- prob.y >= threshold 75 | true_positives <- sum(pred.y.bin & test.y.bin) 76 | return(qbeta(p, true_positives, sum(test.y.bin)-true_positives)) 77 | } 78 | 79 | sensitivity_at_threshold <- function(threshold, test.y, prob.y) { 80 | return(sensitivity_at_threshold_p(0.5, threshold, test.y, prob.y)) 81 | } 82 | 83 | -------------------------------------------------------------------------------- /R/notation_key.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | #' @title notation_key_plot 4 | #' @description Produces some definitions as a grid. 5 | #' @importFrom png readPNG 6 | #' @importFrom grid rasterGrob 7 | #' @export 8 | notation_key_plot <- function() { 9 | img <- png::readPNG(system.file("img", "notation.png", package="classifierplots")) 10 | g <- grid::rasterGrob(img, interpolate=TRUE) 11 | 12 | return(g) 13 | } 14 | -------------------------------------------------------------------------------- /R/positives.R: -------------------------------------------------------------------------------- 1 | 2 | #' @title positives_plot 3 | #' @description Returns a ggplot2 plot object containing an positives-per-decile plot. 4 | #' @param test.y List of know labels on the test set 5 | #' @param pred.prob List of probability predictions on the test set 6 | #' @export 7 | positives_plot <- function(test.y, pred.prob) { 8 | check_classifier_input_and_init(test.y, pred.prob) 9 | 10 | nbuckets = 10 11 | bucket_array <- seq(1.0, 0.0, by=-0.1) 12 | bucket_quantiles <- quantile(pred.prob, bucket_array) 13 | positive_in_bucket <- function(bucket) { 14 | in_bucket_indicator <- pred.prob < bucket_quantiles[bucket] & pred.prob >= bucket_quantiles[bucket+1] 15 | bucket_size <- sum(in_bucket_indicator) 16 | positive <- sum(test.y[in_bucket_indicator] == 1) 17 | return(qbeta(c(llb=0.025, lb=0.25, y=0.5, ub=0.75, uub=0.965), positive, bucket_size-positive)) 18 | } 19 | tbl <- data.table(bucket = 1:nbuckets, percentage = 100.0-bucket_array[1:nbuckets]*100) 20 | tbl <- cbind(tbl, 100*t(sapply(tbl$bucket, positive_in_bucket))) 21 | 22 | return(ggplot(tbl, aes(x=percentage, y=y, ymin=llb, lower=lb, middle=y, upper=ub, ymax=uub)) + 23 | geom_ambiboxplot(fill=green_str, stat="identity", position="identity", width=8) + 24 | classifier_theme + classifier_colours + 25 | scale_x_continuous(name="Instance decile (non-cumulative %)", breaks=seq(0.0, 100.0, 10.0)) + 26 | scale_y_continuous(name="Positive instances in decile (%)") + 27 | ggtitle("Positive instances per decile")) 28 | } 29 | 30 | -------------------------------------------------------------------------------- /R/precision.R: -------------------------------------------------------------------------------- 1 | #' @title precision_plot 2 | #' @description Returns a ggplot2 plot object containing an precision @@ percentile plot 3 | #' @param test.y List of know labels on the test set 4 | #' @param pred.prob List of probability predictions on the test set 5 | #' @param granularity Default 0.02, probability step between points in plot. 6 | #' @param show_numbers Show numbers at deciles T/F default T. 7 | #' @export 8 | precision_plot <- function(test.y, pred.prob, granularity=0.02, show_numbers=T) { 9 | check_classifier_input_and_init(test.y, pred.prob) 10 | step_array <- seq(0.0, 1.0, by=granularity) 11 | thesh_steps <- round(quantile(pred.prob, step_array), digits=4) 12 | precision_tbl_perc <- data.table(percentage=100 - 100*step_array, threshold=thesh_steps) 13 | precision_tbl_perc[, precision := 100.0*sapply(threshold, function(x) precision_at_threshold(x, test.y, pred.prob))] 14 | precision_tbl_perc[, precision_lb := sapply(threshold, function(x) precision_at_threshold_p(0.025, x, test.y, pred.prob))] 15 | precision_tbl_perc[, precision_ub := sapply(threshold, function(x) precision_at_threshold_p(0.975, x, test.y, pred.prob))] 16 | 17 | 18 | if(show_numbers) { 19 | deciles <- seq(10, 100, 10) 20 | precision_tbl_perc[percentage %in% deciles, dec_lbl := paste0(format(precision, digits=2), "%")] 21 | numbers <- geom_text(aes(x=percentage, y=precision+0.02*precision, label=dec_lbl), 22 | hjust=0.3, vjust=-1.0, size=4, color=I(blue_str)) 23 | } else { 24 | numbers <- NULL 25 | } 26 | 27 | mp <- max(precision_tbl_perc[!is.na(precision),]$precision) 28 | minp <- min(precision_tbl_perc[!is.na(precision),]$precision) 29 | 30 | # Smart y breaks calculation 31 | if(mp <= 0.2) { 32 | breaks <- seq(0.0, 0.2, 0.01) 33 | } else { 34 | if(mp <= 2) { 35 | breaks <- seq(0.0, 2.0, 0.1) 36 | } else { 37 | if(mp <= 20) { 38 | breaks <- seq(0.0, 20.0, 1.0) 39 | } else { 40 | breaks <- seq(0.0, 100.0, 10.0) 41 | } 42 | } 43 | } 44 | 45 | if(mp <= 20) { 46 | limits <- c(0, mp*1.1) 47 | } else { 48 | limits <- c(0,100) 49 | } 50 | 51 | return(ggplot(precision_tbl_perc, aes(x=percentage, y=precision)) + 52 | geom_ribbon(aes(ymin=100.0*precision_lb, ymax=100.0*precision_ub), fill=green_str, alpha=0.2) + 53 | geom_abline(slope=0.0, intercept=minp, linetype="dotted") + 54 | geom_line(color=green_str, size=1.5) + classifier_theme + classifier_colours + 55 | scale_x_continuous(name="k% (thresholded to positive class)", breaks=seq(0.0, 100.0, 10.0)) + 56 | scale_y_continuous(name="Precision (%)", limits=limits, breaks=breaks) + 57 | numbers + 58 | ggtitle("Precision @ k")) 59 | } 60 | -------------------------------------------------------------------------------- /R/roc.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | utils::globalVariables(c( 4 | "accuracy_lb", "accuracy_ub", "preds", "y", "llb", "uub", "lb", "ub", 5 | "precision_lb", "precision_ub", "lift", "lift_ub", "lift_lb", "tp", 6 | "resample", "fp", "fpr_step", "50%", "2.5%", "97.5%")) 7 | 8 | #' @title calculate_auc 9 | #' @description Compute auc from predictions and truth 10 | #' @param test.y List of know labels on the test set 11 | #' @param pred.prob List of probability predictions on the test set 12 | #' @return auc 13 | calculate_auc <- function(test.y, pred.prob) { 14 | n <- length(test.y) 15 | 16 | print("(AUC) Sorting data ...") 17 | test.y.bin <- test.y == 1 18 | roc_tbl <- data.table(y=test.y.bin, preds=pred.prob) 19 | roc_tbl <- roc_tbl[order(preds)] 20 | 21 | npositives <- as.double(sum(test.y.bin)) 22 | nnegatives <- as.double(n - npositives) 23 | 24 | print("(AUC) Calculating ranks ...") 25 | # Main AUC calcuation. We use the MW-U stat equivalence, 26 | # since it's a little faster to calculate. 27 | # Note that tied predictions are given a rank equal to the mean of the tied set. 28 | roc_tbl[, rank := mean(.I), by=preds] 29 | 30 | r1 <- roc_tbl[y == T, sum(rank)] 31 | u1 <- r1 - (npositives*(npositives+1))/2.0 32 | auc <- 100*u1/(npositives*nnegatives) 33 | return(auc) 34 | } 35 | 36 | #' @title roc_plot 37 | #' @description Produces a smoothed ROC curve as a ggplot2 plot object. A confidence interval is produced using bootstrapping, although it is turned off by default if you have a large dataset. 38 | #' @param test.y List of know labels on the test set 39 | #' @param pred.prob List of probability predictions on the test set 40 | #' @param resamps How many bootstrap samples to use 41 | #' @param force_bootstrap True/False to force or force off bootstrapping. 42 | #' @export 43 | #' @importFrom caret createResample 44 | roc_plot <- function(test.y, pred.prob, resamps=2000, force_bootstrap=NULL) { 45 | #check_classifier_input_and_init(test.y, pred.prob) 46 | 47 | n <- length(test.y) 48 | test.y.bin <- test.y == 1 49 | nbins <- 50 50 | 51 | npositives <- sum(test.y.bin) 52 | nnegatives <- n - npositives 53 | 54 | negative_steps <- floor(nnegatives/50.0) 55 | negative_steps <- floor(nnegatives/nbins) 56 | 57 | print("Calculating AUC ...") 58 | auc <- calculate_auc(test.y, pred.prob) 59 | print(paste("AUC:", auc)) 60 | 61 | # No point in the ci calculation for large datasets 62 | big_data_cutoff <- 50000 63 | if(!is.null(force_bootstrap)) { 64 | bootstrap <- force_bootstrap 65 | } else { 66 | bootstrap <- n <= big_data_cutoff 67 | } 68 | 69 | # Negated to get the correct sort order later 70 | pos_pred_probs <- -pred.prob[test.y.bin] 71 | neg_pred_probs <- -pred.prob[!test.y.bin] 72 | 73 | if(bootstrap) { 74 | print("Bootstrapping ROC curves") 75 | 76 | pos_pred_boots <- pos_pred_probs[c(caret::createResample(pos_pred_probs, times=resamps, list=F))] 77 | neg_pred_boots <- neg_pred_probs[c(caret::createResample(neg_pred_probs, times=resamps, list=F))] 78 | 79 | roc_tbl <- data.table( 80 | preds=c(pos_pred_boots, neg_pred_boots), 81 | y=c(rep(T, length(pos_pred_boots)), rep(F, length(neg_pred_boots))), 82 | resample=c(rep(1:resamps, each=length(pos_pred_probs)), 83 | rep(1:resamps, each=length(neg_pred_probs)))) 84 | setkey(roc_tbl, "resample", "preds") 85 | 86 | roc_tbl[, tp := cumsum(y), by=resample] 87 | roc_tbl[, fp := cumsum(!y), by=resample] 88 | roc_tbl[, fpr_step := ((fp %% negative_steps) == 0), by=resample] 89 | 90 | substeps_tbl <- roc_tbl[fpr_step == T, ] 91 | # there can be multiple rows with the same fpr, so we pick the last. 92 | subind <- substeps_tbl[, .I[.N], by = c("resample", "fp")] 93 | roc_tbl_sub <- substeps_tbl[subind$V1] 94 | roc_tbl_sub_stats <- roc_tbl_sub[, as.list(quantile(tp, c(0.025, 0.5, 0.975))), keyby=fp] 95 | 96 | print("Eval AUC") 97 | roc_tbl[, rank := mean(.I), by = c("resample", "preds")] 98 | 99 | r1 <- roc_tbl[y == T, sum(rank) - .N*n*(resample-1), keyby="resample"]$V1 100 | u1 <- r1 - (npositives*(npositives+1))/2.0 101 | aucs <- 1.0 - u1/(npositives*nnegatives) 102 | 103 | #auc_tbl_ex <- roc_tbl[roc_tbl[, .I[.N], by = c("resample", "fp")]$V1] 104 | #aucs <- auc_tbl_ex[, sum(tp)/(npositives*nnegatives), by=resample]$V1 105 | auc_bounds <- 100.0*quantile(aucs, c(0.025, 0.5, 0.975)) 106 | 107 | # Make sure we print enough digits so that the lower and upper bounds are not the same 108 | digits_use <- 3 109 | if(format(auc_bounds[1], digits=digits_use) == format(auc_bounds[3], digits=digits_use)) { 110 | digits_use <- 5 111 | } 112 | } else { 113 | roc_tbl <- data.table( 114 | preds=c(pos_pred_probs, neg_pred_probs), 115 | y=c(rep(T, length(pos_pred_probs)), rep(F, length(neg_pred_probs)))) 116 | setkey(roc_tbl, "preds") 117 | 118 | roc_tbl[, tp := cumsum(y)] 119 | roc_tbl[, fp := cumsum(!y)] 120 | roc_tbl[, fpr_step := ((fp %% negative_steps) == 0)] 121 | 122 | substeps_tbl <- roc_tbl[fpr_step == T, ] 123 | # there can be multiple rows with the same fpr, so we pick the last. 124 | subind <- substeps_tbl[, .I[.N], by = c("fp")] 125 | roc_tbl_sub_stats <- substeps_tbl[subind$V1] 126 | roc_tbl_sub_stats[, `50%` := tp] 127 | } 128 | 129 | print("Producing ROC plot") 130 | 131 | plt <- ggplot(roc_tbl_sub_stats, aes( 132 | x=100.0*fp/nnegatives, 133 | y=100.0*`50%`/npositives)) + 134 | geom_line(color=green_str, size=1.5) + 135 | geom_abline(slope=1.0, intercept=0, linetype="dotted") + 136 | annotate("text", x=62.5, y=37.5, 137 | label=paste0("AUC ", format(auc, digits=3), "%"), 138 | parse=F, size=7, colour=fontgrey_str) + 139 | scale_x_continuous(name="False Positive Rate (%) (1-Specificity)", 140 | limits=c(0.0, 100.0), expand=c(0, 0.3)) + 141 | scale_y_continuous(name="True Positive Rate (%) (Sensitivity)", 142 | limits=c(0.0, 100.0), expand=c(0, 0.3)) + 143 | classifier_theme 144 | 145 | if(bootstrap) { 146 | plt <- plt + 147 | geom_ribbon(aes( 148 | ymin=100.0*`2.5%`/npositives, 149 | ymax=100.0*`97.5%`/npositives), 150 | fill=green_str, alpha=0.2) + 151 | annotate("text", x=62.5, y=30, 152 | label=paste0("95% CI: ", format(auc_bounds[1], digits=digits_use), "% - ", format(auc_bounds[3], digits=digits_use), "%"), 153 | parse=F, size=4.5, colour=fontgrey_str) 154 | } 155 | 156 | return(plt + ggtitle("ROC")) 157 | } 158 | -------------------------------------------------------------------------------- /R/style.R: -------------------------------------------------------------------------------- 1 | 2 | fontgrey_str <- "#444444" 3 | green_str <- "#78b45a" 4 | blue_str <- "#51a7f9" 5 | grey_str <- "#7F7F7F" 6 | 7 | 8 | # We are not really using these at the moment, but it eases future changes 9 | classifier_theme <- theme() 10 | classifier_colours <- theme() 11 | 12 | legend_theme <- theme( 13 | legend.position=c(0.5, 0.95), 14 | legend.background=element_rect(fill=alpha('white', 0.0)), 15 | legend.direction="horizontal") 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Classifierplots 2 | 3 | Generates a visualization of binary classifier performance as a grid of diagonstic plots with just one function call. Includes ROC curves, prediction density, accuracy, precision, recall and calibration plots, all using ggplot2 for easy modification. 4 | Debug your binary classifiers faster and easier! 5 | 6 | **classifierplots is on cran now** 7 | 8 | Install with 9 | 10 | install.packages("classifierplots") 11 | 12 | Tested on Windows, Mac OS and Linux. 13 | 14 | ### Usage 15 | 16 | The main function to use when running R interactively is: 17 | 18 | classifierplots(test.y, pred.prob) 19 | 20 | Where you pass in ground truth values test.y and predictions in [0,1] as pred.prob. 21 | 22 | If you want to save the results to disk as folder of seperate plots as well as a single ALL.pdf grid, use 23 | 24 | classifierplots_folder(test.y, pred.prob, folder) 25 | 26 | There are also functions to produce each individual plot which return ggplot2 objects. 27 | 28 | ##### Runnable example 29 | 30 | The small example_predictions dataset is included with the package: 31 | 32 | library(classifierplots) 33 | # Plot to window 34 | classifierplots(example_predictions$test.y, example_predictions$pred.prob) 35 | # Save output directly to disk 36 | classifierplots_folder(example_predictions$test.y, example_predictions$pred.prob, "outfolder") 37 | 38 | ![Example](/man/figures/example.png?raw=true "Example") 39 | 40 | 41 | # Instructions for using the repository version 42 | 43 | We recommend using the CRAN version instead, but you can install from Github or from a local clone of the repository as well. The **devtools package is required** for working with the current development version. 44 | 45 | ### Installing from github with devtools 46 | 47 | library(devtools) 48 | install_github("ambiata/classifierplots") 49 | 50 | ### Building locally 51 | 52 | Once you have cloned the repository, run from a shell in the project directory: 53 | 54 | R CMD build . 55 | 56 | This produces a tarball: classifierplots_1.3.2.tar.gz. Checks: 57 | 58 | R CMD check --as-cran classifierplots_1.3.2.tar.gz 59 | 60 | ### Installing locally 61 | 62 | Run: 63 | 64 | R CMD INSTALL classifierplots_1.3.2.tar.gz 65 | 66 | If you need to install the dependencies as well (you probably do), then run first: 67 | 68 | install.packages(c('Rcpp', 'tibble', 'caret', 'gridExtra', 'ggplot2', 'ROCR', 'png', 'data.table'), dependencies=T, type='source') 69 | 70 | If you're not using Linux you may be able to omit the type='source' part to speed up the install. 71 | 72 | ### Development 73 | 74 | In R, just open an R session within the project's directory, then run: 75 | 76 | devtools::load_all() 77 | 78 | To refresh their definitions without restarting R just run it again. 79 | -------------------------------------------------------------------------------- /boris-git.toml: -------------------------------------------------------------------------------- 1 | [boris] 2 | version = 1 3 | 4 | [build.dist] 5 | git = "refs/heads/master" 6 | 7 | [build.branches] 8 | git = "refs/heads/topic/*" 9 | 10 | [build.all-*] 11 | git = "refs/heads/**" 12 | -------------------------------------------------------------------------------- /boris.toml: -------------------------------------------------------------------------------- 1 | [boris] 2 | version = 1 3 | 4 | [build.dist] 5 | 6 | [build.branches] 7 | 8 | [build.all-rebased] 9 | command = [["rebased"]] 10 | -------------------------------------------------------------------------------- /data/example_predictions.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adefazio/classifierplots/6d2484899ccff159049b1345a1dcb2f01264c64d/data/example_predictions.rda -------------------------------------------------------------------------------- /img/notation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adefazio/classifierplots/6d2484899ccff159049b1345a1dcb2f01264c64d/img/notation.png -------------------------------------------------------------------------------- /img/notation.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 22 | 24 | 43 | 45 | 46 | 48 | image/svg+xml 49 | 51 | 52 | 53 | 54 | 55 | 60 | True Positive Rate (Recall/Sensitivity) 75 | True Positives 86 | Positives 97 | 102 | False Positive Rate(1 - Specificity) 117 | False Positives 128 | Negatives 139 | 144 | Accuracy 155 | Correct Predictions 166 | Test set size 177 | 182 | Precision 193 | True positives 204 | True Positives + False Positives 215 | 220 | "Fraction of positive predictionsthat are correct" 236 | "Fraction of positive instancesclassified as positive" 252 | 257 | True Positives 268 | False Positives 279 | False Negatives 290 | True Negatives 301 | 307 | 312 | 1 327 | 0 338 | 1 353 | 0 364 | Actual 375 | Prediction 386 | 387 | 388 | -------------------------------------------------------------------------------- /inst/img/notation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adefazio/classifierplots/6d2484899ccff159049b1345a1dcb2f01264c64d/inst/img/notation.png -------------------------------------------------------------------------------- /man/accuracy_plot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/individual_plots.R 3 | \name{accuracy_plot} 4 | \alias{accuracy_plot} 5 | \title{accuracy_plot} 6 | \usage{ 7 | accuracy_plot(test.y, pred.prob, granularity = 0.02, show_numbers = T) 8 | } 9 | \arguments{ 10 | \item{test.y}{List of know labels on the test set} 11 | 12 | \item{pred.prob}{List of probability predictions on the test set} 13 | 14 | \item{granularity}{Default 0.02, probability step between points in plot.} 15 | 16 | \item{show_numbers}{Show values as numbers above the plot line} 17 | } 18 | \description{ 19 | Returns a ggplot2 plot object containing an accuracy @ percentile plot 20 | } 21 | 22 | -------------------------------------------------------------------------------- /man/calculate_auc.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/roc.R 3 | \name{calculate_auc} 4 | \alias{calculate_auc} 5 | \title{calculate_auc} 6 | \usage{ 7 | calculate_auc(test.y, pred.prob) 8 | } 9 | \arguments{ 10 | \item{test.y}{List of know labels on the test set} 11 | 12 | \item{pred.prob}{List of probability predictions on the test set} 13 | } 14 | \value{ 15 | auc 16 | } 17 | \description{ 18 | Compute auc from predictions and truth 19 | } 20 | 21 | -------------------------------------------------------------------------------- /man/calibration_plot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/calibration.R 3 | \name{calibration_plot} 4 | \alias{calibration_plot} 5 | \title{calibration_plot} 6 | \usage{ 7 | calibration_plot(test.y, pred.prob) 8 | } 9 | \arguments{ 10 | \item{test.y}{List of know labels on the test set} 11 | 12 | \item{pred.prob}{List of probability predictions on the test set} 13 | } 14 | \description{ 15 | Returns a ggplot2 plot object containing a smoothed propensity @ prediction level plot 16 | } 17 | 18 | -------------------------------------------------------------------------------- /man/classifierplots.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/classifier_plots.R 3 | \docType{package} 4 | \name{classifierplots} 5 | \alias{classifierplots} 6 | \alias{classifierplots-package} 7 | \title{The main functions you want are \code{\link{classifierplots}} or \code{\link{classifierplots_folder}}.} 8 | \usage{ 9 | classifierplots(test.y, pred.prob) 10 | } 11 | \arguments{ 12 | \item{test.y}{List of know labels on the test set} 13 | 14 | \item{pred.prob}{List of probability predictions on the test set} 15 | } 16 | \description{ 17 | The main functions you want are \code{\link{classifierplots}} or \code{\link{classifierplots_folder}}. 18 | 19 | Produce a suit of classifier diagnostic plots 20 | } 21 | \details{ 22 | \figure{example.png} 23 | } 24 | \examples{ 25 | \dontrun{ 26 | classifierplots(example_predictions$test.y, example_predictions$pred.prob) 27 | } 28 | } 29 | 30 | -------------------------------------------------------------------------------- /man/classifierplots_folder.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/classifier_plots.R 3 | \name{classifierplots_folder} 4 | \alias{classifierplots_folder} 5 | \title{classifierplots_folder} 6 | \usage{ 7 | classifierplots_folder(test.y, pred.prob, folder, height = 5, width = 5) 8 | } 9 | \arguments{ 10 | \item{test.y}{List of know labels on the test set} 11 | 12 | \item{pred.prob}{List of probability predictions on the test set} 13 | 14 | \item{folder}{Directory to save plots into} 15 | 16 | \item{height}{height of separately saved plots} 17 | 18 | \item{width}{width of separately saved plots} 19 | } 20 | \description{ 21 | Produce a suit of classifier diagnostic plots, saving to disk. 22 | } 23 | 24 | -------------------------------------------------------------------------------- /man/density_plot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/density.R 3 | \name{density_plot} 4 | \alias{density_plot} 5 | \title{density_plot} 6 | \usage{ 7 | density_plot(test.y, pred.prob) 8 | } 9 | \arguments{ 10 | \item{test.y}{List of know labels on the test set} 11 | 12 | \item{pred.prob}{List of probability predictions on the test set} 13 | } 14 | \description{ 15 | Returns a ggplot2 plot object containing a score density plot. 16 | } 17 | 18 | -------------------------------------------------------------------------------- /man/example_predictions.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/example.R 3 | \docType{data} 4 | \name{example_predictions} 5 | \alias{example_predictions} 6 | \title{Generated using the gen_example included in the github source} 7 | \description{ 8 | Generated using the gen_example included in the github source 9 | } 10 | \keyword{data} 11 | 12 | -------------------------------------------------------------------------------- /man/figures/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adefazio/classifierplots/6d2484899ccff159049b1345a1dcb2f01264c64d/man/figures/example.png -------------------------------------------------------------------------------- /man/lift_plot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/lift.R 3 | \name{lift_plot} 4 | \alias{lift_plot} 5 | \title{lift_plot} 6 | \usage{ 7 | lift_plot(test.y, pred.prob, granularity = 0.02, show_numbers = T) 8 | } 9 | \arguments{ 10 | \item{test.y}{List of know labels on the test set} 11 | 12 | \item{pred.prob}{List of probability predictions on the test set} 13 | 14 | \item{granularity}{Default 0.02, probability step between points in plot.} 15 | 16 | \item{show_numbers}{Show numbers at deciles T/F default T.} 17 | } 18 | \description{ 19 | Returns a ggplot2 plot object containing an precision @ percentile plot 20 | } 21 | 22 | -------------------------------------------------------------------------------- /man/notation_key_plot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/notation_key.R 3 | \name{notation_key_plot} 4 | \alias{notation_key_plot} 5 | \title{notation_key_plot} 6 | \usage{ 7 | notation_key_plot() 8 | } 9 | \description{ 10 | Produces some definitions as a grid. 11 | } 12 | 13 | -------------------------------------------------------------------------------- /man/positives_plot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/positives.R 3 | \name{positives_plot} 4 | \alias{positives_plot} 5 | \title{positives_plot} 6 | \usage{ 7 | positives_plot(test.y, pred.prob) 8 | } 9 | \arguments{ 10 | \item{test.y}{List of know labels on the test set} 11 | 12 | \item{pred.prob}{List of probability predictions on the test set} 13 | } 14 | \description{ 15 | Returns a ggplot2 plot object containing an positives-per-decile plot. 16 | } 17 | 18 | -------------------------------------------------------------------------------- /man/precision_plot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/precision.R 3 | \name{precision_plot} 4 | \alias{precision_plot} 5 | \title{precision_plot} 6 | \usage{ 7 | precision_plot(test.y, pred.prob, granularity = 0.02, show_numbers = T) 8 | } 9 | \arguments{ 10 | \item{test.y}{List of know labels on the test set} 11 | 12 | \item{pred.prob}{List of probability predictions on the test set} 13 | 14 | \item{granularity}{Default 0.02, probability step between points in plot.} 15 | 16 | \item{show_numbers}{Show numbers at deciles T/F default T.} 17 | } 18 | \description{ 19 | Returns a ggplot2 plot object containing an precision @ percentile plot 20 | } 21 | 22 | -------------------------------------------------------------------------------- /man/propensity_plot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/individual_plots.R 3 | \name{propensity_plot} 4 | \alias{propensity_plot} 5 | \title{propensity_plot} 6 | \usage{ 7 | propensity_plot(test.y, pred.prob, granularity = 0.02) 8 | } 9 | \arguments{ 10 | \item{test.y}{List of know labels on the test set} 11 | 12 | \item{pred.prob}{List of probability predictions on the test set} 13 | 14 | \item{granularity}{Default 0.02, probability step between points in plot.} 15 | } 16 | \description{ 17 | Returns a ggplot2 plot object containing an propensity @ percentile plot 18 | } 19 | 20 | -------------------------------------------------------------------------------- /man/recall_plot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/individual_plots.R 3 | \name{recall_plot} 4 | \alias{recall_plot} 5 | \title{recall_plot} 6 | \usage{ 7 | recall_plot(test.y, pred.prob, granularity = 0.02, show_numbers = T) 8 | } 9 | \arguments{ 10 | \item{test.y}{List of know labels on the test set} 11 | 12 | \item{pred.prob}{List of probability predictions on the test set} 13 | 14 | \item{granularity}{Default 0.02, probability step between points in plot.} 15 | 16 | \item{show_numbers}{Show numbers at deciles T/F default T.} 17 | } 18 | \description{ 19 | Returns a ggplot2 plot object containing an sensitivity @ percentile plot 20 | } 21 | 22 | -------------------------------------------------------------------------------- /man/roc_plot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/roc.R 3 | \name{roc_plot} 4 | \alias{roc_plot} 5 | \title{roc_plot} 6 | \usage{ 7 | roc_plot(test.y, pred.prob, resamps = 2000, force_bootstrap = NULL) 8 | } 9 | \arguments{ 10 | \item{test.y}{List of know labels on the test set} 11 | 12 | \item{pred.prob}{List of probability predictions on the test set} 13 | 14 | \item{resamps}{How many bootstrap samples to use} 15 | 16 | \item{force_bootstrap}{True/False to force or force off bootstrapping.} 17 | } 18 | \description{ 19 | Produces a smoothed ROC curve as a ggplot2 plot object. A confidence interval is produced using bootstrapping, although it is turned off by default if you have a large dataset. 20 | } 21 | 22 | -------------------------------------------------------------------------------- /man/sigmoid.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/individual_plots.R 3 | \name{sigmoid} 4 | \alias{sigmoid} 5 | \title{sigmoid} 6 | \usage{ 7 | sigmoid(x) 8 | } 9 | \arguments{ 10 | \item{x}{data} 11 | } 12 | \description{ 13 | Logistic sigmoid function, that maps any real number to the [0,1] interval. Supports vectors of numeric. 14 | } 15 | 16 | -------------------------------------------------------------------------------- /master.toml: -------------------------------------------------------------------------------- 1 | [master] 2 | runner = "s3://ambiata-dispensary-v2/dist/master/master-r/linux/x86_64/20161201235845-c915cd4/master-r-20161201235845-c915cd4" 3 | sha = "618f8c45244110e96260278381f5d27f30e999a9" 4 | version = 1 5 | 6 | [build.dist] 7 | PUBLISH = "true" 8 | PUBLISH_S3 = "$AMBIATA_ARTEFACTS_MASTER" 9 | 10 | [build.branches] 11 | PUBLISH = "true" 12 | PUBLISH_S3 = "$AMBIATA_ARTEFACTS_BRANCHES" 13 | --------------------------------------------------------------------------------