├── .Rbuildignore ├── .gitignore ├── .lintr ├── .travis.yml ├── DESCRIPTION ├── LICENSE ├── NAMESPACE ├── R ├── cca.R ├── ccf.R ├── cct.R ├── data.R └── utilities.R ├── README.Rmd ├── README.md ├── ccf.Rproj ├── data └── spirals.rda ├── demo ├── 00Index ├── cca_demo.R ├── ccf_demo.R ├── dataset_demos.R └── profiling.R ├── man ├── canonical_correlation_analysis.Rd ├── canonical_correlation_tree.Rd ├── ccf.Rd ├── get_missclassification_rate.Rd ├── plot.canonical_correlation_forest.Rd ├── plot_decision_surface.Rd ├── predict.canonical_correlation_forest.Rd └── spirals.Rd └── tests ├── testthat.R └── testthat ├── test_cca.R ├── test_ccf.R └── test_cct.R /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^.*\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^README\.Rmd$ 4 | ^README-.*\.png$ 5 | ^\.travis\.yml$ 6 | ^\.lintr$ 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # History files 2 | .Rhistory 3 | .Rapp.history 4 | 5 | # Session Data files 6 | .RData 7 | 8 | # Example code in package build process 9 | *-Ex.R 10 | 11 | # Output files from R CMD build 12 | /*.tar.gz 13 | 14 | # Output files from R CMD check 15 | /*.Rcheck/ 16 | 17 | # RStudio files 18 | .Rproj.user 19 | 20 | # produced vignettes 21 | vignettes/*.html 22 | vignettes/*.pdf 23 | 24 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 25 | .httr-oauth 26 | 27 | # knitr and R markdown default cache directories 28 | /*_cache/ 29 | /cache/ 30 | 31 | # Temporary files created by R markdown 32 | *.utf8.md 33 | *.knit.md 34 | -------------------------------------------------------------------------------- /.lintr: -------------------------------------------------------------------------------- 1 | linters: with_defaults(assignment_linter = NULL, camel_case_linter = NULL) 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # R for travis: see documentation at https://docs.travis-ci.com/user/languages/r 2 | 3 | language: R 4 | sudo: false 5 | cache: packages 6 | 7 | r_packages: 8 | - covr 9 | 10 | after_success: 11 | - Rscript -e 'library(covr); codecov()' -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: ccf 2 | Type: Package 3 | Title: Canonical correlation forest 4 | Version: 0.1.0 5 | Date: 2016-09-04 6 | Authors@R: c(person("Janosch", "Dobler", email = "info@jandob.com", 7 | role = c("aut", "cre")), 8 | person("Stefan", "Feuerriegel", email="stefan.feuerriegel@is.uni-freiburg.de", 9 | role = c("aut"))) 10 | Author: Janosch Dobler [aut, cre], 11 | Stefan Feuerriegel[aut] 12 | Maintainer: Janosch Dobler 13 | Description: The so-called canonical correlation forest (CCF) presents classification 14 | algorithm for prediction tasks from machine learning. It utilizes an ensemble of 15 | decision trees with data preprocessing by the means of canonical correlation analysis. 16 | Thereby, it is capable of outperforming many widely-used classification methods. 17 | License: MIT + file LICENSE 18 | URL: https://github.com/jandob/ccf 19 | BugReports: https://github.com/jandob/ccf/issues 20 | Depends: 21 | R (>= 2.10) 22 | Suggests: 23 | testthat, 24 | knitr, 25 | ggplot2, 26 | MASS, 27 | pracma 28 | LazyData: TRUE 29 | RoxygenNote: 6.0.1 30 | VignetteBuilder: knitr 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2016 2 | COPYRIGHT HOLDER: Janosch Dobler 3 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method(canonical_correlation_forest,default) 4 | S3method(canonical_correlation_forest,formula) 5 | S3method(plot,canonical_correlation_forest) 6 | S3method(plot,canonical_correlation_tree) 7 | S3method(predict,canonical_correlation_forest) 8 | S3method(predict,canonical_correlation_tree) 9 | export(canonical_correlation_analysis) 10 | export(canonical_correlation_forest) 11 | export(canonical_correlation_tree) 12 | export(get_missclassification_rate) 13 | export(plot_decision_surface) 14 | importFrom(stats,model.frame) 15 | importFrom(stats,model.matrix) 16 | importFrom(stats,model.response) 17 | importFrom(stats,predict) 18 | importFrom(utils,head) 19 | importFrom(utils,tail) 20 | -------------------------------------------------------------------------------- /R/cca.R: -------------------------------------------------------------------------------- 1 | #' @source Idea: \url{http://gastonsanchez.com/how-to/2014/01/15/Center-data-in-R/} 2 | center_colmeans <- function(x) { 3 | x_center <- colMeans(x) 4 | return(x - rep(x_center, rep.int(nrow(x), ncol(x)))) 5 | } 6 | 7 | #' Canonical correlation analysis 8 | #' 9 | #' Canonical correlation analysis (CCA) finds pairs of vectors \eqn{(w,v)} such that projections 10 | #' \eqn{Xw} and \eqn{Yv} have maximal possible correlations. The pairs are ordered in decreasing 11 | #' order of the correlations. In addition, projection vectors are normalized such that the 12 | #' variance of \eqn{Xw} and of \eqn{Yv} is equal to \eqn{1}. This means that projections are 13 | #' not only correlated, but "on the same scale" and hence can be directly compared. 14 | #' @param x Matrix of size n-by-p with n observations from p variables. Alternatively, data 15 | #' frames and numeric vectors are supported and automatically converted. 16 | #' @param y Matrix of size n-by-p with n observations from p variables. Alternatively, data 17 | #' frames and numeric vectors are supported and automatically converted. 18 | #' @param epsilon Numeric value usued as tolerance threshold for rank reduction of the 19 | #' input matrices. Default is \code{1e-4}. 20 | #' @return A list containing the following components 21 | #' \itemize{ 22 | #' \item{xcoef}{Estimated estimated coefficients for the \code{x} variable.} 23 | #' \item{ycoef}{Estimated estimated coefficients for the \code{y} variable.} 24 | #' \item{cor}{Matrix with correlation coefficients.} 25 | #' } 26 | #' @examples 27 | #' library(MASS) 28 | #' library(pracma) 29 | #' 30 | #' X <- mvrnorm(1000, mu = c(0, 0), Sigma = eye(2)) 31 | #' cca <- canonical_correlation_analysis(X, X) 32 | #' cca 33 | #' 34 | #' X <- mvrnorm(1000, mu = c(1, 2), 35 | #' Sigma = matrix(c(1.5, 0.5, 0.5, 1.5), ncol = 2)) 36 | #' cca <- canonical_correlation_analysis(X, X) 37 | #' cca 38 | #' @export 39 | canonical_correlation_analysis <- function(x, y, epsilon = 1e-4) { 40 | if (is.data.frame(x) || is.vector(x)) { 41 | x <- as.matrix(x) 42 | } 43 | if (is.data.frame(y) || is.vector(y)) { 44 | y <- as.matrix(y) 45 | } 46 | 47 | if (!is.matrix(x) || !is.matrix(y)) { 48 | stop("Arguments 'x' and 'y' must be of type matrix or data frame.") 49 | } 50 | 51 | # mean centering 52 | x <- center_colmeans(x) 53 | y <- center_colmeans(y) 54 | 55 | # QR decomposition 56 | # (https://cran.r-project.org/doc/contrib/Hiebeler-matlabR.pdf) 57 | qrDecompX <- qr(x, tol = epsilon) 58 | qX <- qr.Q(qrDecompX) 59 | rX <- qr.R(qrDecompX) 60 | pX <- qrDecompX$pivot 61 | rankX <- qrDecompX$rank 62 | 63 | qrDecomp <- qr(y, tol = epsilon) 64 | qY <- qr.Q(qrDecomp) 65 | rY <- qr.R(qrDecomp) 66 | pY <- qrDecomp$pivot 67 | rankY <- qrDecomp$rank 68 | 69 | # reduce Q and R to full rank 70 | if (rankX == 0) { 71 | matrixA = matrix(0, ncol(x), 1) 72 | matrixA[1, ] = 1 73 | return(list(xcoef = matrixA, 74 | ycoef = NULL, 75 | cor = NULL) 76 | ) 77 | } 78 | 79 | if (rankX < ncol(x)) { 80 | qX <- qX[, 1:rankX] 81 | } 82 | rX <- rX[1:rankX, 1:rankX] 83 | if (rankY < ncol(y)) { 84 | qY <- qY[, 1:rankY] 85 | } 86 | rY <- rY[1:rankY, 1:rankY] 87 | 88 | numberOfCoefficientPairs = min(rankX, rankY) 89 | 90 | # select which decomposition is faster 91 | if (rankX >= rankY) { 92 | svd <- svd(t(qX) %*% qY) 93 | L <- svd$u 94 | D <- svd$d 95 | M <- svd$v 96 | } else { 97 | svd <- svd(t(qY) %*% qX) 98 | M <- svd$u 99 | D <- svd$d 100 | L <- svd$v 101 | } 102 | 103 | # Remove meaningless components in L and M 104 | # Note solve(x) == x^-1 105 | A <- solve(rX) %*% L[, 1:numberOfCoefficientPairs] * sqrt(nrow(x) - 1) 106 | B <- solve(rY) %*% M[, 1:numberOfCoefficientPairs] * sqrt(nrow(x) - 1) 107 | 108 | # restore full size 109 | A <- rbind(A, matrix(0, ncol(x) - rankX, numberOfCoefficientPairs)) 110 | B <- rbind(B, matrix(0, ncol(y) - rankY, numberOfCoefficientPairs)) 111 | 112 | correlations <- diag(D) 113 | 114 | # restore order 115 | A = A[pX, , drop = FALSE] #nolint 116 | B = B[pY, , drop = FALSE] #nolint 117 | 118 | # normalize 119 | A = scale(A, center = FALSE, scale = sqrt(colSums(A ^ 2))) 120 | 121 | # convert to matrices and restore dimension names 122 | matrixA = as.matrix(A) 123 | matrixB = as.matrix(B) 124 | rownames(matrixA) = colnames(x) 125 | rownames(matrixB) = colnames(y) 126 | 127 | return(list(xcoef = matrixA, 128 | ycoef = matrixB, 129 | cor = as.matrix(correlations)) #TODO add dimnames 130 | ) 131 | 132 | } 133 | -------------------------------------------------------------------------------- /R/ccf.R: -------------------------------------------------------------------------------- 1 | #' Canonical correlation forest 2 | #' 3 | #' This function computes a classifier based on a canonical correlation forest. It 4 | #' expects its input in matrix form or as formula notation. 5 | #' @param x Numeric matrix (n * p) with n observations of p variables 6 | #' @param y Numeric matrix with n observations of q variables 7 | #' @param ntree Number of trees the forest will be composed of 8 | #' @param verbose Optional argument to control if additional information are 9 | #' printed to the output. Default is \code{FALSE}. 10 | #' @param projectionBootstrap Use projection bootstrapping. (default \code{FALSE}) 11 | #' @param ... Further arguments passed to or from other methods. 12 | #' @return returns an object of class "canonical_correlation_forest", 13 | #' where an object of this class is a list containing the following 14 | #' components: 15 | #' \itemize{ 16 | #' \item{x,y}{The original input data} 17 | #' \item{y_encoded}{The encoded \code{y} variable in case of classification tasks.} 18 | #' \item{forest}{a vector of length ntree with objects of class 19 | #' \code{canonical_correlation_tree}.} 20 | #' } 21 | #' @examples 22 | #' data(spirals) 23 | #' 24 | #' d_train <- spirals[1:1000, ] 25 | #' d_test <- spirals[-(1:1000), ] 26 | #' 27 | #' # compute classifier on training data 28 | #' ## variant 1: matrix input 29 | #' m1 <- canonical_correlation_forest(d_train[, c("x", "y")], d_train$class, ntree = 20) 30 | #' ## variant 2: formula notation 31 | #' m2 <- canonical_correlation_forest(class ~ ., d_train) 32 | #' 33 | #' # compute predictive accuracy 34 | #' get_missclassification_rate(m1, d_test) 35 | #' get_missclassification_rate(m2, d_test) 36 | #' @references Rainforth, T., and Wood, F. (2015): Canonical correlation forest, 37 | #' arXiv preprint, arXiv:1507.05444, \url{https://arxiv.org/pdf/1507.05444.pdf}. 38 | #' @rdname ccf 39 | #' @export 40 | canonical_correlation_forest = function(x, y = NULL, 41 | ntree = 200, verbose = FALSE, ...) { 42 | UseMethod("canonical_correlation_forest", x) 43 | } 44 | 45 | 46 | #' @rdname ccf 47 | #' @export 48 | canonical_correlation_forest.default = 49 | function(x, y = NULL, ntree = 200, verbose = FALSE, 50 | projectionBootstrap = FALSE, ...) { 51 | forest <- vector(mode = "list", length = ntree) 52 | 53 | if (is.null(y)) { 54 | stop("CCF requires y variable.") 55 | } 56 | 57 | if (is.factor(y)) { 58 | y_encoded <- one_hot_encode(y) 59 | y_use <- y_encoded 60 | } else if (is.integer(y)) { 61 | y_encoded <- one_hot_encode(y) 62 | y_use <- y_encoded 63 | } else { 64 | y_encoded <- NULL 65 | y_use <- y 66 | } 67 | 68 | for (i in 1:ntree) { 69 | if (verbose) { 70 | cat("Training tree", i, "of", ntree, "\n") 71 | } 72 | 73 | if (!projectionBootstrap) { 74 | # use (breiman's) tree bagging 75 | sample_idx <- sample(nrow(x), size = nrow(x), replace = TRUE) 76 | x_bag <- x[sample_idx, , drop = FALSE] #nolint 77 | if (is.vector(y_use)) { 78 | y_bag <- y_use[sample_idx, drop = FALSE] 79 | } else { 80 | y_bag <- y_use[sample_idx, , drop = FALSE] #nolint 81 | } 82 | } else { 83 | # use projection bootstrapping; no sampling needed 84 | x_bag <- x 85 | y_bag <- y_use 86 | } 87 | 88 | forest[[i]] <- canonical_correlation_tree( 89 | x_bag, y_bag, projectionBootstrap = projectionBootstrap) 90 | } 91 | 92 | model <- structure(list(x = x, y = y, y_encoded = y_encoded, 93 | ntree = ntree, forest = forest), 94 | class = "canonical_correlation_forest") 95 | return(model) 96 | } 97 | 98 | #' @importFrom stats model.frame model.response model.matrix 99 | #' @rdname ccf 100 | #' @export 101 | canonical_correlation_forest.formula = function( 102 | x, y = NULL, ntree = 200, verbose = FALSE, ...) { 103 | formula <- x 104 | data <- y 105 | 106 | if (is.matrix(data)) { 107 | data <- as.data.frame(data) 108 | } 109 | 110 | model_frame <- model.frame(formula, data = data) 111 | 112 | x = as.matrix(model.matrix(formula, data = model_frame)) 113 | x = x[,-1] # remove intercept 114 | y = model.response(model_frame) 115 | 116 | canonical_correlation_forest.default(x, y, ntree = ntree, verbose = verbose, ...) 117 | } 118 | 119 | #' Prediction from canonical correlation forest 120 | #' 121 | #' Performs predictions on test data for a trained canonical correlation forest. 122 | #' @param object An object of class \code{canonical_correlation_forest}, as created 123 | #' by the function \code{\link{canonical_correlation_forest}}. 124 | #' @param newdata A data frame or a matrix containing the test data. 125 | #' @param verbose Optional argument to control if additional information are 126 | #' printed to the output. Default is \code{FALSE}. 127 | #' @param ... Additional parameters passed on to prediction from individual 128 | #' canonical correlation trees. 129 | #' @export 130 | predict.canonical_correlation_forest = function( 131 | object, newdata, verbose = FALSE, ...) { 132 | if (missing(newdata)) { 133 | stop("Argument 'newdata' is missing.") 134 | } 135 | 136 | ntree <- length(object$forest) 137 | treePredictions <- matrix(NA, nrow = nrow(newdata), ncol = ntree) 138 | 139 | 140 | if (verbose) { 141 | cat("calculating predictions\n") 142 | } 143 | # returns list of list 144 | treePredictions = lapply(object$forest, predict, newdata) 145 | # convert to matrix 146 | treePredictions = do.call(cbind, treePredictions) 147 | if (verbose) { 148 | cat("Majority vote\n") 149 | } 150 | treePredictions <- apply(treePredictions, 1, function(row) { 151 | names(which.max(table(row))) 152 | }) 153 | 154 | return(treePredictions) 155 | } 156 | 157 | #' Visualization of canonical correlation forest 158 | #' 159 | #' TODO: document 160 | #' @param ... Further arguments passed to or from other methods. 161 | #' @export 162 | plot.canonical_correlation_forest = function(...) { 163 | plot.canonical_correlation_tree(...) 164 | } 165 | -------------------------------------------------------------------------------- /R/cct.R: -------------------------------------------------------------------------------- 1 | setupLeaf = function(y) { 2 | if (is.matrix(y)) { 3 | countsNode <- colSums(y) 4 | maxCounts <- max(countsNode) 5 | equalMaxCounts <- maxCounts == countsNode 6 | 7 | if (sum(equalMaxCounts) == 1) { 8 | classIndex <- which(equalMaxCounts) #TODO check this 9 | } else { 10 | TODO("TODO multiple max tie breaking") 11 | } 12 | } else { 13 | countsNode <- sum(y) 14 | classIndex <- 1 15 | } 16 | 17 | return(list(isLeaf = TRUE, 18 | classIndex = classIndex, 19 | trainingCounts = countsNode)) 20 | } 21 | 22 | #' @importFrom utils head tail 23 | find_best_split = function(X, Y, xVariationTolerance = 1e-10) { 24 | numberOfProjectionDirections = ncol(X) 25 | splitGains = matrix(NA, numberOfProjectionDirections, 1) 26 | splitIndices = matrix(NA, numberOfProjectionDirections, 1) 27 | # iterate over all dimensions in X 28 | for (i in seq(1, numberOfProjectionDirections)) { 29 | # sort by the value of the current dimension (feature) i 30 | sortOrder = order(X[, i]) 31 | X_sorted = X[sortOrder, i] # vector of values from dimension i 32 | # matrix of labels (also sorted according to current dimension) 33 | Ysorted = Y[sortOrder, ] 34 | # So we have something like this (Note that Y is actually one-hot encoded 35 | # and therefore a matrix) 36 | # X_sorted: 64 65 68 69 70 71 72 72 75 75 80 81 83 85 #nolint 37 | # Y_sorted: red blue green green blue red red blue red blue green red red blue #nolint 38 | 39 | # For every possible split_point (corresponds to value in X) 40 | # count the number of classes that occur in the partition 41 | # where X < split_point 42 | # e.g. for split_point = 72: (2 red, 2 blue, 1 green) 43 | LeftCumCounts = apply(Ysorted, 2, cumsum) 44 | total_counts = utils::tail(LeftCumCounts, 1) 45 | # Do the same for X > split_point. 46 | # We can just substract each row from total_counts 47 | RightCumCounts = 48 | sweep(LeftCumCounts, MARGIN = 2, total_counts, FUN = "-") * -1 49 | uniquePoints = c(diff(X_sorted) > xVariationTolerance, recursive = F) 50 | # proportion of classes to the left/right of split points 51 | pL = LeftCumCounts / rowSums(LeftCumCounts) 52 | pR = RightCumCounts / rowSums(RightCumCounts) 53 | 54 | pLProd = pL * log2(pL) 55 | pLProd[pL == 0] = 0 56 | metricLeft = -apply(pLProd, 1, sum) 57 | pRProd = pR * log2(pR) 58 | pRProd[pR == 0] = 0 59 | metricRight = -apply(pRProd, 1, sum) 60 | 61 | metricCurrent = utils::tail(metricLeft, 1) 62 | metricLeft[!uniquePoints] = Inf 63 | metricRight[!uniquePoints] = Inf 64 | N = nrow(X) 65 | metricGain = metricCurrent - (seq(1, N) * metricLeft + 66 | rev(seq(0, N - 1)) * metricRight 67 | ) / N 68 | 69 | # sample from equally best splits 70 | metricGainWOLast = utils::head(metricGain, -1) 71 | maxGain = max(metricGainWOLast) 72 | 73 | # equalMaxIndices correspond to indices of metricGain 74 | equalMaxIndices = which(abs(metricGainWOLast - maxGain) < 10 * eps) 75 | maxIndex = random_element(equalMaxIndices) 76 | 77 | splitGains[i] = metricGainWOLast[maxIndex] 78 | splitIndices[i] = maxIndex 79 | } 80 | maxGain = max(splitGains) 81 | 82 | # equalMaxIndices correspond to indices of splitGains 83 | equalMaxIndices = which(abs(splitGains - maxGain) < 10 * eps) 84 | splitDir = equalMaxIndices[1] 85 | splitIndex = splitIndices[splitDir] 86 | 87 | X = X[, splitDir] 88 | X_sorted = sort(X) 89 | X_sortedLeftPartition = X_sorted[splitIndex] 90 | X_sorted = X_sorted - X_sortedLeftPartition 91 | partitionPoint = X_sorted[splitIndex] * 0.5 + X_sorted[splitIndex + 1] * 0.5 92 | partitionPoint = partitionPoint + X_sortedLeftPartition 93 | lessThanPartPoint = X <= partitionPoint 94 | return(list( 95 | partitionPoint = partitionPoint, 96 | splitDir = splitDir, 97 | gain = maxGain, 98 | lessThanPartPoint = lessThanPartPoint 99 | )) 100 | } 101 | 102 | #' Computes a canonical correlation tree 103 | #' 104 | #' This function computes a single canonical correlation tree given its input values. 105 | #' @param X Predictor matrix of size \eqn{n \times p} with \eqn{n} observations and \eqn{p} 106 | #' variables. 107 | #' @param Y Predicted values as a matrix of size \eqn{n \times p} with \eqn{n} observations 108 | #' and \eqn{p} variables. 109 | #' @param depth Depth of subtree. 110 | #' @param minPointsForSplit Optional parameter setting the threshold when to construct a 111 | #' leaf (default: 2). If the number of data points is smaller than this value, a leaf is 112 | #' constructed. 113 | #' @param maxDepthSplit Optional parameter controlling the construction of leaves after a 114 | #' certain depth (default: \code{Inf}). If the current depth is greater than this value, 115 | #' a leaf is constructed. 116 | #' @param xVariationTolerance Features with variance less than this value are not considered 117 | #' for splitting at tree nodes. (default \code{1e-10}) 118 | #' @param projectionBootstrap Use projection bootstrapping. (default \code{FALSE}) 119 | #' @param ancestralProbs Probabilities of ancestors. Default is \code{NULL} as these are 120 | #' then calculated automatically. 121 | #' @return Function returns an object of class \code{canonical_correlation_tree}, 122 | #' where the object is a list containing at the following components: 123 | #' \itemize{ 124 | #' \item{isLeaf}{Boolean whether the tree is a leaf itself.} 125 | #' \item{trainingCounts}{Number of training examples for constructing this tree (i.e. 126 | #' number of rows in input argument \code{X}).} 127 | #' \item{indicesFeatures}{Feature indices which the node received, as needed for 128 | #' prediction.} 129 | #' \item{decisionProjection}{Numeric matrix containing the projection matrix that was 130 | #' used to find the best split point.} 131 | #' \item{refLeftChild}{Reference to the left subtree.} 132 | #' \item{refRightChild}{Reference to the right subtree.} 133 | #' } 134 | #' @export 135 | canonical_correlation_tree = function( 136 | X, Y, 137 | depth = 0, 138 | minPointsForSplit = 2, 139 | maxDepthSplit = Inf, 140 | xVariationTolerance = 1e-10, 141 | projectionBootstrap = FALSE, 142 | ancestralProbs = NULL) { 143 | if (is.data.frame(X)) { 144 | X <- as.matrix(X) 145 | } 146 | if (is.data.frame(Y)) { 147 | Y <- as.matrix(Y) 148 | } 149 | 150 | if (nrow(X) == 1 151 | || nrow(X) < minPointsForSplit 152 | || depth > maxDepthSplit) { 153 | # Return if one training point or max tree size options fulfilled 154 | return(setupLeaf(Y)) 155 | } else if (is.matrix(Y) && ncol(Y) > 1) { 156 | # Return if pure node 157 | # TODO Zahl aus dem Hut? 158 | # Check if only one class is represented. 159 | if (sum(abs(colSums(Y)) > 1e-12) == 1) { 160 | return(setupLeaf(Y)) 161 | } 162 | } else if (all(Y == 0)) { 163 | # only one column in Y and all zeros 164 | # ie binary classification 165 | return(setupLeaf(Y)) 166 | } 167 | 168 | # TODO feature selection 169 | # TODO other stop conditions 170 | if (nrow(X) == 2) { 171 | if (all(Y[1, ] == Y[2, ])) { 172 | # same class so setupLeaf 173 | return(setupLeaf(Y)) 174 | } 175 | # split in the centor of vector between the two points 176 | projection_matrix = t(X[2, , drop = F] - X[1, , drop = F]) #nolint 177 | partitionPoint = 0.5 * (X[2, ] %*% projection_matrix + 178 | X[1, ] %*% projection_matrix) 179 | lessThanPartPoint = (X %*% projection_matrix) <= partitionPoint[1] 180 | best_split = list(partitionPoint = partitionPoint[1], 181 | splitDir = 1, 182 | lessThanPartPoint = lessThanPartPoint) 183 | } else { 184 | if (projectionBootstrap) { 185 | sampleIndices = sample(nrow(X), size = nrow(X), replace = T) 186 | XSampled = X[sampleIndices, , drop = F] #nolint 187 | YSampled = Y[sampleIndices, , drop = F] #nolint 188 | # Check if only one class is represented. 189 | if (sum( abs(colSums(YSampled)) > 1e-12 ) == 1) { 190 | return(setupLeaf(Y)) 191 | } 192 | cca = canonical_correlation_analysis(XSampled, YSampled) 193 | } else { 194 | cca = canonical_correlation_analysis(X, Y) 195 | } 196 | 197 | projection_matrix = cca$xcoef 198 | 199 | # U are the feature vectors in the projected space 200 | 201 | U = X %*% projection_matrix 202 | 203 | bUVaries = apply(U, 2, stats::var) > xVariationTolerance 204 | if (!any(bUVaries)) { 205 | return(setupLeaf(Y)) 206 | } 207 | 208 | best_split = find_best_split( 209 | U[, bUVaries, drop = F], Y, xVariationTolerance 210 | ) 211 | if (best_split$gain < 0) { 212 | tree = setupLeaf(Y) 213 | return(tree) 214 | } 215 | } 216 | 217 | # each partition can have multiple classes 218 | countsNode = colSums(Y) 219 | nonZeroCounts = sum(countsNode > 0) 220 | uniqueNonZeroCounts = length(unique(countsNode)[unique(countsNode) != 0]) 221 | if (uniqueNonZeroCounts == nonZeroCounts || is.null(ancestralProbs)) { 222 | ancestralProbs = countsNode / sum(countsNode) 223 | } else { 224 | ancestralProbs = rbind(ancestralProbs, countsNode / sum(countsNode)) 225 | } 226 | treeLeft = canonical_correlation_tree( 227 | X[best_split$lessThanPartPoint, , drop = F], #nolint 228 | Y[best_split$lessThanPartPoint, , drop = F], #nolint 229 | depth = depth + 1, 230 | minPointsForSplit = minPointsForSplit, 231 | maxDepthSplit = maxDepthSplit, 232 | ancestralProbs = ancestralProbs 233 | ) 234 | treeRight = canonical_correlation_tree( 235 | X[!best_split$lessThanPartPoint, , drop = F], #nolint 236 | Y[!best_split$lessThanPartPoint, , drop = F], #nolint 237 | depth = depth + 1, 238 | minPointsForSplit = minPointsForSplit, 239 | maxDepthSplit = maxDepthSplit, 240 | ancestralProbs = ancestralProbs 241 | ) 242 | model = structure( 243 | list(isLeaf = F, 244 | trainingCounts = countsNode, 245 | #TODO features that the node got, needed for prediction; 246 | # for now all nodes get all features 247 | #indicesFeatures = indicesFeatures, 248 | decisionProjection = 249 | projection_matrix[, best_split$splitDir, drop = FALSE], 250 | partitionPoint = best_split$partitionPoint, 251 | depth = depth, 252 | refLeftChild = treeLeft, 253 | refRightChild = treeRight 254 | ) 255 | , class = "canonical_correlation_tree") 256 | return(model) 257 | } 258 | 259 | #' @export 260 | predict.canonical_correlation_tree = function(object, newData, ...){ 261 | tree = object 262 | if (tree$isLeaf) { 263 | return(tree$classIndex) 264 | } 265 | nr_of_features = length(tree$decisionProjection) 266 | # TODO use formula instead of all but last column 267 | X = as.matrix( 268 | newData[, 1:nr_of_features, drop = FALSE], ncol = nr_of_features) 269 | 270 | # TODO center_colmeans / input processing 271 | 272 | # transform training data corresponding to the node we are in 273 | U = X %*% tree$decisionProjection 274 | lessThanPartPoint = U <= tree$partitionPoint 275 | 276 | currentNodeClasses = matrix(nrow = max(nrow(X), 1)) 277 | if (any(lessThanPartPoint)) { 278 | currentNodeClasses[lessThanPartPoint, ] = 279 | predict.canonical_correlation_tree(tree$refLeftChild, 280 | X[lessThanPartPoint, ,drop = FALSE]) #nolint 281 | } 282 | if (any(!lessThanPartPoint)) { 283 | currentNodeClasses[!lessThanPartPoint, ] = 284 | predict.canonical_correlation_tree(tree$refRightChild, 285 | X[!lessThanPartPoint, ,drop = FALSE]) #nolint 286 | } 287 | return(currentNodeClasses) 288 | } 289 | 290 | #' @export 291 | plot.canonical_correlation_tree = function(x, dataX, dataY, ...) { 292 | TODO("check if plotable", return = T) 293 | plot_decision_surface(x, dataX, dataY) 294 | } 295 | -------------------------------------------------------------------------------- /R/data.R: -------------------------------------------------------------------------------- 1 | #' Spiral dataset 2 | #' 3 | #' A dataset containing 3 interwinding spirals. 4 | #' @usage 5 | #' data(spirals) 6 | #' @docType data 7 | #' @keywords datasets 8 | #' @name spirals 9 | #' @format A data frame with 10000 rows and 3 variables: 10 | #' \describe{ 11 | #' \item{x}{numeric scalar: x-coordinate} 12 | #' \item{y}{numeric scalar: y-coordinate} 13 | #' \item{class}{integer: either 1,2 or 3} 14 | #' } 15 | #' @source Created by T. Rainforth, URL: 16 | #' \url{https://bitbucket.org/twgr/ccf/raw/49d5fce6fc006bc9a8949c7149fc9524535ce418/Datasets/spirals.csv} 17 | "spirals" 18 | -------------------------------------------------------------------------------- /R/utilities.R: -------------------------------------------------------------------------------- 1 | random_element <- function(x) { 2 | if (length(x) > 1) { 3 | return(sample(x, 1)) 4 | } 5 | return(x) 6 | } 7 | 8 | eps <- 1 - 3 * ( (4 / 3) - 1) 9 | 10 | #' @importFrom stats model.matrix 11 | one_hot_encode <- function(data) { 12 | if (!is.data.frame(data)) { 13 | data <- data.frame(data) 14 | colnames(data) <- "class" 15 | } 16 | 17 | data$class <- as.factor(data$class) 18 | 19 | # This trick only works with factors 20 | return(model.matrix(~ class - 1, data = data)) 21 | } 22 | 23 | one_hot_decode <- function(X_one_hot) { 24 | return(apply(X_one_hot, 1, function(row){ 25 | which(row == 1) 26 | })) 27 | } 28 | 29 | TODO <- function(message = "TODO", return = NULL) { 30 | print(paste("TODO", message)) 31 | 32 | if (is.null(return)) { 33 | stop() 34 | } 35 | 36 | return(return) 37 | } 38 | generate_2d_data_plot <- function(data = NULL, 39 | data_raster = NULL, 40 | interpolate = F, 41 | title = "") { 42 | if (!requireNamespace("ggplot2", quietly = TRUE)) { 43 | stop("ggplot2 needed for plotting to work. Please install it.", 44 | call. = FALSE) 45 | } 46 | blank <- ggplot2::element_blank() 47 | g <- ggplot2::ggplot() + 48 | ggplot2::theme_bw(base_size = 15) + 49 | ggplot2::coord_fixed(ratio = 0.8) + 50 | ggplot2::theme(axis.ticks = blank, 51 | panel.grid.major = blank, 52 | panel.grid.minor = blank, 53 | axis.text = blank, 54 | axis.title = blank, 55 | legend.position = "none") + 56 | ggplot2::geom_raster( 57 | ggplot2::aes(x = data_raster$x, 58 | y = data_raster$y, 59 | fill = as.character(data_raster$z)), 60 | interpolate = interpolate, 61 | alpha = 0.5, 62 | show.legend = F) + 63 | ggplot2::geom_point( 64 | ggplot2::aes(x = data$x, 65 | y = data$y, 66 | color = as.character(data$z)), 67 | size = 2) + 68 | ggplot2::ggtitle(title) 69 | 70 | return(g) 71 | } 72 | #' Helper function to plot classifier decision surface. 73 | #' 74 | #' This function generates a plot (ggplot2) of the decision surface for a 2d classifier. 75 | #' @param model a model object for which prediction is desired. E.g. object of class 76 | #' \code{canonical_correlation_forrest}, \code{canonical_correlation_tree}, 77 | #' \code{tree} or \code{randomForest}. 78 | #' @param X Numeric matrix (n * 2) with n observations of 2 variables 79 | #' @param Y Numeric matrix with n observations of 1 variable 80 | #' @param title Title text for the plot. 81 | #' @param interpolate If TRUE interpolate linearly, if FALSE (the default) 82 | #' don't interpolate. 83 | #' @param ... Further arguments passed to model.predict() 84 | #' @importFrom stats predict 85 | #' @export 86 | plot_decision_surface <- function(model, X, Y, title = NULL, 87 | interpolate = FALSE, ...) { 88 | data <- data.frame(x = X[, 1], y = X[, 2], z = Y) 89 | 90 | # TODO-SF: better use expand? 91 | x_min <- min(data$x) * 1.2 92 | x_max <- max(data$x) * 1.2 93 | y_min <- min(data$y) * 1.2 94 | y_max <- max(data$y) * 1.2 95 | resolution <- 400 96 | grid <- expand.grid(x = seq(x_min, x_max, length.out = resolution), 97 | y = seq(y_min, y_max, length.out = resolution)) 98 | predictions <- predict(model, grid, ...) 99 | 100 | data_raster <- data.frame(x = grid$x, y = grid$y, z = predictions) 101 | plot_object <- generate_2d_data_plot(data, 102 | data_raster, 103 | interpolate = interpolate, 104 | title = title) 105 | return(plot_object) 106 | } 107 | #' Helper function to print prediction accuracy for a model. 108 | #' @param model a model object for which prediction is desired. E.g. object of class 109 | #' \code{canonical_correlation_forrest}, \code{canonical_correlation_tree}, 110 | #' \code{tree} or \code{randomForest}. 111 | #' @param data_test A data frame or a matrix containing the test data. 112 | #' @param ... Further arguments passed to model.predict() 113 | #' @importFrom stats predict 114 | #' @export 115 | get_missclassification_rate <- function(model, data_test, ...) { 116 | predictions <- as.matrix(stats::predict(model, data_test, ...)) 117 | # TODO use formula instead of last column 118 | actual <- data_test[, ncol(data_test)] 119 | return(mean(actual != predictions)) 120 | } 121 | 122 | #' @importFrom stats predict 123 | load_csv_data <- function(data_set_path) { 124 | data <- as.matrix(utils::read.csv(data_set_path, header = FALSE, 125 | sep = ",", quote = "\"", 126 | dec = ".", fill = TRUE, comment.char = "")) 127 | nr_of_features <- ncol(data) - 1 128 | 129 | return(list(X = data[, 1:nr_of_features], 130 | Y = data[, ncol(data), drop = FALSE])) 131 | } 132 | -------------------------------------------------------------------------------- /README.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | output: github_document 3 | --- 4 | 5 | 6 | 7 | ```{r, echo = FALSE} 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "README-" 12 | ) 13 | ``` 14 | [![Build Status](https://travis-ci.org/jandob/ccf.svg?branch=master)](https://travis-ci.org/jandob/ccf) 15 | [![CRAN\_Status\_Badge](http://www.r-pkg.org/badges/version/ccf)](http://cran.r-project.org/package=ccf) 16 | [![Coverage Status](https://img.shields.io/codecov/c/github/jandob/ccf/master.svg)](https://codecov.io/github/jandob/ccf?branch=master) 17 | 18 | The package `ccf` implements canonical correlations forests (CCFs) for use inside R. These present a novel classification algorithm for machine learning tasks, that are often able outperform common methods for predictive classifiers. The CCF algorithm is based on an ensemble of decision trees together with a canonical correlation analysis. The purpose is to de-correlated individual trees and thus improve the predictive performance. 19 | 20 | ## Brief summary of CCF algorithm 21 | 22 | A decision tree is a predictive model that sequentially divides the input space, each for which a local classification and regression model is calculated (e.g. with a simple majority vote). Thereby, it generates a tree-like structure, whose leaves usually group data points belonging ideally to the same class. One can often achieve a better performance by combining individual trees and average over them. This is known as a decision forest or random forest. 23 | 24 | A canonical correlation forest is a now tree ensemble method. While the concept is similar to a forest, its specific characteristics often achieve a favorable predictive performance. It trains the trees by using a canonical correlation analysis (CCA) in order to find a feature projection that gives a maximal correlation between features. It then chooses the best split in this projected space. 25 | 26 | For a thorough explanation and derivation refer to: 27 | 28 | * Rainforth, T., and Wood, F. (2015): [Canonical correlation forest](https://arxiv.org/pdf/1507.05444.pdf), arXiv preprint, arXiv:1507.05444. 29 | 30 | ## Overview 31 | 32 | The most important functions in `ccf` are: 33 | 34 | * `canonical_correlation_forest()` compute classifier based on canonical correlation forests. It supports both a matrix-like input, as well as the common convention using a `formula`. 35 | 36 | * `predict()` applies classifier to unseen data and predicts the class outcome. 37 | 38 | * `plot` visualizates the underlying decision surface. 39 | 40 | To see examples of these functions in use, check out the help pages, the demos and this README (which is identical to the vignette). 41 | 42 | ## Installation 43 | 44 | Using the **devtools** package, you can easily install the latest development version of `ccf` with 45 | 46 | ```{r,eval=FALSE} 47 | install.packages("devtools") 48 | 49 | # Option 1: download and install latest version from ‘GitHub’ 50 | devtools::install_github("jandob/ccf") 51 | 52 | # Option 2: install directly from bundled archive 53 | # devtoos::install_local("ccf_0.1.0.tar.gz") 54 | ``` 55 | 56 | Notes: 57 | 58 | * In the case of option 2, you have to specify the path either to the directory of `ccf` or to the bundled archive **ccf_1.0.0.tar.gz** 59 | 60 | * A CRAN version has not yet been released, but we are working on it. This also applies to the integration into predictive frameworks such as `caret` or `mlr`. 61 | 62 | ## Usage 63 | 64 | This section shows the basic functionality of how to train a canonical correlation forests and make predictions based on it. First, load the corresponding package `ccf`. 65 | 66 | ```{r, message=FALSE} 67 | library(ccf) 68 | ``` 69 | 70 | The interface follows common R conventions as used by other machine learning routines. Therefore, the usage is fairly straightforward. 71 | 72 | ```{r} 73 | # load sample dataset 74 | data(spirals) 75 | 76 | d_train <- spirals[1:1000, ] 77 | d_test <- tail(spirals, 1000) 78 | 79 | # compute classifier on training data 80 | ## variant 1: matrix input 81 | m1 <- canonical_correlation_forest(d_train[, c("x", "y")], d_train$class, ntree = 10) 82 | ## variant 2: formula notation 83 | m2 <- canonical_correlation_forest(class ~ ., d_train, ntree = 10) 84 | 85 | # compute predictive accuracy 86 | get_missclassification_rate(m1, d_test) 87 | get_missclassification_rate(m2, d_test) 88 | 89 | # plot the decision surface of the classifier 90 | ccf_plot <- plot_decision_surface( 91 | m1, d_test[, c("x", "y")], d_test$class, title = "CCF with 10 trees") 92 | ``` 93 | 94 | ## License 95 | 96 | `ccf` is released under the [MIT License](https://opensource.org/licenses/MIT) 97 | 98 | Copyright (c) 2016 Janosch Dobler & Stefan Feuerriegel 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | [![Build Status](https://travis-ci.org/jandob/ccf.svg?branch=master)](https://travis-ci.org/jandob/ccf) [![CRAN\_Status\_Badge](http://www.r-pkg.org/badges/version/ccf)](http://cran.r-project.org/package=ccf) [![Coverage Status](https://img.shields.io/codecov/c/github/jandob/ccf/master.svg)](https://codecov.io/github/jandob/ccf?branch=master) 4 | 5 | The package `ccf` implements canonical correlations forests (CCFs) for use inside R. These present a novel classification algorithm for machine learning tasks, that are often able outperform common methods for predictive classifiers. The CCF algorithm is based on an ensemble of decision trees together with a canonical correlation analysis. The purpose is to de-correlated individual trees and thus improve the predictive performance. 6 | 7 | Brief summary of CCF algorithm 8 | ------------------------------ 9 | 10 | A decision tree is a predictive model that sequentially divides the input space, each for which a local classification and regression model is calculated (e.g. with a simple majority vote). Thereby, it generates a tree-like structure, whose leaves usually group data points belonging ideally to the same class. One can often achieve a better performance by combining individual trees and average over them. This is known as a decision forest or random forest. 11 | 12 | A canonical correlation forest is a now tree ensemble method. While the concept is similar to a forest, its specific characteristics often achieve a favorable predictive performance. It trains the trees by using a canonical correlation analysis (CCA) in order to find a feature projection that gives a maximal correlation between features. It then chooses the best split in this projected space. 13 | 14 | For a thorough explanation and derivation refer to: 15 | 16 | - Rainforth, T., and Wood, F. (2015): [Canonical correlation forest](https://arxiv.org/pdf/1507.05444.pdf), arXiv preprint, arXiv:1507.05444. 17 | 18 | Overview 19 | -------- 20 | 21 | The most important functions in `ccf` are: 22 | 23 | - `canonical_correlation_forest()` compute classifier based on canonical correlation forests. It supports both a matrix-like input, as well as the common convention using a `formula`. 24 | 25 | - `predict()` applies classifier to unseen data and predicts the class outcome. 26 | 27 | - `plot` visualizates the underlying decision surface. 28 | 29 | To see examples of these functions in use, check out the help pages, the demos and this README (which is identical to the vignette). 30 | 31 | Installation 32 | ------------ 33 | 34 | Using the **devtools** package, you can easily install the latest development version of `ccf` with 35 | 36 | ``` r 37 | install.packages("devtools") 38 | 39 | # Option 1: download and install latest version from ‘GitHub’ 40 | devtools::install_github("jandob/ccf") 41 | 42 | # Option 2: install directly from bundled archive 43 | # devtoos::install_local("ccf_0.1.0.tar.gz") 44 | ``` 45 | 46 | Notes: 47 | 48 | - In the case of option 2, you have to specify the path either to the directory of `ccf` or to the bundled archive **ccf\_1.0.0.tar.gz** 49 | 50 | - A CRAN version has not yet been released, but we are working on it. This also applies to the integration into predictive frameworks such as `caret` or `mlr`. 51 | 52 | Usage 53 | ----- 54 | 55 | This section shows the basic functionality of how to train a canonical correlation forests and make predictions based on it. First, load the corresponding package `ccf`. 56 | 57 | ``` r 58 | library(ccf) 59 | ``` 60 | 61 | The interface follows common R conventions as used by other machine learning routines. Therefore, the usage is fairly straightforward. 62 | 63 | ``` r 64 | # load sample dataset 65 | data(spirals) 66 | 67 | d_train <- spirals[1:1000, ] 68 | d_test <- tail(spirals, 1000) 69 | 70 | # compute classifier on training data 71 | ## variant 1: matrix input 72 | m1 <- canonical_correlation_forest(d_train[, c("x", "y")], d_train$class, ntree = 10) 73 | ## variant 2: formula notation 74 | #m2 <- canonical_correlation_forest(class ~ ., d_train) 75 | 76 | # compute predictive accuracy 77 | #get_missclassification_rate(m1, d_test) 78 | #get_missclassification_rate(m2, d_test) 79 | 80 | # plot the decision surface of the classifier 81 | ccf_plot <- plot_decision_surface( 82 | m1, d_test[, c("x", "y")], d_test$class, title = "CCF with 20 trees") 83 | ``` 84 | 85 | License 86 | ------- 87 | 88 | `ccf` is released under the [MIT License](https://opensource.org/licenses/MIT) 89 | 90 | Copyright (c) 2016 Janosch Dobler & Stefan Feuerriegel 91 | -------------------------------------------------------------------------------- /ccf.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: Default 4 | SaveWorkspace: Default 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | 18 | BuildType: Package 19 | PackageUseDevtools: Yes 20 | PackageInstallArgs: --no-multiarch --with-keep.source 21 | -------------------------------------------------------------------------------- /data/spirals.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jandob/ccf/5267a21aafa42915053cc274b85cbf53c03b53d9/data/spirals.rda -------------------------------------------------------------------------------- /demo/00Index: -------------------------------------------------------------------------------- 1 | cca_demo Examples for using the CCA method standalone. 2 | ccf_demo Examples for the ccf. 3 | dataset_demos Examples using datasets. 4 | profiling Shows how to profile the package. 5 | -------------------------------------------------------------------------------- /demo/cca_demo.R: -------------------------------------------------------------------------------- 1 | # CCA demo 2 | library(ccf) 3 | library(ggplot2) 4 | library(pracma) 5 | library(MASS) 6 | library(grid) 7 | 8 | gen_plot <- function(X, Y) { 9 | cca <- canonical_correlation_analysis(X, one_hot_encode(Y)) 10 | data_plot <- generate_2d_data_plot(data.frame(x = X[,1], y = X[,2], z = Y)) + 11 | geom_segment(aes(x = 0, y = 0, xend = x, yend = y), 12 | data = data.frame(x = cca$xcoef[1,], y = cca$xcoef[2,]), 13 | arrow = arrow(length=unit(0.3,"cm"))) 14 | return(data_plot) 15 | } 16 | 17 | nr_of_points <- 400 18 | X <- mvrnorm(nr_of_points, c(0, 0), eye(2)) 19 | Y <- rbind(ones(nr_of_points, m = 1)) 20 | Y[X[,1] < X[,2]] = 2 21 | plot_1 <- gen_plot(X, Y) 22 | 23 | Y[X[,1] < X[,2] & -X[,1] > X[,2]] = 3 24 | plot_2 = gen_plot(X, Y) 25 | 26 | Y[X[,1] > X[,2] & -X[,1] > X[,2]] = 4 27 | plot_3 = gen_plot(X, Y) 28 | 29 | #print(plot_1) 30 | #print(plot_2) 31 | #print(plot_3) 32 | 33 | grid.newpage() 34 | pushViewport(viewport(layout = grid.layout(1, 3))) 35 | 36 | print(plot_1, vp = viewport(layout.pos.row = 1, layout.pos.col = 1)) 37 | print(plot_2, vp = viewport(layout.pos.row = 1, layout.pos.col = 2)) 38 | print(plot_3, vp = viewport(layout.pos.row = 1, layout.pos.col = 3)) 39 | -------------------------------------------------------------------------------- /demo/ccf_demo.R: -------------------------------------------------------------------------------- 1 | # CCF demo 2 | data(spirals) 3 | d <- spirals 4 | colnames(d) <- c("x", "y", "z") 5 | d$z <- as.factor(d$z) 6 | 7 | d_train <- d[1:1000,] 8 | d_test <- d[1001:10000,] 9 | 10 | # sample 1000 d points 11 | #d <- d[sample(nrow(d), 1000), ] 12 | 13 | generate_2d_data_plot(d_train) 14 | 15 | 16 | # convert to matrices 17 | X <- cbind(d_train$x,d_train$y) 18 | Y <- d_train$z 19 | 20 | # classify with a standard binary decision tree 21 | library("tree") 22 | model <- tree(as.factor(z)~., d_train) 23 | error_tree <- get_missclassification_rate(model, 24 | d_test, 25 | type = 'class') 26 | print(paste("tree missclassification rate:", error_tree)) 27 | plot_tree <- plot_decision_surface(model, X, Y, 28 | type = 'class', 29 | title = "Single CART") 30 | 31 | # classify with random forest 32 | library("randomForest") 33 | model <- randomForest(as.factor(z)~., d_train, ntree = 200) 34 | error_rf <- get_missclassification_rate(model, 35 | d_test, 36 | type = 'class') 37 | print(paste("rf missclassification rate:", error_rf)) 38 | plot_rf <- plot_decision_surface(model, X, Y, 39 | type = 'class', 40 | title = "RF with 200 Trees") 41 | 42 | # classify with oblique tree 43 | #library("oblique.tree") 44 | #model <- oblique.tree(z~., d = d_train) 45 | 46 | #predict(model, d.frame(x=X[,1], y=X[,2], z=0), type = "class") 47 | #plot_tree <- plot_decision_surface(model, X, Y, 48 | # type = 'class', 49 | # title = "Oblique Tree") 50 | 51 | # classify with rotation forest 52 | #library("rotationForest") 53 | #model <- rotationForest(X, one_hot_encode(as.factor(Y))) 54 | #d(iris) 55 | #y <- as.factor(one_hot_encode(iris$species[1:100])) 56 | #x <- iris[1:100,-5] 57 | #rF <- rotationForest(x,y) 58 | #predict(object=rF,newd=x) 59 | 60 | # classify with single CCT 61 | model <- canonical_correlation_tree(X, one_hot_encode(Y)) 62 | error_cct <- get_missclassification_rate(model, d_test) 63 | print(paste("cct missclassification rate:", error_cct)) 64 | plot_cct <- plot_decision_surface(model, X, Y, title = "Single CCT") 65 | 66 | # classify with CCF 67 | model <- canonical_correlation_forest(X, one_hot_encode(Y), 68 | ntree = 200, verbose = TRUE) 69 | #canonical_correlation_forest(as.factor(z)~., d_train) 70 | error_ccf <- get_missclassification_rate(model, d_test) 71 | print(paste("ccf missclassification rate:", error_ccf)) 72 | plot_ccf <- plot_decision_surface(model, X, Y, title = "CCF with 200 Trees") 73 | 74 | library(grid) 75 | grid.newpage() 76 | pushViewport(viewport(layout = grid.layout(2, 2))) 77 | 78 | print(plot_tree, vp = viewport(layout.pos.row = 1, layout.pos.col = 1)) 79 | print(plot_rf, vp = viewport(layout.pos.row = 1, layout.pos.col = 2)) 80 | print(plot_cct, vp = viewport(layout.pos.row = 2, layout.pos.col = 1)) 81 | print(plot_ccf, vp = viewport(layout.pos.row = 2, layout.pos.col = 2)) 82 | 83 | 84 | -------------------------------------------------------------------------------- /demo/dataset_demos.R: -------------------------------------------------------------------------------- 1 | 2 | k_fold_cross_validation = function(data, k = 10) { 3 | library(caret) 4 | folds = caret::createFolds(data$X[,1], k = k, list = T, returnTrain = F) 5 | errors = list() 6 | for (foldIndex in seq(folds)) { 7 | foldTest = unlist(folds[foldIndex]) 8 | XTest = data$X[foldTest, , drop = F] 9 | YTest = data$Y[foldTest, , drop = F] 10 | 11 | foldTrainIndices = seq(folds)[-foldIndex] 12 | foldTrain = do.call(c, folds[foldTrainIndices]) 13 | XTrain = data$X[foldTrain, , drop = F] 14 | YTrain = data$Y[foldTrain, , drop = F] 15 | 16 | model = canonical_correlation_forest(XTrain,one_hot_encode(YTrain)) 17 | error = get_missclassification_rate(model, cbind(XTest, YTest)) 18 | print(paste("fold", foldIndex,": error:", error)) 19 | errors[foldIndex] = error 20 | } 21 | errors = unlist(errors) 22 | print(paste("mean:", mean(errors), "standard deviation:", sd(errors))) 23 | } 24 | 25 | # compatible csv datasets are available at 26 | # https://bitbucket.org/twgr/ccf/src/49d5fce6fc006bc9a8949c7149fc9524535ce418/Datasets/?at=master 27 | files = list.files(file.path(getwd(), "data"), pattern = "\\.csv$") 28 | for (file_name in files) { 29 | if (file_name == "skinSeg.csv") { 30 | # currently not working (recursion to deep) 31 | next 32 | } 33 | if (file_name %in% list("letter.csv", "")) { 34 | # takes long 35 | next 36 | } 37 | file_path = file.path(getwd(), "data", file_name) 38 | print(paste("dataset: ", file_path)) 39 | data = load_csv_data_set(file_path) 40 | k_fold_cross_validation(data, k = 10) 41 | } 42 | 43 | 44 | 45 | 46 | 47 | #cross val 48 | -------------------------------------------------------------------------------- /demo/profiling.R: -------------------------------------------------------------------------------- 1 | #profiling 2 | data = ccf::spirals 3 | colnames(data) = c("x", "y", "z") 4 | # sample 1000 data points 5 | data = data[sample(nrow(data), 1000), ] 6 | generate_2d_data_plot(data) 7 | 8 | # convert to matrices 9 | X = cbind(data$x,data$y) 10 | Y = data$z 11 | 12 | library(profvis) 13 | 14 | profvis::profvis({ 15 | model = canonical_correlation_tree(X, one_hot_encode(Y)) 16 | }) 17 | 18 | # benchmark for the right_cum_counts in find_best_split() 19 | v = seq(1,1000) 20 | m = cbind(v,v,v) 21 | total = c(1,2,3) 22 | microbenchmark::microbenchmark( 23 | sweep(m, MARGIN = 2, total, FUN = '-')*-1, 24 | t(apply(m, 1, function(x) {total - x})), 25 | apply((apply(apply(m, 2, rev), 2, cumsum)),2,rev), 26 | (apply(m[nrow(m):1,], 2, cumsum))[nrow(m):1,] 27 | ) 28 | -------------------------------------------------------------------------------- /man/canonical_correlation_analysis.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cca.R 3 | \name{canonical_correlation_analysis} 4 | \alias{canonical_correlation_analysis} 5 | \title{Canonical correlation analysis} 6 | \usage{ 7 | canonical_correlation_analysis(x, y, epsilon = 1e-04) 8 | } 9 | \arguments{ 10 | \item{x}{Matrix of size n-by-p with n observations from p variables. Alternatively, data 11 | frames and numeric vectors are supported and automatically converted.} 12 | 13 | \item{y}{Matrix of size n-by-p with n observations from p variables. Alternatively, data 14 | frames and numeric vectors are supported and automatically converted.} 15 | 16 | \item{epsilon}{Numeric value usued as tolerance threshold for rank reduction of the 17 | input matrices. Default is \code{1e-4}.} 18 | } 19 | \value{ 20 | A list containing the following components 21 | \itemize{ 22 | \item{xcoef}{Estimated estimated coefficients for the \code{x} variable.} 23 | \item{ycoef}{Estimated estimated coefficients for the \code{y} variable.} 24 | \item{cor}{Matrix with correlation coefficients.} 25 | } 26 | } 27 | \description{ 28 | Canonical correlation analysis (CCA) finds pairs of vectors \eqn{(w,v)} such that projections 29 | \eqn{Xw} and \eqn{Yv} have maximal possible correlations. The pairs are ordered in decreasing 30 | order of the correlations. In addition, projection vectors are normalized such that the 31 | variance of \eqn{Xw} and of \eqn{Yv} is equal to \eqn{1}. This means that projections are 32 | not only correlated, but "on the same scale" and hence can be directly compared. 33 | } 34 | \examples{ 35 | library(MASS) 36 | library(pracma) 37 | 38 | X <- mvrnorm(1000, mu = c(0, 0), Sigma = eye(2)) 39 | cca <- canonical_correlation_analysis(X, X) 40 | cca 41 | 42 | X <- mvrnorm(1000, mu = c(1, 2), 43 | Sigma = matrix(c(1.5, 0.5, 0.5, 1.5), ncol = 2)) 44 | cca <- canonical_correlation_analysis(X, X) 45 | cca 46 | } 47 | -------------------------------------------------------------------------------- /man/canonical_correlation_tree.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cct.R 3 | \name{canonical_correlation_tree} 4 | \alias{canonical_correlation_tree} 5 | \title{Computes a canonical correlation tree} 6 | \usage{ 7 | canonical_correlation_tree(X, Y, depth = 0, minPointsForSplit = 2, 8 | maxDepthSplit = Inf, xVariationTolerance = 1e-10, 9 | projectionBootstrap = FALSE, ancestralProbs = NULL) 10 | } 11 | \arguments{ 12 | \item{X}{Predictor matrix of size \eqn{n \times p} with \eqn{n} observations and \eqn{p} 13 | variables.} 14 | 15 | \item{Y}{Predicted values as a matrix of size \eqn{n \times p} with \eqn{n} observations 16 | and \eqn{p} variables.} 17 | 18 | \item{depth}{Depth of subtree.} 19 | 20 | \item{minPointsForSplit}{Optional parameter setting the threshold when to construct a 21 | leaf (default: 2). If the number of data points is smaller than this value, a leaf is 22 | constructed.} 23 | 24 | \item{maxDepthSplit}{Optional parameter controlling the construction of leaves after a 25 | certain depth (default: \code{Inf}). If the current depth is greater than this value, 26 | a leaf is constructed.} 27 | 28 | \item{xVariationTolerance}{Features with variance less than this value are not considered 29 | for splitting at tree nodes. (default \code{1e-10})} 30 | 31 | \item{projectionBootstrap}{Use projection bootstrapping. (default \code{FALSE})} 32 | 33 | \item{ancestralProbs}{Probabilities of ancestors. Default is \code{NULL} as these are 34 | then calculated automatically.} 35 | } 36 | \value{ 37 | Function returns an object of class \code{canonical_correlation_tree}, 38 | where the object is a list containing at the following components: 39 | \itemize{ 40 | \item{isLeaf}{Boolean whether the tree is a leaf itself.} 41 | \item{trainingCounts}{Number of training examples for constructing this tree (i.e. 42 | number of rows in input argument \code{X}).} 43 | \item{indicesFeatures}{Feature indices which the node received, as needed for 44 | prediction.} 45 | \item{decisionProjection}{Numeric matrix containing the projection matrix that was 46 | used to find the best split point.} 47 | \item{refLeftChild}{Reference to the left subtree.} 48 | \item{refRightChild}{Reference to the right subtree.} 49 | } 50 | } 51 | \description{ 52 | This function computes a single canonical correlation tree given its input values. 53 | } 54 | -------------------------------------------------------------------------------- /man/ccf.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/ccf.R 3 | \name{canonical_correlation_forest} 4 | \alias{canonical_correlation_forest} 5 | \alias{canonical_correlation_forest.default} 6 | \alias{canonical_correlation_forest.formula} 7 | \title{Canonical correlation forest} 8 | \usage{ 9 | canonical_correlation_forest(x, y = NULL, ntree = 200, verbose = FALSE, 10 | ...) 11 | 12 | \method{canonical_correlation_forest}{default}(x, y = NULL, ntree = 200, 13 | verbose = FALSE, projectionBootstrap = FALSE, ...) 14 | 15 | \method{canonical_correlation_forest}{formula}(x, y = NULL, ntree = 200, 16 | verbose = FALSE, ...) 17 | } 18 | \arguments{ 19 | \item{x}{Numeric matrix (n * p) with n observations of p variables} 20 | 21 | \item{y}{Numeric matrix with n observations of q variables} 22 | 23 | \item{ntree}{Number of trees the forest will be composed of} 24 | 25 | \item{verbose}{Optional argument to control if additional information are 26 | printed to the output. Default is \code{FALSE}.} 27 | 28 | \item{...}{Further arguments passed to or from other methods.} 29 | 30 | \item{projectionBootstrap}{Use projection bootstrapping. (default \code{FALSE})} 31 | } 32 | \value{ 33 | returns an object of class "canonical_correlation_forest", 34 | where an object of this class is a list containing the following 35 | components: 36 | \itemize{ 37 | \item{x,y}{The original input data} 38 | \item{y_encoded}{The encoded \code{y} variable in case of classification tasks.} 39 | \item{forest}{a vector of length ntree with objects of class 40 | \code{canonical_correlation_tree}.} 41 | } 42 | } 43 | \description{ 44 | This function computes a classifier based on a canonical correlation forest. It 45 | expects its input in matrix form or as formula notation. 46 | } 47 | \examples{ 48 | data(spirals) 49 | 50 | d_train <- spirals[1:1000, ] 51 | d_test <- spirals[-(1:1000), ] 52 | 53 | # compute classifier on training data 54 | ## variant 1: matrix input 55 | m1 <- canonical_correlation_forest(d_train[, c("x", "y")], d_train$class, ntree = 20) 56 | ## variant 2: formula notation 57 | m2 <- canonical_correlation_forest(class ~ ., d_train) 58 | 59 | # compute predictive accuracy 60 | get_missclassification_rate(m1, d_test) 61 | get_missclassification_rate(m2, d_test) 62 | } 63 | \references{ 64 | Rainforth, T., and Wood, F. (2015): Canonical correlation forest, 65 | arXiv preprint, arXiv:1507.05444, \url{https://arxiv.org/pdf/1507.05444.pdf}. 66 | } 67 | -------------------------------------------------------------------------------- /man/get_missclassification_rate.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/utilities.R 3 | \name{get_missclassification_rate} 4 | \alias{get_missclassification_rate} 5 | \title{Helper function to print prediction accuracy for a model.} 6 | \usage{ 7 | get_missclassification_rate(model, data_test, ...) 8 | } 9 | \arguments{ 10 | \item{model}{a model object for which prediction is desired. E.g. object of class 11 | \code{canonical_correlation_forrest}, \code{canonical_correlation_tree}, 12 | \code{tree} or \code{randomForest}.} 13 | 14 | \item{data_test}{A data frame or a matrix containing the test data.} 15 | 16 | \item{...}{Further arguments passed to model.predict()} 17 | } 18 | \description{ 19 | Helper function to print prediction accuracy for a model. 20 | } 21 | -------------------------------------------------------------------------------- /man/plot.canonical_correlation_forest.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/ccf.R 3 | \name{plot.canonical_correlation_forest} 4 | \alias{plot.canonical_correlation_forest} 5 | \title{Visualization of canonical correlation forest} 6 | \usage{ 7 | \method{plot}{canonical_correlation_forest}(...) 8 | } 9 | \arguments{ 10 | \item{...}{Further arguments passed to or from other methods.} 11 | } 12 | \description{ 13 | TODO: document 14 | } 15 | -------------------------------------------------------------------------------- /man/plot_decision_surface.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/utilities.R 3 | \name{plot_decision_surface} 4 | \alias{plot_decision_surface} 5 | \title{Helper function to plot classifier decision surface.} 6 | \usage{ 7 | plot_decision_surface(model, X, Y, title = NULL, interpolate = FALSE, ...) 8 | } 9 | \arguments{ 10 | \item{model}{a model object for which prediction is desired. E.g. object of class 11 | \code{canonical_correlation_forrest}, \code{canonical_correlation_tree}, 12 | \code{tree} or \code{randomForest}.} 13 | 14 | \item{X}{Numeric matrix (n * 2) with n observations of 2 variables} 15 | 16 | \item{Y}{Numeric matrix with n observations of 1 variable} 17 | 18 | \item{title}{Title text for the plot.} 19 | 20 | \item{interpolate}{If TRUE interpolate linearly, if FALSE (the default) 21 | don't interpolate.} 22 | 23 | \item{...}{Further arguments passed to model.predict()} 24 | } 25 | \description{ 26 | This function generates a plot (ggplot2) of the decision surface for a 2d classifier. 27 | } 28 | -------------------------------------------------------------------------------- /man/predict.canonical_correlation_forest.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/ccf.R 3 | \name{predict.canonical_correlation_forest} 4 | \alias{predict.canonical_correlation_forest} 5 | \title{Prediction from canonical correlation forest} 6 | \usage{ 7 | \method{predict}{canonical_correlation_forest}(object, newdata, 8 | verbose = FALSE, ...) 9 | } 10 | \arguments{ 11 | \item{object}{An object of class \code{canonical_correlation_forest}, as created 12 | by the function \code{\link{canonical_correlation_forest}}.} 13 | 14 | \item{newdata}{A data frame or a matrix containing the test data.} 15 | 16 | \item{verbose}{Optional argument to control if additional information are 17 | printed to the output. Default is \code{FALSE}.} 18 | 19 | \item{...}{Additional parameters passed on to prediction from individual 20 | canonical correlation trees.} 21 | } 22 | \description{ 23 | Performs predictions on test data for a trained canonical correlation forest. 24 | } 25 | -------------------------------------------------------------------------------- /man/spirals.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{spirals} 5 | \alias{spirals} 6 | \title{Spiral dataset} 7 | \format{A data frame with 10000 rows and 3 variables: 8 | \describe{ 9 | \item{x}{numeric scalar: x-coordinate} 10 | \item{y}{numeric scalar: y-coordinate} 11 | \item{class}{integer: either 1,2 or 3} 12 | }} 13 | \source{ 14 | Created by T. Rainforth, URL: 15 | \url{https://bitbucket.org/twgr/ccf/raw/49d5fce6fc006bc9a8949c7149fc9524535ce418/Datasets/spirals.csv} 16 | } 17 | \usage{ 18 | data(spirals) 19 | } 20 | \description{ 21 | A dataset containing 3 interwinding spirals. 22 | } 23 | \keyword{datasets} 24 | -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(ccf) 3 | 4 | test_check("ccf") 5 | -------------------------------------------------------------------------------- /tests/testthat/test_cca.R: -------------------------------------------------------------------------------- 1 | context("Canonical Correlation Analysis") 2 | 3 | library(pracma) 4 | library(MASS) 5 | 6 | test_that("cca 2d", { 7 | X = mvrnorm(1000, c(0, 0), eye(2)) 8 | cca = canonical_correlation_analysis(X, X) 9 | expect_equal(cca$cor, eye(2)) 10 | }) 11 | 12 | test_that("cca 3d", { 13 | X = mvrnorm(1000, c(0, 0, 0), eye(3)) 14 | cca = canonical_correlation_analysis(X, X) 15 | expect_equal(cca$cor, eye(3)) 16 | }) 17 | 18 | test_that("cca 10d", { 19 | X = mvrnorm(1000, c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), eye(10)) 20 | cca = canonical_correlation_analysis(X, X) 21 | expect_equal(cca$cor, eye(10)) 22 | }) 23 | 24 | test_that("projection to 0", { 25 | X = rbind(c(0, 0, 0), 26 | c(1, 1, 1), 27 | c(2, 2, 2), 28 | c(0, 0, 1), 29 | c(1, 1, 2), 30 | c(2, 2, 3)) 31 | Y = rbind(matrix(rep(c(1, 0, 0, 0), each = 3), nrow = 3), 32 | matrix(rep(c(0, 0, 0, 1), each = 3), nrow = 3)) 33 | cca = canonical_correlation_analysis(X, Y) 34 | expect_equal(cca$cor, eye(1)) 35 | }) 36 | -------------------------------------------------------------------------------- /tests/testthat/test_ccf.R: -------------------------------------------------------------------------------- 1 | context("Canonical Correlation Analysis") 2 | library(pracma) 3 | 4 | test_that("ccf 2d", { 5 | diagonal = as.data.frame(rbind(c(0, 0), c(1, 1), c(2, 2))) 6 | colnames(diagonal) = c("x", "y") 7 | diagonal_upper = diagonal_lower = diagonal 8 | diagonal_upper$y = diagonal$y + 1 9 | diagonal_lower$x = diagonal$x + 1 10 | X = rbind(diagonal, diagonal_upper, diagonal_lower) 11 | Y = rbind(1 * ones(nrow(diagonal), 1), 12 | 2 * ones(nrow(diagonal), 1), 13 | 3 * ones(nrow(diagonal), 1)) 14 | X = as.matrix(X) 15 | set.seed(42) 16 | ccf = canonical_correlation_forest(X, one_hot_encode(Y), ntree = 10) 17 | error_ccf = get_missclassification_rate(ccf, cbind(X, Y)) 18 | expect_that(error_ccf, equals(0)) 19 | }) 20 | test_that("ccf spiral with projection bootstrap", { 21 | data(spirals) 22 | d <- spirals 23 | colnames(d) <- c("x", "y", "z") 24 | d$z <- as.factor(d$z) 25 | 26 | d_train <- d[1:100, ] 27 | d_test <- d[101:1000, ] 28 | 29 | # convert to matrices 30 | X <- cbind(d_train$x, d_train$y) 31 | Y <- d_train$z 32 | 33 | set.seed(42) 34 | 35 | ccf = canonical_correlation_forest( 36 | X, one_hot_encode(Y), projectionBootstrap = TRUE) 37 | error_ccf <- get_missclassification_rate(ccf, d_test) 38 | expect_true(error_ccf < 0.29) 39 | }) 40 | 41 | test_that("ccf formula", { 42 | diagonal = as.data.frame(rbind(c(0, 0), c(1, 1), c(2, 2))) 43 | colnames(diagonal) = c("x", "y") 44 | diagonal_upper = diagonal_lower = diagonal 45 | diagonal_upper$y = diagonal$y + 1 46 | diagonal_lower$x = diagonal$x + 1 47 | X = rbind(diagonal, diagonal_upper, diagonal_lower) 48 | Y = rbind(1 * ones(nrow(diagonal), 1), 49 | 2 * ones(nrow(diagonal), 1), 50 | 3 * ones(nrow(diagonal), 1)) 51 | set.seed(42) 52 | Y = as.factor(Y) 53 | d_train = data.frame(Y, X) 54 | ccf = canonical_correlation_forest(Y ~ ., d_train, ntree = 10) 55 | error_ccf = get_missclassification_rate(ccf, cbind(X, Y)) 56 | expect_that(error_ccf, equals(0)) 57 | }) 58 | -------------------------------------------------------------------------------- /tests/testthat/test_cct.R: -------------------------------------------------------------------------------- 1 | context("Canonical Correlation Analysis") 2 | 3 | library(pracma) 4 | 5 | test_that("cct 2d", { 6 | diagonal = as.data.frame(rbind(c(0, 0), c(1, 1), c(2, 2))) 7 | colnames(diagonal) = c("x", "y") 8 | diagonal_upper = diagonal_lower = diagonal 9 | diagonal_upper$y = diagonal$y + 1 10 | diagonal_lower$x = diagonal$x + 1 11 | X = rbind(diagonal, diagonal_upper, diagonal_lower) 12 | Y = rbind(1 * ones(nrow(diagonal), 1), 13 | 2 * ones(nrow(diagonal), 1), 14 | 3 * ones(nrow(diagonal), 1)) 15 | X = as.matrix(X) 16 | cct = canonical_correlation_tree(X, one_hot_encode(Y)) 17 | error_cct = get_missclassification_rate(cct, cbind(X, Y)) 18 | expect_that(error_cct, equals(0)) 19 | }) 20 | 21 | test_that("cct 3d", { 22 | diagonal = as.data.frame(rbind(c(0, 0, 0), c(1, 1, 1), c(2, 2, 2))) 23 | colnames(diagonal) = c("x", "y", "z") 24 | diagonal_upper = diagonal_lower = diagonal_front = diagonal 25 | diagonal_upper$y = diagonal$y + 1 26 | diagonal_lower$x = diagonal$x + 1 27 | diagonal_front$z = diagonal$z + 1 28 | X = rbind(diagonal, diagonal_upper, diagonal_lower, diagonal_front) 29 | Y = rbind(1 * ones(nrow(diagonal), 1), 30 | 2 * ones(nrow(diagonal), 1), 31 | 3 * ones(nrow(diagonal), 1), 32 | 4 * ones(nrow(diagonal), 1)) 33 | X = as.matrix(X) 34 | cct = canonical_correlation_tree(X, one_hot_encode(Y)) 35 | error_cct = get_missclassification_rate(cct, cbind(X, Y)) 36 | expect_that(error_cct, equals(0)) 37 | }) 38 | test_that("cct 3d with projection bootstrap", { 39 | diagonal = as.data.frame(rbind(c(0, 0, 0), c(1, 1, 1), c(2, 2, 2))) 40 | colnames(diagonal) = c("x", "y", "z") 41 | diagonal_upper = diagonal_lower = diagonal_front = diagonal 42 | diagonal_upper$y = diagonal$y + 1 43 | diagonal_lower$x = diagonal$x + 1 44 | diagonal_front$z = diagonal$z + 1 45 | X = rbind(diagonal, diagonal_upper, diagonal_lower, diagonal_front) 46 | Y = rbind(1 * ones(nrow(diagonal), 1), 47 | 2 * ones(nrow(diagonal), 1), 48 | 3 * ones(nrow(diagonal), 1), 49 | 4 * ones(nrow(diagonal), 1)) 50 | X = as.matrix(X) 51 | set.seed(42) 52 | cct = canonical_correlation_tree( 53 | X, one_hot_encode(Y), projectionBootstrap = TRUE) 54 | error_cct = get_missclassification_rate(cct, cbind(X, Y)) 55 | expect_that(error_cct, equals(0)) 56 | }) 57 | test_that("cct spiral with projection bootstrap", { 58 | data(spirals) 59 | d <- spirals 60 | colnames(d) <- c("x", "y", "z") 61 | d$z <- as.factor(d$z) 62 | 63 | d_train <- d[1:100, ] 64 | d_test <- d[101:1000, ] 65 | 66 | # convert to matrices 67 | X <- cbind(d_train$x, d_train$y) 68 | Y <- d_train$z 69 | 70 | set.seed(42) 71 | 72 | cct = canonical_correlation_tree( 73 | X, one_hot_encode(Y), projectionBootstrap = TRUE) 74 | error_cct <- get_missclassification_rate(cct, d_test) 75 | expect_true(error_cct < 0.38) 76 | }) 77 | --------------------------------------------------------------------------------