├── README.md ├── examples ├── 3_clusters_2d.png ├── 3_clusters_2d_ps.png ├── 4_clusters_3d.png └── 4_clusters_3d_ps.png ├── prediction-strength.R └── simulations.R /README.md: -------------------------------------------------------------------------------- 1 | # About 2 | 3 | An implementation of the prediction strength algorithm from Tibshirani, Walther, Botstein, and Brown's "Cluster validation by prediction strength". A description of the algorithm can be found [here](http://echen.me/posts/counting-clusters). 4 | 5 | # Examples 6 | 7 | # Three clusters in 2 dimensions 8 | x = c(rnorm(20, mean = 0), rnorm(20, mean = 3), rnorm(20, mean = 5)) 9 | y = c(rnorm(20, mean = 0), rnorm(20, mean = 5), rnorm(20, mean = 0)) 10 | data = cbind(x, y) 11 | 12 | png("examples/3_clusters_2d.png") 13 | qplot(x, y) 14 | dev.off() 15 | 16 | ![3 clusters in 2 dimensions](https://github.com/echen/prediction-strength/raw/master/examples/3_clusters_2d.png) 17 | 18 | png("examples/3_clusters_2d_ps.png") 19 | prediction_strength(data) 20 | dev.off() 21 | 22 | ![3 clusters in 2 dimensions](https://github.com/echen/prediction-strength/raw/master/examples/3_clusters_2d_ps.png) 23 | 24 | # Four clusters in 3 dimensions 25 | x = c(rnorm(20, mean = 0), rnorm(20, mean = 3), rnorm(20, mean = 5), rnorm(20, mean = -10)) 26 | y = rnorm(80, mean = 0) 27 | z = c(rnorm(40, mean = -5), rnorm(40, mean = 0)) 28 | data = cbind(x, y, z) 29 | 30 | png("examples/4_clusters_3d.png") 31 | scatterplot3d(x, y, z) 32 | dev.off() 33 | 34 | ![4 clusters in 3 dimensions](https://github.com/echen/prediction-strength/raw/master/examples/4_clusters_3d.png) 35 | 36 | png("examples/4_clusters_3d_ps.png") 37 | prediction_strength(data) 38 | dev.off() 39 | 40 | ![4 clusters in 3 dimensions](https://github.com/echen/prediction-strength/raw/master/examples/4_clusters_3d_ps.png) -------------------------------------------------------------------------------- /examples/3_clusters_2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/echen/prediction-strength/eec37dab2fb647c004c51bbc4d1ffb5d1a189e3f/examples/3_clusters_2d.png -------------------------------------------------------------------------------- /examples/3_clusters_2d_ps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/echen/prediction-strength/eec37dab2fb647c004c51bbc4d1ffb5d1a189e3f/examples/3_clusters_2d_ps.png -------------------------------------------------------------------------------- /examples/4_clusters_3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/echen/prediction-strength/eec37dab2fb647c004c51bbc4d1ffb5d1a189e3f/examples/4_clusters_3d.png -------------------------------------------------------------------------------- /examples/4_clusters_3d_ps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/echen/prediction-strength/eec37dab2fb647c004c51bbc4d1ffb5d1a189e3f/examples/4_clusters_3d_ps.png -------------------------------------------------------------------------------- /prediction-strength.R: -------------------------------------------------------------------------------- 1 | # An implementation of the prediction strength algorithm from Tibshirani, Walther, Botstein, and Brown's "Cluster validation by prediction strength". 2 | 3 | library(ggplot2) 4 | 5 | # Given a matrix `data`, where rows are observations and columns are individual dimensions, compute and plot the prediction strength. 6 | prediction_strength = function(data, min_num_clusters = 1, max_num_clusters = 10, num_trials = 5) { 7 | num_clusters = min_num_clusters:max_num_clusters 8 | prediction_strengths = c() 9 | for (i in 1:num_trials) { 10 | y = maply(num_clusters, function(n) calculate_prediction_strength(data, n)) 11 | prediction_strengths = cbind(prediction_strengths, y) 12 | } 13 | 14 | means = aaply(prediction_strengths, 1, mean) 15 | stddevs = aaply(prediction_strengths, 1, sd) 16 | 17 | print(plot_prediction_strength(means, stddevs, num_clusters)) 18 | 19 | # We use 0.8 as our prediction strength threshold. Find the largest number of clusters with a prediction strength greater than this threshold; this forms our estimate of the number of clusters. 20 | if (any(means > 0.8)) { 21 | estimated = max((1:length(means))[means > 0.8]) 22 | } else { 23 | estimated = which.max(means) 24 | } 25 | print(paste("The estimated number of clusters is ", estimated, ".", sep = "")) 26 | } 27 | 28 | plot_prediction_strength = function(means, stddevs, num_clusters) { 29 | qplot(num_clusters, means, xlab = "# clusters", ylab = "prediction strength", geom = "line", main = "Estimating the number of clusters via the prediction strength") + geom_errorbar(aes(num_clusters, ymin = means - stddevs, ymax = means + stddevs), size = 0.3, width = 0.2, colour = "darkblue") 30 | } 31 | 32 | calculate_prediction_strength = function(data, num_clusters) { 33 | # R's k-means algorithm doesn't work when there is only one cluster. 34 | if (num_clusters == 1) { 35 | 1 # The prediction strength is always 1 for 1 cluster, in any case. 36 | } else { 37 | rands = runif(nrow(data), min = 0, max = 1) 38 | training_set = data[(1:length(rands))[rands <= 0.5], ] 39 | test_set = data[(1:length(rands))[rands > 0.5], ] 40 | 41 | # Run k-means `nstart` times. 42 | kmeans_training = kmeans(training_set, centers = num_clusters, nstart = 10) 43 | kmeans_test = kmeans(test_set, centers = num_clusters, nstart = 10) 44 | 45 | # The prediction strength is the minimum prediction strength among all clusters. 46 | prediction_strengths = maply(1:num_clusters, function(n) prediction_strength_of_cluster(test_set, kmeans_test, kmeans_training$center, n)) 47 | min(prediction_strengths) 48 | } 49 | } 50 | 51 | # Calculate the proportion of pairs of points in test cluster `k` that would again be assigned to the same cluster, if each were clustered according to its closest training cluster mean. 52 | prediction_strength_of_cluster = function(test_set, kmeans_test, training_centers, k) { 53 | if (sum(kmeans_test$cluster == k) <= 1) { 54 | 1 # No points in the cluster. 55 | } else { 56 | test_cluster = test_set[kmeans_test$cluster == k, ] 57 | count = 0 58 | for (i in 1:(nrow(test_cluster)-1)) { 59 | for (j in (i+1):nrow(test_cluster)) { 60 | p1 = test_cluster[i, ] 61 | p2 = test_cluster[j, ] 62 | if (closest_center(training_centers, p1) == closest_center(training_centers, p2)) { 63 | count = count + 1 64 | } 65 | } 66 | } 67 | # Return the proportion of pairs that stayed in the same cluster. 68 | count / (nrow(test_cluster) * (nrow(test_cluster) - 1) / 2) 69 | } 70 | } 71 | 72 | # Returns the index of the center that x is closest to. (TODO: Vectorize...) 73 | closest_center = function(centers, x) { 74 | min_index = -1 75 | min_distance = 999999 76 | for (i in 1:nrow(centers)) { 77 | center = centers[i, ] 78 | d = sum((x - center)^2) 79 | if (d < min_distance) { 80 | min_index = i 81 | min_distance = d 82 | } 83 | } 84 | min_index 85 | } -------------------------------------------------------------------------------- /simulations.R: -------------------------------------------------------------------------------- 1 | library(scatterplot3d) 2 | 3 | # Three clusters in 2 dimensions 4 | x = c(rnorm(20, mean = 0), rnorm(20, mean = 3), rnorm(20, mean = 5)) 5 | y = c(rnorm(20, mean = 0), rnorm(20, mean = 5), rnorm(20, mean = 0)) 6 | data = cbind(x, y) 7 | 8 | png("3_clusters_2d.png") 9 | qplot(x, y) 10 | dev.off() 11 | 12 | png("3_clusters_2d_ps.png") 13 | prediction_strength(data) 14 | dev.off() 15 | 16 | # Four clusters in 3 dimensions 17 | x = c(rnorm(20, mean = 0), rnorm(20, mean = 3), rnorm(20, mean = 5), rnorm(20, mean = -10)) 18 | y = rnorm(80, mean = 0) 19 | z = c(rnorm(40, mean = -5), rnorm(40, mean = 0)) 20 | data = cbind(x, y, z) 21 | 22 | png("4_clusters_3d.png") 23 | scatterplot3d(x, y, z) 24 | dev.off() 25 | 26 | png("4_clusters_3d_ps.png") 27 | prediction_strength(data) 28 | dev.off() --------------------------------------------------------------------------------