├── .Rbuildignore
├── .github
├── .gitignore
└── workflows
│ ├── R-CMD-check.yaml
│ └── pkgdown.yaml
├── .gitignore
├── DESCRIPTION
├── LICENSE
├── LICENSE.md
├── NAMESPACE
├── R
├── 0_imports.R
├── data.R
├── glmnet_coef_plot.R
├── utils-pipe.R
├── viz_classbalance.R
├── viz_decision_boundary.R
├── viz_dispersion.R
├── viz_fitted_line.R
├── viz_pca.R
├── viz_pcacm.R
├── viz_prob_distribution.R
├── viz_prob_region.R
├── viz_residuals.R
├── viz_tsne.R
└── workflow_utils.R
├── README.Rmd
├── README.md
├── _pkgdown.yml
├── data-raw
├── fairy_tales.R
└── mnist_sample.R
├── data
├── fairy_tales.rda
└── mnist_sample.rda
├── horus.Rproj
├── man
├── fairy_tales.Rd
├── figures
│ ├── README-unnamed-chunk-2-1.png
│ ├── README-unnamed-chunk-3-1.png
│ └── logo.png
├── glmnet_coef_plot.Rd
├── mnist_sample.Rd
├── pipe.Rd
├── viz_classbalance.Rd
├── viz_decision_boundary.Rd
├── viz_dispersion.Rd
├── viz_fitted_line.Rd
├── viz_pca.Rd
├── viz_pcacm.Rd
├── viz_prob_region.Rd
├── viz_residuals.Rd
└── viz_tsne.Rd
├── pkgdown
└── favicon
│ ├── apple-touch-icon-120x120.png
│ ├── apple-touch-icon-152x152.png
│ ├── apple-touch-icon-180x180.png
│ ├── apple-touch-icon-60x60.png
│ ├── apple-touch-icon-76x76.png
│ ├── apple-touch-icon.png
│ ├── favicon-16x16.png
│ ├── favicon-32x32.png
│ └── favicon.ico
└── tests
├── figs
├── deps.txt
├── viz_classbalance
│ ├── viz-classbalance-n-max-2.svg
│ ├── viz-classbalance-n-max-4-ties.svg
│ └── viz-classbalance-simple.svg
├── viz_decision_boundary
│ ├── viz-decision-boundary-expand.svg
│ ├── viz-decision-boundary-resolution.svg
│ └── viz-decision-boundary-simple.svg
├── viz_fitted_line
│ ├── viz-fitted-line-expand.svg
│ ├── viz-fitted-line-resolution.svg
│ ├── viz-fitted-line-simple.svg
│ └── viz-fitted-line-style.svg
├── viz_pca
│ ├── viz-pca-components.svg
│ ├── viz-pca-loadings.svg
│ └── viz-pca-simple.svg
├── viz_pcacm
│ ├── viz-pcacm-components.svg
│ ├── viz-pcacm-loadings-components.svg
│ ├── viz-pcacm-loadings.svg
│ └── viz-pcacm-simple.svg
├── viz_prob_region
│ ├── viz-prob-region-expand.svg
│ ├── viz-prob-region-facet-expand.svg
│ ├── viz-prob-region-facet-resolution.svg
│ ├── viz-prob-region-facet-simple.svg
│ ├── viz-prob-region-resolution.svg
│ └── viz-prob-region-simple.svg
├── viz_residuals
│ └── viz-residuals-simple.svg
└── viz_tsne
│ ├── viz-tsne-simple-factor.svg
│ └── viz-tsne-simple-numeric.svg
├── testthat.R
└── testthat
├── test-viz_classbalance.R
├── test-viz_decision_boundary.R
├── test-viz_fitted_line.R
├── test-viz_pca.R
├── test-viz_pcacm.R
├── test-viz_prob_region.R
├── test-viz_residuals.R
└── test-viz_tsne.R
/.Rbuildignore:
--------------------------------------------------------------------------------
1 | ^codecov\.yml$
2 | ^appveyor\.yml$
3 | ^\.travis\.yml$
4 | ^README\.Rmd$
5 | ^data-raw$
6 | ^LICENSE\.md$
7 | ^horus\.Rproj$
8 | ^\.Rproj\.user$
9 | ^\.github$
10 | ^_pkgdown\.yml$
11 | ^docs$
12 | ^pkgdown$
13 |
--------------------------------------------------------------------------------
/.github/.gitignore:
--------------------------------------------------------------------------------
1 | *.html
2 |
--------------------------------------------------------------------------------
/.github/workflows/R-CMD-check.yaml:
--------------------------------------------------------------------------------
1 | # For help debugging build failures open an issue on the RStudio community with the 'github-actions' tag.
2 | # https://community.rstudio.com/new-topic?category=Package%20development&tags=github-actions
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - master
8 | pull_request:
9 | branches:
10 | - main
11 | - master
12 |
13 | name: R-CMD-check
14 |
15 | jobs:
16 | R-CMD-check:
17 | runs-on: ${{ matrix.config.os }}
18 |
19 | name: ${{ matrix.config.os }} (${{ matrix.config.r }})
20 |
21 | strategy:
22 | fail-fast: false
23 | matrix:
24 | config:
25 | - {os: windows-latest, r: 'release'}
26 | - {os: macOS-latest, r: 'release'}
27 | - {os: ubuntu-20.04, r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/focal/latest"}
28 | - {os: ubuntu-20.04, r: 'devel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/focal/latest"}
29 |
30 | env:
31 | R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
32 | RSPM: ${{ matrix.config.rspm }}
33 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
34 |
35 | steps:
36 | - uses: actions/checkout@v2
37 |
38 | - uses: r-lib/actions/setup-r@v1
39 | with:
40 | r-version: ${{ matrix.config.r }}
41 |
42 | - uses: r-lib/actions/setup-pandoc@v1
43 |
44 | - name: Query dependencies
45 | run: |
46 | install.packages('remotes')
47 | saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2)
48 | writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version")
49 | shell: Rscript {0}
50 |
51 | - name: Cache R packages
52 | if: runner.os != 'Windows'
53 | uses: actions/cache@v2
54 | with:
55 | path: ${{ env.R_LIBS_USER }}
56 | key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }}
57 | restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-
58 |
59 | - name: Install system dependencies
60 | if: runner.os == 'Linux'
61 | run: |
62 | while read -r cmd
63 | do
64 | eval sudo $cmd
65 | done < <(Rscript -e 'writeLines(remotes::system_requirements("ubuntu", "20.04"))')
66 |
67 | - name: Install dependencies
68 | run: |
69 | remotes::install_deps(dependencies = TRUE)
70 | remotes::install_cran("rcmdcheck")
71 | shell: Rscript {0}
72 |
73 | - name: Check
74 | env:
75 | _R_CHECK_CRAN_INCOMING_REMOTE_: false
76 | run: |
77 | options(crayon.enabled = TRUE)
78 | rcmdcheck::rcmdcheck(args = c("--no-manual", "--as-cran"), error_on = "warning", check_dir = "check")
79 | shell: Rscript {0}
80 |
81 | - name: Upload check results
82 | if: failure()
83 | uses: actions/upload-artifact@main
84 | with:
85 | name: ${{ runner.os }}-r${{ matrix.config.r }}-results
86 | path: check
87 |
--------------------------------------------------------------------------------
/.github/workflows/pkgdown.yaml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 | branches:
4 | - main
5 | - master
6 |
7 | name: pkgdown
8 |
9 | jobs:
10 | pkgdown:
11 | runs-on: macOS-latest
12 | env:
13 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
14 | steps:
15 | - uses: actions/checkout@v2
16 |
17 | - uses: r-lib/actions/setup-r@v1
18 |
19 | - uses: r-lib/actions/setup-pandoc@v1
20 |
21 | - name: Query dependencies
22 | run: |
23 | install.packages('remotes')
24 | saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2)
25 | writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version")
26 | shell: Rscript {0}
27 |
28 | - name: Cache R packages
29 | uses: actions/cache@v2
30 | with:
31 | path: ${{ env.R_LIBS_USER }}
32 | key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }}
33 | restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-
34 |
35 | - name: Install dependencies
36 | run: |
37 | remotes::install_deps(dependencies = TRUE)
38 | install.packages("pkgdown", type = "binary")
39 | shell: Rscript {0}
40 |
41 | - name: Install package
42 | run: R CMD INSTALL .
43 |
44 | - name: Deploy package
45 | run: |
46 | git config --local user.email "actions@github.com"
47 | git config --local user.name "GitHub Actions"
48 | Rscript -e 'pkgdown::deploy_to_branch(new_process = FALSE)'
49 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .Rproj.user
2 | .Rhistory
3 | .RData
4 | .Ruserdata
5 | docs
6 |
--------------------------------------------------------------------------------
/DESCRIPTION:
--------------------------------------------------------------------------------
1 | Package: horus
2 | Title: Visual Tools To Help Machine Learning model Selection
3 | Version: 0.0.0.9000
4 | Authors@R:
5 | person(given = "Emil",
6 | family = "Hvitfeldt",
7 | role = c("aut", "cre"),
8 | email = "emilhhvitfeldt@gmail.com",
9 | comment = c(ORCID = "0000-0002-0679-1945"))
10 | Description: Includes a suite of functions that visually aids
11 | the understanding and model selection of a wide array of machine
12 | learning applications.
13 | License: MIT + file LICENSE
14 | URL: https://emilhvitfeldt.github.io/horus
15 | Depends:
16 | R (>= 2.10)
17 | Imports:
18 | dplyr,
19 | forcats,
20 | ggplot2,
21 | glue,
22 | magrittr,
23 | parsnip (>= 0.1.5.9000),
24 | purrr,
25 | readr,
26 | rlang,
27 | Rtsne,
28 | tibble,
29 | tidyr,
30 | tidytext,
31 | workflows
32 | Suggests:
33 | covr,
34 | kernlab,
35 | kknn,
36 | ranger,
37 | testthat,
38 | vdiffr,
39 | glmnet
40 | Remotes:
41 | tidymodels/parsnip
42 | Encoding: UTF-8
43 | LazyData: true
44 | RoxygenNote: 7.1.1.9001
45 | Config/testthat/edition: 3
46 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | YEAR: 2018
2 | COPYRIGHT HOLDER: Emil Hvitfeldt
3 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | Copyright (c) 2018 Emil Hvitfeldt
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/NAMESPACE:
--------------------------------------------------------------------------------
1 | # Generated by roxygen2: do not edit by hand
2 |
3 | export("%>%")
4 | export(glmnet_coef_plot)
5 | export(viz_classbalance)
6 | export(viz_decision_boundary)
7 | export(viz_dispersion)
8 | export(viz_fitted_line)
9 | export(viz_pca)
10 | export(viz_pcacm)
11 | export(viz_prob_region)
12 | export(viz_residuals)
13 | export(viz_tsne)
14 | importFrom(Rtsne,Rtsne)
15 | importFrom(dplyr,all_of)
16 | importFrom(dplyr,arrange)
17 | importFrom(dplyr,bind_cols)
18 | importFrom(dplyr,count)
19 | importFrom(dplyr,desc)
20 | importFrom(dplyr,everything)
21 | importFrom(dplyr,filter)
22 | importFrom(dplyr,group_by)
23 | importFrom(dplyr,mutate)
24 | importFrom(dplyr,mutate_)
25 | importFrom(dplyr,pull)
26 | importFrom(dplyr,recode)
27 | importFrom(dplyr,row_number)
28 | importFrom(dplyr,select)
29 | importFrom(dplyr,select_)
30 | importFrom(dplyr,slice)
31 | importFrom(dplyr,summarize)
32 | importFrom(forcats,fct_infreq)
33 | importFrom(forcats,fct_lump)
34 | importFrom(forcats,fct_relevel)
35 | importFrom(forcats,fct_rev)
36 | importFrom(ggplot2,aes)
37 | importFrom(ggplot2,aes_)
38 | importFrom(ggplot2,aes_string)
39 | importFrom(ggplot2,arrow)
40 | importFrom(ggplot2,facet_grid)
41 | importFrom(ggplot2,facet_wrap)
42 | importFrom(ggplot2,geom_abline)
43 | importFrom(ggplot2,geom_bar)
44 | importFrom(ggplot2,geom_histogram)
45 | importFrom(ggplot2,geom_line)
46 | importFrom(ggplot2,geom_point)
47 | importFrom(ggplot2,geom_raster)
48 | importFrom(ggplot2,geom_segment)
49 | importFrom(ggplot2,geom_text)
50 | importFrom(ggplot2,ggplot)
51 | importFrom(ggplot2,guide_legend)
52 | importFrom(ggplot2,guides)
53 | importFrom(ggplot2,labs)
54 | importFrom(ggplot2,scale_alpha)
55 | importFrom(ggplot2,scale_fill_gradient2)
56 | importFrom(ggplot2,scale_x_log10)
57 | importFrom(ggplot2,scale_y_discrete)
58 | importFrom(ggplot2,theme_minimal)
59 | importFrom(ggplot2,unit)
60 | importFrom(ggplot2,xlim)
61 | importFrom(glue,glue)
62 | importFrom(magrittr,"%>%")
63 | importFrom(parsnip,augment)
64 | importFrom(purrr,map)
65 | importFrom(readr,parse_number)
66 | importFrom(rlang,.data)
67 | importFrom(rlang,abort)
68 | importFrom(rlang,enquo)
69 | importFrom(rlang,ensym)
70 | importFrom(rlang,eval_tidy)
71 | importFrom(rlang,inform)
72 | importFrom(rlang,quo_name)
73 | importFrom(rlang,set_names)
74 | importFrom(stats,prcomp)
75 | importFrom(stats,predict)
76 | importFrom(stats,var)
77 | importFrom(tibble,as_tibble)
78 | importFrom(tibble,rownames_to_column)
79 | importFrom(tibble,tibble)
80 | importFrom(tidyr,drop_na)
81 | importFrom(tidyr,nest)
82 | importFrom(tidyr,pivot_longer)
83 | importFrom(tidyr,unnest)
84 | importFrom(utils,globalVariables)
85 |
--------------------------------------------------------------------------------
/R/0_imports.R:
--------------------------------------------------------------------------------
1 | #' @importFrom dplyr all_of arrange bind_cols count desc everything filter
2 | #' @importFrom dplyr group_by mutate mutate_ pull recode row_number select
3 | #' @importFrom dplyr select_ slice summarize
4 | #' @importFrom forcats fct_infreq fct_lump fct_relevel fct_rev
5 | #' @importFrom ggplot2 aes aes_ aes_string arrow facet_grid facet_wrap
6 | #' @importFrom ggplot2 geom_abline geom_bar geom_histogram geom_line
7 | #' @importFrom ggplot2 geom_point geom_raster geom_segment geom_text ggplot
8 | #' @importFrom ggplot2 guide_legend guides labs scale_alpha scale_fill_gradient2
9 | #' @importFrom ggplot2 scale_y_discrete theme_minimal unit xlim scale_x_log10
10 | #' @importFrom glue glue
11 | #' @importFrom magrittr %>%
12 | #' @importFrom parsnip augment
13 | #' @importFrom purrr map
14 | #' @importFrom readr parse_number
15 | #' @importFrom rlang .data abort enquo ensym eval_tidy inform quo_name set_names
16 | #' @importFrom Rtsne Rtsne
17 | #' @importFrom stats prcomp predict var
18 | #' @importFrom tibble as_tibble rownames_to_column tibble
19 | #' @importFrom tidyr drop_na nest pivot_longer unnest
20 | #' @importFrom utils globalVariables
21 | NULL
22 | #'
23 |
24 | # ------------------------------------------------------------------------------
25 | # nocov
26 |
27 | utils::globalVariables(
28 | c(".pred", ".resid", "name", "value", "contribution", "lambda")
29 | )
30 |
31 | # nocov end
32 |
--------------------------------------------------------------------------------
/R/data.R:
--------------------------------------------------------------------------------
1 | #' Sample of 5 fairly tales from H.C. Andersen
2 | #'
3 | #' Includes the fairy tales "A leaf from heaven", "A story",
4 | #' "The bird of folklore", "The flea and the professor" and "The rags".
5 | #'
6 | #' @format A tibble with 486 obs. of 2 variables.
7 | "fairy_tales"
8 |
9 | #' Down-sampled MNIST
10 | #'
11 | #' this data set is a down-sampled version of the MNIST data set which have been
12 | #' umap'ed down to 2 dimensions.
13 | #'
14 | #' @source \url{http://yann.lecun.com/exdb/mnist/}
15 | #' @format A data frame with 1200 rows and 3 variables:
16 | #' \describe{
17 | #' \item{class}{Factor, values 0 through 9.}
18 | #' \item{umap_1}{Numeric.}
19 | #' \item{umap_2}{Numeric.}
20 | #' }
21 | "mnist_sample"
22 |
--------------------------------------------------------------------------------
/R/glmnet_coef_plot.R:
--------------------------------------------------------------------------------
1 | #' Create Lambda chart for glmnet object
2 | #'
3 | #' @param fit parsnip fit object
4 | #'
5 | #' @return ggplot2 object
6 | #' @export
7 | #'
8 | #' @examples
9 | #' library(parsnip)
10 | #' library(glmnet)
11 | #' linear_reg_glmnet_spec <-
12 | #' linear_reg(penalty = 0, mixture = 1) %>%
13 | #' set_engine('glmnet')
14 | #'
15 | #' lm_fit <- fit(linear_reg_glmnet_spec, data = mtcars, mpg ~ .)
16 | #'
17 | #' glmnet_coef_plot(lm_fit)
18 | glmnet_coef_plot <- function(fit) {
19 | as_tibble(t(as.matrix(fit$fit$beta))) %>%
20 | mutate(lambda = fit$fit$lambda) %>%
21 | pivot_longer(-lambda) %>%
22 | ggplot(aes(lambda, value, group = name)) +
23 | geom_line() +
24 | scale_x_log10()
25 | }
26 |
--------------------------------------------------------------------------------
/R/utils-pipe.R:
--------------------------------------------------------------------------------
1 | #' Pipe operator
2 | #'
3 | #' See \code{magrittr::\link[magrittr]{\%>\%}} for details.
4 | #'
5 | #' @name %>%
6 | #' @rdname pipe
7 | #' @keywords internal
8 | #' @export
9 | #' @usage lhs \%>\% rhs
10 | NULL
11 |
--------------------------------------------------------------------------------
/R/viz_classbalance.R:
--------------------------------------------------------------------------------
1 | #' Visualise Class Imbalance
2 | #'
3 | #' @param data A data.frame.
4 | #' @param variable target variable to show balance for.
5 | #' @param n_max integer, maximum number of classes shown before lumping.
6 | #' Defaults to 25.
7 | #'
8 | #' @return ggplot2 object.
9 | #' @export
10 | #'
11 | #' @examples
12 | #' viz_classbalance(mnist_sample, class)
13 | viz_classbalance <- function(data, variable, n_max = 25) {
14 | enquo_variable <- enquo(variable)
15 |
16 | if (!is.factor(data[[quo_name(enquo_variable)]])) {
17 | abort("`variable` must be a factor")
18 | }
19 |
20 | n_vars <- length(table(pull(data, !!enquo_variable)))
21 | if (n_vars > n_max) {
22 | data[[quo_name(enquo_variable)]] <- data[[quo_name(enquo_variable)]] %>%
23 | as.factor() %>%
24 | fct_infreq() %>%
25 | fct_lump(n_max)
26 |
27 | n_shown <- length(levels(data[[quo_name(enquo_variable)]])) - 1
28 | inform(glue("The number of catagories is {n_vars} only the first ",
29 | "{n_shown} is shown."))
30 | }
31 |
32 | title <- paste0(
33 | "Class balence for ",
34 | format(nrow(data), big.mark = ","),
35 | " observations"
36 | )
37 |
38 | ggplot(data, aes(!!enquo_variable)) +
39 | geom_bar() +
40 | labs(title = title) +
41 | theme_minimal()
42 | }
43 |
--------------------------------------------------------------------------------
/R/viz_decision_boundary.R:
--------------------------------------------------------------------------------
1 | #' Draw Decision boundary for Classification model
2 | #'
3 | #' This function is mostly useful in an educational setting. Can only be used
4 | #' with trained workflow objects with 2 numeric predictor variables.
5 | #'
6 | #' @param x trained `workflows::workflow` object.
7 | #' @param new_data A data frame or tibble for whom the preprocessing will be
8 | #' applied.
9 | #' @param resolution Number of squared in grid. Defaults to 100.
10 | #' @param expand Expansion rate. Defaults to 0.1. This means that the width and
11 | #' height of the shaded area is 10% wider then the rectangle containing the
12 | #' data.
13 | #'
14 | #' The chart have been minimally modified to allow for easier styling.
15 | #'
16 | #' @return `ggplot2::ggplot` object
17 | #' @export
18 | #'
19 | #' @examples
20 | #' library(parsnip)
21 | #' library(workflows)
22 | #' svm_spec <- svm_rbf() %>%
23 | #' set_mode("classification") %>%
24 | #' set_engine("kernlab")
25 | #'
26 | #' svm_fit <- workflow() %>%
27 | #' add_formula(Species ~ Petal.Length + Petal.Width) %>%
28 | #' add_model(svm_spec) %>%
29 | #' fit(iris)
30 | #'
31 | #' viz_decision_boundary(svm_fit, iris)
32 | #'
33 | #' viz_decision_boundary(svm_fit, iris, resolution = 20)
34 | #'
35 | #' viz_decision_boundary(svm_fit, iris, expand = 1)
36 | #'
37 | #' svm_multi_fit <- workflow() %>%
38 | #' add_formula(class ~ umap_1 + umap_2) %>%
39 | #' add_model(svm_spec) %>%
40 | #' fit(mnist_sample)
41 | #'
42 | #' viz_decision_boundary(svm_multi_fit, mnist_sample)
43 | viz_decision_boundary <- function(x, new_data, resolution = 100, expand = 0.1) {
44 |
45 | if (!inherits(x, "workflow")) {
46 | abort("`viz_decision_boundary()` only works with `workflow` objects.")
47 | }
48 | if (!x$trained) {
49 | abort("`x` must be a trained `workflow` object.")
50 | }
51 |
52 | var_names <- extract_variable_names(x, new_data, n_pred = 2)
53 |
54 | predict_area <- new_data %>%
55 | select(all_of(var_names$predictors)) %>%
56 | lapply(expanded_seq, expand, resolution) %>%
57 | expand.grid()
58 |
59 | predict_area %>%
60 | bind_cols(predict(x, predict_area)) %>%
61 | ggplot(
62 | aes_string(
63 | var_names$predictors[1],
64 | var_names$predictors[2],
65 | fill = ".pred_class"
66 | )
67 | ) +
68 | geom_raster(alpha = 0.2) +
69 | geom_point(
70 | aes_string(
71 | var_names$predictors[1],
72 | var_names$predictors[2],
73 | fill = var_names$response
74 | ),
75 | color = "black", shape = 22, data = new_data, inherit.aes = FALSE
76 | ) +
77 | theme_minimal()
78 | }
79 |
80 | expanded_seq <- function(x, expand, resolution) {
81 | x_range <- range(x, na.rm = TRUE)
82 |
83 | x_range_width <- x_range[2] - x_range[1]
84 |
85 | sequence <- seq(
86 | from = x_range[1] - x_range_width * expand / 2,
87 | to = x_range[2] + x_range_width * expand / 2,
88 | length.out = resolution
89 | )
90 |
91 | if (is.integer(x)) {
92 | sequence <- unique(as.integer(sequence))
93 | }
94 | sequence
95 | }
96 |
--------------------------------------------------------------------------------
/R/viz_dispersion.R:
--------------------------------------------------------------------------------
1 | #' Vizualize lexical dispersion plot
2 | #'
3 | #' @param data A data.frame.
4 | #' @param var variable that contains the words to be visualized.
5 | #' @param group If present with show a group for each line with the words color
6 | #' coded.
7 | #' @param words Numerical or character. If numerical it will display the n
8 | #' most common words. If character will show the location of said strings.
9 | #' @param symbol The word symbol. Default to is 18 (filed diamond) when number
10 | #' of points are less then 200 and to 108 (vertical line) when there are more
11 | #' then 200 points.
12 | #' @param alpha color transperency of the word symbols.
13 | #' @return ggplot2 object.
14 | #' @examples
15 | #' \dontrun{
16 | #' library(tidytext)
17 | #'
18 | #' text_data <- unnest_tokens(fairy_tales, word, text)
19 | #' viz_dispersion(text_data, word)
20 | #' viz_dispersion(text_data, word, words = c("branches", "not a word"))
21 | #' viz_dispersion(text_data, word, symbol = "2")
22 | #' viz_dispersion(text_data, word, group = book)
23 | #' }
24 | #' @export
25 | viz_dispersion <- function(data, var, group, words = 10, symbol = NULL,
26 | alpha = 0.7) {
27 | var <- ensym(var)
28 |
29 | ## TODO implement helper function for this
30 | if (class(words) == "numeric") {
31 | words <- count(data, !!var, sort = TRUE) %>%
32 | slice(seq_len(words))
33 |
34 | vec <- pull(words, !!var)
35 | }
36 | if (any(class(words) == "character")) {
37 | vec <- words
38 | }
39 |
40 | if (missing(group)) {
41 | factors <- dispersion_factor(pull(data, !!var), vec)
42 |
43 | plot_data <- data %>%
44 | mutate_(
45 | x = ~row_number(),
46 | y = ~factors
47 | ) %>%
48 | drop_na() %>%
49 | select_(.dots = c("x", "y"))
50 |
51 | x_limit <- nrow(data)
52 | } else {
53 | group <- ensym(group)
54 |
55 | plot_data <- nest(data, !!var) %>%
56 | mutate(data = map(data, ~ {
57 | factors <- dispersion_factor(pull(.x, !!var), vec)
58 | .x %>%
59 | mutate_(
60 | x = ~ seq_len(nrow(.x)),
61 | color = ~factors
62 | )
63 | })) %>%
64 | unnest() %>%
65 | drop_na() %>%
66 | select_(.dots = c("x", "color", "y" = "book"))
67 |
68 | x_limit <- nest(data, !!var)$data %>%
69 | sapply(nrow) %>%
70 | max()
71 | }
72 |
73 | if (is.null(symbol)) {
74 | symbol <- ifelse(nrow(plot_data) > 200, 108, 18)
75 | }
76 |
77 | if (missing(group)) {
78 | base_plot <- ggplot(plot_data) +
79 | aes_(~x, ~y)
80 | } else {
81 | base_plot <- ggplot(plot_data) +
82 | aes_(~x, ~y, color = ~color)
83 | }
84 | base_plot +
85 | geom_point(shape = symbol, alpha = alpha) +
86 | scale_y_discrete(drop = FALSE) +
87 | xlim(c(1, x_limit)) +
88 | guides(color = guide_legend(override.aes = list(shape = c(18)))) +
89 | labs(
90 | x = "Word Offset",
91 | y = NULL,
92 | title = "Lexical Dispersion Plot"
93 | ) +
94 | theme_minimal()
95 | }
96 |
97 | dispersion_factor <- function(x, names) {
98 | replacement <- seq_len(length(names))
99 | names(replacement) <- names
100 |
101 | factor(recode(x, !!!replacement,
102 | .default = NA_integer_
103 | ), levels = replacement, labels = names)
104 | }
105 |
--------------------------------------------------------------------------------
/R/viz_fitted_line.R:
--------------------------------------------------------------------------------
1 | #' Draw fitted regression line
2 | #'
3 | #' This function is mostly useful in an educational setting. Can only be used
4 | #' with trained workflow objects with 1 numeric predictor variable.
5 | #'
6 | #' @param x trained `workflows::workflow` object.
7 | #' @param new_data A data frame or tibble for whom the preprocessing will be
8 | #' applied.
9 | #' @param resolution Number of squared in grid. Defaults to 100.
10 | #' @param expand Expansion rate. Defaults to 0.1. This means that the width of
11 | #' the plotting area is 10 percent wider then the data.
12 | #' @param color Character, color of the fitted line. Passed to `geom_line()`.
13 | #' Defaults to `"blue"`.
14 | #' @param size Numeric, size of the fitted line. Passed to `geom_line()`.
15 | #' Defaults to `1`.
16 | #'
17 | #' @details
18 | #' The chart have been minimally modified to allow for easier styling.
19 | #'
20 | #' @return `ggplot2::ggplot` object
21 | #' @export
22 | #'
23 | #' @examples
24 | #' library(parsnip)
25 | #' library(workflows)
26 | #' lm_spec <- linear_reg() %>%
27 | #' set_mode("regression") %>%
28 | #' set_engine("lm")
29 | #'
30 | #' lm_fit <- workflow() %>%
31 | #' add_formula(mpg ~ disp) %>%
32 | #' add_model(lm_spec) %>%
33 | #' fit(mtcars)
34 | #'
35 | #' viz_fitted_line(lm_fit, mtcars)
36 | #'
37 | #' viz_fitted_line(lm_fit, mtcars, expand = 1)
38 | viz_fitted_line <- function(x, new_data, resolution = 100, expand = 0.1,
39 | color = "blue", size = 1) {
40 |
41 | if (!inherits(x, "workflow")) {
42 | abort("`viz_fitted_line()` only works with `workflow` objects.")
43 | }
44 | if (!x$trained) {
45 | abort("`x` must be a trained `workflow` object.")
46 | }
47 |
48 | var_names <- extract_variable_names(x, new_data, n_pred = 1)
49 |
50 | predict_area <- new_data %>%
51 | select(all_of(var_names$predictors)) %>%
52 | lapply(expanded_seq, expand, resolution) %>%
53 | expand.grid()
54 |
55 | fitted_line <- bind_cols(predict_area, predict(x, predict_area))
56 |
57 | new_data %>%
58 | ggplot(
59 | aes_string(
60 | var_names$predictors[1],
61 | var_names$response
62 | )
63 | ) +
64 | geom_point(alpha = 0.5) +
65 | geom_line(
66 | aes_string(
67 | var_names$predictors[1],
68 | ".pred"
69 | ),
70 | color = color,
71 | size = size,
72 | data = fitted_line,
73 | inherit.aes = FALSE
74 | ) +
75 | theme_minimal() +
76 | xlim(range(predict_area[[1]]))
77 | }
78 |
--------------------------------------------------------------------------------
/R/viz_pca.R:
--------------------------------------------------------------------------------
1 | #' Vizualize principal pomponents
2 | #'
3 | #' @param data A data.frame.
4 | #' @param label variable to color with.
5 | #' @param components principal components to showcase.
6 | #' @param loadings Set this to true if you want to see the PCA loadings.
7 | #' @return ggplot2 object.
8 | #' @examples
9 | #' viz_pca(iris, Species)
10 | #' viz_pca(iris, Species, c(3, 1))
11 | #' viz_pca(iris, Species, loadings = TRUE)
12 | #' @export
13 | viz_pca <- function(data, label, components = c(1, 2), loadings = FALSE) {
14 | label_enquo <- enquo(label)
15 |
16 | names <- paste0("PC", components)
17 |
18 | pca_obj <- data %>%
19 | select(-!!label_enquo) %>%
20 | as.matrix() %>%
21 | prcomp()
22 |
23 | plot_data <- as_tibble(pca_obj$x) %>%
24 | mutate(Label = pull(data, !!label_enquo))
25 |
26 | p <- plot_data %>%
27 | ggplot() +
28 | aes_string(names[1], names[2], color = "Label") +
29 | geom_point() +
30 | labs(
31 | x = glue("Principal Component {components[1]}"),
32 | y = glue("Principal Component {components[2]}"),
33 | title = "Principal Component plot"
34 | )
35 |
36 | if (loadings) {
37 | loadings_data <- pca_obj$rotation[, components] %>%
38 | as.data.frame() %>%
39 | rownames_to_column()
40 | arrow <- arrow(length = unit(0.03, "npc"))
41 | p <- p +
42 | geom_segment(aes_string(
43 | x = 0, y = 0,
44 | xend = names[1],
45 | yend = names[2]
46 | ),
47 | data = loadings_data, inherit.aes = FALSE,
48 | arrow = arrow
49 | ) +
50 | geom_text(aes_string(names[1], names[2], label = "rowname"),
51 | data = loadings_data, inherit.aes = FALSE
52 | ) +
53 | theme_minimal()
54 | }
55 |
56 | p
57 | }
58 |
--------------------------------------------------------------------------------
/R/viz_pcacm.R:
--------------------------------------------------------------------------------
1 | #' PCA component measures
2 | #'
3 | #' @param x a data.frame
4 | #' @param n_pca Number of Principle components to show
5 | #' @param n_var Number of Variables to show
6 | #'
7 | #' @return a ggplot2 object
8 | #' @export
9 | #'
10 | #' @examples
11 | #' viz_pcacm(mtcars)
12 | #' viz_pcacm(USArrests)
13 | #' viz_pcacm(beaver1)
14 | #'
15 | #' viz_pcacm(mtcars, n_pca = 4)
16 | #' viz_pcacm(mtcars, n_var = 4)
17 | #' viz_pcacm(mtcars, n_pca = 2, n_var = 6)
18 | viz_pcacm <- function(x, n_pca = 10, n_var = 10) {
19 | tidy_rotation <- as.data.frame(prcomp(x, rank = n_pca)$rotation) %>%
20 | rownames_to_column("var") %>%
21 | pivot_longer(-var) %>%
22 | filter(name %in% paste0("PC", seq_len(n_pca)))
23 |
24 | order <- tidy_rotation %>%
25 | group_by(var) %>%
26 | summarize(contribution = sum(abs(value) * 1 / parse_number(name))) %>%
27 | arrange(desc(contribution)) %>%
28 | slice(seq_len(n_var)) %>%
29 | pull(var)
30 |
31 | tidy_rotation %>%
32 | filter(var %in% order) %>%
33 | mutate(var = fct_relevel(var, order),
34 | name = fct_relevel(name, paste0("PC", seq_len(n_pca))),
35 | name = fct_rev(name)) %>%
36 | ggplot(aes(var, name, fill = value)) +
37 | geom_raster() +
38 | scale_fill_gradient2(
39 | low = "purple",
40 | mid = "grey90",
41 | high = "orange"
42 | ) +
43 | labs(
44 | y = "Principle Component",
45 | x = "Variable"
46 | ) +
47 | theme_minimal()
48 | }
49 |
--------------------------------------------------------------------------------
/R/viz_prob_distribution.R:
--------------------------------------------------------------------------------
1 | #' #' Vizualize the predicted probalities for each class
2 | #' #'
3 | #' #' @param model A `model_fit` object from the parsnip package.
4 | #' #' @param new_data A data.frame to run predictions on.
5 | #' #' @param truth A vector of true classes used to color probalities.
6 | #' #' Defaults to NULL.
7 | #' #'
8 | #' #' @return ggplot2 object.
9 | #' #' @export
10 | #' viz_prob_distribution <- function(model, new_data, truth) {
11 | #' UseMethod("viz_prob_distribution")
12 | #' }
13 | #'
14 | #' #' @rdname viz_prob_distribution
15 | #' #' @export
16 | #' viz_prob_distribution.default <- function(model, new_data, truth) {
17 | #' stop("`model` must be a `model_fit` object from the parsnip package.",
18 | #' call. = FALSE
19 | #' )
20 | #' }
21 | #'
22 | #' #' @rdname viz_prob_distribution
23 | #' #' @export
24 | #' #'
25 | #' #' @examples
26 | #' #' library(parsnip)
27 | #' #' library(ranger)
28 | #' #'
29 | #' #' fit_model <- rand_forest("classification") %>%
30 | #' #' set_engine("ranger") %>%
31 | #' #' fit(Species ~ ., data = iris)
32 | #' #'
33 | #' #' viz_prob_distribution(fit_model, new_data = iris)
34 | #' #'
35 | #' #' viz_prob_distribution(fit_model, new_data = iris, truth = iris$Species)
36 | #' viz_prob_distribution.model_fit <- function(model, new_data, truth = NULL) {
37 | #' if (model$spec$mode != "classification") {
38 | #' stop("`model` must be a classification model.", call. = FALSE)
39 | #' }
40 | #'
41 | #' plotting_data <- predict(model, new_data, type = "prob")
42 | #'
43 | #' if (is.null(truth)) {
44 | #' plotting_data <- plotting_data %>%
45 | #' pivot_longer(cols = everything())
46 | #' } else {
47 | #' plotting_data <- plotting_data %>%
48 | #' mutate(truth = truth) %>%
49 | #' pivot_longer(cols = -truth) %>%
50 | #' mutate(correct = factor(
51 | #' gsub(".pred_", "", .data$name) == .data$truth,
52 | #' c(TRUE, FALSE),
53 | #' c("Yes", "No")
54 | #' )) %>%
55 | #' filter(gsub(".pred_", "", .data$name) == .data$truth)
56 | #' }
57 | #'
58 | #' out <- plotting_data %>%
59 | #' ggplot(aes(.data$value)) +
60 | #' geom_histogram(bins = 50) +
61 | #' facet_grid(~ .data$name) +
62 | #' theme_minimal() +
63 | #' labs(x = "Predicted probability")
64 | #'
65 | #' if (!is.null(truth)) {
66 | #' out <- out +
67 | #' aes(fill = .data$correct) +
68 | #' labs(fill = "Correctlty predicted")
69 | #' }
70 | #' out
71 | #' }
72 |
--------------------------------------------------------------------------------
/R/viz_prob_region.R:
--------------------------------------------------------------------------------
1 | #' Draw Probability regions for Classification model
2 | #'
3 | #' This function is mostly useful in an educational setting. Can only be used
4 | #' with trained workflow objects with 2 numeric predictor variables.
5 | #'
6 | #' @param x trained `workflows::workflow` object.
7 | #' @param new_data A data frame or tibble for whom the preprocessing will be
8 | #' applied.
9 | #' @param resolution Number of squared in grid. Defaults to 100.
10 | #' @param expand Expansion rate. Defaults to 0.1. This means that the width and
11 | #' height of the shaded area is 10% wider then the rectangle containing the
12 | #' data.
13 | #' @param facet Logical, whether to facet chart by class. Defaults to FALSE.
14 | #'
15 | #' The chart have been minimally modified to allow for easier styling.
16 | #'
17 | #' @return `ggplot2::ggplot` object
18 | #' @export
19 | #'
20 | #' @examples
21 | #' library(parsnip)
22 | #' library(workflows)
23 | #'
24 | #' iris2 <- iris
25 | #' iris2$Species <- factor(iris2$Species == "setosa",
26 | #' labels = c("setosa", "not setosa"))
27 | #'
28 | #' svm_spec <- svm_rbf() %>%
29 | #' set_mode("classification") %>%
30 | #' set_engine("kernlab")
31 | #'
32 | #' svm_fit <- workflow() %>%
33 | #' add_formula(Species ~ Petal.Length + Petal.Width) %>%
34 | #' add_model(svm_spec) %>%
35 | #' fit(iris2)
36 | #'
37 | #' viz_prob_region(svm_fit, iris2)
38 | #'
39 | #' viz_prob_region(svm_fit, iris2, resolution = 20)
40 | #'
41 | #' viz_prob_region(svm_fit, iris2, expand = 1)
42 | #'
43 | #' viz_prob_region(svm_fit, iris2, facet = TRUE)
44 | #'
45 | #' knn_spec <- nearest_neighbor() %>%
46 | #' set_mode("classification") %>%
47 | #' set_engine("kknn")
48 | #'
49 | #' knn_fit <- workflow() %>%
50 | #' add_formula(class ~ umap_1 + umap_2) %>%
51 | #' add_model(knn_spec) %>%
52 | #' fit(mnist_sample)
53 | #'
54 | #' viz_prob_region(knn_fit, mnist_sample, facet = TRUE)
55 | viz_prob_region <- function(x, new_data, resolution = 100, expand = 0.1,
56 | facet = FALSE) {
57 | if (!inherits(x, "workflow")) {
58 | abort("`viz_decision_boundary()` only works with `workflow` objects.")
59 | }
60 | if (!x$trained) {
61 | abort("`x` must be a trained `workflow` object.")
62 | }
63 |
64 | var_names <- extract_variable_names(x, new_data, n_pred = 2)
65 |
66 | if (length(levels(new_data[[var_names$response]])) != 2 & !facet) {
67 | abort("The response must have only 2 levels for unfaceted chart.")
68 | }
69 |
70 | predict_area <- new_data %>%
71 | select(all_of(var_names$predictors)) %>%
72 | lapply(expanded_seq, expand, resolution) %>%
73 | expand.grid()
74 |
75 | predicted_prob <- predict(x, predict_area, type = "prob")
76 |
77 | plotting_data <- predict_area %>%
78 | bind_cols(predicted_prob)
79 |
80 | if (facet) {
81 | plotting_data <- pivot_longer(plotting_data, names(predicted_prob),
82 | names_to = ".class", values_to = "probability")
83 |
84 | plotting_data$.class <- gsub("^.pred_", "", plotting_data$.class)
85 |
86 | res <- plotting_data %>%
87 | ggplot(
88 | aes_string(
89 | var_names$predictors[1],
90 | var_names$predictors[2],
91 | fill = ".class"
92 | )
93 | ) +
94 | geom_raster(aes_string(alpha = "probability")) +
95 | facet_wrap(~ .data$.class) +
96 | geom_point(
97 | aes_string(
98 | var_names$predictors[1],
99 | var_names$predictors[2],
100 | fill = var_names$response
101 | ),
102 | color = "black", shape = 22, data = new_data, inherit.aes = FALSE
103 | ) +
104 | theme_minimal() +
105 | scale_alpha(range = c(0, 1))
106 |
107 | } else {
108 | res <- plotting_data %>%
109 | ggplot(
110 | aes_string(
111 | var_names$predictors[1],
112 | var_names$predictors[2],
113 | fill = names(predicted_prob)[1]
114 | )
115 | ) +
116 | geom_raster(alpha = 0.5) +
117 | scale_fill_gradient2(low = "blue", mid = "white", high = "green",
118 | midpoint = 0.5) +
119 | geom_point(
120 | aes_string(
121 | var_names$predictors[1],
122 | var_names$predictors[2]
123 | ),
124 | color = "black", shape = 22, data = new_data, inherit.aes = FALSE
125 | ) +
126 | theme_minimal()
127 | }
128 | res
129 | }
130 |
--------------------------------------------------------------------------------
/R/viz_residuals.R:
--------------------------------------------------------------------------------
1 | #' Residual plot of model
2 | #'
3 | #' @param fit a parsnip model object
4 | #' @inheritParams parsnip:::augment.model_fit
5 | #'
6 | #' @return a ggplot2 object
7 | #' @export
8 | #'
9 | #' @examples
10 | #' library(parsnip)
11 | #' reg <- linear_reg() %>%
12 | #' set_engine("lm") %>%
13 | #' set_mode("regression") %>%
14 | #' fit(mpg ~ ., data = mtcars)
15 | #'
16 | #' viz_residuals(reg, mtcars)
17 | #'
18 | #' linear_reg() %>%
19 | #' set_engine("lm") %>%
20 | #' set_mode("regression") %>%
21 | #' fit(mpg ~ ., data = mtcars) %>%
22 | #' viz_residuals(mtcars)
23 | viz_residuals <- function(fit, new_data) {
24 | augment(fit, new_data = new_data) %>%
25 | ggplot(aes(.pred, .resid)) +
26 | geom_point() +
27 | geom_abline(slope = 0, intercept = 0) +
28 | theme_minimal()
29 | }
30 |
--------------------------------------------------------------------------------
/R/viz_tsne.R:
--------------------------------------------------------------------------------
1 | #' Vizualize t-SNE
2 | #'
3 | #' @param data A data.frame.
4 | #' @param label variable to color with.
5 | #' @return ggplot2 object.
6 | #' @examples
7 | #' library(dplyr)
8 | #' viz_tsne(iris, Species)
9 | #' viz_tsne(select(iris, -Species), Sepal.Length)
10 | #' @export
11 | viz_tsne <- function(data, label) {
12 | label_enquo <- enquo(label)
13 |
14 | names <- paste0("V", c(1, 2))
15 |
16 | tsne_data <- data %>%
17 | select(-!!label_enquo) %>%
18 | as.matrix() %>%
19 | Rtsne(check_duplicates = FALSE)
20 |
21 | plotting_data <- tsne_data$Y
22 |
23 | colnames(plotting_data) <- names
24 |
25 | as_tibble(plotting_data) %>%
26 | mutate(Label = pull(data, !!label_enquo)) %>%
27 | ggplot() +
28 | aes_string(names[1], names[2], color = "Label") +
29 | geom_point() +
30 | labs(
31 | x = "",
32 | y = "",
33 | title = "t-SNE Manifold"
34 | ) +
35 | theme_minimal()
36 | }
37 |
--------------------------------------------------------------------------------
/R/workflow_utils.R:
--------------------------------------------------------------------------------
1 | extract_variable_names <- function(x, new_data, n_pred) {
2 | if (names(x$pre$actions) == "variables") {
3 | predictors <- eval_tidy(
4 | x$pre$actions$variables$predictors,
5 | set_names(names(new_data))
6 | )
7 |
8 | response <- eval_tidy(
9 | x$pre$actions$variables$outcome,
10 | set_names(names(new_data))
11 | )
12 | } else if (names(x$pre$actions) == "formula") {
13 | predictors <- intersect(
14 | as.character(x$pre$actions$formula$formula[[3]]),
15 | names(new_data)
16 | )
17 |
18 | response <- as.character(x$pre$actions$formula$formula[[2]])
19 | } else if (names(x$pre$actions) == "recipe") {
20 | var_info <- x$pre$actions$recipe$recipe$var_info
21 | predictors <- var_info$variable[var_info$role == "predictor"]
22 | response <- var_info$variable[var_info$role == "outcome"]
23 | }
24 |
25 | if (length(predictors) != n_pred) {
26 | abort(glue("`x` must have only {n_pred} predictors."))
27 | }
28 |
29 | list(predictors = predictors, response = response)
30 | }
31 |
--------------------------------------------------------------------------------
/README.Rmd:
--------------------------------------------------------------------------------
1 | ---
2 | output: github_document
3 | ---
4 |
5 |
6 |
7 | ```{r, include = FALSE}
8 | knitr::opts_chunk$set(
9 | collapse = TRUE,
10 | comment = "#>",
11 | fig.path = "man/figures/README-",
12 | out.width = "100%"
13 | )
14 | ```
15 |
16 | # horus
17 |
18 |
19 | [](https://github.com/EmilHvitfeldt/horus/actions)
20 |
21 |
22 | **WIP -- Very early build, things are very likely to change -- WIP**
23 |
24 | The goal of horus is to allow quick visualization methods for common machine learning and modeling tasks. This project is hugely inspired by the Python library [yellowbrick](https://github.com/DistrictDataLabs/yellowbrick).
25 |
26 | ## Installation
27 |
28 | For the time being `horus` is only available on Github, and can be installed
29 | with `devtools`:
30 |
31 | ```{r, eval=FALSE}
32 | # install.packages('devtools')
33 | devtools::install_github('EmilHvitfeldt/horus')
34 | ```
35 |
36 | In the future the package will be available on CRAN as well.
37 |
38 | ## Example
39 |
40 | There is no reason why a principal component plot of a data set should be as hard as it currently is in R. Using **horus** it is down to a single line!
41 |
42 | ```{r}
43 | library(horus)
44 | viz_pca(iris, Species)
45 | ```
46 |
47 | ## Advanced Gallary
48 |
49 | [Random Forrest variability Visualization](https://gist.github.com/EmilHvitfeldt/e81f9d462c423978f515f036c8ad0232)
50 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # horus
5 |
6 |
7 |
8 | [](https://github.com/EmilHvitfeldt/horus/actions)
9 |
10 |
11 | **WIP – Very early build, things are very likely to change – WIP**
12 |
13 | The goal of horus is to allow quick visualization methods for common
14 | machine learning and modeling tasks. This project is hugely inspired by
15 | the Python library
16 | [yellowbrick](https://github.com/DistrictDataLabs/yellowbrick).
17 |
18 | ## Installation
19 |
20 | For the time being `horus` is only available on Github, and can be
21 | installed with `devtools`:
22 |
23 | ``` r
24 | # install.packages('devtools')
25 | devtools::install_github('EmilHvitfeldt/horus')
26 | ```
27 |
28 | In the future the package will be available on CRAN as well.
29 |
30 | ## Example
31 |
32 | There is no reason why a principal component plot of a data set should
33 | be as hard as it currently is in R. Using **horus** it is down to a
34 | single line!
35 |
36 | ``` r
37 | library(horus)
38 | viz_pca(iris, Species)
39 | ```
40 |
41 |
42 |
43 | ## Advanced Gallary
44 |
45 | [Random Forrest variability
46 | Visualization](https://gist.github.com/EmilHvitfeldt/e81f9d462c423978f515f036c8ad0232)
47 |
--------------------------------------------------------------------------------
/_pkgdown.yml:
--------------------------------------------------------------------------------
1 | url: https://emilhvitfeldt.github.io/horus
2 |
--------------------------------------------------------------------------------
/data-raw/fairy_tales.R:
--------------------------------------------------------------------------------
1 | set.seed(12345)
2 | library(tidyverse)
3 | # devtools::install_github("EmilHvitfeldt/hcandersenr")
4 | select_books <- hcandersenr::hcandersen_en %>%
5 | pull(book) %>%
6 | unique() %>%
7 | sample(size = 5)
8 |
9 | fairy_tales <- hcandersenr::hcandersen_en %>%
10 | filter(book %in% select_books)
11 |
--------------------------------------------------------------------------------
/data-raw/mnist_sample.R:
--------------------------------------------------------------------------------
1 | library(keras)
2 | library(recipes)
3 | library(embed)
4 | library(purrr)
5 | library(dplyr)
6 | set.seed(1234)
7 |
8 | mnist_raw <- dataset_mnist()
9 |
10 | tidy_mnist <- map_dfc(1:28, ~tibble::as_tibble(mnist_raw$train$x[, , .x])) %>%
11 | mutate(class = factor(mnist_raw$train$y))
12 |
13 | mnist_sample <- recipe(class ~ ., data = slice_sample(tidy_mnist, prop = 0.02)) %>%
14 | step_umap(all_predictors(), num_comp = 2) %>%
15 | prep() %>%
16 | juice()
17 |
18 | usethis::use_data(mnist_sample, overwrite = TRUE)
19 |
--------------------------------------------------------------------------------
/data/fairy_tales.rda:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/data/fairy_tales.rda
--------------------------------------------------------------------------------
/data/mnist_sample.rda:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/data/mnist_sample.rda
--------------------------------------------------------------------------------
/horus.Rproj:
--------------------------------------------------------------------------------
1 | Version: 1.0
2 |
3 | RestoreWorkspace: No
4 | SaveWorkspace: No
5 | AlwaysSaveHistory: Default
6 |
7 | EnableCodeIndexing: Yes
8 | UseSpacesForTab: Yes
9 | NumSpacesForTab: 2
10 | Encoding: UTF-8
11 |
12 | RnwWeave: knitr
13 | LaTeX: pdfLaTeX
14 |
15 | AutoAppendNewline: Yes
16 | StripTrailingWhitespace: Yes
17 |
18 | BuildType: Package
19 | PackageUseDevtools: Yes
20 | PackageInstallArgs: --no-multiarch --with-keep.source
21 | PackageRoxygenize: rd,collate,namespace
22 |
--------------------------------------------------------------------------------
/man/fairy_tales.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/data.R
3 | \docType{data}
4 | \name{fairy_tales}
5 | \alias{fairy_tales}
6 | \title{Sample of 5 fairly tales from H.C. Andersen}
7 | \format{
8 | A tibble with 486 obs. of 2 variables.
9 | }
10 | \usage{
11 | fairy_tales
12 | }
13 | \description{
14 | Includes the fairy tales "A leaf from heaven", "A story",
15 | "The bird of folklore", "The flea and the professor" and "The rags".
16 | }
17 | \keyword{datasets}
18 |
--------------------------------------------------------------------------------
/man/figures/README-unnamed-chunk-2-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/man/figures/README-unnamed-chunk-2-1.png
--------------------------------------------------------------------------------
/man/figures/README-unnamed-chunk-3-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/man/figures/README-unnamed-chunk-3-1.png
--------------------------------------------------------------------------------
/man/figures/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/man/figures/logo.png
--------------------------------------------------------------------------------
/man/glmnet_coef_plot.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/glmnet_coef_plot.R
3 | \name{glmnet_coef_plot}
4 | \alias{glmnet_coef_plot}
5 | \title{Create Lambda chart for glmnet object}
6 | \usage{
7 | glmnet_coef_plot(fit)
8 | }
9 | \arguments{
10 | \item{fit}{parsnip fit object}
11 | }
12 | \value{
13 | ggplot2 object
14 | }
15 | \description{
16 | Create Lambda chart for glmnet object
17 | }
18 | \examples{
19 | library(parsnip)
20 | library(glmnet)
21 | linear_reg_glmnet_spec <-
22 | linear_reg(penalty = 0, mixture = 1) \%>\%
23 | set_engine('glmnet')
24 |
25 | lm_fit <- fit(linear_reg_glmnet_spec, data = mtcars, mpg ~ .)
26 |
27 | glmnet_coef_plot(lm_fit)
28 | }
29 |
--------------------------------------------------------------------------------
/man/mnist_sample.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/data.R
3 | \docType{data}
4 | \name{mnist_sample}
5 | \alias{mnist_sample}
6 | \title{Down-sampled MNIST}
7 | \format{
8 | A data frame with 1200 rows and 3 variables:
9 | \describe{
10 | \item{class}{Factor, values 0 through 9.}
11 | \item{umap_1}{Numeric.}
12 | \item{umap_2}{Numeric.}
13 | }
14 | }
15 | \source{
16 | \url{http://yann.lecun.com/exdb/mnist/}
17 | }
18 | \usage{
19 | mnist_sample
20 | }
21 | \description{
22 | this data set is a down-sampled version of the MNIST data set which have been
23 | umap'ed down to 2 dimensions.
24 | }
25 | \keyword{datasets}
26 |
--------------------------------------------------------------------------------
/man/pipe.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/utils-pipe.R
3 | \name{\%>\%}
4 | \alias{\%>\%}
5 | \title{Pipe operator}
6 | \usage{
7 | lhs \%>\% rhs
8 | }
9 | \description{
10 | See \code{magrittr::\link[magrittr]{\%>\%}} for details.
11 | }
12 | \keyword{internal}
13 |
--------------------------------------------------------------------------------
/man/viz_classbalance.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/viz_classbalance.R
3 | \name{viz_classbalance}
4 | \alias{viz_classbalance}
5 | \title{Visualise Class Imbalance}
6 | \usage{
7 | viz_classbalance(data, variable, n_max = 25)
8 | }
9 | \arguments{
10 | \item{data}{A data.frame.}
11 |
12 | \item{variable}{target variable to show balance for.}
13 |
14 | \item{n_max}{integer, maximum number of classes shown before lumping.
15 | Defaults to 25.}
16 | }
17 | \value{
18 | ggplot2 object.
19 | }
20 | \description{
21 | Visualise Class Imbalance
22 | }
23 | \examples{
24 | viz_classbalance(mnist_sample, class)
25 | }
26 |
--------------------------------------------------------------------------------
/man/viz_decision_boundary.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/viz_decision_boundary.R
3 | \name{viz_decision_boundary}
4 | \alias{viz_decision_boundary}
5 | \title{Draw Decision boundary for Classification model}
6 | \usage{
7 | viz_decision_boundary(x, new_data, resolution = 100, expand = 0.1)
8 | }
9 | \arguments{
10 | \item{x}{trained `workflows::workflow` object.}
11 |
12 | \item{new_data}{A data frame or tibble for whom the preprocessing will be
13 | applied.}
14 |
15 | \item{resolution}{Number of squared in grid. Defaults to 100.}
16 |
17 | \item{expand}{Expansion rate. Defaults to 0.1. This means that the width and
18 | height of the shaded area is 10% wider then the rectangle containing the
19 | data.
20 |
21 | The chart have been minimally modified to allow for easier styling.}
22 | }
23 | \value{
24 | `ggplot2::ggplot` object
25 | }
26 | \description{
27 | This function is mostly useful in an educational setting. Can only be used
28 | with trained workflow objects with 2 numeric predictor variables.
29 | }
30 | \examples{
31 | library(parsnip)
32 | library(workflows)
33 | svm_spec <- svm_rbf() \%>\%
34 | set_mode("classification") \%>\%
35 | set_engine("kernlab")
36 |
37 | svm_fit <- workflow() \%>\%
38 | add_formula(Species ~ Petal.Length + Petal.Width) \%>\%
39 | add_model(svm_spec) \%>\%
40 | fit(iris)
41 |
42 | viz_decision_boundary(svm_fit, iris)
43 |
44 | viz_decision_boundary(svm_fit, iris, resolution = 20)
45 |
46 | viz_decision_boundary(svm_fit, iris, expand = 1)
47 |
48 | svm_multi_fit <- workflow() \%>\%
49 | add_formula(class ~ umap_1 + umap_2) \%>\%
50 | add_model(svm_spec) \%>\%
51 | fit(mnist_sample)
52 |
53 | viz_decision_boundary(svm_multi_fit, mnist_sample)
54 | }
55 |
--------------------------------------------------------------------------------
/man/viz_dispersion.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/viz_dispersion.R
3 | \name{viz_dispersion}
4 | \alias{viz_dispersion}
5 | \title{Vizualize lexical dispersion plot}
6 | \usage{
7 | viz_dispersion(data, var, group, words = 10, symbol = NULL, alpha = 0.7)
8 | }
9 | \arguments{
10 | \item{data}{A data.frame.}
11 |
12 | \item{var}{variable that contains the words to be visualized.}
13 |
14 | \item{group}{If present with show a group for each line with the words color
15 | coded.}
16 |
17 | \item{words}{Numerical or character. If numerical it will display the n
18 | most common words. If character will show the location of said strings.}
19 |
20 | \item{symbol}{The word symbol. Default to is 18 (filed diamond) when number
21 | of points are less then 200 and to 108 (vertical line) when there are more
22 | then 200 points.}
23 |
24 | \item{alpha}{color transperency of the word symbols.}
25 | }
26 | \value{
27 | ggplot2 object.
28 | }
29 | \description{
30 | Vizualize lexical dispersion plot
31 | }
32 | \examples{
33 | \dontrun{
34 | library(tidytext)
35 |
36 | text_data <- unnest_tokens(fairy_tales, word, text)
37 | viz_dispersion(text_data, word)
38 | viz_dispersion(text_data, word, words = c("branches", "not a word"))
39 | viz_dispersion(text_data, word, symbol = "2")
40 | viz_dispersion(text_data, word, group = book)
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/man/viz_fitted_line.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/viz_fitted_line.R
3 | \name{viz_fitted_line}
4 | \alias{viz_fitted_line}
5 | \title{Draw fitted regression line}
6 | \usage{
7 | viz_fitted_line(
8 | x,
9 | new_data,
10 | resolution = 100,
11 | expand = 0.1,
12 | color = "blue",
13 | size = 1
14 | )
15 | }
16 | \arguments{
17 | \item{x}{trained `workflows::workflow` object.}
18 |
19 | \item{new_data}{A data frame or tibble for whom the preprocessing will be
20 | applied.}
21 |
22 | \item{resolution}{Number of squared in grid. Defaults to 100.}
23 |
24 | \item{expand}{Expansion rate. Defaults to 0.1. This means that the width of
25 | the plotting area is 10 percent wider then the data.}
26 |
27 | \item{color}{Character, color of the fitted line. Passed to `geom_line()`.
28 | Defaults to `"blue"`.}
29 |
30 | \item{size}{Numeric, size of the fitted line. Passed to `geom_line()`.
31 | Defaults to `1`.}
32 | }
33 | \value{
34 | `ggplot2::ggplot` object
35 | }
36 | \description{
37 | This function is mostly useful in an educational setting. Can only be used
38 | with trained workflow objects with 1 numeric predictor variable.
39 | }
40 | \details{
41 | The chart have been minimally modified to allow for easier styling.
42 | }
43 | \examples{
44 | library(parsnip)
45 | library(workflows)
46 | lm_spec <- linear_reg() \%>\%
47 | set_mode("regression") \%>\%
48 | set_engine("lm")
49 |
50 | lm_fit <- workflow() \%>\%
51 | add_formula(mpg ~ disp) \%>\%
52 | add_model(lm_spec) \%>\%
53 | fit(mtcars)
54 |
55 | viz_fitted_line(lm_fit, mtcars)
56 |
57 | viz_fitted_line(lm_fit, mtcars, expand = 1)
58 | }
59 |
--------------------------------------------------------------------------------
/man/viz_pca.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/viz_pca.R
3 | \name{viz_pca}
4 | \alias{viz_pca}
5 | \title{Vizualize principal pomponents}
6 | \usage{
7 | viz_pca(data, label, components = c(1, 2), loadings = FALSE)
8 | }
9 | \arguments{
10 | \item{data}{A data.frame.}
11 |
12 | \item{label}{variable to color with.}
13 |
14 | \item{components}{principal components to showcase.}
15 |
16 | \item{loadings}{Set this to true if you want to see the PCA loadings.}
17 | }
18 | \value{
19 | ggplot2 object.
20 | }
21 | \description{
22 | Vizualize principal pomponents
23 | }
24 | \examples{
25 | viz_pca(iris, Species)
26 | viz_pca(iris, Species, c(3, 1))
27 | viz_pca(iris, Species, loadings = TRUE)
28 | }
29 |
--------------------------------------------------------------------------------
/man/viz_pcacm.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/viz_pcacm.R
3 | \name{viz_pcacm}
4 | \alias{viz_pcacm}
5 | \title{PCA component measures}
6 | \usage{
7 | viz_pcacm(x, n_pca = 10, n_var = 10)
8 | }
9 | \arguments{
10 | \item{x}{a data.frame}
11 |
12 | \item{n_pca}{Number of Principle components to show}
13 |
14 | \item{n_var}{Number of Variables to show}
15 | }
16 | \value{
17 | a ggplot2 object
18 | }
19 | \description{
20 | PCA component measures
21 | }
22 | \examples{
23 | viz_pcacm(mtcars)
24 | viz_pcacm(USArrests)
25 | viz_pcacm(beaver1)
26 |
27 | viz_pcacm(mtcars, n_pca = 4)
28 | viz_pcacm(mtcars, n_var = 4)
29 | viz_pcacm(mtcars, n_pca = 2, n_var = 6)
30 | }
31 |
--------------------------------------------------------------------------------
/man/viz_prob_region.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/viz_prob_region.R
3 | \name{viz_prob_region}
4 | \alias{viz_prob_region}
5 | \title{Draw Probability regions for Classification model}
6 | \usage{
7 | viz_prob_region(x, new_data, resolution = 100, expand = 0.1, facet = FALSE)
8 | }
9 | \arguments{
10 | \item{x}{trained `workflows::workflow` object.}
11 |
12 | \item{new_data}{A data frame or tibble for whom the preprocessing will be
13 | applied.}
14 |
15 | \item{resolution}{Number of squared in grid. Defaults to 100.}
16 |
17 | \item{expand}{Expansion rate. Defaults to 0.1. This means that the width and
18 | height of the shaded area is 10% wider then the rectangle containing the
19 | data.}
20 |
21 | \item{facet}{Logical, whether to facet chart by class. Defaults to FALSE.
22 |
23 | The chart have been minimally modified to allow for easier styling.}
24 | }
25 | \value{
26 | `ggplot2::ggplot` object
27 | }
28 | \description{
29 | This function is mostly useful in an educational setting. Can only be used
30 | with trained workflow objects with 2 numeric predictor variables.
31 | }
32 | \examples{
33 | library(parsnip)
34 | library(workflows)
35 |
36 | iris2 <- iris
37 | iris2$Species <- factor(iris2$Species == "setosa",
38 | labels = c("setosa", "not setosa"))
39 |
40 | svm_spec <- svm_rbf() \%>\%
41 | set_mode("classification") \%>\%
42 | set_engine("kernlab")
43 |
44 | svm_fit <- workflow() \%>\%
45 | add_formula(Species ~ Petal.Length + Petal.Width) \%>\%
46 | add_model(svm_spec) \%>\%
47 | fit(iris2)
48 |
49 | viz_prob_region(svm_fit, iris2)
50 |
51 | viz_prob_region(svm_fit, iris2, resolution = 20)
52 |
53 | viz_prob_region(svm_fit, iris2, expand = 1)
54 |
55 | viz_prob_region(svm_fit, iris2, facet = TRUE)
56 |
57 | knn_spec <- nearest_neighbor() \%>\%
58 | set_mode("classification") \%>\%
59 | set_engine("kknn")
60 |
61 | knn_fit <- workflow() \%>\%
62 | add_formula(class ~ umap_1 + umap_2) \%>\%
63 | add_model(knn_spec) \%>\%
64 | fit(mnist_sample)
65 |
66 | viz_prob_region(knn_fit, mnist_sample, facet = TRUE)
67 | }
68 |
--------------------------------------------------------------------------------
/man/viz_residuals.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/viz_residuals.R
3 | \name{viz_residuals}
4 | \alias{viz_residuals}
5 | \title{Residual plot of model}
6 | \usage{
7 | viz_residuals(fit, new_data)
8 | }
9 | \arguments{
10 | \item{fit}{a parsnip model object}
11 |
12 | \item{new_data}{A data frame or matrix.}
13 | }
14 | \value{
15 | a ggplot2 object
16 | }
17 | \description{
18 | Residual plot of model
19 | }
20 | \examples{
21 | library(parsnip)
22 | reg <- linear_reg() \%>\%
23 | set_engine("lm") \%>\%
24 | set_mode("regression") \%>\%
25 | fit(mpg ~ ., data = mtcars)
26 |
27 | viz_residuals(reg, mtcars)
28 |
29 | linear_reg() \%>\%
30 | set_engine("lm") \%>\%
31 | set_mode("regression") \%>\%
32 | fit(mpg ~ ., data = mtcars) \%>\%
33 | viz_residuals(mtcars)
34 | }
35 |
--------------------------------------------------------------------------------
/man/viz_tsne.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/viz_tsne.R
3 | \name{viz_tsne}
4 | \alias{viz_tsne}
5 | \title{Vizualize t-SNE}
6 | \usage{
7 | viz_tsne(data, label)
8 | }
9 | \arguments{
10 | \item{data}{A data.frame.}
11 |
12 | \item{label}{variable to color with.}
13 | }
14 | \value{
15 | ggplot2 object.
16 | }
17 | \description{
18 | Vizualize t-SNE
19 | }
20 | \examples{
21 | library(dplyr)
22 | viz_tsne(iris, Species)
23 | viz_tsne(select(iris, -Species), Sepal.Length)
24 | }
25 |
--------------------------------------------------------------------------------
/pkgdown/favicon/apple-touch-icon-120x120.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/apple-touch-icon-120x120.png
--------------------------------------------------------------------------------
/pkgdown/favicon/apple-touch-icon-152x152.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/apple-touch-icon-152x152.png
--------------------------------------------------------------------------------
/pkgdown/favicon/apple-touch-icon-180x180.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/apple-touch-icon-180x180.png
--------------------------------------------------------------------------------
/pkgdown/favicon/apple-touch-icon-60x60.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/apple-touch-icon-60x60.png
--------------------------------------------------------------------------------
/pkgdown/favicon/apple-touch-icon-76x76.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/apple-touch-icon-76x76.png
--------------------------------------------------------------------------------
/pkgdown/favicon/apple-touch-icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/apple-touch-icon.png
--------------------------------------------------------------------------------
/pkgdown/favicon/favicon-16x16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/favicon-16x16.png
--------------------------------------------------------------------------------
/pkgdown/favicon/favicon-32x32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/favicon-32x32.png
--------------------------------------------------------------------------------
/pkgdown/favicon/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmilHvitfeldt/horus/86a9b9367d29be768e7573062cd2792fe2275027/pkgdown/favicon/favicon.ico
--------------------------------------------------------------------------------
/tests/figs/deps.txt:
--------------------------------------------------------------------------------
1 | - vdiffr-svg-engine: 1.0
2 | - vdiffr: 0.3.3.9000
3 | - freetypeharfbuzz: 0.2.6
4 |
--------------------------------------------------------------------------------
/tests/figs/viz_classbalance/viz-classbalance-n-max-2.svg:
--------------------------------------------------------------------------------
1 |
2 |
50 |
--------------------------------------------------------------------------------
/tests/figs/viz_classbalance/viz-classbalance-n-max-4-ties.svg:
--------------------------------------------------------------------------------
1 |
2 |
58 |
--------------------------------------------------------------------------------
/tests/figs/viz_classbalance/viz-classbalance-simple.svg:
--------------------------------------------------------------------------------
1 |
2 |
68 |
--------------------------------------------------------------------------------
/tests/figs/viz_fitted_line/viz-fitted-line-expand.svg:
--------------------------------------------------------------------------------
1 |
2 |
89 |
--------------------------------------------------------------------------------
/tests/figs/viz_fitted_line/viz-fitted-line-resolution.svg:
--------------------------------------------------------------------------------
1 |
2 |
94 |
--------------------------------------------------------------------------------
/tests/figs/viz_fitted_line/viz-fitted-line-simple.svg:
--------------------------------------------------------------------------------
1 |
2 |
94 |
--------------------------------------------------------------------------------
/tests/figs/viz_fitted_line/viz-fitted-line-style.svg:
--------------------------------------------------------------------------------
1 |
2 |
94 |
--------------------------------------------------------------------------------
/tests/figs/viz_pcacm/viz-pcacm-components.svg:
--------------------------------------------------------------------------------
1 |
2 |
69 |
--------------------------------------------------------------------------------
/tests/figs/viz_pcacm/viz-pcacm-loadings-components.svg:
--------------------------------------------------------------------------------
1 |
2 |
61 |
--------------------------------------------------------------------------------
/tests/figs/viz_pcacm/viz-pcacm-loadings.svg:
--------------------------------------------------------------------------------
1 |
2 |
71 |
--------------------------------------------------------------------------------
/tests/figs/viz_pcacm/viz-pcacm-simple.svg:
--------------------------------------------------------------------------------
1 |
2 |
81 |
--------------------------------------------------------------------------------
/tests/figs/viz_residuals/viz-residuals-simple.svg:
--------------------------------------------------------------------------------
1 |
2 |
89 |
--------------------------------------------------------------------------------
/tests/testthat.R:
--------------------------------------------------------------------------------
1 | library(testthat)
2 | library(horus)
3 |
4 | test_check("horus")
5 |
--------------------------------------------------------------------------------
/tests/testthat/test-viz_classbalance.R:
--------------------------------------------------------------------------------
1 | library(testthat)
2 |
3 | test_that("viz_classbalance works", {
4 |
5 | vdiffr::expect_doppelganger(
6 | "viz_classbalance simple",
7 | viz_classbalance(mnist_sample, class),
8 | "viz_classbalance"
9 | )
10 |
11 | vdiffr::expect_doppelganger(
12 | "viz_classbalance n_max 2",
13 | viz_classbalance(mnist_sample, class, n_max = 2),
14 | "viz_classbalance"
15 | )
16 |
17 | vdiffr::expect_doppelganger(
18 | "viz_classbalance n_max 4 ties",
19 | expect_message(
20 | viz_classbalance(mnist_sample, class, n_max = 4),
21 | "5"
22 | ),
23 | "viz_classbalance"
24 | )
25 |
26 | })
27 |
--------------------------------------------------------------------------------
/tests/testthat/test-viz_decision_boundary.R:
--------------------------------------------------------------------------------
1 | library(testthat)
2 | library(parsnip)
3 | library(workflows)
4 |
5 | set.seed(1234)
6 |
7 | svm_spec <- nearest_neighbor() %>%
8 | set_mode("classification") %>%
9 | set_engine("kknn")
10 |
11 | svm_fit <- workflow() %>%
12 | add_formula(Species ~ Petal.Length + Petal.Width) %>%
13 | add_model(svm_spec) %>%
14 | fit(iris)
15 |
16 | test_that("viz_decision_boundary works", {
17 |
18 | vdiffr::expect_doppelganger(
19 | "viz_decision_boundary simple",
20 | viz_decision_boundary(svm_fit, iris),
21 | "viz_decision_boundary"
22 | )
23 |
24 | vdiffr::expect_doppelganger(
25 | "viz_decision_boundary resolution",
26 | viz_decision_boundary(svm_fit, iris, resolution = 20),
27 | "viz_decision_boundary"
28 | )
29 |
30 | vdiffr::expect_doppelganger(
31 | "viz_decision_boundary expand",
32 | viz_decision_boundary(svm_fit, iris, expand = 1),
33 | "viz_decision_boundary"
34 | )
35 | })
36 |
--------------------------------------------------------------------------------
/tests/testthat/test-viz_fitted_line.R:
--------------------------------------------------------------------------------
1 | library(testthat)
2 | library(parsnip)
3 | library(workflows)
4 |
5 | set.seed(1234)
6 |
7 | knn_spec <- nearest_neighbor() %>%
8 | set_mode("regression") %>%
9 | set_engine("kknn")
10 |
11 | knn_fit <- workflow() %>%
12 | add_formula(mpg ~ disp) %>%
13 | add_model(knn_spec) %>%
14 | fit(mtcars)
15 |
16 | test_that("viz_fitted_line works", {
17 |
18 | vdiffr::expect_doppelganger(
19 | "viz_fitted_line simple",
20 | viz_fitted_line(knn_fit, mtcars),
21 | "viz_fitted_line"
22 | )
23 |
24 | vdiffr::expect_doppelganger(
25 | "viz_fitted_line resolution",
26 | viz_fitted_line(knn_fit, mtcars, resolution = 20),
27 | "viz_fitted_line"
28 | )
29 |
30 | vdiffr::expect_doppelganger(
31 | "viz_fitted_line expand",
32 | viz_fitted_line(knn_fit, mtcars, expand = 1),
33 | "viz_fitted_line"
34 | )
35 |
36 | vdiffr::expect_doppelganger(
37 | "viz_fitted_line style",
38 | viz_fitted_line(knn_fit, mtcars, color = "pink", size = 4),
39 | "viz_fitted_line"
40 | )
41 |
42 | })
43 |
--------------------------------------------------------------------------------
/tests/testthat/test-viz_pca.R:
--------------------------------------------------------------------------------
1 | library(testthat)
2 | test_that("viz_pca works", {
3 |
4 | vdiffr::expect_doppelganger(
5 | "viz_pca simple",
6 | viz_pca(iris, Species),
7 | "viz_pca"
8 | )
9 |
10 | vdiffr::expect_doppelganger(
11 | "viz_pca components",
12 | viz_pca(iris, Species, components = c(3, 1)),
13 | "viz_pca"
14 | )
15 |
16 | vdiffr::expect_doppelganger(
17 | "viz_pca loadings",
18 | viz_pca(iris, Species, loadings = TRUE),
19 | "viz_pca"
20 | )
21 | })
22 |
--------------------------------------------------------------------------------
/tests/testthat/test-viz_pcacm.R:
--------------------------------------------------------------------------------
1 | library(testthat)
2 | test_that("viz_pcacm works", {
3 |
4 | vdiffr::expect_doppelganger(
5 | "viz_pcacm simple",
6 | viz_pcacm(mtcars),
7 | "viz_pcacm"
8 | )
9 |
10 | vdiffr::expect_doppelganger(
11 | "viz_pcacm components",
12 | viz_pcacm(mtcars, n_pca = 4),
13 | "viz_pcacm"
14 | )
15 |
16 | vdiffr::expect_doppelganger(
17 | "viz_pcacm loadings",
18 | viz_pcacm(mtcars, n_var = 5),
19 | "viz_pcacm"
20 | )
21 |
22 | vdiffr::expect_doppelganger(
23 | "viz_pcacm loadings components",
24 | viz_pcacm(mtcars, n_pca = 2, n_var = 8),
25 | "viz_pcacm"
26 | )
27 | })
28 |
--------------------------------------------------------------------------------
/tests/testthat/test-viz_prob_region.R:
--------------------------------------------------------------------------------
1 | library(testthat)
2 | library(parsnip)
3 | library(workflows)
4 |
5 | set.seed(1234)
6 |
7 | iris2 <- iris
8 | iris2$Species <- factor(iris2$Species == "setosa",
9 | labels = c("setosa", "not setosa"))
10 |
11 | svm_spec <- svm_rbf() %>%
12 | set_mode("classification") %>%
13 | set_engine("kernlab")
14 |
15 | svm_fit <- workflow() %>%
16 | add_formula(Species ~ Petal.Length + Petal.Width) %>%
17 | add_model(svm_spec) %>%
18 | fit(iris2)
19 |
20 | svm_fit_full <- workflow() %>%
21 | add_formula(Species ~ Petal.Length + Petal.Width) %>%
22 | add_model(svm_spec) %>%
23 | fit(iris)
24 |
25 | test_that("viz_prob_region works", {
26 |
27 | vdiffr::expect_doppelganger(
28 | "viz_prob_region simple",
29 | viz_prob_region(svm_fit, iris2),
30 | "viz_prob_region"
31 | )
32 |
33 | vdiffr::expect_doppelganger(
34 | "viz_prob_region resolution",
35 | viz_prob_region(svm_fit, iris2, resolution = 20),
36 | "viz_prob_region"
37 | )
38 |
39 | vdiffr::expect_doppelganger(
40 | "viz_prob_region expand",
41 | viz_prob_region(svm_fit, iris2, expand = 1),
42 | "viz_prob_region"
43 | )
44 | })
45 |
46 | test_that("viz_prob_region facet works", {
47 |
48 | expect_error(
49 | viz_prob_region(svm_fit_full, iris)
50 | )
51 |
52 | vdiffr::expect_doppelganger(
53 | "viz_prob_region facet simple",
54 | viz_prob_region(svm_fit_full, iris, facet = TRUE),
55 | "viz_prob_region"
56 | )
57 |
58 | vdiffr::expect_doppelganger(
59 | "viz_prob_region facet resolution",
60 | viz_prob_region(svm_fit_full, iris, resolution = 20, facet = TRUE),
61 | "viz_prob_region"
62 | )
63 |
64 | vdiffr::expect_doppelganger(
65 | "viz_prob_region facet expand",
66 | viz_prob_region(svm_fit_full, iris, expand = 1, facet = TRUE),
67 | "viz_prob_region"
68 | )
69 | })
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/tests/testthat/test-viz_residuals.R:
--------------------------------------------------------------------------------
1 | library(parsnip)
2 |
3 | reg <- linear_reg() %>%
4 | set_engine("lm") %>%
5 | set_mode("regression") %>%
6 | fit(mpg ~ ., data = mtcars)
7 |
8 | test_that("viz_residuals works", {
9 |
10 | vdiffr::expect_doppelganger(
11 | "viz_residuals simple",
12 | viz_residuals(reg, mtcars),
13 | "viz_residuals"
14 | )
15 | })
16 |
--------------------------------------------------------------------------------
/tests/testthat/test-viz_tsne.R:
--------------------------------------------------------------------------------
1 | library(testthat)
2 | library(dplyr)
3 |
4 | set.seed(1234)
5 |
6 | test_that("viz_tsne works", {
7 |
8 | vdiffr::expect_doppelganger(
9 | "viz_tsne simple factor",
10 | viz_tsne(iris, Species),
11 | "viz_tsne"
12 | )
13 |
14 | vdiffr::expect_doppelganger(
15 | "viz_tsne simple numeric",
16 | viz_tsne(select(iris, -Species), Sepal.Length),
17 | "viz_tsne"
18 | )
19 | })
20 |
--------------------------------------------------------------------------------