├── .Rbuildignore ├── .github ├── .gitignore └── workflows │ ├── R-CMD-check.yaml │ └── pkgdown.yaml ├── .gitignore ├── DESCRIPTION ├── LICENSE ├── LICENSE.md ├── NAMESPACE ├── R ├── 0_imports.R ├── data.R ├── glmnet_coef_plot.R ├── utils-pipe.R ├── viz_classbalance.R ├── viz_decision_boundary.R ├── viz_dispersion.R ├── viz_fitted_line.R ├── viz_pca.R ├── viz_pcacm.R ├── viz_prob_distribution.R ├── viz_prob_region.R ├── viz_residuals.R ├── viz_tsne.R └── workflow_utils.R ├── README.Rmd ├── README.md ├── _pkgdown.yml ├── data-raw ├── fairy_tales.R └── mnist_sample.R ├── data ├── fairy_tales.rda └── mnist_sample.rda ├── horus.Rproj ├── man ├── fairy_tales.Rd ├── figures │ ├── README-unnamed-chunk-2-1.png │ ├── README-unnamed-chunk-3-1.png │ └── logo.png ├── glmnet_coef_plot.Rd ├── mnist_sample.Rd ├── pipe.Rd ├── viz_classbalance.Rd ├── viz_decision_boundary.Rd ├── viz_dispersion.Rd ├── viz_fitted_line.Rd ├── viz_pca.Rd ├── viz_pcacm.Rd ├── viz_prob_region.Rd ├── viz_residuals.Rd └── viz_tsne.Rd ├── pkgdown └── favicon │ ├── apple-touch-icon-120x120.png │ ├── apple-touch-icon-152x152.png │ ├── apple-touch-icon-180x180.png │ ├── apple-touch-icon-60x60.png │ ├── apple-touch-icon-76x76.png │ ├── apple-touch-icon.png │ ├── favicon-16x16.png │ ├── favicon-32x32.png │ └── favicon.ico └── tests ├── figs ├── deps.txt ├── viz_classbalance │ ├── viz-classbalance-n-max-2.svg │ ├── viz-classbalance-n-max-4-ties.svg │ └── viz-classbalance-simple.svg ├── viz_decision_boundary │ ├── viz-decision-boundary-expand.svg │ ├── viz-decision-boundary-resolution.svg │ └── viz-decision-boundary-simple.svg ├── viz_fitted_line │ ├── viz-fitted-line-expand.svg │ ├── viz-fitted-line-resolution.svg │ ├── viz-fitted-line-simple.svg │ └── viz-fitted-line-style.svg ├── viz_pca │ ├── viz-pca-components.svg │ ├── viz-pca-loadings.svg │ └── viz-pca-simple.svg ├── viz_pcacm │ ├── viz-pcacm-components.svg │ ├── viz-pcacm-loadings-components.svg │ ├── viz-pcacm-loadings.svg │ └── viz-pcacm-simple.svg ├── viz_prob_region │ ├── viz-prob-region-expand.svg │ ├── viz-prob-region-facet-expand.svg │ ├── viz-prob-region-facet-resolution.svg │ ├── viz-prob-region-facet-simple.svg │ ├── viz-prob-region-resolution.svg │ └── viz-prob-region-simple.svg ├── viz_residuals │ └── viz-residuals-simple.svg └── viz_tsne │ ├── viz-tsne-simple-factor.svg │ └── viz-tsne-simple-numeric.svg ├── testthat.R └── testthat ├── test-viz_classbalance.R ├── test-viz_decision_boundary.R ├── test-viz_fitted_line.R ├── test-viz_pca.R ├── test-viz_pcacm.R ├── test-viz_prob_region.R ├── test-viz_residuals.R └── test-viz_tsne.R /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^codecov\.yml$ 2 | ^appveyor\.yml$ 3 | ^\.travis\.yml$ 4 | ^README\.Rmd$ 5 | ^data-raw$ 6 | ^LICENSE\.md$ 7 | ^horus\.Rproj$ 8 | ^\.Rproj\.user$ 9 | ^\.github$ 10 | ^_pkgdown\.yml$ 11 | ^docs$ 12 | ^pkgdown$ 13 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # For help debugging build failures open an issue on the RStudio community with the 'github-actions' tag. 2 | # https://community.rstudio.com/new-topic?category=Package%20development&tags=github-actions 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | pull_request: 9 | branches: 10 | - main 11 | - master 12 | 13 | name: R-CMD-check 14 | 15 | jobs: 16 | R-CMD-check: 17 | runs-on: ${{ matrix.config.os }} 18 | 19 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 20 | 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | config: 25 | - {os: windows-latest, r: 'release'} 26 | - {os: macOS-latest, r: 'release'} 27 | - {os: ubuntu-20.04, r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/focal/latest"} 28 | - {os: ubuntu-20.04, r: 'devel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/focal/latest"} 29 | 30 | env: 31 | R_REMOTES_NO_ERRORS_FROM_WARNINGS: true 32 | RSPM: ${{ matrix.config.rspm }} 33 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 34 | 35 | steps: 36 | - uses: actions/checkout@v2 37 | 38 | - uses: r-lib/actions/setup-r@v1 39 | with: 40 | r-version: ${{ matrix.config.r }} 41 | 42 | - uses: r-lib/actions/setup-pandoc@v1 43 | 44 | - name: Query dependencies 45 | run: | 46 | install.packages('remotes') 47 | saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2) 48 | writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version") 49 | shell: Rscript {0} 50 | 51 | - name: Cache R packages 52 | if: runner.os != 'Windows' 53 | uses: actions/cache@v2 54 | with: 55 | path: ${{ env.R_LIBS_USER }} 56 | key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }} 57 | restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1- 58 | 59 | - name: Install system dependencies 60 | if: runner.os == 'Linux' 61 | run: | 62 | while read -r cmd 63 | do 64 | eval sudo $cmd 65 | done < <(Rscript -e 'writeLines(remotes::system_requirements("ubuntu", "20.04"))') 66 | 67 | - name: Install dependencies 68 | run: | 69 | remotes::install_deps(dependencies = TRUE) 70 | remotes::install_cran("rcmdcheck") 71 | shell: Rscript {0} 72 | 73 | - name: Check 74 | env: 75 | _R_CHECK_CRAN_INCOMING_REMOTE_: false 76 | run: | 77 | options(crayon.enabled = TRUE) 78 | rcmdcheck::rcmdcheck(args = c("--no-manual", "--as-cran"), error_on = "warning", check_dir = "check") 79 | shell: Rscript {0} 80 | 81 | - name: Upload check results 82 | if: failure() 83 | uses: actions/upload-artifact@main 84 | with: 85 | name: ${{ runner.os }}-r${{ matrix.config.r }}-results 86 | path: check 87 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | - master 6 | 7 | name: pkgdown 8 | 9 | jobs: 10 | pkgdown: 11 | runs-on: macOS-latest 12 | env: 13 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 14 | steps: 15 | - uses: actions/checkout@v2 16 | 17 | - uses: r-lib/actions/setup-r@v1 18 | 19 | - uses: r-lib/actions/setup-pandoc@v1 20 | 21 | - name: Query dependencies 22 | run: | 23 | install.packages('remotes') 24 | saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2) 25 | writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version") 26 | shell: Rscript {0} 27 | 28 | - name: Cache R packages 29 | uses: actions/cache@v2 30 | with: 31 | path: ${{ env.R_LIBS_USER }} 32 | key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }} 33 | restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1- 34 | 35 | - name: Install dependencies 36 | run: | 37 | remotes::install_deps(dependencies = TRUE) 38 | install.packages("pkgdown", type = "binary") 39 | shell: Rscript {0} 40 | 41 | - name: Install package 42 | run: R CMD INSTALL . 43 | 44 | - name: Deploy package 45 | run: | 46 | git config --local user.email "actions@github.com" 47 | git config --local user.name "GitHub Actions" 48 | Rscript -e 'pkgdown::deploy_to_branch(new_process = FALSE)' 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .RData 4 | .Ruserdata 5 | docs 6 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: horus 2 | Title: Visual Tools To Help Machine Learning model Selection 3 | Version: 0.0.0.9000 4 | Authors@R: 5 | person(given = "Emil", 6 | family = "Hvitfeldt", 7 | role = c("aut", "cre"), 8 | email = "emilhhvitfeldt@gmail.com", 9 | comment = c(ORCID = "0000-0002-0679-1945")) 10 | Description: Includes a suite of functions that visually aids 11 | the understanding and model selection of a wide array of machine 12 | learning applications. 13 | License: MIT + file LICENSE 14 | URL: https://emilhvitfeldt.github.io/horus 15 | Depends: 16 | R (>= 2.10) 17 | Imports: 18 | dplyr, 19 | forcats, 20 | ggplot2, 21 | glue, 22 | magrittr, 23 | parsnip (>= 0.1.5.9000), 24 | purrr, 25 | readr, 26 | rlang, 27 | Rtsne, 28 | tibble, 29 | tidyr, 30 | tidytext, 31 | workflows 32 | Suggests: 33 | covr, 34 | kernlab, 35 | kknn, 36 | ranger, 37 | testthat, 38 | vdiffr, 39 | glmnet 40 | Remotes: 41 | tidymodels/parsnip 42 | Encoding: UTF-8 43 | LazyData: true 44 | RoxygenNote: 7.1.1.9001 45 | Config/testthat/edition: 3 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2018 2 | COPYRIGHT HOLDER: Emil Hvitfeldt 3 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2018 Emil Hvitfeldt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | export("%>%") 4 | export(glmnet_coef_plot) 5 | export(viz_classbalance) 6 | export(viz_decision_boundary) 7 | export(viz_dispersion) 8 | export(viz_fitted_line) 9 | export(viz_pca) 10 | export(viz_pcacm) 11 | export(viz_prob_region) 12 | export(viz_residuals) 13 | export(viz_tsne) 14 | importFrom(Rtsne,Rtsne) 15 | importFrom(dplyr,all_of) 16 | importFrom(dplyr,arrange) 17 | importFrom(dplyr,bind_cols) 18 | importFrom(dplyr,count) 19 | importFrom(dplyr,desc) 20 | importFrom(dplyr,everything) 21 | importFrom(dplyr,filter) 22 | importFrom(dplyr,group_by) 23 | importFrom(dplyr,mutate) 24 | importFrom(dplyr,mutate_) 25 | importFrom(dplyr,pull) 26 | importFrom(dplyr,recode) 27 | importFrom(dplyr,row_number) 28 | importFrom(dplyr,select) 29 | importFrom(dplyr,select_) 30 | importFrom(dplyr,slice) 31 | importFrom(dplyr,summarize) 32 | importFrom(forcats,fct_infreq) 33 | importFrom(forcats,fct_lump) 34 | importFrom(forcats,fct_relevel) 35 | importFrom(forcats,fct_rev) 36 | importFrom(ggplot2,aes) 37 | importFrom(ggplot2,aes_) 38 | importFrom(ggplot2,aes_string) 39 | importFrom(ggplot2,arrow) 40 | importFrom(ggplot2,facet_grid) 41 | importFrom(ggplot2,facet_wrap) 42 | importFrom(ggplot2,geom_abline) 43 | importFrom(ggplot2,geom_bar) 44 | importFrom(ggplot2,geom_histogram) 45 | importFrom(ggplot2,geom_line) 46 | importFrom(ggplot2,geom_point) 47 | importFrom(ggplot2,geom_raster) 48 | importFrom(ggplot2,geom_segment) 49 | importFrom(ggplot2,geom_text) 50 | importFrom(ggplot2,ggplot) 51 | importFrom(ggplot2,guide_legend) 52 | importFrom(ggplot2,guides) 53 | importFrom(ggplot2,labs) 54 | importFrom(ggplot2,scale_alpha) 55 | importFrom(ggplot2,scale_fill_gradient2) 56 | importFrom(ggplot2,scale_x_log10) 57 | importFrom(ggplot2,scale_y_discrete) 58 | importFrom(ggplot2,theme_minimal) 59 | importFrom(ggplot2,unit) 60 | importFrom(ggplot2,xlim) 61 | importFrom(glue,glue) 62 | importFrom(magrittr,"%>%") 63 | importFrom(parsnip,augment) 64 | importFrom(purrr,map) 65 | importFrom(readr,parse_number) 66 | importFrom(rlang,.data) 67 | importFrom(rlang,abort) 68 | importFrom(rlang,enquo) 69 | importFrom(rlang,ensym) 70 | importFrom(rlang,eval_tidy) 71 | importFrom(rlang,inform) 72 | importFrom(rlang,quo_name) 73 | importFrom(rlang,set_names) 74 | importFrom(stats,prcomp) 75 | importFrom(stats,predict) 76 | importFrom(stats,var) 77 | importFrom(tibble,as_tibble) 78 | importFrom(tibble,rownames_to_column) 79 | importFrom(tibble,tibble) 80 | importFrom(tidyr,drop_na) 81 | importFrom(tidyr,nest) 82 | importFrom(tidyr,pivot_longer) 83 | importFrom(tidyr,unnest) 84 | importFrom(utils,globalVariables) 85 | -------------------------------------------------------------------------------- /R/0_imports.R: -------------------------------------------------------------------------------- 1 | #' @importFrom dplyr all_of arrange bind_cols count desc everything filter 2 | #' @importFrom dplyr group_by mutate mutate_ pull recode row_number select 3 | #' @importFrom dplyr select_ slice summarize 4 | #' @importFrom forcats fct_infreq fct_lump fct_relevel fct_rev 5 | #' @importFrom ggplot2 aes aes_ aes_string arrow facet_grid facet_wrap 6 | #' @importFrom ggplot2 geom_abline geom_bar geom_histogram geom_line 7 | #' @importFrom ggplot2 geom_point geom_raster geom_segment geom_text ggplot 8 | #' @importFrom ggplot2 guide_legend guides labs scale_alpha scale_fill_gradient2 9 | #' @importFrom ggplot2 scale_y_discrete theme_minimal unit xlim scale_x_log10 10 | #' @importFrom glue glue 11 | #' @importFrom magrittr %>% 12 | #' @importFrom parsnip augment 13 | #' @importFrom purrr map 14 | #' @importFrom readr parse_number 15 | #' @importFrom rlang .data abort enquo ensym eval_tidy inform quo_name set_names 16 | #' @importFrom Rtsne Rtsne 17 | #' @importFrom stats prcomp predict var 18 | #' @importFrom tibble as_tibble rownames_to_column tibble 19 | #' @importFrom tidyr drop_na nest pivot_longer unnest 20 | #' @importFrom utils globalVariables 21 | NULL 22 | #' 23 | 24 | # ------------------------------------------------------------------------------ 25 | # nocov 26 | 27 | utils::globalVariables( 28 | c(".pred", ".resid", "name", "value", "contribution", "lambda") 29 | ) 30 | 31 | # nocov end 32 | -------------------------------------------------------------------------------- /R/data.R: -------------------------------------------------------------------------------- 1 | #' Sample of 5 fairly tales from H.C. Andersen 2 | #' 3 | #' Includes the fairy tales "A leaf from heaven", "A story", 4 | #' "The bird of folklore", "The flea and the professor" and "The rags". 5 | #' 6 | #' @format A tibble with 486 obs. of 2 variables. 7 | "fairy_tales" 8 | 9 | #' Down-sampled MNIST 10 | #' 11 | #' this data set is a down-sampled version of the MNIST data set which have been 12 | #' umap'ed down to 2 dimensions. 13 | #' 14 | #' @source \url{http://yann.lecun.com/exdb/mnist/} 15 | #' @format A data frame with 1200 rows and 3 variables: 16 | #' \describe{ 17 | #' \item{class}{Factor, values 0 through 9.} 18 | #' \item{umap_1}{Numeric.} 19 | #' \item{umap_2}{Numeric.} 20 | #' } 21 | "mnist_sample" 22 | -------------------------------------------------------------------------------- /R/glmnet_coef_plot.R: -------------------------------------------------------------------------------- 1 | #' Create Lambda chart for glmnet object 2 | #' 3 | #' @param fit parsnip fit object 4 | #' 5 | #' @return ggplot2 object 6 | #' @export 7 | #' 8 | #' @examples 9 | #' library(parsnip) 10 | #' library(glmnet) 11 | #' linear_reg_glmnet_spec <- 12 | #' linear_reg(penalty = 0, mixture = 1) %>% 13 | #' set_engine('glmnet') 14 | #' 15 | #' lm_fit <- fit(linear_reg_glmnet_spec, data = mtcars, mpg ~ .) 16 | #' 17 | #' glmnet_coef_plot(lm_fit) 18 | glmnet_coef_plot <- function(fit) { 19 | as_tibble(t(as.matrix(fit$fit$beta))) %>% 20 | mutate(lambda = fit$fit$lambda) %>% 21 | pivot_longer(-lambda) %>% 22 | ggplot(aes(lambda, value, group = name)) + 23 | geom_line() + 24 | scale_x_log10() 25 | } 26 | -------------------------------------------------------------------------------- /R/utils-pipe.R: -------------------------------------------------------------------------------- 1 | #' Pipe operator 2 | #' 3 | #' See \code{magrittr::\link[magrittr]{\%>\%}} for details. 4 | #' 5 | #' @name %>% 6 | #' @rdname pipe 7 | #' @keywords internal 8 | #' @export 9 | #' @usage lhs \%>\% rhs 10 | NULL 11 | -------------------------------------------------------------------------------- /R/viz_classbalance.R: -------------------------------------------------------------------------------- 1 | #' Visualise Class Imbalance 2 | #' 3 | #' @param data A data.frame. 4 | #' @param variable target variable to show balance for. 5 | #' @param n_max integer, maximum number of classes shown before lumping. 6 | #' Defaults to 25. 7 | #' 8 | #' @return ggplot2 object. 9 | #' @export 10 | #' 11 | #' @examples 12 | #' viz_classbalance(mnist_sample, class) 13 | viz_classbalance <- function(data, variable, n_max = 25) { 14 | enquo_variable <- enquo(variable) 15 | 16 | if (!is.factor(data[[quo_name(enquo_variable)]])) { 17 | abort("`variable` must be a factor") 18 | } 19 | 20 | n_vars <- length(table(pull(data, !!enquo_variable))) 21 | if (n_vars > n_max) { 22 | data[[quo_name(enquo_variable)]] <- data[[quo_name(enquo_variable)]] %>% 23 | as.factor() %>% 24 | fct_infreq() %>% 25 | fct_lump(n_max) 26 | 27 | n_shown <- length(levels(data[[quo_name(enquo_variable)]])) - 1 28 | inform(glue("The number of catagories is {n_vars} only the first ", 29 | "{n_shown} is shown.")) 30 | } 31 | 32 | title <- paste0( 33 | "Class balence for ", 34 | format(nrow(data), big.mark = ","), 35 | " observations" 36 | ) 37 | 38 | ggplot(data, aes(!!enquo_variable)) + 39 | geom_bar() + 40 | labs(title = title) + 41 | theme_minimal() 42 | } 43 | -------------------------------------------------------------------------------- /R/viz_decision_boundary.R: -------------------------------------------------------------------------------- 1 | #' Draw Decision boundary for Classification model 2 | #' 3 | #' This function is mostly useful in an educational setting. Can only be used 4 | #' with trained workflow objects with 2 numeric predictor variables. 5 | #' 6 | #' @param x trained `workflows::workflow` object. 7 | #' @param new_data A data frame or tibble for whom the preprocessing will be 8 | #' applied. 9 | #' @param resolution Number of squared in grid. Defaults to 100. 10 | #' @param expand Expansion rate. Defaults to 0.1. This means that the width and 11 | #' height of the shaded area is 10% wider then the rectangle containing the 12 | #' data. 13 | #' 14 | #' The chart have been minimally modified to allow for easier styling. 15 | #' 16 | #' @return `ggplot2::ggplot` object 17 | #' @export 18 | #' 19 | #' @examples 20 | #' library(parsnip) 21 | #' library(workflows) 22 | #' svm_spec <- svm_rbf() %>% 23 | #' set_mode("classification") %>% 24 | #' set_engine("kernlab") 25 | #' 26 | #' svm_fit <- workflow() %>% 27 | #' add_formula(Species ~ Petal.Length + Petal.Width) %>% 28 | #' add_model(svm_spec) %>% 29 | #' fit(iris) 30 | #' 31 | #' viz_decision_boundary(svm_fit, iris) 32 | #' 33 | #' viz_decision_boundary(svm_fit, iris, resolution = 20) 34 | #' 35 | #' viz_decision_boundary(svm_fit, iris, expand = 1) 36 | #' 37 | #' svm_multi_fit <- workflow() %>% 38 | #' add_formula(class ~ umap_1 + umap_2) %>% 39 | #' add_model(svm_spec) %>% 40 | #' fit(mnist_sample) 41 | #' 42 | #' viz_decision_boundary(svm_multi_fit, mnist_sample) 43 | viz_decision_boundary <- function(x, new_data, resolution = 100, expand = 0.1) { 44 | 45 | if (!inherits(x, "workflow")) { 46 | abort("`viz_decision_boundary()` only works with `workflow` objects.") 47 | } 48 | if (!x$trained) { 49 | abort("`x` must be a trained `workflow` object.") 50 | } 51 | 52 | var_names <- extract_variable_names(x, new_data, n_pred = 2) 53 | 54 | predict_area <- new_data %>% 55 | select(all_of(var_names$predictors)) %>% 56 | lapply(expanded_seq, expand, resolution) %>% 57 | expand.grid() 58 | 59 | predict_area %>% 60 | bind_cols(predict(x, predict_area)) %>% 61 | ggplot( 62 | aes_string( 63 | var_names$predictors[1], 64 | var_names$predictors[2], 65 | fill = ".pred_class" 66 | ) 67 | ) + 68 | geom_raster(alpha = 0.2) + 69 | geom_point( 70 | aes_string( 71 | var_names$predictors[1], 72 | var_names$predictors[2], 73 | fill = var_names$response 74 | ), 75 | color = "black", shape = 22, data = new_data, inherit.aes = FALSE 76 | ) + 77 | theme_minimal() 78 | } 79 | 80 | expanded_seq <- function(x, expand, resolution) { 81 | x_range <- range(x, na.rm = TRUE) 82 | 83 | x_range_width <- x_range[2] - x_range[1] 84 | 85 | sequence <- seq( 86 | from = x_range[1] - x_range_width * expand / 2, 87 | to = x_range[2] + x_range_width * expand / 2, 88 | length.out = resolution 89 | ) 90 | 91 | if (is.integer(x)) { 92 | sequence <- unique(as.integer(sequence)) 93 | } 94 | sequence 95 | } 96 | -------------------------------------------------------------------------------- /R/viz_dispersion.R: -------------------------------------------------------------------------------- 1 | #' Vizualize lexical dispersion plot 2 | #' 3 | #' @param data A data.frame. 4 | #' @param var variable that contains the words to be visualized. 5 | #' @param group If present with show a group for each line with the words color 6 | #' coded. 7 | #' @param words Numerical or character. If numerical it will display the n 8 | #' most common words. If character will show the location of said strings. 9 | #' @param symbol The word symbol. Default to is 18 (filed diamond) when number 10 | #' of points are less then 200 and to 108 (vertical line) when there are more 11 | #' then 200 points. 12 | #' @param alpha color transperency of the word symbols. 13 | #' @return ggplot2 object. 14 | #' @examples 15 | #' \dontrun{ 16 | #' library(tidytext) 17 | #' 18 | #' text_data <- unnest_tokens(fairy_tales, word, text) 19 | #' viz_dispersion(text_data, word) 20 | #' viz_dispersion(text_data, word, words = c("branches", "not a word")) 21 | #' viz_dispersion(text_data, word, symbol = "2") 22 | #' viz_dispersion(text_data, word, group = book) 23 | #' } 24 | #' @export 25 | viz_dispersion <- function(data, var, group, words = 10, symbol = NULL, 26 | alpha = 0.7) { 27 | var <- ensym(var) 28 | 29 | ## TODO implement helper function for this 30 | if (class(words) == "numeric") { 31 | words <- count(data, !!var, sort = TRUE) %>% 32 | slice(seq_len(words)) 33 | 34 | vec <- pull(words, !!var) 35 | } 36 | if (any(class(words) == "character")) { 37 | vec <- words 38 | } 39 | 40 | if (missing(group)) { 41 | factors <- dispersion_factor(pull(data, !!var), vec) 42 | 43 | plot_data <- data %>% 44 | mutate_( 45 | x = ~row_number(), 46 | y = ~factors 47 | ) %>% 48 | drop_na() %>% 49 | select_(.dots = c("x", "y")) 50 | 51 | x_limit <- nrow(data) 52 | } else { 53 | group <- ensym(group) 54 | 55 | plot_data <- nest(data, !!var) %>% 56 | mutate(data = map(data, ~ { 57 | factors <- dispersion_factor(pull(.x, !!var), vec) 58 | .x %>% 59 | mutate_( 60 | x = ~ seq_len(nrow(.x)), 61 | color = ~factors 62 | ) 63 | })) %>% 64 | unnest() %>% 65 | drop_na() %>% 66 | select_(.dots = c("x", "color", "y" = "book")) 67 | 68 | x_limit <- nest(data, !!var)$data %>% 69 | sapply(nrow) %>% 70 | max() 71 | } 72 | 73 | if (is.null(symbol)) { 74 | symbol <- ifelse(nrow(plot_data) > 200, 108, 18) 75 | } 76 | 77 | if (missing(group)) { 78 | base_plot <- ggplot(plot_data) + 79 | aes_(~x, ~y) 80 | } else { 81 | base_plot <- ggplot(plot_data) + 82 | aes_(~x, ~y, color = ~color) 83 | } 84 | base_plot + 85 | geom_point(shape = symbol, alpha = alpha) + 86 | scale_y_discrete(drop = FALSE) + 87 | xlim(c(1, x_limit)) + 88 | guides(color = guide_legend(override.aes = list(shape = c(18)))) + 89 | labs( 90 | x = "Word Offset", 91 | y = NULL, 92 | title = "Lexical Dispersion Plot" 93 | ) + 94 | theme_minimal() 95 | } 96 | 97 | dispersion_factor <- function(x, names) { 98 | replacement <- seq_len(length(names)) 99 | names(replacement) <- names 100 | 101 | factor(recode(x, !!!replacement, 102 | .default = NA_integer_ 103 | ), levels = replacement, labels = names) 104 | } 105 | -------------------------------------------------------------------------------- /R/viz_fitted_line.R: -------------------------------------------------------------------------------- 1 | #' Draw fitted regression line 2 | #' 3 | #' This function is mostly useful in an educational setting. Can only be used 4 | #' with trained workflow objects with 1 numeric predictor variable. 5 | #' 6 | #' @param x trained `workflows::workflow` object. 7 | #' @param new_data A data frame or tibble for whom the preprocessing will be 8 | #' applied. 9 | #' @param resolution Number of squared in grid. Defaults to 100. 10 | #' @param expand Expansion rate. Defaults to 0.1. This means that the width of 11 | #' the plotting area is 10 percent wider then the data. 12 | #' @param color Character, color of the fitted line. Passed to `geom_line()`. 13 | #' Defaults to `"blue"`. 14 | #' @param size Numeric, size of the fitted line. Passed to `geom_line()`. 15 | #' Defaults to `1`. 16 | #' 17 | #' @details 18 | #' The chart have been minimally modified to allow for easier styling. 19 | #' 20 | #' @return `ggplot2::ggplot` object 21 | #' @export 22 | #' 23 | #' @examples 24 | #' library(parsnip) 25 | #' library(workflows) 26 | #' lm_spec <- linear_reg() %>% 27 | #' set_mode("regression") %>% 28 | #' set_engine("lm") 29 | #' 30 | #' lm_fit <- workflow() %>% 31 | #' add_formula(mpg ~ disp) %>% 32 | #' add_model(lm_spec) %>% 33 | #' fit(mtcars) 34 | #' 35 | #' viz_fitted_line(lm_fit, mtcars) 36 | #' 37 | #' viz_fitted_line(lm_fit, mtcars, expand = 1) 38 | viz_fitted_line <- function(x, new_data, resolution = 100, expand = 0.1, 39 | color = "blue", size = 1) { 40 | 41 | if (!inherits(x, "workflow")) { 42 | abort("`viz_fitted_line()` only works with `workflow` objects.") 43 | } 44 | if (!x$trained) { 45 | abort("`x` must be a trained `workflow` object.") 46 | } 47 | 48 | var_names <- extract_variable_names(x, new_data, n_pred = 1) 49 | 50 | predict_area <- new_data %>% 51 | select(all_of(var_names$predictors)) %>% 52 | lapply(expanded_seq, expand, resolution) %>% 53 | expand.grid() 54 | 55 | fitted_line <- bind_cols(predict_area, predict(x, predict_area)) 56 | 57 | new_data %>% 58 | ggplot( 59 | aes_string( 60 | var_names$predictors[1], 61 | var_names$response 62 | ) 63 | ) + 64 | geom_point(alpha = 0.5) + 65 | geom_line( 66 | aes_string( 67 | var_names$predictors[1], 68 | ".pred" 69 | ), 70 | color = color, 71 | size = size, 72 | data = fitted_line, 73 | inherit.aes = FALSE 74 | ) + 75 | theme_minimal() + 76 | xlim(range(predict_area[[1]])) 77 | } 78 | -------------------------------------------------------------------------------- /R/viz_pca.R: -------------------------------------------------------------------------------- 1 | #' Vizualize principal pomponents 2 | #' 3 | #' @param data A data.frame. 4 | #' @param label variable to color with. 5 | #' @param components principal components to showcase. 6 | #' @param loadings Set this to true if you want to see the PCA loadings. 7 | #' @return ggplot2 object. 8 | #' @examples 9 | #' viz_pca(iris, Species) 10 | #' viz_pca(iris, Species, c(3, 1)) 11 | #' viz_pca(iris, Species, loadings = TRUE) 12 | #' @export 13 | viz_pca <- function(data, label, components = c(1, 2), loadings = FALSE) { 14 | label_enquo <- enquo(label) 15 | 16 | names <- paste0("PC", components) 17 | 18 | pca_obj <- data %>% 19 | select(-!!label_enquo) %>% 20 | as.matrix() %>% 21 | prcomp() 22 | 23 | plot_data <- as_tibble(pca_obj$x) %>% 24 | mutate(Label = pull(data, !!label_enquo)) 25 | 26 | p <- plot_data %>% 27 | ggplot() + 28 | aes_string(names[1], names[2], color = "Label") + 29 | geom_point() + 30 | labs( 31 | x = glue("Principal Component {components[1]}"), 32 | y = glue("Principal Component {components[2]}"), 33 | title = "Principal Component plot" 34 | ) 35 | 36 | if (loadings) { 37 | loadings_data <- pca_obj$rotation[, components] %>% 38 | as.data.frame() %>% 39 | rownames_to_column() 40 | arrow <- arrow(length = unit(0.03, "npc")) 41 | p <- p + 42 | geom_segment(aes_string( 43 | x = 0, y = 0, 44 | xend = names[1], 45 | yend = names[2] 46 | ), 47 | data = loadings_data, inherit.aes = FALSE, 48 | arrow = arrow 49 | ) + 50 | geom_text(aes_string(names[1], names[2], label = "rowname"), 51 | data = loadings_data, inherit.aes = FALSE 52 | ) + 53 | theme_minimal() 54 | } 55 | 56 | p 57 | } 58 | -------------------------------------------------------------------------------- /R/viz_pcacm.R: -------------------------------------------------------------------------------- 1 | #' PCA component measures 2 | #' 3 | #' @param x a data.frame 4 | #' @param n_pca Number of Principle components to show 5 | #' @param n_var Number of Variables to show 6 | #' 7 | #' @return a ggplot2 object 8 | #' @export 9 | #' 10 | #' @examples 11 | #' viz_pcacm(mtcars) 12 | #' viz_pcacm(USArrests) 13 | #' viz_pcacm(beaver1) 14 | #' 15 | #' viz_pcacm(mtcars, n_pca = 4) 16 | #' viz_pcacm(mtcars, n_var = 4) 17 | #' viz_pcacm(mtcars, n_pca = 2, n_var = 6) 18 | viz_pcacm <- function(x, n_pca = 10, n_var = 10) { 19 | tidy_rotation <- as.data.frame(prcomp(x, rank = n_pca)$rotation) %>% 20 | rownames_to_column("var") %>% 21 | pivot_longer(-var) %>% 22 | filter(name %in% paste0("PC", seq_len(n_pca))) 23 | 24 | order <- tidy_rotation %>% 25 | group_by(var) %>% 26 | summarize(contribution = sum(abs(value) * 1 / parse_number(name))) %>% 27 | arrange(desc(contribution)) %>% 28 | slice(seq_len(n_var)) %>% 29 | pull(var) 30 | 31 | tidy_rotation %>% 32 | filter(var %in% order) %>% 33 | mutate(var = fct_relevel(var, order), 34 | name = fct_relevel(name, paste0("PC", seq_len(n_pca))), 35 | name = fct_rev(name)) %>% 36 | ggplot(aes(var, name, fill = value)) + 37 | geom_raster() + 38 | scale_fill_gradient2( 39 | low = "purple", 40 | mid = "grey90", 41 | high = "orange" 42 | ) + 43 | labs( 44 | y = "Principle Component", 45 | x = "Variable" 46 | ) + 47 | theme_minimal() 48 | } 49 | -------------------------------------------------------------------------------- /R/viz_prob_distribution.R: -------------------------------------------------------------------------------- 1 | #' #' Vizualize the predicted probalities for each class 2 | #' #' 3 | #' #' @param model A `model_fit` object from the parsnip package. 4 | #' #' @param new_data A data.frame to run predictions on. 5 | #' #' @param truth A vector of true classes used to color probalities. 6 | #' #' Defaults to NULL. 7 | #' #' 8 | #' #' @return ggplot2 object. 9 | #' #' @export 10 | #' viz_prob_distribution <- function(model, new_data, truth) { 11 | #' UseMethod("viz_prob_distribution") 12 | #' } 13 | #' 14 | #' #' @rdname viz_prob_distribution 15 | #' #' @export 16 | #' viz_prob_distribution.default <- function(model, new_data, truth) { 17 | #' stop("`model` must be a `model_fit` object from the parsnip package.", 18 | #' call. = FALSE 19 | #' ) 20 | #' } 21 | #' 22 | #' #' @rdname viz_prob_distribution 23 | #' #' @export 24 | #' #' 25 | #' #' @examples 26 | #' #' library(parsnip) 27 | #' #' library(ranger) 28 | #' #' 29 | #' #' fit_model <- rand_forest("classification") %>% 30 | #' #' set_engine("ranger") %>% 31 | #' #' fit(Species ~ ., data = iris) 32 | #' #' 33 | #' #' viz_prob_distribution(fit_model, new_data = iris) 34 | #' #' 35 | #' #' viz_prob_distribution(fit_model, new_data = iris, truth = iris$Species) 36 | #' viz_prob_distribution.model_fit <- function(model, new_data, truth = NULL) { 37 | #' if (model$spec$mode != "classification") { 38 | #' stop("`model` must be a classification model.", call. = FALSE) 39 | #' } 40 | #' 41 | #' plotting_data <- predict(model, new_data, type = "prob") 42 | #' 43 | #' if (is.null(truth)) { 44 | #' plotting_data <- plotting_data %>% 45 | #' pivot_longer(cols = everything()) 46 | #' } else { 47 | #' plotting_data <- plotting_data %>% 48 | #' mutate(truth = truth) %>% 49 | #' pivot_longer(cols = -truth) %>% 50 | #' mutate(correct = factor( 51 | #' gsub(".pred_", "", .data$name) == .data$truth, 52 | #' c(TRUE, FALSE), 53 | #' c("Yes", "No") 54 | #' )) %>% 55 | #' filter(gsub(".pred_", "", .data$name) == .data$truth) 56 | #' } 57 | #' 58 | #' out <- plotting_data %>% 59 | #' ggplot(aes(.data$value)) + 60 | #' geom_histogram(bins = 50) + 61 | #' facet_grid(~ .data$name) + 62 | #' theme_minimal() + 63 | #' labs(x = "Predicted probability") 64 | #' 65 | #' if (!is.null(truth)) { 66 | #' out <- out + 67 | #' aes(fill = .data$correct) + 68 | #' labs(fill = "Correctlty predicted") 69 | #' } 70 | #' out 71 | #' } 72 | -------------------------------------------------------------------------------- /R/viz_prob_region.R: -------------------------------------------------------------------------------- 1 | #' Draw Probability regions for Classification model 2 | #' 3 | #' This function is mostly useful in an educational setting. Can only be used 4 | #' with trained workflow objects with 2 numeric predictor variables. 5 | #' 6 | #' @param x trained `workflows::workflow` object. 7 | #' @param new_data A data frame or tibble for whom the preprocessing will be 8 | #' applied. 9 | #' @param resolution Number of squared in grid. Defaults to 100. 10 | #' @param expand Expansion rate. Defaults to 0.1. This means that the width and 11 | #' height of the shaded area is 10% wider then the rectangle containing the 12 | #' data. 13 | #' @param facet Logical, whether to facet chart by class. Defaults to FALSE. 14 | #' 15 | #' The chart have been minimally modified to allow for easier styling. 16 | #' 17 | #' @return `ggplot2::ggplot` object 18 | #' @export 19 | #' 20 | #' @examples 21 | #' library(parsnip) 22 | #' library(workflows) 23 | #' 24 | #' iris2 <- iris 25 | #' iris2$Species <- factor(iris2$Species == "setosa", 26 | #' labels = c("setosa", "not setosa")) 27 | #' 28 | #' svm_spec <- svm_rbf() %>% 29 | #' set_mode("classification") %>% 30 | #' set_engine("kernlab") 31 | #' 32 | #' svm_fit <- workflow() %>% 33 | #' add_formula(Species ~ Petal.Length + Petal.Width) %>% 34 | #' add_model(svm_spec) %>% 35 | #' fit(iris2) 36 | #' 37 | #' viz_prob_region(svm_fit, iris2) 38 | #' 39 | #' viz_prob_region(svm_fit, iris2, resolution = 20) 40 | #' 41 | #' viz_prob_region(svm_fit, iris2, expand = 1) 42 | #' 43 | #' viz_prob_region(svm_fit, iris2, facet = TRUE) 44 | #' 45 | #' knn_spec <- nearest_neighbor() %>% 46 | #' set_mode("classification") %>% 47 | #' set_engine("kknn") 48 | #' 49 | #' knn_fit <- workflow() %>% 50 | #' add_formula(class ~ umap_1 + umap_2) %>% 51 | #' add_model(knn_spec) %>% 52 | #' fit(mnist_sample) 53 | #' 54 | #' viz_prob_region(knn_fit, mnist_sample, facet = TRUE) 55 | viz_prob_region <- function(x, new_data, resolution = 100, expand = 0.1, 56 | facet = FALSE) { 57 | if (!inherits(x, "workflow")) { 58 | abort("`viz_decision_boundary()` only works with `workflow` objects.") 59 | } 60 | if (!x$trained) { 61 | abort("`x` must be a trained `workflow` object.") 62 | } 63 | 64 | var_names <- extract_variable_names(x, new_data, n_pred = 2) 65 | 66 | if (length(levels(new_data[[var_names$response]])) != 2 & !facet) { 67 | abort("The response must have only 2 levels for unfaceted chart.") 68 | } 69 | 70 | predict_area <- new_data %>% 71 | select(all_of(var_names$predictors)) %>% 72 | lapply(expanded_seq, expand, resolution) %>% 73 | expand.grid() 74 | 75 | predicted_prob <- predict(x, predict_area, type = "prob") 76 | 77 | plotting_data <- predict_area %>% 78 | bind_cols(predicted_prob) 79 | 80 | if (facet) { 81 | plotting_data <- pivot_longer(plotting_data, names(predicted_prob), 82 | names_to = ".class", values_to = "probability") 83 | 84 | plotting_data$.class <- gsub("^.pred_", "", plotting_data$.class) 85 | 86 | res <- plotting_data %>% 87 | ggplot( 88 | aes_string( 89 | var_names$predictors[1], 90 | var_names$predictors[2], 91 | fill = ".class" 92 | ) 93 | ) + 94 | geom_raster(aes_string(alpha = "probability")) + 95 | facet_wrap(~ .data$.class) + 96 | geom_point( 97 | aes_string( 98 | var_names$predictors[1], 99 | var_names$predictors[2], 100 | fill = var_names$response 101 | ), 102 | color = "black", shape = 22, data = new_data, inherit.aes = FALSE 103 | ) + 104 | theme_minimal() + 105 | scale_alpha(range = c(0, 1)) 106 | 107 | } else { 108 | res <- plotting_data %>% 109 | ggplot( 110 | aes_string( 111 | var_names$predictors[1], 112 | var_names$predictors[2], 113 | fill = names(predicted_prob)[1] 114 | ) 115 | ) + 116 | geom_raster(alpha = 0.5) + 117 | scale_fill_gradient2(low = "blue", mid = "white", high = "green", 118 | midpoint = 0.5) + 119 | geom_point( 120 | aes_string( 121 | var_names$predictors[1], 122 | var_names$predictors[2] 123 | ), 124 | color = "black", shape = 22, data = new_data, inherit.aes = FALSE 125 | ) + 126 | theme_minimal() 127 | } 128 | res 129 | } 130 | -------------------------------------------------------------------------------- /R/viz_residuals.R: -------------------------------------------------------------------------------- 1 | #' Residual plot of model 2 | #' 3 | #' @param fit a parsnip model object 4 | #' @inheritParams parsnip:::augment.model_fit 5 | #' 6 | #' @return a ggplot2 object 7 | #' @export 8 | #' 9 | #' @examples 10 | #' library(parsnip) 11 | #' reg <- linear_reg() %>% 12 | #' set_engine("lm") %>% 13 | #' set_mode("regression") %>% 14 | #' fit(mpg ~ ., data = mtcars) 15 | #' 16 | #' viz_residuals(reg, mtcars) 17 | #' 18 | #' linear_reg() %>% 19 | #' set_engine("lm") %>% 20 | #' set_mode("regression") %>% 21 | #' fit(mpg ~ ., data = mtcars) %>% 22 | #' viz_residuals(mtcars) 23 | viz_residuals <- function(fit, new_data) { 24 | augment(fit, new_data = new_data) %>% 25 | ggplot(aes(.pred, .resid)) + 26 | geom_point() + 27 | geom_abline(slope = 0, intercept = 0) + 28 | theme_minimal() 29 | } 30 | -------------------------------------------------------------------------------- /R/viz_tsne.R: -------------------------------------------------------------------------------- 1 | #' Vizualize t-SNE 2 | #' 3 | #' @param data A data.frame. 4 | #' @param label variable to color with. 5 | #' @return ggplot2 object. 6 | #' @examples 7 | #' library(dplyr) 8 | #' viz_tsne(iris, Species) 9 | #' viz_tsne(select(iris, -Species), Sepal.Length) 10 | #' @export 11 | viz_tsne <- function(data, label) { 12 | label_enquo <- enquo(label) 13 | 14 | names <- paste0("V", c(1, 2)) 15 | 16 | tsne_data <- data %>% 17 | select(-!!label_enquo) %>% 18 | as.matrix() %>% 19 | Rtsne(check_duplicates = FALSE) 20 | 21 | plotting_data <- tsne_data$Y 22 | 23 | colnames(plotting_data) <- names 24 | 25 | as_tibble(plotting_data) %>% 26 | mutate(Label = pull(data, !!label_enquo)) %>% 27 | ggplot() + 28 | aes_string(names[1], names[2], color = "Label") + 29 | geom_point() + 30 | labs( 31 | x = "", 32 | y = "", 33 | title = "t-SNE Manifold" 34 | ) + 35 | theme_minimal() 36 | } 37 | -------------------------------------------------------------------------------- /R/workflow_utils.R: -------------------------------------------------------------------------------- 1 | extract_variable_names <- function(x, new_data, n_pred) { 2 | if (names(x$pre$actions) == "variables") { 3 | predictors <- eval_tidy( 4 | x$pre$actions$variables$predictors, 5 | set_names(names(new_data)) 6 | ) 7 | 8 | response <- eval_tidy( 9 | x$pre$actions$variables$outcome, 10 | set_names(names(new_data)) 11 | ) 12 | } else if (names(x$pre$actions) == "formula") { 13 | predictors <- intersect( 14 | as.character(x$pre$actions$formula$formula[[3]]), 15 | names(new_data) 16 | ) 17 | 18 | response <- as.character(x$pre$actions$formula$formula[[2]]) 19 | } else if (names(x$pre$actions) == "recipe") { 20 | var_info <- x$pre$actions$recipe$recipe$var_info 21 | predictors <- var_info$variable[var_info$role == "predictor"] 22 | response <- var_info$variable[var_info$role == "outcome"] 23 | } 24 | 25 | if (length(predictors) != n_pred) { 26 | abort(glue("`x` must have only {n_pred} predictors.")) 27 | } 28 | 29 | list(predictors = predictors, response = response) 30 | } 31 | -------------------------------------------------------------------------------- /README.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | output: github_document 3 | --- 4 | 5 | 6 | 7 | ```{r, include = FALSE} 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "man/figures/README-", 12 | out.width = "100%" 13 | ) 14 | ``` 15 | 16 | # horus 17 | 18 | 19 | [![R-CMD-check](https://github.com/EmilHvitfeldt/horus/workflows/R-CMD-check/badge.svg)](https://github.com/EmilHvitfeldt/horus/actions) 20 | 21 | 22 | **WIP -- Very early build, things are very likely to change -- WIP** 23 | 24 | The goal of horus is to allow quick visualization methods for common machine learning and modeling tasks. This project is hugely inspired by the Python library [yellowbrick](https://github.com/DistrictDataLabs/yellowbrick). 25 | 26 | ## Installation 27 | 28 | For the time being `horus` is only available on Github, and can be installed 29 | with `devtools`: 30 | 31 | ```{r, eval=FALSE} 32 | # install.packages('devtools') 33 | devtools::install_github('EmilHvitfeldt/horus') 34 | ``` 35 | 36 | In the future the package will be available on CRAN as well. 37 | 38 | ## Example 39 | 40 | There is no reason why a principal component plot of a data set should be as hard as it currently is in R. Using **horus** it is down to a single line! 41 | 42 | ```{r} 43 | library(horus) 44 | viz_pca(iris, Species) 45 | ``` 46 | 47 | ## Advanced Gallary 48 | 49 | [Random Forrest variability Visualization](https://gist.github.com/EmilHvitfeldt/e81f9d462c423978f515f036c8ad0232) 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # horus 5 | 6 | 7 | 8 | [![R-CMD-check](https://github.com/EmilHvitfeldt/horus/workflows/R-CMD-check/badge.svg)](https://github.com/EmilHvitfeldt/horus/actions) 9 | 10 | 11 | **WIP – Very early build, things are very likely to change – WIP** 12 | 13 | The goal of horus is to allow quick visualization methods for common 14 | machine learning and modeling tasks. This project is hugely inspired by 15 | the Python library 16 | [yellowbrick](https://github.com/DistrictDataLabs/yellowbrick). 17 | 18 | ## Installation 19 | 20 | For the time being `horus` is only available on Github, and can be 21 | installed with `devtools`: 22 | 23 | ``` r 24 | # install.packages('devtools') 25 | devtools::install_github('EmilHvitfeldt/horus') 26 | ``` 27 | 28 | In the future the package will be available on CRAN as well. 29 | 30 | ## Example 31 | 32 | There is no reason why a principal component plot of a data set should 33 | be as hard as it currently is in R. Using **horus** it is down to a 34 | single line! 35 | 36 | ``` r 37 | library(horus) 38 | viz_pca(iris, Species) 39 | ``` 40 | 41 | 42 | 43 | ## Advanced Gallary 44 | 45 | [Random Forrest variability 46 | Visualization](https://gist.github.com/EmilHvitfeldt/e81f9d462c423978f515f036c8ad0232) 47 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- 1 | url: https://emilhvitfeldt.github.io/horus 2 | -------------------------------------------------------------------------------- /data-raw/fairy_tales.R: -------------------------------------------------------------------------------- 1 | set.seed(12345) 2 | library(tidyverse) 3 | # devtools::install_github("EmilHvitfeldt/hcandersenr") 4 | select_books <- hcandersenr::hcandersen_en %>% 5 | pull(book) %>% 6 | unique() %>% 7 | sample(size = 5) 8 | 9 | fairy_tales <- hcandersenr::hcandersen_en %>% 10 | filter(book %in% select_books) 11 | -------------------------------------------------------------------------------- /data-raw/mnist_sample.R: -------------------------------------------------------------------------------- 1 | library(keras) 2 | library(recipes) 3 | library(embed) 4 | library(purrr) 5 | library(dplyr) 6 | set.seed(1234) 7 | 8 | mnist_raw <- dataset_mnist() 9 | 10 | tidy_mnist <- map_dfc(1:28, ~tibble::as_tibble(mnist_raw$train$x[, , .x])) %>% 11 | mutate(class = factor(mnist_raw$train$y)) 12 | 13 | mnist_sample <- recipe(class ~ ., data = slice_sample(tidy_mnist, prop = 0.02)) %>% 14 | step_umap(all_predictors(), num_comp = 2) %>% 15 | prep() %>% 16 | juice() 17 | 18 | usethis::use_data(mnist_sample, overwrite = TRUE) 19 | -------------------------------------------------------------------------------- /data/fairy_tales.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/data/fairy_tales.rda -------------------------------------------------------------------------------- /data/mnist_sample.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/data/mnist_sample.rda -------------------------------------------------------------------------------- /horus.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: No 4 | SaveWorkspace: No 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: knitr 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | 18 | BuildType: Package 19 | PackageUseDevtools: Yes 20 | PackageInstallArgs: --no-multiarch --with-keep.source 21 | PackageRoxygenize: rd,collate,namespace 22 | -------------------------------------------------------------------------------- /man/fairy_tales.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{fairy_tales} 5 | \alias{fairy_tales} 6 | \title{Sample of 5 fairly tales from H.C. Andersen} 7 | \format{ 8 | A tibble with 486 obs. of 2 variables. 9 | } 10 | \usage{ 11 | fairy_tales 12 | } 13 | \description{ 14 | Includes the fairy tales "A leaf from heaven", "A story", 15 | "The bird of folklore", "The flea and the professor" and "The rags". 16 | } 17 | \keyword{datasets} 18 | -------------------------------------------------------------------------------- /man/figures/README-unnamed-chunk-2-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/man/figures/README-unnamed-chunk-2-1.png -------------------------------------------------------------------------------- /man/figures/README-unnamed-chunk-3-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/man/figures/README-unnamed-chunk-3-1.png -------------------------------------------------------------------------------- /man/figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/man/figures/logo.png -------------------------------------------------------------------------------- /man/glmnet_coef_plot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/glmnet_coef_plot.R 3 | \name{glmnet_coef_plot} 4 | \alias{glmnet_coef_plot} 5 | \title{Create Lambda chart for glmnet object} 6 | \usage{ 7 | glmnet_coef_plot(fit) 8 | } 9 | \arguments{ 10 | \item{fit}{parsnip fit object} 11 | } 12 | \value{ 13 | ggplot2 object 14 | } 15 | \description{ 16 | Create Lambda chart for glmnet object 17 | } 18 | \examples{ 19 | library(parsnip) 20 | library(glmnet) 21 | linear_reg_glmnet_spec <- 22 | linear_reg(penalty = 0, mixture = 1) \%>\% 23 | set_engine('glmnet') 24 | 25 | lm_fit <- fit(linear_reg_glmnet_spec, data = mtcars, mpg ~ .) 26 | 27 | glmnet_coef_plot(lm_fit) 28 | } 29 | -------------------------------------------------------------------------------- /man/mnist_sample.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{mnist_sample} 5 | \alias{mnist_sample} 6 | \title{Down-sampled MNIST} 7 | \format{ 8 | A data frame with 1200 rows and 3 variables: 9 | \describe{ 10 | \item{class}{Factor, values 0 through 9.} 11 | \item{umap_1}{Numeric.} 12 | \item{umap_2}{Numeric.} 13 | } 14 | } 15 | \source{ 16 | \url{http://yann.lecun.com/exdb/mnist/} 17 | } 18 | \usage{ 19 | mnist_sample 20 | } 21 | \description{ 22 | this data set is a down-sampled version of the MNIST data set which have been 23 | umap'ed down to 2 dimensions. 24 | } 25 | \keyword{datasets} 26 | -------------------------------------------------------------------------------- /man/pipe.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/utils-pipe.R 3 | \name{\%>\%} 4 | \alias{\%>\%} 5 | \title{Pipe operator} 6 | \usage{ 7 | lhs \%>\% rhs 8 | } 9 | \description{ 10 | See \code{magrittr::\link[magrittr]{\%>\%}} for details. 11 | } 12 | \keyword{internal} 13 | -------------------------------------------------------------------------------- /man/viz_classbalance.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/viz_classbalance.R 3 | \name{viz_classbalance} 4 | \alias{viz_classbalance} 5 | \title{Visualise Class Imbalance} 6 | \usage{ 7 | viz_classbalance(data, variable, n_max = 25) 8 | } 9 | \arguments{ 10 | \item{data}{A data.frame.} 11 | 12 | \item{variable}{target variable to show balance for.} 13 | 14 | \item{n_max}{integer, maximum number of classes shown before lumping. 15 | Defaults to 25.} 16 | } 17 | \value{ 18 | ggplot2 object. 19 | } 20 | \description{ 21 | Visualise Class Imbalance 22 | } 23 | \examples{ 24 | viz_classbalance(mnist_sample, class) 25 | } 26 | -------------------------------------------------------------------------------- /man/viz_decision_boundary.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/viz_decision_boundary.R 3 | \name{viz_decision_boundary} 4 | \alias{viz_decision_boundary} 5 | \title{Draw Decision boundary for Classification model} 6 | \usage{ 7 | viz_decision_boundary(x, new_data, resolution = 100, expand = 0.1) 8 | } 9 | \arguments{ 10 | \item{x}{trained `workflows::workflow` object.} 11 | 12 | \item{new_data}{A data frame or tibble for whom the preprocessing will be 13 | applied.} 14 | 15 | \item{resolution}{Number of squared in grid. Defaults to 100.} 16 | 17 | \item{expand}{Expansion rate. Defaults to 0.1. This means that the width and 18 | height of the shaded area is 10% wider then the rectangle containing the 19 | data. 20 | 21 | The chart have been minimally modified to allow for easier styling.} 22 | } 23 | \value{ 24 | `ggplot2::ggplot` object 25 | } 26 | \description{ 27 | This function is mostly useful in an educational setting. Can only be used 28 | with trained workflow objects with 2 numeric predictor variables. 29 | } 30 | \examples{ 31 | library(parsnip) 32 | library(workflows) 33 | svm_spec <- svm_rbf() \%>\% 34 | set_mode("classification") \%>\% 35 | set_engine("kernlab") 36 | 37 | svm_fit <- workflow() \%>\% 38 | add_formula(Species ~ Petal.Length + Petal.Width) \%>\% 39 | add_model(svm_spec) \%>\% 40 | fit(iris) 41 | 42 | viz_decision_boundary(svm_fit, iris) 43 | 44 | viz_decision_boundary(svm_fit, iris, resolution = 20) 45 | 46 | viz_decision_boundary(svm_fit, iris, expand = 1) 47 | 48 | svm_multi_fit <- workflow() \%>\% 49 | add_formula(class ~ umap_1 + umap_2) \%>\% 50 | add_model(svm_spec) \%>\% 51 | fit(mnist_sample) 52 | 53 | viz_decision_boundary(svm_multi_fit, mnist_sample) 54 | } 55 | -------------------------------------------------------------------------------- /man/viz_dispersion.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/viz_dispersion.R 3 | \name{viz_dispersion} 4 | \alias{viz_dispersion} 5 | \title{Vizualize lexical dispersion plot} 6 | \usage{ 7 | viz_dispersion(data, var, group, words = 10, symbol = NULL, alpha = 0.7) 8 | } 9 | \arguments{ 10 | \item{data}{A data.frame.} 11 | 12 | \item{var}{variable that contains the words to be visualized.} 13 | 14 | \item{group}{If present with show a group for each line with the words color 15 | coded.} 16 | 17 | \item{words}{Numerical or character. If numerical it will display the n 18 | most common words. If character will show the location of said strings.} 19 | 20 | \item{symbol}{The word symbol. Default to is 18 (filed diamond) when number 21 | of points are less then 200 and to 108 (vertical line) when there are more 22 | then 200 points.} 23 | 24 | \item{alpha}{color transperency of the word symbols.} 25 | } 26 | \value{ 27 | ggplot2 object. 28 | } 29 | \description{ 30 | Vizualize lexical dispersion plot 31 | } 32 | \examples{ 33 | \dontrun{ 34 | library(tidytext) 35 | 36 | text_data <- unnest_tokens(fairy_tales, word, text) 37 | viz_dispersion(text_data, word) 38 | viz_dispersion(text_data, word, words = c("branches", "not a word")) 39 | viz_dispersion(text_data, word, symbol = "2") 40 | viz_dispersion(text_data, word, group = book) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /man/viz_fitted_line.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/viz_fitted_line.R 3 | \name{viz_fitted_line} 4 | \alias{viz_fitted_line} 5 | \title{Draw fitted regression line} 6 | \usage{ 7 | viz_fitted_line( 8 | x, 9 | new_data, 10 | resolution = 100, 11 | expand = 0.1, 12 | color = "blue", 13 | size = 1 14 | ) 15 | } 16 | \arguments{ 17 | \item{x}{trained `workflows::workflow` object.} 18 | 19 | \item{new_data}{A data frame or tibble for whom the preprocessing will be 20 | applied.} 21 | 22 | \item{resolution}{Number of squared in grid. Defaults to 100.} 23 | 24 | \item{expand}{Expansion rate. Defaults to 0.1. This means that the width of 25 | the plotting area is 10 percent wider then the data.} 26 | 27 | \item{color}{Character, color of the fitted line. Passed to `geom_line()`. 28 | Defaults to `"blue"`.} 29 | 30 | \item{size}{Numeric, size of the fitted line. Passed to `geom_line()`. 31 | Defaults to `1`.} 32 | } 33 | \value{ 34 | `ggplot2::ggplot` object 35 | } 36 | \description{ 37 | This function is mostly useful in an educational setting. Can only be used 38 | with trained workflow objects with 1 numeric predictor variable. 39 | } 40 | \details{ 41 | The chart have been minimally modified to allow for easier styling. 42 | } 43 | \examples{ 44 | library(parsnip) 45 | library(workflows) 46 | lm_spec <- linear_reg() \%>\% 47 | set_mode("regression") \%>\% 48 | set_engine("lm") 49 | 50 | lm_fit <- workflow() \%>\% 51 | add_formula(mpg ~ disp) \%>\% 52 | add_model(lm_spec) \%>\% 53 | fit(mtcars) 54 | 55 | viz_fitted_line(lm_fit, mtcars) 56 | 57 | viz_fitted_line(lm_fit, mtcars, expand = 1) 58 | } 59 | -------------------------------------------------------------------------------- /man/viz_pca.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/viz_pca.R 3 | \name{viz_pca} 4 | \alias{viz_pca} 5 | \title{Vizualize principal pomponents} 6 | \usage{ 7 | viz_pca(data, label, components = c(1, 2), loadings = FALSE) 8 | } 9 | \arguments{ 10 | \item{data}{A data.frame.} 11 | 12 | \item{label}{variable to color with.} 13 | 14 | \item{components}{principal components to showcase.} 15 | 16 | \item{loadings}{Set this to true if you want to see the PCA loadings.} 17 | } 18 | \value{ 19 | ggplot2 object. 20 | } 21 | \description{ 22 | Vizualize principal pomponents 23 | } 24 | \examples{ 25 | viz_pca(iris, Species) 26 | viz_pca(iris, Species, c(3, 1)) 27 | viz_pca(iris, Species, loadings = TRUE) 28 | } 29 | -------------------------------------------------------------------------------- /man/viz_pcacm.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/viz_pcacm.R 3 | \name{viz_pcacm} 4 | \alias{viz_pcacm} 5 | \title{PCA component measures} 6 | \usage{ 7 | viz_pcacm(x, n_pca = 10, n_var = 10) 8 | } 9 | \arguments{ 10 | \item{x}{a data.frame} 11 | 12 | \item{n_pca}{Number of Principle components to show} 13 | 14 | \item{n_var}{Number of Variables to show} 15 | } 16 | \value{ 17 | a ggplot2 object 18 | } 19 | \description{ 20 | PCA component measures 21 | } 22 | \examples{ 23 | viz_pcacm(mtcars) 24 | viz_pcacm(USArrests) 25 | viz_pcacm(beaver1) 26 | 27 | viz_pcacm(mtcars, n_pca = 4) 28 | viz_pcacm(mtcars, n_var = 4) 29 | viz_pcacm(mtcars, n_pca = 2, n_var = 6) 30 | } 31 | -------------------------------------------------------------------------------- /man/viz_prob_region.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/viz_prob_region.R 3 | \name{viz_prob_region} 4 | \alias{viz_prob_region} 5 | \title{Draw Probability regions for Classification model} 6 | \usage{ 7 | viz_prob_region(x, new_data, resolution = 100, expand = 0.1, facet = FALSE) 8 | } 9 | \arguments{ 10 | \item{x}{trained `workflows::workflow` object.} 11 | 12 | \item{new_data}{A data frame or tibble for whom the preprocessing will be 13 | applied.} 14 | 15 | \item{resolution}{Number of squared in grid. Defaults to 100.} 16 | 17 | \item{expand}{Expansion rate. Defaults to 0.1. This means that the width and 18 | height of the shaded area is 10% wider then the rectangle containing the 19 | data.} 20 | 21 | \item{facet}{Logical, whether to facet chart by class. Defaults to FALSE. 22 | 23 | The chart have been minimally modified to allow for easier styling.} 24 | } 25 | \value{ 26 | `ggplot2::ggplot` object 27 | } 28 | \description{ 29 | This function is mostly useful in an educational setting. Can only be used 30 | with trained workflow objects with 2 numeric predictor variables. 31 | } 32 | \examples{ 33 | library(parsnip) 34 | library(workflows) 35 | 36 | iris2 <- iris 37 | iris2$Species <- factor(iris2$Species == "setosa", 38 | labels = c("setosa", "not setosa")) 39 | 40 | svm_spec <- svm_rbf() \%>\% 41 | set_mode("classification") \%>\% 42 | set_engine("kernlab") 43 | 44 | svm_fit <- workflow() \%>\% 45 | add_formula(Species ~ Petal.Length + Petal.Width) \%>\% 46 | add_model(svm_spec) \%>\% 47 | fit(iris2) 48 | 49 | viz_prob_region(svm_fit, iris2) 50 | 51 | viz_prob_region(svm_fit, iris2, resolution = 20) 52 | 53 | viz_prob_region(svm_fit, iris2, expand = 1) 54 | 55 | viz_prob_region(svm_fit, iris2, facet = TRUE) 56 | 57 | knn_spec <- nearest_neighbor() \%>\% 58 | set_mode("classification") \%>\% 59 | set_engine("kknn") 60 | 61 | knn_fit <- workflow() \%>\% 62 | add_formula(class ~ umap_1 + umap_2) \%>\% 63 | add_model(knn_spec) \%>\% 64 | fit(mnist_sample) 65 | 66 | viz_prob_region(knn_fit, mnist_sample, facet = TRUE) 67 | } 68 | -------------------------------------------------------------------------------- /man/viz_residuals.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/viz_residuals.R 3 | \name{viz_residuals} 4 | \alias{viz_residuals} 5 | \title{Residual plot of model} 6 | \usage{ 7 | viz_residuals(fit, new_data) 8 | } 9 | \arguments{ 10 | \item{fit}{a parsnip model object} 11 | 12 | \item{new_data}{A data frame or matrix.} 13 | } 14 | \value{ 15 | a ggplot2 object 16 | } 17 | \description{ 18 | Residual plot of model 19 | } 20 | \examples{ 21 | library(parsnip) 22 | reg <- linear_reg() \%>\% 23 | set_engine("lm") \%>\% 24 | set_mode("regression") \%>\% 25 | fit(mpg ~ ., data = mtcars) 26 | 27 | viz_residuals(reg, mtcars) 28 | 29 | linear_reg() \%>\% 30 | set_engine("lm") \%>\% 31 | set_mode("regression") \%>\% 32 | fit(mpg ~ ., data = mtcars) \%>\% 33 | viz_residuals(mtcars) 34 | } 35 | -------------------------------------------------------------------------------- /man/viz_tsne.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/viz_tsne.R 3 | \name{viz_tsne} 4 | \alias{viz_tsne} 5 | \title{Vizualize t-SNE} 6 | \usage{ 7 | viz_tsne(data, label) 8 | } 9 | \arguments{ 10 | \item{data}{A data.frame.} 11 | 12 | \item{label}{variable to color with.} 13 | } 14 | \value{ 15 | ggplot2 object. 16 | } 17 | \description{ 18 | Vizualize t-SNE 19 | } 20 | \examples{ 21 | library(dplyr) 22 | viz_tsne(iris, Species) 23 | viz_tsne(select(iris, -Species), Sepal.Length) 24 | } 25 | -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-120x120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/apple-touch-icon-120x120.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-152x152.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/apple-touch-icon-152x152.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-180x180.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/apple-touch-icon-180x180.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-60x60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/apple-touch-icon-60x60.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-76x76.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/apple-touch-icon-76x76.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/apple-touch-icon.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/favicon-16x16.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/favicon-32x32.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/favicon.ico -------------------------------------------------------------------------------- /tests/figs/deps.txt: -------------------------------------------------------------------------------- 1 | - vdiffr-svg-engine: 1.0 2 | - vdiffr: 0.3.3.9000 3 | - freetypeharfbuzz: 0.2.6 4 | -------------------------------------------------------------------------------- /tests/figs/viz_classbalance/viz-classbalance-n-max-2.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 0 40 | 250 41 | 500 42 | 750 43 | 0 44 | 1 45 | Other 46 | class 47 | count 48 | Class balence for 1,200 observations 49 | 50 | -------------------------------------------------------------------------------- /tests/figs/viz_classbalance/viz-classbalance-n-max-4-ties.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 0 45 | 200 46 | 400 47 | 600 48 | 0 49 | 1 50 | 3 51 | 2 52 | 8 53 | Other 54 | class 55 | count 56 | Class balence for 1,200 observations 57 | 58 | -------------------------------------------------------------------------------- /tests/figs/viz_classbalance/viz-classbalance-simple.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 0 52 | 50 53 | 100 54 | 0 55 | 1 56 | 2 57 | 3 58 | 4 59 | 5 60 | 6 61 | 7 62 | 8 63 | 9 64 | class 65 | count 66 | Class balence for 1,200 observations 67 | 68 | -------------------------------------------------------------------------------- /tests/figs/viz_fitted_line/viz-fitted-line-expand.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 10 77 | 15 78 | 20 79 | 25 80 | 30 81 | 35 82 | 0 83 | 250 84 | 500 85 | disp 86 | mpg 87 | viz_fitted_line expand 88 | 89 | -------------------------------------------------------------------------------- /tests/figs/viz_fitted_line/viz-fitted-line-resolution.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 10 80 | 15 81 | 20 82 | 25 83 | 30 84 | 35 85 | 100 86 | 200 87 | 300 88 | 400 89 | 500 90 | disp 91 | mpg 92 | viz_fitted_line resolution 93 | 94 | -------------------------------------------------------------------------------- /tests/figs/viz_fitted_line/viz-fitted-line-simple.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 10 80 | 15 81 | 20 82 | 25 83 | 30 84 | 35 85 | 100 86 | 200 87 | 300 88 | 400 89 | 500 90 | disp 91 | mpg 92 | viz_fitted_line simple 93 | 94 | -------------------------------------------------------------------------------- /tests/figs/viz_fitted_line/viz-fitted-line-style.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 10 80 | 15 81 | 20 82 | 25 83 | 30 84 | 35 85 | 100 86 | 200 87 | 300 88 | 400 89 | 500 90 | disp 91 | mpg 92 | viz_fitted_line style 93 | 94 | -------------------------------------------------------------------------------- /tests/figs/viz_pcacm/viz-pcacm-components.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | PC4 41 | PC3 42 | PC2 43 | PC1 44 | disp 45 | hp 46 | mpg 47 | qsec 48 | carb 49 | cyl 50 | wt 51 | am 52 | gear 53 | vs 54 | Variable 55 | Principle Component 56 | 57 | -0.5 58 | 0.0 59 | 0.5 60 | value 61 | 62 | 63 | 64 | 65 | 66 | 67 | viz_pcacm components 68 | 69 | -------------------------------------------------------------------------------- /tests/figs/viz_pcacm/viz-pcacm-loadings-components.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | PC2 37 | PC1 38 | disp 39 | hp 40 | mpg 41 | carb 42 | qsec 43 | cyl 44 | wt 45 | gear 46 | Variable 47 | Principle Component 48 | 49 | -0.5 50 | 0.0 51 | 0.5 52 | value 53 | 54 | 55 | 56 | 57 | 58 | 59 | viz_pcacm loadings components 60 | 61 | -------------------------------------------------------------------------------- /tests/figs/viz_pcacm/viz-pcacm-loadings.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | PC10 42 | PC9 43 | PC8 44 | PC7 45 | PC6 46 | PC5 47 | PC4 48 | PC3 49 | PC2 50 | PC1 51 | disp 52 | hp 53 | mpg 54 | qsec 55 | carb 56 | Variable 57 | Principle Component 58 | 59 | -0.5 60 | 0.0 61 | 0.5 62 | value 63 | 64 | 65 | 66 | 67 | 68 | 69 | viz_pcacm loadings 70 | 71 | -------------------------------------------------------------------------------- /tests/figs/viz_pcacm/viz-pcacm-simple.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | PC10 47 | PC9 48 | PC8 49 | PC7 50 | PC6 51 | PC5 52 | PC4 53 | PC3 54 | PC2 55 | PC1 56 | disp 57 | hp 58 | mpg 59 | qsec 60 | carb 61 | cyl 62 | gear 63 | wt 64 | drat 65 | am 66 | Variable 67 | Principle Component 68 | 69 | -0.5 70 | 0.0 71 | 0.5 72 | value 73 | 74 | 75 | 76 | 77 | 78 | 79 | viz_pcacm simple 80 | 81 | -------------------------------------------------------------------------------- /tests/figs/viz_residuals/viz-residuals-simple.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | -2 77 | 0 78 | 2 79 | 4 80 | 10 81 | 15 82 | 20 83 | 25 84 | 30 85 | .pred 86 | .resid 87 | viz_residuals simple 88 | 89 | -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(horus) 3 | 4 | test_check("horus") 5 | -------------------------------------------------------------------------------- /tests/testthat/test-viz_classbalance.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | 3 | test_that("viz_classbalance works", { 4 | 5 | vdiffr::expect_doppelganger( 6 | "viz_classbalance simple", 7 | viz_classbalance(mnist_sample, class), 8 | "viz_classbalance" 9 | ) 10 | 11 | vdiffr::expect_doppelganger( 12 | "viz_classbalance n_max 2", 13 | viz_classbalance(mnist_sample, class, n_max = 2), 14 | "viz_classbalance" 15 | ) 16 | 17 | vdiffr::expect_doppelganger( 18 | "viz_classbalance n_max 4 ties", 19 | expect_message( 20 | viz_classbalance(mnist_sample, class, n_max = 4), 21 | "5" 22 | ), 23 | "viz_classbalance" 24 | ) 25 | 26 | }) 27 | -------------------------------------------------------------------------------- /tests/testthat/test-viz_decision_boundary.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(parsnip) 3 | library(workflows) 4 | 5 | set.seed(1234) 6 | 7 | svm_spec <- nearest_neighbor() %>% 8 | set_mode("classification") %>% 9 | set_engine("kknn") 10 | 11 | svm_fit <- workflow() %>% 12 | add_formula(Species ~ Petal.Length + Petal.Width) %>% 13 | add_model(svm_spec) %>% 14 | fit(iris) 15 | 16 | test_that("viz_decision_boundary works", { 17 | 18 | vdiffr::expect_doppelganger( 19 | "viz_decision_boundary simple", 20 | viz_decision_boundary(svm_fit, iris), 21 | "viz_decision_boundary" 22 | ) 23 | 24 | vdiffr::expect_doppelganger( 25 | "viz_decision_boundary resolution", 26 | viz_decision_boundary(svm_fit, iris, resolution = 20), 27 | "viz_decision_boundary" 28 | ) 29 | 30 | vdiffr::expect_doppelganger( 31 | "viz_decision_boundary expand", 32 | viz_decision_boundary(svm_fit, iris, expand = 1), 33 | "viz_decision_boundary" 34 | ) 35 | }) 36 | -------------------------------------------------------------------------------- /tests/testthat/test-viz_fitted_line.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(parsnip) 3 | library(workflows) 4 | 5 | set.seed(1234) 6 | 7 | knn_spec <- nearest_neighbor() %>% 8 | set_mode("regression") %>% 9 | set_engine("kknn") 10 | 11 | knn_fit <- workflow() %>% 12 | add_formula(mpg ~ disp) %>% 13 | add_model(knn_spec) %>% 14 | fit(mtcars) 15 | 16 | test_that("viz_fitted_line works", { 17 | 18 | vdiffr::expect_doppelganger( 19 | "viz_fitted_line simple", 20 | viz_fitted_line(knn_fit, mtcars), 21 | "viz_fitted_line" 22 | ) 23 | 24 | vdiffr::expect_doppelganger( 25 | "viz_fitted_line resolution", 26 | viz_fitted_line(knn_fit, mtcars, resolution = 20), 27 | "viz_fitted_line" 28 | ) 29 | 30 | vdiffr::expect_doppelganger( 31 | "viz_fitted_line expand", 32 | viz_fitted_line(knn_fit, mtcars, expand = 1), 33 | "viz_fitted_line" 34 | ) 35 | 36 | vdiffr::expect_doppelganger( 37 | "viz_fitted_line style", 38 | viz_fitted_line(knn_fit, mtcars, color = "pink", size = 4), 39 | "viz_fitted_line" 40 | ) 41 | 42 | }) 43 | -------------------------------------------------------------------------------- /tests/testthat/test-viz_pca.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | test_that("viz_pca works", { 3 | 4 | vdiffr::expect_doppelganger( 5 | "viz_pca simple", 6 | viz_pca(iris, Species), 7 | "viz_pca" 8 | ) 9 | 10 | vdiffr::expect_doppelganger( 11 | "viz_pca components", 12 | viz_pca(iris, Species, components = c(3, 1)), 13 | "viz_pca" 14 | ) 15 | 16 | vdiffr::expect_doppelganger( 17 | "viz_pca loadings", 18 | viz_pca(iris, Species, loadings = TRUE), 19 | "viz_pca" 20 | ) 21 | }) 22 | -------------------------------------------------------------------------------- /tests/testthat/test-viz_pcacm.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | test_that("viz_pcacm works", { 3 | 4 | vdiffr::expect_doppelganger( 5 | "viz_pcacm simple", 6 | viz_pcacm(mtcars), 7 | "viz_pcacm" 8 | ) 9 | 10 | vdiffr::expect_doppelganger( 11 | "viz_pcacm components", 12 | viz_pcacm(mtcars, n_pca = 4), 13 | "viz_pcacm" 14 | ) 15 | 16 | vdiffr::expect_doppelganger( 17 | "viz_pcacm loadings", 18 | viz_pcacm(mtcars, n_var = 5), 19 | "viz_pcacm" 20 | ) 21 | 22 | vdiffr::expect_doppelganger( 23 | "viz_pcacm loadings components", 24 | viz_pcacm(mtcars, n_pca = 2, n_var = 8), 25 | "viz_pcacm" 26 | ) 27 | }) 28 | -------------------------------------------------------------------------------- /tests/testthat/test-viz_prob_region.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(parsnip) 3 | library(workflows) 4 | 5 | set.seed(1234) 6 | 7 | iris2 <- iris 8 | iris2$Species <- factor(iris2$Species == "setosa", 9 | labels = c("setosa", "not setosa")) 10 | 11 | svm_spec <- svm_rbf() %>% 12 | set_mode("classification") %>% 13 | set_engine("kernlab") 14 | 15 | svm_fit <- workflow() %>% 16 | add_formula(Species ~ Petal.Length + Petal.Width) %>% 17 | add_model(svm_spec) %>% 18 | fit(iris2) 19 | 20 | svm_fit_full <- workflow() %>% 21 | add_formula(Species ~ Petal.Length + Petal.Width) %>% 22 | add_model(svm_spec) %>% 23 | fit(iris) 24 | 25 | test_that("viz_prob_region works", { 26 | 27 | vdiffr::expect_doppelganger( 28 | "viz_prob_region simple", 29 | viz_prob_region(svm_fit, iris2), 30 | "viz_prob_region" 31 | ) 32 | 33 | vdiffr::expect_doppelganger( 34 | "viz_prob_region resolution", 35 | viz_prob_region(svm_fit, iris2, resolution = 20), 36 | "viz_prob_region" 37 | ) 38 | 39 | vdiffr::expect_doppelganger( 40 | "viz_prob_region expand", 41 | viz_prob_region(svm_fit, iris2, expand = 1), 42 | "viz_prob_region" 43 | ) 44 | }) 45 | 46 | test_that("viz_prob_region facet works", { 47 | 48 | expect_error( 49 | viz_prob_region(svm_fit_full, iris) 50 | ) 51 | 52 | vdiffr::expect_doppelganger( 53 | "viz_prob_region facet simple", 54 | viz_prob_region(svm_fit_full, iris, facet = TRUE), 55 | "viz_prob_region" 56 | ) 57 | 58 | vdiffr::expect_doppelganger( 59 | "viz_prob_region facet resolution", 60 | viz_prob_region(svm_fit_full, iris, resolution = 20, facet = TRUE), 61 | "viz_prob_region" 62 | ) 63 | 64 | vdiffr::expect_doppelganger( 65 | "viz_prob_region facet expand", 66 | viz_prob_region(svm_fit_full, iris, expand = 1, facet = TRUE), 67 | "viz_prob_region" 68 | ) 69 | }) 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /tests/testthat/test-viz_residuals.R: -------------------------------------------------------------------------------- 1 | library(parsnip) 2 | 3 | reg <- linear_reg() %>% 4 | set_engine("lm") %>% 5 | set_mode("regression") %>% 6 | fit(mpg ~ ., data = mtcars) 7 | 8 | test_that("viz_residuals works", { 9 | 10 | vdiffr::expect_doppelganger( 11 | "viz_residuals simple", 12 | viz_residuals(reg, mtcars), 13 | "viz_residuals" 14 | ) 15 | }) 16 | -------------------------------------------------------------------------------- /tests/testthat/test-viz_tsne.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(dplyr) 3 | 4 | set.seed(1234) 5 | 6 | test_that("viz_tsne works", { 7 | 8 | vdiffr::expect_doppelganger( 9 | "viz_tsne simple factor", 10 | viz_tsne(iris, Species), 11 | "viz_tsne" 12 | ) 13 | 14 | vdiffr::expect_doppelganger( 15 | "viz_tsne simple numeric", 16 | viz_tsne(select(iris, -Species), Sepal.Length), 17 | "viz_tsne" 18 | ) 19 | }) 20 | --------------------------------------------------------------------------------