├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── ann ├── README.md ├── img │ ├── rm_vs_medv-1.png │ ├── train_boston-1.png │ └── train_breast_cancer_data-1.png └── readme.Rmd ├── caret ├── README.md ├── img │ ├── auroc_test-1.png │ ├── bwplot-1.png │ ├── bwplot_diff_values-1.png │ ├── heatmap_grid_res-1.png │ ├── line_plot_grid_res-1.png │ ├── ntree_mtry_heatmap-1.png │ ├── plot_my_rf-1.png │ ├── plot_my_rf_full-1.png │ ├── resamples_dotplot-1.png │ ├── roc_density-1.png │ └── spam_rf_tune-1.png ├── readme.Rmd ├── rf.Rmd ├── rf.md ├── xgboost.Rmd └── xgboost.md ├── data ├── README.md ├── Seasons_Stats.csv.gz ├── breast_cancer_data.Rmd ├── breast_cancer_data.csv ├── breast_cancer_data.md ├── spambase.Rmd ├── spambase.csv ├── spambase.md └── titanic.csv.gz ├── deep_learning ├── README.md ├── ae.md ├── cnn.md ├── img │ ├── digit_heatmap-1.png │ ├── model_plot-1.png │ ├── unnamed-chunk-10-1.png │ ├── unnamed-chunk-12-1.png │ ├── unnamed-chunk-14-1.png │ ├── unnamed-chunk-16-1.png │ ├── unnamed-chunk-19-1.png │ ├── unnamed-chunk-23-1.png │ ├── unnamed-chunk-3-1.png │ ├── unnamed-chunk-5-1.png │ ├── unnamed-chunk-7-1.png │ └── unnamed-chunk-8-1.png ├── readme.Rmd ├── rnn.md └── transformer.md ├── evaluation ├── README.md ├── cutpointr.Rmd ├── cutpointr.md ├── formulas.Rmd ├── formulas.md ├── img │ ├── confusion_matrix.png │ ├── cross_validation.png │ ├── dendrogram-1.png │ ├── ig_df-1.png │ ├── mouse_dendrogram.png │ ├── plot_cutpointr-1.png │ ├── plot_metric-1.png │ ├── plot_spam_youden-1.png │ ├── precision_recall-1.png │ ├── random_forest_roc-1.png │ ├── random_predictor-1.png │ ├── rmse.png │ ├── roc_verification-1.png │ ├── roc_verification.png │ ├── roc_verification_ci-1.png │ ├── roc_verification_ci.png │ ├── roc_versicolor.png │ ├── try_k-1.png │ ├── unnamed-chunk-1-1.png │ ├── unnamed-chunk-2-1.png │ ├── unnamed-chunk-3-1.png │ └── unnamed-chunk-4-1.png └── readme.Rmd ├── gmm ├── gmm.Rmd ├── gmm.md └── img │ ├── clearly_distinct-1.png │ ├── close-1.png │ ├── distinct-1.png │ ├── faithful_scatter-1.png │ └── plot_density-1.png ├── hclust ├── README.md ├── img │ └── unnamed-chunk-6-1.png └── readme.Rmd ├── kmeans ├── README.md ├── img │ ├── elbow_plot-1.png │ ├── fviz_cluster-1.png │ ├── kmeans_k_2-1.png │ ├── pca-1.png │ ├── pca_figure-1.png │ ├── plot_hist-1.png │ └── silhouette_analysis-1.png └── readme.Rmd ├── knn ├── README.md ├── img │ └── unnamed-chunk-1-1.png └── readme.Rmd ├── logit_regression ├── README.md ├── img │ ├── mpg_vs_hp-1.png │ ├── unnamed-chunk-3-1.png │ └── unnamed-chunk-5-1.png └── readme.Rmd ├── machine_learning.Rproj ├── naive_bayes ├── README.md └── readme.Rmd ├── pca ├── README.md ├── img │ └── plot-1.png └── readme.Rmd ├── proximus ├── README.md ├── img │ └── unnamed-chunk-1-1.png └── proximus.Rmd ├── random_forest ├── README.md ├── img │ ├── tune_rf-1.png │ └── var_imp_plot-1.png └── readme.Rmd ├── ref ├── OUCS-2002-12.pdf ├── README.md └── fdws02.pdf ├── script ├── chown.sh ├── rmd_to_md.sh └── run_rstudio.sh ├── som ├── README.md ├── img │ ├── check_convergence-1.png │ ├── code_plot-1.png │ ├── code_plot-2.png │ ├── heatmap_sepal_width-1.png │ ├── node_count-1.png │ └── species_vs_sepal_width-1.png └── readme.Rmd ├── svm ├── README.md ├── example.Rmd ├── img │ └── SVM_Example_of_Hyperplanes.png └── readme.Rmd ├── tabnet ├── README.md └── readme.Rmd ├── template ├── README.md ├── img │ └── plot-1.png └── readme.Rmd ├── tidymodels ├── README.md ├── img │ ├── imbalance_pr_curve-1.png │ ├── imbalance_roc_curve-1.png │ ├── pr_curve-1.png │ ├── roc_curve-1.png │ ├── rocr_pr_curve-1.png │ └── rocr_roc_curve-1.png ├── readme.Rmd └── xgboost.Rmd ├── tree ├── README.md ├── img │ ├── unnamed-chunk-1-1.png │ ├── unnamed-chunk-2-1.png │ ├── unnamed-chunk-3-1.png │ ├── unnamed-chunk-4-1.png │ ├── unnamed-chunk-5-1.png │ ├── unnamed-chunk-5-2.png │ ├── unnamed-chunk-6-1.png │ └── unnamed-chunk-9-1.png └── readme.Rmd ├── variant ├── README.md ├── image │ ├── swissvar_cadd.png │ └── swissvar_gerp.png ├── kabuki_hg18.bed ├── kabuki_hg19_dbnsfp.out ├── miller_hg18.bed ├── miller_hg19_dbnsfp.out ├── myvariant.Rmd ├── myvariant.pdf ├── random_forest │ ├── README.md │ ├── analysis.R │ ├── download.sh │ ├── image │ │ └── CFTR.png │ └── stratify.pl └── run.sh └── xgboost ├── README.md ├── img ├── arthritis_imp_plot-1.png ├── breast_cancer_feature_importance-1.png ├── feature_importance-1.png └── feature_importance_plot-1.png └── readme.Rmd /.gitignore: -------------------------------------------------------------------------------- 1 | variant/*_hg19.bed 2 | variant/*_hg19.tsv 3 | variant/*_hg19.vcf 4 | variant/*_hg19_dbnsfp.out.err 5 | variant/clinvar_20160831.vcf.gz* 6 | *.swp 7 | DataS1* 8 | __* 9 | .Rhistory 10 | positive* 11 | negative* 12 | clinvar.vcf.gz* 13 | clinvar.out* 14 | .DS_Store 15 | Pfeiffer.vcf 16 | *.out 17 | clinvar_* 18 | .Rproj.user 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Dave Tang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: template som ann tree hclust logit svm kmeans knn naive_bayes xgboost random_forest eval pca caret tidymodels deep_learning tabnet 2 | 3 | template: template/README.md 4 | som: som/README.md 5 | ann: ann/README.md 6 | tree: tree/README.md 7 | hclust: hclust/README.md 8 | logit: logit_regression/README.md 9 | svm: svm/README.md 10 | kmeans: kmeans/README.md 11 | knn: knn/README.md 12 | naive_bayes: naive_bayes/README.md 13 | xgboost: xgboost/README.md 14 | random_forest: random_forest/README.md 15 | eval: evaluation/README.md 16 | pca: pca/README.md 17 | caret: caret/README.md 18 | tidymodels: tidymodels/README.md 19 | deep_learning: deep_learning/README.md 20 | tabnet: tabnet/README.md 21 | 22 | template/README.md: script/rmd_to_md.sh template/readme.Rmd 23 | $^ 24 | 25 | som/README.md: script/rmd_to_md.sh som/readme.Rmd 26 | $^ 27 | 28 | ann/README.md: script/rmd_to_md.sh ann/readme.Rmd 29 | $^ 30 | 31 | tree/README.md: script/rmd_to_md.sh tree/readme.Rmd 32 | $^ 33 | 34 | hclust/README.md: script/rmd_to_md.sh hclust/readme.Rmd 35 | $^ 36 | 37 | logit_regression/README.md: script/rmd_to_md.sh logit_regression/readme.Rmd 38 | $^ 39 | 40 | svm/README.md: script/rmd_to_md.sh svm/readme.Rmd 41 | $^ 42 | 43 | kmeans/README.md: script/rmd_to_md.sh kmeans/readme.Rmd 44 | $^ 45 | 46 | knn/README.md: script/rmd_to_md.sh knn/readme.Rmd 47 | $^ 48 | 49 | naive_bayes/README.md: script/rmd_to_md.sh naive_bayes/readme.Rmd 50 | $^ 51 | 52 | xgboost/README.md: script/rmd_to_md.sh xgboost/readme.Rmd 53 | $^ 54 | 55 | random_forest/README.md: script/rmd_to_md.sh random_forest/readme.Rmd 56 | $^ 57 | 58 | evaluation/README.md: script/rmd_to_md.sh evaluation/readme.Rmd 59 | $^ 60 | 61 | pca/README.md: script/rmd_to_md.sh pca/readme.Rmd 62 | $^ 63 | 64 | caret/README.md: script/rmd_to_md.sh caret/readme.Rmd 65 | $^ 66 | 67 | tidymodels/README.md: script/rmd_to_md.sh tidymodels/readme.Rmd 68 | $^ 69 | 70 | deep_learning/README.md: script/rmd_to_md.sh deep_learning/readme.Rmd 71 | $^ 72 | 73 | tabnet/README.md: script/rmd_to_md.sh tabnet/readme.Rmd 74 | $^ 75 | -------------------------------------------------------------------------------- /ann/img/rm_vs_medv-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/ann/img/rm_vs_medv-1.png -------------------------------------------------------------------------------- /ann/img/train_boston-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/ann/img/train_boston-1.png -------------------------------------------------------------------------------- /ann/img/train_breast_cancer_data-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/ann/img/train_breast_cancer_data-1.png -------------------------------------------------------------------------------- /ann/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Artificial Neural Networks" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | ## Introduction 17 | 18 | This notebook is adapted from [this tutorial](https://datascienceplus.com/fitting-neural-network-in-r/). 19 | 20 | Install packages if missing and load. 21 | 22 | ```{r load_package, message=FALSE, warning=FALSE} 23 | .libPaths('/packages') 24 | my_packages <- c('MASS', 'neuralnet') 25 | 26 | for (my_package in my_packages){ 27 | if(!require(my_package, character.only = TRUE)){ 28 | install.packages(my_package, '/packages') 29 | library(my_package, character.only = TRUE) 30 | } 31 | } 32 | ``` 33 | 34 | ## Housing values in Boston 35 | 36 | The `Boston` data set from the `MASS` package contains the following features: 37 | 38 | * `crim` - per capita crime rate by town. 39 | * `zn` - proportion of residential land zoned for lots over 25,000 sq.ft. 40 | * `indus` - proportion of non-retail business acres per town. 41 | * `chas` - Charles River dummy variable (= 1 if tract bounds river; 0 otherwise). 42 | * `nox` - nitrogen oxides concentration (parts per 10 million). 43 | * `rm` - average number of rooms per dwelling. 44 | * `age` - proportion of owner-occupied units built prior to 1940. 45 | * `dis` - weighted mean of distances to five Boston employment centres. 46 | * `rad` - index of accessibility to radial highways. 47 | * `tax` - full-value property-tax rate per $10,000. 48 | * `ptratio` - pupil-teacher ratio by town. 49 | * `black` - 1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town. 50 | * `lstat` - lower status of the population (percent). 51 | * `medv` - median value of owner-occupied homes in $1000s. 52 | 53 | ```{r boston_data} 54 | str(Boston) 55 | any(is.na(Boston)) 56 | ``` 57 | 58 | ### Multiple linear regression 59 | 60 | Carry out multiple linear regression by regressing the median value onto all other features. 61 | 62 | ```{r glm} 63 | set.seed(500) 64 | index <- sample(1:nrow(Boston), round(0.75*nrow(Boston))) 65 | train <- Boston[index,] 66 | test <- Boston[-index,] 67 | lm.fit <- glm(medv ~ ., data=train) 68 | summary(lm.fit) 69 | ``` 70 | 71 | The number of rooms has the highest _t_-statistic. 72 | 73 | ```{r rm_vs_medv} 74 | ggplot(Boston, aes(rm, medv)) + 75 | geom_point() + 76 | labs(x = "Average number of rooms per dwelling", y = "Median value in $1,000") 77 | ``` 78 | 79 | Predict prices and calculate the mean squared error (MSE). 80 | 81 | ```{r glm_predict} 82 | pr.lm <- predict(lm.fit, test) 83 | MSE.lm <- sum((pr.lm - test$medv)^2)/nrow(test) 84 | MSE.lm 85 | ``` 86 | 87 | ### Neural network 88 | 89 | First we will carry out [feature scaling](https://en.wikipedia.org/wiki/Feature_scaling) using: 90 | 91 | ![](https://latex.codecogs.com/png.image?\large&space;\dpi{110}\bg{white}&space;x'&space;=&space;\frac{x&space;-&space;min(x)}{max(x)&space;-&space;min(x)}) 92 | 93 | Manually perform min-max and compare `scale` approach (just for fun). 94 | 95 | ```{r compare_min_max} 96 | x <- 1:20 97 | x_a <- (x - min(x)) / (max(x) - min(x)) 98 | x_b <- as.vector(scale(x, center = min(x), scale = max(x) - min(x))) 99 | identical(x_a, x_b) 100 | ``` 101 | 102 | Carrying out scaling on Boston data set. 103 | 104 | ```{r boston_scaled} 105 | maxs <- apply(Boston, 2, max) 106 | mins <- apply(Boston, 2, min) 107 | 108 | scaled <- as.data.frame( 109 | scale(Boston, center = mins, scale = maxs - mins) 110 | ) 111 | 112 | train_scaled <- scaled[index,] 113 | test_scaled <- scaled[-index,] 114 | ``` 115 | 116 | Manually create formula as `f` since `neuralnet` does not recognise R formulae. 117 | 118 | ```{r build_formula} 119 | n <- names(train_scaled) 120 | f <- as.formula(paste("medv ~", paste(n[!n %in% "medv"], collapse = " + "))) 121 | ``` 122 | 123 | Train neural network using two hidden layers with 5 and 3 neurons, respectively. 124 | 125 | ```{r train_boston, fig.width = 8, fig.height=6} 126 | nn <- neuralnet(f, data = train_scaled, hidden=c(5,3), linear.output = TRUE) 127 | plot(nn, rep = "best") 128 | ``` 129 | 130 | Predict (scaled) value. 131 | 132 | ```{r nn_predict} 133 | pr.nn <- compute(nn, test_scaled[,1:13]) 134 | ``` 135 | 136 | We need to unscale the data before calculating the MSE. 137 | 138 | ```{r mse_nn} 139 | pr.nn_unscaled <- pr.nn$net.result * (max(Boston$medv) - min(Boston$medv)) + min(Boston$medv) 140 | test.r <- (test_scaled$medv) * (max(Boston$medv) - min(Boston$medv)) + min(Boston$medv) 141 | 142 | MSE.nn <- sum((test.r - pr.nn_unscaled)^2)/nrow(test_scaled) 143 | ``` 144 | 145 | Comparing the MSEs. 146 | 147 | ```{r mse_comp} 148 | print(paste0("MSE of multiple linear regression: ", MSE.lm)) 149 | print(paste0("MSE of neural network regression: ", MSE.nn)) 150 | ``` 151 | 152 | ## Breast cancer data 153 | 154 | Classify breast cancer samples using the [Breast Cancer Wisconsin (Diagnostic) Data Set](https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)). 155 | 156 | ```{r breast_cancer_data} 157 | data <- read.table( 158 | "../data/breast_cancer_data.csv", 159 | stringsAsFactors = FALSE, 160 | sep = ',', 161 | header = TRUE 162 | ) 163 | data$class <- factor(data$class) 164 | data <- data[,-1] 165 | ``` 166 | 167 | Separate into training (80%) and testing (20%). 168 | 169 | ```{r split_breast_cancer_data} 170 | set.seed(31) 171 | my_prob <- 0.8 172 | my_split <- as.logical( 173 | rbinom( 174 | n = nrow(data), 175 | size = 1, 176 | p = my_prob 177 | ) 178 | ) 179 | 180 | train <- data[my_split,] 181 | test <- data[!my_split,] 182 | ``` 183 | 184 | Train neural network. 185 | 186 | ```{r train_breast_cancer_data, fig.width = 8, fig.height=6} 187 | n <- names(train) 188 | f <- as.formula(paste("class ~", paste(n[!n %in% "class"], collapse = " + "))) 189 | nn <- neuralnet(f, data = train, hidden=c(5,3), linear.output = FALSE) 190 | plot(nn, rep = "best") 191 | ``` 192 | 193 | Predict and check results. 194 | 195 | ```{r predict_breast_cancer_data} 196 | result <- compute(nn, test[,-10]) 197 | result <- apply(result$net.result, 1, function(x) ifelse(x[1] > x[2], yes = 2, no = 4)) 198 | 199 | # test$class are the rows and nn result are the columns 200 | table(test$class, result) 201 | ``` 202 | 203 | ## Further reading 204 | 205 | The neuralnet [reference manual](https://cran.r-project.org/web/packages/neuralnet/neuralnet.pdf). 206 | 207 | ## Session info 208 | 209 | Time built. 210 | 211 | ```{r time, echo=FALSE} 212 | Sys.time() 213 | ``` 214 | 215 | Session info. 216 | 217 | ```{r session_info, echo=FALSE} 218 | sessionInfo() 219 | ``` 220 | -------------------------------------------------------------------------------- /caret/img/auroc_test-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/caret/img/auroc_test-1.png -------------------------------------------------------------------------------- /caret/img/bwplot-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/caret/img/bwplot-1.png -------------------------------------------------------------------------------- /caret/img/bwplot_diff_values-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/caret/img/bwplot_diff_values-1.png -------------------------------------------------------------------------------- /caret/img/heatmap_grid_res-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/caret/img/heatmap_grid_res-1.png -------------------------------------------------------------------------------- /caret/img/line_plot_grid_res-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/caret/img/line_plot_grid_res-1.png -------------------------------------------------------------------------------- /caret/img/ntree_mtry_heatmap-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/caret/img/ntree_mtry_heatmap-1.png -------------------------------------------------------------------------------- /caret/img/plot_my_rf-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/caret/img/plot_my_rf-1.png -------------------------------------------------------------------------------- /caret/img/plot_my_rf_full-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/caret/img/plot_my_rf_full-1.png -------------------------------------------------------------------------------- /caret/img/resamples_dotplot-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/caret/img/resamples_dotplot-1.png -------------------------------------------------------------------------------- /caret/img/roc_density-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/caret/img/roc_density-1.png -------------------------------------------------------------------------------- /caret/img/spam_rf_tune-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/caret/img/spam_rf_tune-1.png -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Datasets 2 | ======== 3 | 4 | # Titanic 5 | 6 | File downloaded from 7 | 8 | ``` 9 | # Survived Survival (0 = No; 1 = Yes) 10 | # Pclass Passenger Class (1 = 1st; 2 = 2nd; 3 = 3rd) 11 | # Name Name 12 | # Sex Sex 13 | # Age Age 14 | # SibSp Number of Siblings/Spouses Aboard 15 | # Parch Number of Parents/Children Aboard 16 | # Ticket Ticket Number 17 | # Fare Passenger Fare 18 | # Cabin Cabin 19 | # Embarked Port of Embarkation (C = Cherbourg; Q = Queenstown; S = Southampton) 20 | ``` 21 | 22 | # NBA 23 | 24 | `Seasons_Stats.csv.gz` downloaded from 25 | 26 | -------------------------------------------------------------------------------- /data/Seasons_Stats.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/data/Seasons_Stats.csv.gz -------------------------------------------------------------------------------- /data/breast_cancer_data.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Breast cancer data" 3 | output: md_document 4 | --- 5 | 6 | ```{r setup, include=FALSE} 7 | knitr::opts_chunk$set(cache = FALSE) 8 | knitr::opts_chunk$set(echo = TRUE) 9 | ``` 10 | 11 | Using the [Breast Cancer Wisconsin (Diagnostic) Data Set](https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)). 12 | 13 | ```{r save_data} 14 | my_link <- 'https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/breast-cancer-wisconsin.data' 15 | data <- read.table(url(my_link), stringsAsFactors = FALSE, header = FALSE, sep = ',') 16 | names(data) <- c('id','ct','ucsize','ucshape','ma','secs','bn','bc','nn','miti','class') 17 | head(data) 18 | ``` 19 | 20 | Any missing data? 21 | 22 | ```{r any_missing_data} 23 | any(is.na(data)) 24 | ``` 25 | 26 | Data structure. 27 | 28 | ```{r str} 29 | str(data) 30 | ``` 31 | 32 | Bare nuclei (`bn`) was stored as characters because of question marks. 33 | 34 | ```{r bn} 35 | table(data$bn) 36 | ``` 37 | 38 | Change the question marks into NA's and then into median values. 39 | 40 | ```{r convert_na} 41 | data$bn <- as.integer(gsub(pattern = '\\?', replacement = NA, x = data$bn)) 42 | 43 | data$bn[is.na(data$bn)] <- median(data$bn, na.rm = TRUE) 44 | 45 | str(data) 46 | ``` 47 | 48 | Convert `class` into a factor with levels `2` for benign and `4` for malignant. 49 | 50 | ```{r factor_class} 51 | data$class <- factor(data$class) 52 | ``` 53 | 54 | The `id` column is duplicated and it is not clear why they are duplicated. 55 | 56 | ```{r dup_data} 57 | dup_idx <- which(duplicated(data$id)) 58 | dup_data <- data[data$id %in% data[dup_idx, 'id'], ] 59 | 60 | head(dup_data[order(dup_data$id), ]) 61 | ``` 62 | 63 | We will remove duplicated IDs. 64 | 65 | ```{r remove_dup_id} 66 | dim(data) 67 | 68 | data <- data[! data$id %in% data[dup_idx, 'id'], ] 69 | 70 | dim(data) 71 | 72 | write.csv( 73 | x = data, 74 | file = "breast_cancer_data.csv", 75 | quote = FALSE, 76 | row.names = FALSE 77 | ) 78 | ``` 79 | 80 | ## Session info 81 | 82 | Time built. 83 | 84 | ```{r time, echo=FALSE} 85 | Sys.time() 86 | ``` 87 | 88 | Session info. 89 | 90 | ```{r session_info, echo=FALSE} 91 | sessionInfo() 92 | ``` 93 | 94 | -------------------------------------------------------------------------------- /data/breast_cancer_data.md: -------------------------------------------------------------------------------- 1 | Using the [Breast Cancer Wisconsin (Diagnostic) Data 2 | Set](https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)). 3 | 4 | my_link <- 'https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/breast-cancer-wisconsin.data' 5 | data <- read.table(url(my_link), stringsAsFactors = FALSE, header = FALSE, sep = ',') 6 | names(data) <- c('id','ct','ucsize','ucshape','ma','secs','bn','bc','nn','miti','class') 7 | head(data) 8 | 9 | ## id ct ucsize ucshape ma secs bn bc nn miti class 10 | ## 1 1000025 5 1 1 1 2 1 3 1 1 2 11 | ## 2 1002945 5 4 4 5 7 10 3 2 1 2 12 | ## 3 1015425 3 1 1 1 2 2 3 1 1 2 13 | ## 4 1016277 6 8 8 1 3 4 3 7 1 2 14 | ## 5 1017023 4 1 1 3 2 1 3 1 1 2 15 | ## 6 1017122 8 10 10 8 7 10 9 7 1 4 16 | 17 | Any missing data? 18 | 19 | any(is.na(data)) 20 | 21 | ## [1] FALSE 22 | 23 | Data structure. 24 | 25 | str(data) 26 | 27 | ## 'data.frame': 699 obs. of 11 variables: 28 | ## $ id : int 1000025 1002945 1015425 1016277 1017023 1017122 1018099 1018561 1033078 1033078 ... 29 | ## $ ct : int 5 5 3 6 4 8 1 2 2 4 ... 30 | ## $ ucsize : int 1 4 1 8 1 10 1 1 1 2 ... 31 | ## $ ucshape: int 1 4 1 8 1 10 1 2 1 1 ... 32 | ## $ ma : int 1 5 1 1 3 8 1 1 1 1 ... 33 | ## $ secs : int 2 7 2 3 2 7 2 2 2 2 ... 34 | ## $ bn : chr "1" "10" "2" "4" ... 35 | ## $ bc : int 3 3 3 3 3 9 3 3 1 2 ... 36 | ## $ nn : int 1 2 1 7 1 7 1 1 1 1 ... 37 | ## $ miti : int 1 1 1 1 1 1 1 1 5 1 ... 38 | ## $ class : int 2 2 2 2 2 4 2 2 2 2 ... 39 | 40 | Bare nuclei (`bn`) was stored as characters because of question marks. 41 | 42 | table(data$bn) 43 | 44 | ## 45 | ## ? 1 10 2 3 4 5 6 7 8 9 46 | ## 16 402 132 30 28 19 30 4 8 21 9 47 | 48 | Change the question marks into NA’s and then into median values. 49 | 50 | data$bn <- as.integer(gsub(pattern = '\\?', replacement = NA, x = data$bn)) 51 | 52 | data$bn[is.na(data$bn)] <- median(data$bn, na.rm = TRUE) 53 | 54 | str(data) 55 | 56 | ## 'data.frame': 699 obs. of 11 variables: 57 | ## $ id : int 1000025 1002945 1015425 1016277 1017023 1017122 1018099 1018561 1033078 1033078 ... 58 | ## $ ct : int 5 5 3 6 4 8 1 2 2 4 ... 59 | ## $ ucsize : int 1 4 1 8 1 10 1 1 1 2 ... 60 | ## $ ucshape: int 1 4 1 8 1 10 1 2 1 1 ... 61 | ## $ ma : int 1 5 1 1 3 8 1 1 1 1 ... 62 | ## $ secs : int 2 7 2 3 2 7 2 2 2 2 ... 63 | ## $ bn : int 1 10 2 4 1 10 10 1 1 1 ... 64 | ## $ bc : int 3 3 3 3 3 9 3 3 1 2 ... 65 | ## $ nn : int 1 2 1 7 1 7 1 1 1 1 ... 66 | ## $ miti : int 1 1 1 1 1 1 1 1 5 1 ... 67 | ## $ class : int 2 2 2 2 2 4 2 2 2 2 ... 68 | 69 | Convert `class` into a factor with levels `2` for benign and `4` for 70 | malignant. 71 | 72 | data$class <- factor(data$class) 73 | 74 | The `id` column is duplicated and it is not clear why they are 75 | duplicated. 76 | 77 | dup_idx <- which(duplicated(data$id)) 78 | dup_data <- data[data$id %in% data[dup_idx, 'id'], ] 79 | 80 | head(dup_data[order(dup_data$id), ]) 81 | 82 | ## id ct ucsize ucshape ma secs bn bc nn miti class 83 | ## 268 320675 3 3 5 2 3 10 7 1 1 4 84 | ## 273 320675 3 3 5 2 3 10 7 1 1 4 85 | ## 270 385103 1 1 1 1 2 1 3 1 1 2 86 | ## 576 385103 5 1 2 1 2 1 3 1 1 2 87 | ## 272 411453 5 1 1 1 2 1 3 1 1 2 88 | ## 608 411453 1 1 1 1 2 1 1 1 1 2 89 | 90 | We will remove duplicated IDs. 91 | 92 | dim(data) 93 | 94 | ## [1] 699 11 95 | 96 | data <- data[! data$id %in% data[dup_idx, 'id'], ] 97 | 98 | dim(data) 99 | 100 | ## [1] 599 11 101 | 102 | write.csv( 103 | x = data, 104 | file = "breast_cancer_data.csv", 105 | quote = FALSE, 106 | row.names = FALSE 107 | ) 108 | 109 | Session info 110 | ------------ 111 | 112 | Time built. 113 | 114 | ## [1] "2022-04-11 03:40:46 UTC" 115 | 116 | Session info. 117 | 118 | ## R version 4.1.3 (2022-03-10) 119 | ## Platform: x86_64-pc-linux-gnu (64-bit) 120 | ## Running under: Ubuntu 20.04.4 LTS 121 | ## 122 | ## Matrix products: default 123 | ## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 124 | ## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/liblapack.so.3 125 | ## 126 | ## locale: 127 | ## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C 128 | ## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8 129 | ## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 130 | ## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C 131 | ## [9] LC_ADDRESS=C LC_TELEPHONE=C 132 | ## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C 133 | ## 134 | ## attached base packages: 135 | ## [1] stats graphics grDevices utils datasets methods base 136 | ## 137 | ## loaded via a namespace (and not attached): 138 | ## [1] compiler_4.1.3 magrittr_2.0.3 fastmap_1.1.0 cli_3.2.0 139 | ## [5] tools_4.1.3 htmltools_0.5.2 yaml_2.3.5 stringi_1.7.6 140 | ## [9] rmarkdown_2.13 knitr_1.38 stringr_1.4.0 xfun_0.30 141 | ## [13] digest_0.6.29 rlang_1.0.2 evaluate_0.15 142 | -------------------------------------------------------------------------------- /data/spambase.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Spambase Data Set" 3 | output: md_document 4 | --- 5 | 6 | ```{r setup, include=FALSE} 7 | knitr::opts_chunk$set(cache = FALSE) 8 | knitr::opts_chunk$set(echo = TRUE) 9 | ``` 10 | 11 | From 12 | 13 | >Our collection of spam e-mails came from our postmaster and individuals who had filed spam. Our collection of non-spam e-mails came from filed work and personal e-mails, and hence the word 'george' and the area code '650' are indicators of non-spam. These are useful when constructing a personalized spam filter. One would either have to blind such non-spam indicators or get a very wide collection of non-spam to generate a general purpose spam filter. 14 | 15 | The last column of `spambase.data` denotes whether the e-mail was considered spam (1) or not (0). Most of the attributes indicate whether a particular word or character was frequently occurring in the e-mail. The run-length attributes (55-57) measure the length of sequences of consecutive capital letters. 16 | 17 | ```{r save_data} 18 | data_url <- 'https://archive.ics.uci.edu/ml/machine-learning-databases/spambase/spambase.data' 19 | data <- read.table(url(data_url), stringsAsFactors = FALSE, header = FALSE, sep = ',') 20 | 21 | data_col_url <- 'https://archive.ics.uci.edu/ml/machine-learning-databases/spambase/spambase.names' 22 | data_col <- read.table(url(data_col_url), stringsAsFactors = FALSE, header = FALSE, comment.char = "|") 23 | 24 | my_cols <- c(gsub(":$", "", data_col$V1[-1]), 'class') 25 | 26 | colnames(data) <- my_cols 27 | dim(data) 28 | ``` 29 | 30 | Any missing data? 31 | 32 | ```{r any_missing_data} 33 | any(is.na(data)) 34 | ``` 35 | 36 | Data structure. 37 | 38 | ```{r str} 39 | str(data) 40 | ``` 41 | 42 | Convert `class` into a factor with levels `0` for ham and `1` for spam. 43 | 44 | ```{r factor_class} 45 | data$class <- factor(data$class) 46 | ``` 47 | 48 | Save as CSV. 49 | 50 | ```{r save_csv} 51 | write.csv( 52 | x = data, 53 | file = "spambase.csv", 54 | quote = FALSE, 55 | row.names = FALSE 56 | ) 57 | ``` 58 | 59 | ## Session info 60 | 61 | Time built. 62 | 63 | ```{r time, echo=FALSE} 64 | Sys.time() 65 | ``` 66 | 67 | Session info. 68 | 69 | ```{r session_info, echo=FALSE} 70 | sessionInfo() 71 | ``` 72 | -------------------------------------------------------------------------------- /data/spambase.md: -------------------------------------------------------------------------------- 1 | From 2 | https://archive.ics.uci.edu/ml/datasets/spambase 3 | 4 | > Our collection of spam e-mails came from our postmaster and 5 | > individuals who had filed spam. Our collection of non-spam e-mails 6 | > came from filed work and personal e-mails, and hence the word ‘george’ 7 | > and the area code ‘650’ are indicators of non-spam. These are useful 8 | > when constructing a personalized spam filter. One would either have to 9 | > blind such non-spam indicators or get a very wide collection of 10 | > non-spam to generate a general purpose spam filter. 11 | 12 | The last column of `spambase.data` denotes whether the e-mail was 13 | considered spam (1) or not (0). Most of the attributes indicate whether 14 | a particular word or character was frequently occurring in the e-mail. 15 | The run-length attributes (55-57) measure the length of sequences of 16 | consecutive capital letters. 17 | 18 | data_url <- 'https://archive.ics.uci.edu/ml/machine-learning-databases/spambase/spambase.data' 19 | data <- read.table(url(data_url), stringsAsFactors = FALSE, header = FALSE, sep = ',') 20 | 21 | data_col_url <- 'https://archive.ics.uci.edu/ml/machine-learning-databases/spambase/spambase.names' 22 | data_col <- read.table(url(data_col_url), stringsAsFactors = FALSE, header = FALSE, comment.char = "|") 23 | 24 | my_cols <- c(gsub(":$", "", data_col$V1[-1]), 'class') 25 | 26 | colnames(data) <- my_cols 27 | dim(data) 28 | 29 | ## [1] 4601 58 30 | 31 | Any missing data? 32 | 33 | any(is.na(data)) 34 | 35 | ## [1] FALSE 36 | 37 | Data structure. 38 | 39 | str(data) 40 | 41 | ## 'data.frame': 4601 obs. of 58 variables: 42 | ## $ word_freq_make : num 0 0.21 0.06 0 0 0 0 0 0.15 0.06 ... 43 | ## $ word_freq_address : num 0.64 0.28 0 0 0 0 0 0 0 0.12 ... 44 | ## $ word_freq_all : num 0.64 0.5 0.71 0 0 0 0 0 0.46 0.77 ... 45 | ## $ word_freq_3d : num 0 0 0 0 0 0 0 0 0 0 ... 46 | ## $ word_freq_our : num 0.32 0.14 1.23 0.63 0.63 1.85 1.92 1.88 0.61 0.19 ... 47 | ## $ word_freq_over : num 0 0.28 0.19 0 0 0 0 0 0 0.32 ... 48 | ## $ word_freq_remove : num 0 0.21 0.19 0.31 0.31 0 0 0 0.3 0.38 ... 49 | ## $ word_freq_internet : num 0 0.07 0.12 0.63 0.63 1.85 0 1.88 0 0 ... 50 | ## $ word_freq_order : num 0 0 0.64 0.31 0.31 0 0 0 0.92 0.06 ... 51 | ## $ word_freq_mail : num 0 0.94 0.25 0.63 0.63 0 0.64 0 0.76 0 ... 52 | ## $ word_freq_receive : num 0 0.21 0.38 0.31 0.31 0 0.96 0 0.76 0 ... 53 | ## $ word_freq_will : num 0.64 0.79 0.45 0.31 0.31 0 1.28 0 0.92 0.64 ... 54 | ## $ word_freq_people : num 0 0.65 0.12 0.31 0.31 0 0 0 0 0.25 ... 55 | ## $ word_freq_report : num 0 0.21 0 0 0 0 0 0 0 0 ... 56 | ## $ word_freq_addresses : num 0 0.14 1.75 0 0 0 0 0 0 0.12 ... 57 | ## $ word_freq_free : num 0.32 0.14 0.06 0.31 0.31 0 0.96 0 0 0 ... 58 | ## $ word_freq_business : num 0 0.07 0.06 0 0 0 0 0 0 0 ... 59 | ## $ word_freq_email : num 1.29 0.28 1.03 0 0 0 0.32 0 0.15 0.12 ... 60 | ## $ word_freq_you : num 1.93 3.47 1.36 3.18 3.18 0 3.85 0 1.23 1.67 ... 61 | ## $ word_freq_credit : num 0 0 0.32 0 0 0 0 0 3.53 0.06 ... 62 | ## $ word_freq_your : num 0.96 1.59 0.51 0.31 0.31 0 0.64 0 2 0.71 ... 63 | ## $ word_freq_font : num 0 0 0 0 0 0 0 0 0 0 ... 64 | ## $ word_freq_000 : num 0 0.43 1.16 0 0 0 0 0 0 0.19 ... 65 | ## $ word_freq_money : num 0 0.43 0.06 0 0 0 0 0 0.15 0 ... 66 | ## $ word_freq_hp : num 0 0 0 0 0 0 0 0 0 0 ... 67 | ## $ word_freq_hpl : num 0 0 0 0 0 0 0 0 0 0 ... 68 | ## $ word_freq_george : num 0 0 0 0 0 0 0 0 0 0 ... 69 | ## $ word_freq_650 : num 0 0 0 0 0 0 0 0 0 0 ... 70 | ## $ word_freq_lab : num 0 0 0 0 0 0 0 0 0 0 ... 71 | ## $ word_freq_labs : num 0 0 0 0 0 0 0 0 0 0 ... 72 | ## $ word_freq_telnet : num 0 0 0 0 0 0 0 0 0 0 ... 73 | ## $ word_freq_857 : num 0 0 0 0 0 0 0 0 0 0 ... 74 | ## $ word_freq_data : num 0 0 0 0 0 0 0 0 0.15 0 ... 75 | ## $ word_freq_415 : num 0 0 0 0 0 0 0 0 0 0 ... 76 | ## $ word_freq_85 : num 0 0 0 0 0 0 0 0 0 0 ... 77 | ## $ word_freq_technology : num 0 0 0 0 0 0 0 0 0 0 ... 78 | ## $ word_freq_1999 : num 0 0.07 0 0 0 0 0 0 0 0 ... 79 | ## $ word_freq_parts : num 0 0 0 0 0 0 0 0 0 0 ... 80 | ## $ word_freq_pm : num 0 0 0 0 0 0 0 0 0 0 ... 81 | ## $ word_freq_direct : num 0 0 0.06 0 0 0 0 0 0 0 ... 82 | ## $ word_freq_cs : num 0 0 0 0 0 0 0 0 0 0 ... 83 | ## $ word_freq_meeting : num 0 0 0 0 0 0 0 0 0 0 ... 84 | ## $ word_freq_original : num 0 0 0.12 0 0 0 0 0 0.3 0 ... 85 | ## $ word_freq_project : num 0 0 0 0 0 0 0 0 0 0.06 ... 86 | ## $ word_freq_re : num 0 0 0.06 0 0 0 0 0 0 0 ... 87 | ## $ word_freq_edu : num 0 0 0.06 0 0 0 0 0 0 0 ... 88 | ## $ word_freq_table : num 0 0 0 0 0 0 0 0 0 0 ... 89 | ## $ word_freq_conference : num 0 0 0 0 0 0 0 0 0 0 ... 90 | ## $ char_freq_; : num 0 0 0.01 0 0 0 0 0 0 0.04 ... 91 | ## $ char_freq_( : num 0 0.132 0.143 0.137 0.135 0.223 0.054 0.206 0.271 0.03 ... 92 | ## $ char_freq_[ : num 0 0 0 0 0 0 0 0 0 0 ... 93 | ## $ char_freq_! : num 0.778 0.372 0.276 0.137 0.135 0 0.164 0 0.181 0.244 ... 94 | ## $ char_freq_$ : num 0 0.18 0.184 0 0 0 0.054 0 0.203 0.081 ... 95 | ## $ char_freq_# : num 0 0.048 0.01 0 0 0 0 0 0.022 0 ... 96 | ## $ capital_run_length_average: num 3.76 5.11 9.82 3.54 3.54 ... 97 | ## $ capital_run_length_longest: int 61 101 485 40 40 15 4 11 445 43 ... 98 | ## $ capital_run_length_total : int 278 1028 2259 191 191 54 112 49 1257 749 ... 99 | ## $ class : int 1 1 1 1 1 1 1 1 1 1 ... 100 | 101 | Convert `class` into a factor with levels `0` for ham and `1` for spam. 102 | 103 | data$class <- factor(data$class) 104 | 105 | Save as CSV. 106 | 107 | write.csv( 108 | x = data, 109 | file = "spambase.csv", 110 | quote = FALSE, 111 | row.names = FALSE 112 | ) 113 | 114 | Session info 115 | ------------ 116 | 117 | Time built. 118 | 119 | ## [1] "2022-04-12 06:54:12 UTC" 120 | 121 | Session info. 122 | 123 | ## R version 4.1.3 (2022-03-10) 124 | ## Platform: x86_64-pc-linux-gnu (64-bit) 125 | ## Running under: Ubuntu 20.04.4 LTS 126 | ## 127 | ## Matrix products: default 128 | ## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 129 | ## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/liblapack.so.3 130 | ## 131 | ## locale: 132 | ## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C 133 | ## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8 134 | ## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 135 | ## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C 136 | ## [9] LC_ADDRESS=C LC_TELEPHONE=C 137 | ## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C 138 | ## 139 | ## attached base packages: 140 | ## [1] stats graphics grDevices utils datasets methods base 141 | ## 142 | ## loaded via a namespace (and not attached): 143 | ## [1] compiler_4.1.3 magrittr_2.0.3 fastmap_1.1.0 cli_3.2.0 144 | ## [5] tools_4.1.3 htmltools_0.5.2 yaml_2.3.5 stringi_1.7.6 145 | ## [9] rmarkdown_2.13 knitr_1.38 stringr_1.4.0 xfun_0.30 146 | ## [13] digest_0.6.29 rlang_1.0.2 evaluate_0.15 147 | -------------------------------------------------------------------------------- /data/titanic.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/data/titanic.csv.gz -------------------------------------------------------------------------------- /deep_learning/ae.md: -------------------------------------------------------------------------------- 1 | # Autoencoders 2 | 3 | From ChatGPT. 4 | 5 | Autoencoder models are a class of neural networks used for unsupervised learning tasks, particularly for dimensionality reduction, feature learning, and data compression. The primary objective of an autoencoder is to learn a compact representation of the input data by encoding it into a lower-dimensional space and then reconstructing the original data from this representation as accurately as possible. 6 | 7 | The basic structure of an autoencoder consists of two main components: 8 | 9 | 1. **Encoder**: The encoder network maps the input data into a lower-dimensional latent space representation. It typically consists of one or more layers of neurons that gradually reduce the dimensionality of the input data. 10 | 11 | 2. **Decoder**: The decoder network reconstructs the original input data from the latent space representation produced by the encoder. It mirrors the structure of the encoder but in reverse, gradually expanding the dimensionality of the latent representation back to the original input space. 12 | 13 | During training, autoencoders are optimised to minimise the reconstruction error, which measures the difference between the input data and its reconstruction. By learning to reconstruct the input data accurately, autoencoders implicitly learn meaningful features and representations of the data. 14 | 15 | There are several families or variants of autoencoder models, each with its own characteristics and use cases. Some of the main families of autoencoders include: 16 | 17 | 1. **Vanilla Autoencoders**: Also known as undercomplete autoencoders, vanilla autoencoders have a bottleneck structure in the latent space, forcing the model to learn a compressed representation of the input data. 18 | 19 | 2. **Denoising Autoencoders**: Denoising autoencoders are trained to reconstruct clean input data from noisy or corrupted input samples. They learn to denoise the input data by capturing the underlying structure and removing noise. 20 | 21 | 3. **Sparse Autoencoders**: Sparse autoencoders introduce sparsity constraints on the activations of the hidden layers. They encourage the model to learn sparse representations of the input data, which can help disentangle and extract meaningful features. 22 | 23 | 4. **Variational Autoencoders (VAEs)**: VAEs are probabilistic autoencoder models that learn a probabilistic model of the input data and use variational inference techniques to approximate the latent space distribution. They enable the generation of new data samples by sampling from the learned latent space distribution. 24 | 25 | 5. **Contractive Autoencoders**: Contractive autoencoders incorporate penalty terms in the loss function to enforce smoothness and local linearity in the learned latent space. They are robust to small variations in the input data and can learn invariant representations. 26 | 27 | 6. **Generative Adversarial Networks (GANs)**: Although not traditional autoencoders, GANs consist of a generator network that learns to generate realistic data samples from random noise and a discriminator network that learns to distinguish between real and generated samples. GANs can be considered as implicitly learning a compressed representation of the input data through the generator network. 28 | 29 | Each family of autoencoder models has its own advantages and is suitable for different types of data and tasks. Choosing the appropriate autoencoder variant depends on the specific requirements of the problem at hand, such as the nature of the input data, the desired properties of the learned representations, and the intended application of the model. 30 | 31 | ## Variational Autoencoder 32 | 33 | A Variational Autoencoder (VAE) is a type of generative model that learns to generate new data samples by capturing the underlying structure and distribution of the input data. VAEs belong to the family of autoencoder models, which are neural networks trained to learn efficient representations of input data by compressing it into a lower-dimensional latent space and then reconstructing the original data from this representation. 34 | 35 | The key innovation of VAEs compared to traditional autoencoders is the introduction of probabilistic modeling and variational inference techniques. VAEs learn a probabilistic model of the data and use variational inference to approximate the posterior distribution of the latent variables given the observed data. This allows VAEs to generate new data samples by sampling from the learned latent space distribution. 36 | 37 | Here's how a Variational Autoencoder works: 38 | 39 | 1. **Encoder Network**: 40 | - The encoder network takes an input data sample and maps it to a distribution in the latent space. Instead of directly outputting the values of the latent variables, the encoder network outputs the parameters (mean and variance) of a Gaussian distribution that represents the approximate posterior distribution of the latent variables given the input data. 41 | 42 | 2. **Sampling Latent Variables**: 43 | - During training, a sample is drawn from the approximate posterior distribution (using the reparameterisation trick) to obtain a latent representation for the input data sample. 44 | - This sampled latent vector serves as a compressed representation of the input data in the latent space. 45 | 46 | 3. **Decoder Network**: 47 | - The decoder network takes the sampled latent vector as input and reconstructs the original data sample from it. 48 | - The decoder network is trained to produce a reconstruction that closely matches the input data sample. 49 | 50 | 4. **Training Objective**: 51 | - VAEs are trained to minimise a loss function that consists of two components: a reconstruction loss and a regularisation term. 52 | - The reconstruction loss measures the discrepancy between the input data sample and its reconstruction. 53 | - The regularisation term, often expressed as the Kullback-Leibler (KL) divergence between the approximate posterior distribution and a prior distribution (typically a standard Gaussian), encourages the learned latent space to be structured and interpretable. 54 | 55 | 5. **Generating New Data**: 56 | - After training, VAEs can generate new data samples by sampling from the learned latent space distribution. 57 | - By sampling from the latent space and feeding the samples through the decoder network, VAEs can generate new data samples that resemble the training data. 58 | 59 | Variational Autoencoders have applications in various domains, including image generation, text generation, and molecular design. They can learn meaningful latent representations of complex data distributions and generate diverse and realistic new samples from these distributions. 60 | -------------------------------------------------------------------------------- /deep_learning/cnn.md: -------------------------------------------------------------------------------- 1 | # Convolutional Neural Network 2 | 3 | From ChatGPT. 4 | 5 | A Convolutional Neural Network (CNN) is a type of deep learning model specifically designed for processing structured grid-like data, such as images. CNNs are particularly effective for tasks involving visual perception, such as image classification, object detection, and image segmentation. 6 | 7 | CNNs are inspired by the structure and functioning of the visual cortex in animals, where neurons in different layers respond to different features of the visual stimuli. Similarly, CNNs consist of multiple layers of interconnected neurons that learn hierarchical representations of features present in the input data. 8 | 9 | Here's an overview of the key components and operations in a typical CNN architecture: 10 | 11 | 1. **Convolutional Layers**: 12 | - Convolutional layers are the core building blocks of CNNs. They apply a set of learnable filters (also called kernels or convolutional kernels) to the input data to extract features. 13 | - Each filter performs a convolution operation, which involves sliding the filter over the input data and computing dot products between the filter weights and the local regions of the input. 14 | - The output of each filter is a feature map that represents the presence of a particular feature or pattern in the input data. 15 | - Multiple filters are typically used in each convolutional layer to capture different types of features. 16 | 17 | 2. **Activation Functions**: 18 | - After the convolution operation, an activation function (such as ReLU, sigmoid, or tanh) is applied element-wise to the feature maps to introduce non-linearity into the network. 19 | - The non-linear activation helps the network learn complex patterns and relationships in the input data. 20 | 21 | 3. **Pooling Layers**: 22 | - Pooling layers are used to downsample the feature maps and reduce their spatial dimensions while retaining the most important information. 23 | - Common pooling operations include max pooling and average pooling, which take the maximum or average value, respectively, within each pooling region. 24 | - Pooling helps make the representations more invariant to small spatial translations and reduces the computational complexity of the network. 25 | 26 | 4. **Fully Connected Layers**: 27 | - After several convolutional and pooling layers, the feature maps are flattened into a one-dimensional vector and passed through one or more fully connected (dense) layers. 28 | - Fully connected layers perform a linear transformation followed by a non-linear activation function to generate the final output of the network. 29 | - They enable the network to learn complex mappings from the extracted features to the output classes or predictions. 30 | 31 | 5. **Training**: 32 | - CNNs are trained using gradient-based optimisation algorithms, such as stochastic gradient descent (SGD) or Adam, to minimise a loss function that measures the difference between the predicted outputs and the ground truth labels. 33 | - The weights of the filters and fully connected layers are updated iteratively using backpropagation, which calculates gradients of the loss function with respect to the network parameters and adjusts the parameters accordingly. 34 | 35 | Overall, CNNs are powerful and versatile models for visual perception tasks, capable of automatically learning hierarchical representations of features from raw input data. They have demonstrated state-of-the-art performance in various computer vision tasks and are widely used in both academic research and industry applications. 36 | -------------------------------------------------------------------------------- /deep_learning/img/digit_heatmap-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/deep_learning/img/digit_heatmap-1.png -------------------------------------------------------------------------------- /deep_learning/img/model_plot-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/deep_learning/img/model_plot-1.png -------------------------------------------------------------------------------- /deep_learning/img/unnamed-chunk-10-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/deep_learning/img/unnamed-chunk-10-1.png -------------------------------------------------------------------------------- /deep_learning/img/unnamed-chunk-12-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/deep_learning/img/unnamed-chunk-12-1.png -------------------------------------------------------------------------------- /deep_learning/img/unnamed-chunk-14-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/deep_learning/img/unnamed-chunk-14-1.png -------------------------------------------------------------------------------- /deep_learning/img/unnamed-chunk-16-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/deep_learning/img/unnamed-chunk-16-1.png -------------------------------------------------------------------------------- /deep_learning/img/unnamed-chunk-19-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/deep_learning/img/unnamed-chunk-19-1.png -------------------------------------------------------------------------------- /deep_learning/img/unnamed-chunk-23-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/deep_learning/img/unnamed-chunk-23-1.png -------------------------------------------------------------------------------- /deep_learning/img/unnamed-chunk-3-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/deep_learning/img/unnamed-chunk-3-1.png -------------------------------------------------------------------------------- /deep_learning/img/unnamed-chunk-5-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/deep_learning/img/unnamed-chunk-5-1.png -------------------------------------------------------------------------------- /deep_learning/img/unnamed-chunk-7-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/deep_learning/img/unnamed-chunk-7-1.png -------------------------------------------------------------------------------- /deep_learning/img/unnamed-chunk-8-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/deep_learning/img/unnamed-chunk-8-1.png -------------------------------------------------------------------------------- /deep_learning/rnn.md: -------------------------------------------------------------------------------- 1 | # Recurrent Neural Network 2 | 3 | From ChatGPT. 4 | 5 | A Recurrent Neural Network (RNN) is a type of artificial neural network designed to model sequential data and handle input of arbitrary length. Unlike feedforward neural networks, which process input data in a single pass without any memory, RNNs have connections that allow them to exhibit temporal dynamic behavior. This memory property makes them well-suited for tasks involving sequences, such as time series prediction, natural language processing, speech recognition, and handwriting recognition. 6 | 7 | The key characteristic of RNNs is their ability to maintain a state or memory of previous inputs while processing the current input. This is achieved through recurrent connections, where the output of the network at a given time step is fed back as input to the network at the next time step. 8 | 9 | Here's a high-level overview of the structure and functioning of a simple RNN: 10 | 11 | 1. **Recurrent Connections**: 12 | - At each time step $t$, an RNN takes an input $x_t$ and produces an output $y_t$ and an internal state $h_t$. 13 | - The internal state $h_t$ is computed based on the current input $x_t$ and the previous internal state $h_{t-1}$. 14 | - Mathematically, the internal state $h_t$ is computed as $h_t = f(x_t, h_{t-1})$, where $f$ is a non-linear activation function applied to a combination of the current input and the previous state. 15 | 16 | 2. **Sequence Processing**: 17 | - RNNs can process input sequences of arbitrary length. They operate recursively, with each time step processing one element of the sequence. 18 | - The output of the RNN at each time step can be used for prediction, classification, or further processing. 19 | 20 | 3. **Training**: 21 | - RNNs are typically trained using backpropagation through time (BPTT), an extension of the backpropagation algorithm adapted for sequences. 22 | - BPTT involves unfolding the network over time and calculating gradients through time, allowing the network to learn from the entire sequence. 23 | 24 | While RNNs have demonstrated success in modeling sequential data, they suffer from some limitations, such as difficulty in capturing long-term dependencies and vanishing or exploding gradients during training, especially in deep networks. To address these issues, various RNN variants have been developed, such as Long Short-Term Memory (LSTM) networks and Gated Recurrent Units (GRUs), which incorporate mechanisms to better capture long-term dependencies and mitigate the vanishing gradient problem. These variants have become widely used in practice for tasks involving sequential data. 25 | -------------------------------------------------------------------------------- /deep_learning/transformer.md: -------------------------------------------------------------------------------- 1 | # Transformer 2 | 3 | From ChatGPT. 4 | 5 | The Transformer architecture is a type of deep learning model that was introduced in the paper "[Attention is All You Need](https://arxiv.org/abs/1706.03762)" by Vaswani et al. in 2017. It has become widely used in various natural language processing (NLP) tasks and has also been adapted for other sequence-to-sequence tasks like image captioning and speech recognition. 6 | 7 | The Transformer architecture relies heavily on the attention mechanism, which enables the model to focus on different parts of the input sequence when processing it. Here's an overview of the key components of the Transformer architecture: 8 | 9 | 1. **Self-Attention Mechanism**: 10 | - The self-attention mechanism allows the model to weigh the importance of different words in the input sequence when generating each output word. 11 | - At each position in the sequence, the model computes attention scores between that position and every other position in the sequence. 12 | - These attention scores are used to compute a weighted sum of the input embeddings, where the weights are determined by the attention scores. 13 | - This mechanism allows the model to capture dependencies between words in the input sequence and generate contextually relevant representations. 14 | 15 | 2. **Multi-Head Attention**: 16 | - To capture different types of information and attend to different parts of the input sequence simultaneously, the Transformer architecture uses multiple attention heads. 17 | - Each attention head independently computes attention scores and produces its own output representation. 18 | - The outputs of the attention heads are concatenated and linearly transformed to produce the final output of the multi-head attention layer. 19 | 20 | 3. **Positional Encoding**: 21 | - Since the Transformer architecture does not inherently capture the order of words in the input sequence, positional encodings are added to the input embeddings to provide information about the position of each word. 22 | - Positional encodings are learned embeddings that encode the position of each word in the input sequence using sinusoidal functions. 23 | 24 | 4. **Encoder and Decoder Stacks**: 25 | - The Transformer architecture consists of a stack of encoder layers and a stack of decoder layers. 26 | - The encoder stack processes the input sequence and generates a sequence of hidden representations. 27 | - The decoder stack takes the hidden representations generated by the encoder and produces the output sequence, one word at a time. 28 | 29 | 5. **Feed-Forward Neural Networks**: 30 | - Each encoder and decoder layer in the Transformer architecture contains feed-forward neural networks (FFNs). 31 | - The FFNs apply a linear transformation followed by a non-linear activation function (usually ReLU) to each position in the sequence independently. 32 | 33 | 6. **Layer Normalisation and Residual Connections**: 34 | - To stabilise the training process and facilitate the flow of gradients, layer normalisation and residual connections are applied after each sub-layer in the encoder and decoder stacks. 35 | 36 | Overall, the Transformer architecture has revolutionised the field of natural language processing by providing a more parallelisable and scalable approach to sequence modeling compared to recurrent neural networks (RNNs) and convolutional neural networks (CNNs). Its attention-based mechanism allows it to capture long-range dependencies in sequences and achieve state-of-the-art performance on various NLP tasks. 37 | -------------------------------------------------------------------------------- /evaluation/cutpointr.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Calculating cutpoints" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | knitr::opts_chunk$set(cache = FALSE) 10 | knitr::opts_chunk$set(echo = TRUE) 11 | knitr::opts_chunk$set(fig.path = "img/") 12 | ``` 13 | 14 | ## Introduction 15 | 16 | [cutpointr](https://cran.r-project.org/web/packages/cutpointr/vignettes/cutpointr.html) can be used to calculate optimal cut offs. 17 | 18 | Install packages if missing and load. 19 | 20 | ```{r load_package, message=FALSE, warning=FALSE} 21 | .libPaths('/packages') 22 | my_packages <- c('cutpointr', 'randomForest') 23 | 24 | for (my_package in my_packages){ 25 | if(!require(my_package, character.only = TRUE)){ 26 | install.packages(my_package, '/packages') 27 | } 28 | library(my_package, character.only = TRUE) 29 | } 30 | 31 | library(tidyverse) 32 | theme_set(theme_bw()) 33 | ``` 34 | 35 | ## Example 36 | 37 | The `suicide` data is a data frame with 532 rows and 4 variables: 38 | 39 | 1. `age` - (numeric) Age of participants in years 40 | 2. `gender` - (factor) Gender 41 | 3. `dsi` - (numeric) Sum-score (0 = low suicidality, 12 = high suicidality) 42 | 4. `suicide` - (factor) Past suicide attempt (no = no attempt, yes = at least one attempt) 43 | 44 | ```{r load_spam} 45 | data(suicide) 46 | head(suicide) 47 | ``` 48 | 49 | * `data` - A data.frame with the data needed for x, class and optionally subgroup. 50 | * `x` - The variable name to be used for classification, e.g. predictions. The raw vector of values if the data argument is unused. 51 | * `class` - The variable name indicating class membership. If the data argument is unused, the vector of raw numeric values. 52 | * `method` - (function) A function for determining cutpoints. Can be user supplied or use some of the built in methods. See details. 53 | * `metric` - (function) The function for computing a metric when using maximize_metric or minimize_metric as method and and for the out-of-bag values during bootstrapping. A way of internally validating the performance. User defined functions can be supplied, see details. 54 | 55 | ```{r cutpointr_suicide} 56 | cp <- cutpointr( 57 | data = suicide, 58 | x = dsi, 59 | class = suicide, 60 | method = maximize_metric, 61 | metric = sum_sens_spec 62 | ) 63 | 64 | summary(cp) 65 | ``` 66 | 67 | Optimal cut off. 68 | 69 | ```{r optimal_cutoff} 70 | cp$optimal_cutpoint 71 | ``` 72 | 73 | Plot. 74 | 75 | ```{r plot_cutpointr} 76 | plot(cp) 77 | ``` 78 | 79 | [Youden's J statistic](https://en.wikipedia.org/wiki/Youden%27s_J_statistic) = sensitivity + specificity - 1. 80 | 81 | ```{r cutpointr_youden} 82 | cp_youden <- cutpointr( 83 | data = suicide, 84 | x = dsi, 85 | class = suicide, 86 | method = maximize_metric, 87 | metric = youden 88 | ) 89 | 90 | summary(cp_youden) 91 | ``` 92 | 93 | ## Random Forests 94 | 95 | Use [spam data](https://archive.ics.uci.edu/ml/machine-learning-databases/spambase/spambase.names) to train a Random Forest model to test with `cutpointr`. Class 0 and 1 are ham (non-spam) and spam, respectively. 96 | 97 | ```{r random_forest_spam} 98 | spam_data <- read.csv(file = "../data/spambase.csv") 99 | spam_data$class <- factor(spam_data$class) 100 | set.seed(1984) 101 | rf <- randomForest(class ~ ., data = spam_data) 102 | ``` 103 | 104 | Optimal cut off using Youden's J statistic. 105 | 106 | ```{r spam_youden} 107 | cp_rf_youden <- cutpointr( 108 | x = rf$votes[, 2], 109 | class = as.integer(spam_data$class)-1, 110 | method = maximize_metric, 111 | metric = youden 112 | ) 113 | 114 | summary(cp_rf_youden) 115 | ``` 116 | 117 | Plot. 118 | 119 | ```{r plot_spam_youden} 120 | plot(cp_rf_youden) 121 | ``` 122 | 123 | Plot metric. 124 | 125 | ```{r plot_metric} 126 | plot_metric(cp_rf_youden) 127 | ``` 128 | 129 | ## Session info 130 | 131 | Time built. 132 | 133 | ```{r time, echo=FALSE} 134 | Sys.time() 135 | ``` 136 | 137 | Session info. 138 | 139 | ```{r session_info, echo=FALSE} 140 | sessionInfo() 141 | ``` 142 | -------------------------------------------------------------------------------- /evaluation/formulas.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Evaluation formulae" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | knitr::opts_chunk$set(cache = FALSE) 10 | knitr::opts_chunk$set(echo = TRUE) 11 | knitr::opts_chunk$set(fig.path = "img/") 12 | ``` 13 | 14 | ## Introduction 15 | 16 | Mathematical formulae elaborated. 17 | 18 | ## Root Mean Square Error 19 | 20 | The Root Mean Square Error (RMSE) is often used in regression problems to indicate how much error was made by the predictions, with a higher value to larger errors. 21 | 22 | $$ RMSE(X, h) = \sqrt{ \frac{1}{m} \sum^m_{i = 1} ( h(x^{(i)}) - y^{(i)} )^2 } $$ 23 | 24 | * $m$ is the number of instances/cases in the dataset 25 | * $x^{(i)}$ is a vector of all the feature values of the $i^{th}$ instance in the dataset and $y^{(i)}$ is the label 26 | * $X$ is a matrix containing all the feature values of all instances in the dataset. There is one row per instance, and the $i^{th}$ row is equal to the transpose of $x^{(i)}$, noted $(x^{(i)})^T$. 27 | * $h$ is the prediction function, also called a hypothesis. When the function is given an instance's feature vector $x^{(i)}$, it outputs a predicted value $\hat{y}^{(i)} = h(x^{(i)})$ for that instance. 28 | * $RMSE(X,h)$ is the cost function measured on the set using hypothesis $h$. 29 | 30 | ## Mean Absolute Error 31 | 32 | If there are many outliers, the Mean Absolute Error (MAE, also called the average absolute deviation) can be considered. 33 | 34 | $$ MAE(X, h) = \frac{1}{m} \sum^m_{i = 1} |h(x^{(i)}) - y^{(i)}| $$ 35 | 36 | ## Shannon Entropy 37 | 38 | [Entropy is a measure of randomness](https://medium.com/udacity/shannon-entropy-information-gain-and-picking-balls-from-buckets-5810d35d54b4) (or variance), where high entropy == more randomness/variance and low entropy == less randomness/variance. The general formula is: 39 | 40 | $$ Entropy = - \sum^n_{i=1} p_i\ log_2\ p_i$$ 41 | 42 | * $n$ is the number of classes/labels 43 | * $p_i$ is the probability of the $i^{th}$ class 44 | 45 | The `entropy` function will take a vector of classes/labels and return the entropy. 46 | 47 | ```{r entropy} 48 | eg1 <- c('A', 'A', 'A', 'A', 'A', 'A', 'A', 'A') 49 | eg2 <- c('A', 'A', 'A', 'A', 'B', 'B', 'C', 'D') 50 | eg3 <- c('A', 'A', 'B', 'B', 'C', 'C', 'D', 'D') 51 | 52 | entropy <- function(x){ 53 | probs <- table(x) / length(x) 54 | -sum(probs * log2(probs)) 55 | } 56 | 57 | entropy(eg1) 58 | entropy(eg2) 59 | entropy(eg3) 60 | ``` 61 | 62 | ## Information Gain 63 | 64 | Consider the [following dataset](https://victorzhou.com/blog/information-gain/). 65 | 66 | ```{r ig_df} 67 | set.seed(1984) 68 | y <- runif(n = 10, min = 0, max = 3) 69 | x1 <- runif(n = 5, min = 0, max = 2) 70 | x2 <- runif(n = 5, min = 2, max = 3) 71 | 72 | df <- data.frame( 73 | x = c(x1, x2), 74 | y = y, 75 | label = rep(c("blue", "green"), each = 5) 76 | ) 77 | 78 | plot(df$x, df$y, col = df$label, pch = 16) 79 | abline(v = 1.6, lty = 3) 80 | ``` 81 | 82 | Before the split, the entropy was: 83 | 84 | ```{r entropy_before} 85 | entropy(df$label) 86 | ``` 87 | 88 | After the split. 89 | 90 | ```{r entropy_split} 91 | my_split <- df$x < 1.6 92 | 93 | left_split <- df$label[my_split] 94 | entropy(left_split) 95 | 96 | right_split <- df$label[!my_split] 97 | entropy(right_split) 98 | ``` 99 | 100 | Weigh by number of elements and calculate entropy after split. 101 | 102 | ```{r entropy_after} 103 | entropy_after <- entropy(left_split) * (length(left_split) / length(df$label)) + entropy(right_split) * (length(right_split) / length(df$label)) 104 | entropy_after 105 | ``` 106 | 107 | Information gain == how much entropy we removed. 108 | 109 | ```{r information_gain} 110 | information_gain <- entropy(df$label) - entropy_after 111 | information_gain 112 | ``` 113 | 114 | Information gain is calculated for a split by subtracting the weighted entropies of each branch from the original entropy. 115 | 116 | ## Session info 117 | 118 | Time built. 119 | 120 | ```{r time, echo=FALSE} 121 | Sys.time() 122 | ``` 123 | 124 | Session info. 125 | 126 | ```{r session_info, echo=FALSE} 127 | sessionInfo() 128 | ``` 129 | -------------------------------------------------------------------------------- /evaluation/formulas.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | Mathematical formulae elaborated. 4 | 5 | ## Root Mean Square Error 6 | 7 | The Root Mean Square Error (RMSE) is often used in regression problems 8 | to indicate how much error was made by the predictions, with a higher 9 | value to larger errors. 10 | 11 | $$ RMSE(X, h) = \sqrt{ \frac{1}{m} \sum^m_{i = 1} ( h(x^{(i)}) - y^{(i)} )^2 } $$ 12 | 13 | - $m$ is the number of instances/cases in the dataset 14 | - $x^{(i)}$ is a vector of all the feature values of the $i^{th}$ 15 | instance in the dataset and $y^{(i)}$ is the label 16 | - $X$ is a matrix containing all the feature values of all instances 17 | in the dataset. There is one row per instance, and the $i^{th}$ row 18 | is equal to the transpose of $x^{(i)}$, noted $(x^{(i)})^T$. 19 | - $h$ is the prediction function, also called a hypothesis. When the 20 | function is given an instance's feature vector $x^{(i)}$, it outputs 21 | a predicted value $\hat{y}^{(i)} = h(x^{(i)})$ for that instance. 22 | - $RMSE(X,h)$ is the cost function measured on the set using 23 | hypothesis $h$. 24 | 25 | ## Mean Absolute Error 26 | 27 | If there are many outliers, the Mean Absolute Error (MAE, also called 28 | the average absolute deviation) can be considered. 29 | 30 | $$ MAE(X, h) = \frac{1}{m} \sum^m_{i = 1} |h(x^{(i)}) - y^{(i)}| $$ 31 | 32 | ## Shannon Entropy 33 | 34 | [Entropy is a measure of 35 | randomness](https://medium.com/udacity/shannon-entropy-information-gain-and-picking-balls-from-buckets-5810d35d54b4) 36 | (or variance), where high entropy == more randomness/variance and low 37 | entropy == less randomness/variance. The general formula is: 38 | 39 | $$ Entropy = - \sum^n_{i=1} p_i\ log_2\ p_i$$ 40 | 41 | - $n$ is the number of classes/labels 42 | - $p_i$ is the probability of the $i^{th}$ class 43 | 44 | The `entropy` function will take a vector of classes/labels and return 45 | the entropy. 46 | 47 | ``` r 48 | eg1 <- c('A', 'A', 'A', 'A', 'A', 'A', 'A', 'A') 49 | eg2 <- c('A', 'A', 'A', 'A', 'B', 'B', 'C', 'D') 50 | eg3 <- c('A', 'A', 'B', 'B', 'C', 'C', 'D', 'D') 51 | 52 | entropy <- function(x){ 53 | probs <- table(x) / length(x) 54 | -sum(probs * log2(probs)) 55 | } 56 | 57 | entropy(eg1) 58 | ``` 59 | 60 | ## [1] 0 61 | 62 | ``` r 63 | entropy(eg2) 64 | ``` 65 | 66 | ## [1] 1.75 67 | 68 | ``` r 69 | entropy(eg3) 70 | ``` 71 | 72 | ## [1] 2 73 | 74 | ## Information Gain 75 | 76 | Consider the [following 77 | dataset](https://victorzhou.com/blog/information-gain/). 78 | 79 | ``` r 80 | set.seed(1984) 81 | y <- runif(n = 10, min = 0, max = 3) 82 | x1 <- runif(n = 5, min = 0, max = 2) 83 | x2 <- runif(n = 5, min = 2, max = 3) 84 | 85 | df <- data.frame( 86 | x = c(x1, x2), 87 | y = y, 88 | label = rep(c("blue", "green"), each = 5) 89 | ) 90 | 91 | plot(df$x, df$y, col = df$label, pch = 16) 92 | abline(v = 1.6, lty = 3) 93 | ``` 94 | 95 | ![](img/ig_df-1.png) 96 | 97 | Before the split, the entropy was: 98 | 99 | ``` r 100 | entropy(df$label) 101 | ``` 102 | 103 | ## [1] 1 104 | 105 | After the split. 106 | 107 | ``` r 108 | my_split <- df$x < 1.6 109 | 110 | left_split <- df$label[my_split] 111 | entropy(left_split) 112 | ``` 113 | 114 | ## [1] 0 115 | 116 | ``` r 117 | right_split <- df$label[!my_split] 118 | entropy(right_split) 119 | ``` 120 | 121 | ## [1] 0.6500224 122 | 123 | Weigh by number of elements and calculate entropy after split. 124 | 125 | ``` r 126 | entropy_after <- entropy(left_split) * (length(left_split) / length(df$label)) + entropy(right_split) * (length(right_split) / length(df$label)) 127 | entropy_after 128 | ``` 129 | 130 | ## [1] 0.3900135 131 | 132 | Information gain == how much entropy we removed. 133 | 134 | ``` r 135 | information_gain <- entropy(df$label) - entropy_after 136 | information_gain 137 | ``` 138 | 139 | ## [1] 0.6099865 140 | 141 | Information gain is calculated for a split by subtracting the weighted 142 | entropies of each branch from the original entropy. 143 | 144 | ## Session info 145 | 146 | Time built. 147 | 148 | ## [1] "2022-11-22 07:45:47 UTC" 149 | 150 | Session info. 151 | 152 | ## R version 4.2.0 (2022-04-22) 153 | ## Platform: x86_64-pc-linux-gnu (64-bit) 154 | ## Running under: Ubuntu 20.04.4 LTS 155 | ## 156 | ## Matrix products: default 157 | ## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 158 | ## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/liblapack.so.3 159 | ## 160 | ## locale: 161 | ## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C 162 | ## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8 163 | ## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 164 | ## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C 165 | ## [9] LC_ADDRESS=C LC_TELEPHONE=C 166 | ## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C 167 | ## 168 | ## attached base packages: 169 | ## [1] stats graphics grDevices utils datasets methods base 170 | ## 171 | ## loaded via a namespace (and not attached): 172 | ## [1] compiler_4.2.0 magrittr_2.0.3 fastmap_1.1.0 cli_3.4.1 173 | ## [5] tools_4.2.0 htmltools_0.5.3 rstudioapi_0.14 yaml_2.3.6 174 | ## [9] stringi_1.7.8 rmarkdown_2.17 highr_0.9 knitr_1.40 175 | ## [13] stringr_1.4.1 xfun_0.34 digest_0.6.30 rlang_1.0.6 176 | ## [17] evaluate_0.17 177 | -------------------------------------------------------------------------------- /evaluation/img/confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/confusion_matrix.png -------------------------------------------------------------------------------- /evaluation/img/cross_validation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/cross_validation.png -------------------------------------------------------------------------------- /evaluation/img/dendrogram-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/dendrogram-1.png -------------------------------------------------------------------------------- /evaluation/img/ig_df-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/ig_df-1.png -------------------------------------------------------------------------------- /evaluation/img/mouse_dendrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/mouse_dendrogram.png -------------------------------------------------------------------------------- /evaluation/img/plot_cutpointr-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/plot_cutpointr-1.png -------------------------------------------------------------------------------- /evaluation/img/plot_metric-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/plot_metric-1.png -------------------------------------------------------------------------------- /evaluation/img/plot_spam_youden-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/plot_spam_youden-1.png -------------------------------------------------------------------------------- /evaluation/img/precision_recall-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/precision_recall-1.png -------------------------------------------------------------------------------- /evaluation/img/random_forest_roc-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/random_forest_roc-1.png -------------------------------------------------------------------------------- /evaluation/img/random_predictor-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/random_predictor-1.png -------------------------------------------------------------------------------- /evaluation/img/rmse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/rmse.png -------------------------------------------------------------------------------- /evaluation/img/roc_verification-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/roc_verification-1.png -------------------------------------------------------------------------------- /evaluation/img/roc_verification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/roc_verification.png -------------------------------------------------------------------------------- /evaluation/img/roc_verification_ci-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/roc_verification_ci-1.png -------------------------------------------------------------------------------- /evaluation/img/roc_verification_ci.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/roc_verification_ci.png -------------------------------------------------------------------------------- /evaluation/img/roc_versicolor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/roc_versicolor.png -------------------------------------------------------------------------------- /evaluation/img/try_k-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/try_k-1.png -------------------------------------------------------------------------------- /evaluation/img/unnamed-chunk-1-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/unnamed-chunk-1-1.png -------------------------------------------------------------------------------- /evaluation/img/unnamed-chunk-2-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/unnamed-chunk-2-1.png -------------------------------------------------------------------------------- /evaluation/img/unnamed-chunk-3-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/unnamed-chunk-3-1.png -------------------------------------------------------------------------------- /evaluation/img/unnamed-chunk-4-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/evaluation/img/unnamed-chunk-4-1.png -------------------------------------------------------------------------------- /gmm/gmm.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Gaussian Mixture Models" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | knitr::opts_chunk$set(cache = FALSE) 10 | knitr::opts_chunk$set(echo = TRUE) 11 | knitr::opts_chunk$set(fig.path = "img/") 12 | options(scipen = 999) 13 | library(tidyverse) 14 | theme_set(theme_bw()) 15 | ``` 16 | 17 | ## Introduction 18 | 19 | Following these blog posts: 20 | 21 | * 22 | * 23 | * 24 | 25 | Install and load [mixtools](https://cran.r-project.org/web/packages/mixtools/index.html). 26 | 27 | ```{r load_package, message=FALSE, warning=FALSE} 28 | .libPaths('/packages') 29 | my_packages <- c('mixtools') 30 | 31 | for (my_package in my_packages){ 32 | if(!require(my_package, character.only = TRUE)){ 33 | install.packages(my_package, '/packages') 34 | } 35 | library(my_package, character.only = TRUE) 36 | } 37 | ``` 38 | 39 | Waiting time between eruptions and the duration of the eruption for the Old Faithful geyser in Yellowstone National Park, Wyoming, USA. 40 | 41 | A data frame with 272 observations on 2 variables. 42 | 43 | * `eruptions` - Eruption time in minutes 44 | * `waiting` - Waiting time to next eruption in minutes 45 | 46 | ```{r faithful_scatter} 47 | ggplot(faithful, aes(waiting, eruptions)) + 48 | geom_point() 49 | ``` 50 | 51 | The `normalmixEM` function builds a 2-component GMM (k = 2 indicates to use 2 components). 52 | 53 | ```{r mixmdl} 54 | set.seed(1984) 55 | mixmdl <- normalmixEM(faithful$waiting, k = 2) 56 | summary(mixmdl) 57 | ``` 58 | 59 | The red and blue lines indicate the two fitted Gaussian distributions. 60 | 61 | ```{r plot_density} 62 | #' Plot a Mixture Component 63 | #' 64 | #' @param x Input data 65 | #' @param mu Mean of component 66 | #' @param sigma Standard deviation of component 67 | #' @param lam Mixture weight of component 68 | plot_mix_comps <- function(x, mu, sigma, lam) { 69 | lam * dnorm(x, mu, sigma) 70 | } 71 | 72 | data.frame(x = mixmdl$x) %>% 73 | ggplot() + 74 | geom_histogram(aes(x, ..density..), binwidth = 1, colour = "black", 75 | fill = "white") + 76 | stat_function(geom = "line", fun = plot_mix_comps, 77 | args = list(mixmdl$mu[1], mixmdl$sigma[1], lam = mixmdl$lambda[1]), 78 | colour = "red", lwd = 1.5) + 79 | stat_function(geom = "line", fun = plot_mix_comps, 80 | args = list(mixmdl$mu[2], mixmdl$sigma[2], lam = mixmdl$lambda[2]), 81 | colour = "blue", lwd = 1.5) + 82 | ylab("Density") 83 | ``` 84 | 85 | Mean and standard deviation. 86 | 87 | ```{r parameters} 88 | mixmdl$mu 89 | mixmdl$sigma 90 | ``` 91 | 92 | `lambda` indicates the ratio of the dataset but formally it is referred to as the mixing weights (or mixing proportions or mixing co-efficients). It can be interpreted as the red component representing 36% and the blue component representing 64% of the input data. 93 | 94 | ```{r lambda} 95 | mixmdl$lambda 96 | sum(mixmdl$lambda) 97 | ``` 98 | 99 | Each input data point is assigned a posterior probability of belonging to one of these components. 100 | 101 | ```{r pp} 102 | head(cbind(mixmdl$x, mixmdl$posterior)) 103 | ``` 104 | 105 | ## Predict 106 | 107 | When two distributions are close to each other. 108 | 109 | ```{r close} 110 | set.seed(1984) 111 | k1 <- rnorm(n = 123, mean = 0.4, sd = 0.2) 112 | k2 <- rnorm(n = 156, mean = 0.6, sd = 0.2) 113 | 114 | plot(density(c(k1, k2)), main = '') 115 | ``` 116 | 117 | EM. 118 | 119 | ```{r normalmixem_close} 120 | set.seed(1984) 121 | mm <- normalmixEM(c(k1, k2), k = 2, maxit = 10000) 122 | summary(mm) 123 | ``` 124 | 125 | Two more distinguishable distributions. 126 | 127 | ```{r distinct} 128 | set.seed(1984) 129 | k1 <- rnorm(n = 123, mean = 0.3, sd = 0.2) 130 | k2 <- rnorm(n = 156, mean = 0.7, sd = 0.2) 131 | 132 | df <- tibble( 133 | label = rep(c('k1', 'k2'), c(length(k1), length(k2))), 134 | value = c(k1, k2) 135 | ) 136 | 137 | ggplot(df, aes(value, fill = label)) + 138 | geom_histogram(bins = 25, position = "dodge") 139 | ``` 140 | 141 | EM. 142 | 143 | ```{r normalmixem_distinct} 144 | set.seed(1984) 145 | mm2 <- normalmixEM(df$value, k = 2, maxit = 10000) 146 | summary(mm2) 147 | ``` 148 | 149 | Component 1 is the distribution with the higher mean (`k2`). 150 | 151 | ```{r confusion_matrix} 152 | table( 153 | real = df$label, 154 | predicted = ifelse(mm2$posterior[, 1] < 0.5, yes = 'k1', no = 'k2') 155 | ) 156 | ``` 157 | 158 | Two clearly distinguishable distributions. 159 | 160 | ```{r clearly_distinct} 161 | set.seed(1984) 162 | k1 <- rnorm(n = 123, mean = 0.2, sd = 0.2) 163 | k2 <- rnorm(n = 156, mean = 0.8, sd = 0.2) 164 | 165 | df2 <- tibble( 166 | label = rep(c('k1', 'k2'), c(length(k1), length(k2))), 167 | value = c(k1, k2) 168 | ) 169 | 170 | ggplot(df2, aes(value, fill = label)) + 171 | geom_histogram(bins = 25, position = "dodge") 172 | ``` 173 | 174 | EM and predict. 175 | 176 | ```{r normalmixem_clearly_distinct} 177 | set.seed(1984) 178 | mm3 <- normalmixEM(df2$value, k = 2, maxit = 10000) 179 | table( 180 | real = df2$label, 181 | predicted = ifelse(mm3$posterior[, 2] < 0.5, yes = 'k1', no = 'k2') 182 | ) 183 | ``` 184 | 185 | Do we do better by simply setting a threshold at 0.5? In this case, yes. 186 | 187 | ```{r hard_filter} 188 | table( 189 | real = df2$label, 190 | predicted = ifelse(df2$value < 0.5, yes = 'k1', no = 'k2') 191 | ) 192 | ``` 193 | 194 | ## Session info 195 | 196 | Time built. 197 | 198 | ```{r time, echo=FALSE} 199 | Sys.time() 200 | ``` 201 | 202 | Session info. 203 | 204 | ```{r session_info, echo=FALSE} 205 | sessionInfo() 206 | ``` 207 | -------------------------------------------------------------------------------- /gmm/img/clearly_distinct-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/gmm/img/clearly_distinct-1.png -------------------------------------------------------------------------------- /gmm/img/close-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/gmm/img/close-1.png -------------------------------------------------------------------------------- /gmm/img/distinct-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/gmm/img/distinct-1.png -------------------------------------------------------------------------------- /gmm/img/faithful_scatter-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/gmm/img/faithful_scatter-1.png -------------------------------------------------------------------------------- /gmm/img/plot_density-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/gmm/img/plot_density-1.png -------------------------------------------------------------------------------- /hclust/README.md: -------------------------------------------------------------------------------- 1 | Introduction 2 | ------------ 3 | 4 | Here's a [good 5 | post](http://datascienceplus.com/hierarchical-clustering-in-r/) on 6 | hierarchical clustering. Using data from the `datamicroarray` package. 7 | 8 | ``` {.r} 9 | .libPaths('/packages') 10 | my_packages <- c('dendextend', 'remotes', 'datamicroarray') 11 | 12 | for (my_package in my_packages){ 13 | if(!require(my_package, character.only = TRUE)){ 14 | if (my_package == 'datamicroarray'){ 15 | install_github('ramhiser/datamicroarray') 16 | } else { 17 | install.packages(my_package, '/packages') 18 | } 19 | library(my_package, character.only = TRUE) 20 | } 21 | } 22 | ``` 23 | 24 | Using microarray data 25 | --------------------- 26 | 27 | I will use the `yeoh` data set. 28 | 29 | ``` {.r} 30 | data('yeoh', package = "datamicroarray") 31 | 32 | dim(yeoh$x) 33 | ``` 34 | 35 | ## [1] 248 12625 36 | 37 | ``` {.r} 38 | table(yeoh$y) 39 | ``` 40 | 41 | ## 42 | ## BCR E2A Hyperdip MLL T TEL 43 | ## 15 27 64 20 43 79 44 | 45 | Calculate distance between all samples. 46 | 47 | ``` {.r} 48 | choose(248, 2) 49 | ``` 50 | 51 | ## [1] 30628 52 | 53 | ``` {.r} 54 | my_dist <- dist(yeoh$x) 55 | 56 | summary(my_dist) 57 | ``` 58 | 59 | ## Min. 1st Qu. Median Mean 3rd Qu. Max. 60 | ## 10.65 19.32 21.93 22.40 24.87 48.88 61 | 62 | Perform hierarchical clustering using complete (maximum) linkage, which 63 | is the default. 64 | 65 | ``` {.r} 66 | my_hclust <- hclust(my_dist) 67 | ``` 68 | 69 | Form six clusters based on the clustering. 70 | 71 | ``` {.r} 72 | my_clus <- cutree(my_hclust, k = 6) 73 | 74 | table(my_clus, yeoh$y) 75 | ``` 76 | 77 | ## 78 | ## my_clus BCR E2A Hyperdip MLL T TEL 79 | ## 1 13 23 32 3 24 35 80 | ## 2 2 0 4 1 0 18 81 | ## 3 0 4 9 5 5 10 82 | ## 4 0 0 18 10 0 14 83 | ## 5 0 0 1 1 3 2 84 | ## 6 0 0 0 0 11 0 85 | 86 | ``` {.r} 87 | cluster_one <- yeoh$y[my_clus == 1] 88 | ``` 89 | 90 | Form `n` clusters based on arbitrary distance. 91 | 92 | ``` {.r} 93 | my_clus_two <- cutree(my_hclust, h = 25) 94 | # much more homogeneous 95 | table(my_clus_two, yeoh$y) 96 | ``` 97 | 98 | ## 99 | ## my_clus_two BCR E2A Hyperdip MLL T TEL 100 | ## 1 12 0 10 0 0 2 101 | ## 2 1 0 22 0 0 33 102 | ## 3 2 0 3 0 0 17 103 | ## 4 0 23 0 3 0 0 104 | ## 5 0 4 8 5 1 10 105 | ## 6 0 0 16 0 0 0 106 | ## 7 0 0 1 0 0 0 107 | ## 8 0 0 1 0 0 0 108 | ## 9 0 0 1 10 0 0 109 | ## 10 0 0 1 1 0 1 110 | ## 11 0 0 1 0 0 0 111 | ## 12 0 0 0 1 3 2 112 | ## 13 0 0 0 0 6 0 113 | ## 14 0 0 0 0 18 0 114 | ## 15 0 0 0 0 8 0 115 | ## 16 0 0 0 0 3 0 116 | ## 17 0 0 0 0 4 0 117 | ## 18 0 0 0 0 0 14 118 | 119 | Plot. 120 | 121 | ``` {.r} 122 | my_hclust_mod <- my_hclust 123 | my_hclust_mod$labels <- as.vector(yeoh$y) 124 | plot(color_branches(my_hclust_mod, h = 25, groupLabels = TRUE)) 125 | ``` 126 | 127 | ![](img/unnamed-chunk-6-1.png) 128 | 129 | Session info 130 | ------------ 131 | 132 | Time built. 133 | 134 | ## [1] "2022-10-20 06:51:24 UTC" 135 | 136 | Session info. 137 | 138 | ## R version 4.2.1 (2022-06-23) 139 | ## Platform: x86_64-pc-linux-gnu (64-bit) 140 | ## Running under: Ubuntu 20.04.4 LTS 141 | ## 142 | ## Matrix products: default 143 | ## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 144 | ## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/liblapack.so.3 145 | ## 146 | ## locale: 147 | ## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C 148 | ## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8 149 | ## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 150 | ## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C 151 | ## [9] LC_ADDRESS=C LC_TELEPHONE=C 152 | ## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C 153 | ## 154 | ## attached base packages: 155 | ## [1] stats graphics grDevices utils datasets methods base 156 | ## 157 | ## other attached packages: 158 | ## [1] datamicroarray_0.2.3 remotes_2.4.2 dendextend_1.16.0 159 | ## [4] forcats_0.5.1 stringr_1.4.0 dplyr_1.0.9 160 | ## [7] purrr_0.3.4 readr_2.1.2 tidyr_1.2.0 161 | ## [10] tibble_3.1.7 ggplot2_3.3.6 tidyverse_1.3.1 162 | ## 163 | ## loaded via a namespace (and not attached): 164 | ## [1] tidyselect_1.1.2 xfun_0.31 haven_2.5.0 colorspace_2.0-3 165 | ## [5] vctrs_0.4.1 generics_0.1.3 viridisLite_0.4.0 htmltools_0.5.2 166 | ## [9] yaml_2.3.5 utf8_1.2.2 rlang_1.0.3 pillar_1.7.0 167 | ## [13] glue_1.6.2 withr_2.5.0 DBI_1.1.3 dbplyr_2.2.1 168 | ## [17] modelr_0.1.8 readxl_1.4.0 lifecycle_1.0.1 munsell_0.5.0 169 | ## [21] gtable_0.3.0 cellranger_1.1.0 rvest_1.0.2 evaluate_0.15 170 | ## [25] knitr_1.39 tzdb_0.3.0 fastmap_1.1.0 fansi_1.0.3 171 | ## [29] highr_0.9 broom_1.0.0 scales_1.2.0 backports_1.4.1 172 | ## [33] jsonlite_1.8.0 fs_1.5.2 gridExtra_2.3 hms_1.1.1 173 | ## [37] digest_0.6.29 stringi_1.7.6 grid_4.2.1 cli_3.3.0 174 | ## [41] tools_4.2.1 magrittr_2.0.3 crayon_1.5.1 pkgconfig_2.0.3 175 | ## [45] ellipsis_0.3.2 xml2_1.3.3 reprex_2.0.1 lubridate_1.8.0 176 | ## [49] viridis_0.6.2 rstudioapi_0.13 assertthat_0.2.1 rmarkdown_2.14 177 | ## [53] httr_1.4.3 R6_2.5.1 compiler_4.2.1 178 | -------------------------------------------------------------------------------- /hclust/img/unnamed-chunk-6-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/hclust/img/unnamed-chunk-6-1.png -------------------------------------------------------------------------------- /hclust/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Hierarchical clustering" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | ## Introduction 17 | 18 | Here's a [good post](http://datascienceplus.com/hierarchical-clustering-in-r/) on hierarchical clustering. Using data from the `datamicroarray` package. 19 | 20 | ```{r load_package, message=FALSE, warning=FALSE} 21 | .libPaths('/packages') 22 | my_packages <- c('dendextend', 'remotes', 'datamicroarray') 23 | 24 | for (my_package in my_packages){ 25 | if(!require(my_package, character.only = TRUE)){ 26 | if (my_package == 'datamicroarray'){ 27 | install_github('ramhiser/datamicroarray') 28 | } else { 29 | install.packages(my_package, '/packages') 30 | } 31 | library(my_package, character.only = TRUE) 32 | } 33 | } 34 | ``` 35 | 36 | ## Using microarray data 37 | 38 | I will use the `yeoh` data set. 39 | 40 | ```{r} 41 | data('yeoh', package = "datamicroarray") 42 | 43 | dim(yeoh$x) 44 | 45 | table(yeoh$y) 46 | ``` 47 | 48 | Calculate distance between all samples. 49 | 50 | ```{r} 51 | choose(248, 2) 52 | 53 | my_dist <- dist(yeoh$x) 54 | 55 | summary(my_dist) 56 | ``` 57 | 58 | Perform hierarchical clustering using complete (maximum) linkage, which is the default. 59 | 60 | ```{r} 61 | my_hclust <- hclust(my_dist) 62 | ``` 63 | 64 | Form six clusters based on the clustering. 65 | 66 | ```{r} 67 | my_clus <- cutree(my_hclust, k = 6) 68 | 69 | table(my_clus, yeoh$y) 70 | 71 | cluster_one <- yeoh$y[my_clus == 1] 72 | ``` 73 | 74 | Form `n` clusters based on arbitrary distance. 75 | 76 | ```{r} 77 | my_clus_two <- cutree(my_hclust, h = 25) 78 | # much more homogeneous 79 | table(my_clus_two, yeoh$y) 80 | ``` 81 | 82 | Plot. 83 | 84 | ```{r} 85 | my_hclust_mod <- my_hclust 86 | my_hclust_mod$labels <- as.vector(yeoh$y) 87 | plot(color_branches(my_hclust_mod, h = 25, groupLabels = TRUE)) 88 | ``` 89 | 90 | ## Session info 91 | 92 | Time built. 93 | 94 | ```{r time, echo=FALSE} 95 | Sys.time() 96 | ``` 97 | 98 | Session info. 99 | 100 | ```{r session_info, echo=FALSE} 101 | sessionInfo() 102 | ``` 103 | 104 | -------------------------------------------------------------------------------- /kmeans/img/elbow_plot-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/kmeans/img/elbow_plot-1.png -------------------------------------------------------------------------------- /kmeans/img/fviz_cluster-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/kmeans/img/fviz_cluster-1.png -------------------------------------------------------------------------------- /kmeans/img/kmeans_k_2-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/kmeans/img/kmeans_k_2-1.png -------------------------------------------------------------------------------- /kmeans/img/pca-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/kmeans/img/pca-1.png -------------------------------------------------------------------------------- /kmeans/img/pca_figure-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/kmeans/img/pca_figure-1.png -------------------------------------------------------------------------------- /kmeans/img/plot_hist-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/kmeans/img/plot_hist-1.png -------------------------------------------------------------------------------- /kmeans/img/silhouette_analysis-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/kmeans/img/silhouette_analysis-1.png -------------------------------------------------------------------------------- /kmeans/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "K-means clustering" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | ## Introduction 17 | 18 | Install packages if missing and load. 19 | 20 | ```{r load_package, message=FALSE, warning=FALSE} 21 | .libPaths('/packages') 22 | my_packages <- c("cluster", "ggrepel", "factoextra") 23 | 24 | for (my_package in my_packages){ 25 | if(!require(my_package, character.only = TRUE)){ 26 | install.packages(my_package, '/packages') 27 | library(my_package, character.only = TRUE) 28 | } 29 | } 30 | ``` 31 | 32 | ## Data set 33 | 34 | We will use NBA [data set](https://www.kaggle.com/drgilermo/nba-players-stats) downloaded from Kaggle and create a subsetted data set using stats from the year 2017 and consisting of: 35 | 36 | * FG - Field Goals 37 | * FGA - Field Goal Attempts 38 | * FT - Free Throws 39 | * FTA - Free Throw Attempts 40 | * 3P3 - Point Field Goals 41 | * 3PA3 - Point Field Goal Attempts 42 | * PTS - Points 43 | * TRB - Total Rebounds 44 | * AST - Assists 45 | * STL - Steals 46 | * BLK - Blocks 47 | * TOV - Turnovers 48 | 49 | Some players were traded during the season and thus have stats for different teams. The "TOT" team is a summation of stats from different teams but since we will perform our own summation, we will exclude that row of data. 50 | 51 | We will scale the statistics by the number of games each player has played, since not everyone has played the same number of games. We will also standardise the data to ensure so that all the stats are on the same scale. 52 | 53 | ```{r prepare_data} 54 | season_stat <- read.csv("../data/Seasons_Stats.csv.gz") 55 | 56 | # some players played in different teams 57 | season_stat %>% 58 | filter(Year == 2017, Tm != "TOT", G > 50) %>% 59 | select(Player, Pos, G, FG, FGA, FT, FTA, X3P, X3PA, PTS, TRB, AST, STL, BLK, TOV) -> data_subset 60 | 61 | data_subset %>% 62 | group_by(Player, Pos) %>% 63 | summarise_all(sum) -> data_subset 64 | 65 | # scale stats by number of games played and normalise 66 | data_subset[, -(1:3)] <- data_subset[, -(1:3)] / data_subset$G 67 | data_subset[, -(1:3)] <- scale(data_subset[, -(1:3)]) 68 | 69 | str(data_subset) 70 | ``` 71 | 72 | We'll plot histograms of all standardised statistics to visualise the distributions. 73 | 74 | ```{r plot_hist} 75 | data_subset[, -(1:3)] %>% 76 | gather() %>% 77 | ggplot(., aes(value)) + 78 | geom_histogram(bins = 20) + 79 | facet_wrap(~key) 80 | ``` 81 | 82 | ## K-means 83 | 84 | The idea behind k-means clustering is to define clusters such that the total within-cluster variation is minimised. The within-cluster variation is calculated as the sum of squared Euclidean distances between observations and the centroid of a cluster. The total within-cluster variation is the sum of all within-cluster calculations for _k_ clusters. 85 | 86 | We will use `kmeans` to perform k-means clustering with a _k_ of 5 since there are 5 positions in basketball. 87 | 88 | ```{r kmeans} 89 | my_kmeans <- kmeans(x = data_subset[, -(1:3)], centers = 5) 90 | my_kmeans 91 | ``` 92 | 93 | The cluster assignments are in `cluster` and since we set _k_ to 5 each player is assigned to 1 of 5 possible clusterings. 94 | 95 | ```{r kmeans_cluster_table} 96 | table(my_kmeans$cluster) 97 | ``` 98 | 99 | The total within-cluster variation is stored in `tot.withinss`. 100 | 101 | ```{r kmeans_tot_withinss} 102 | my_kmeans$tot.withinss 103 | ``` 104 | 105 | We can use `fviz_cluster` to visualise the clusters in a scatter plot of the first two principal components. 106 | 107 | ```{r fviz_cluster} 108 | fviz_cluster(my_kmeans, data = data_subset[, -(1:3)]) 109 | ``` 110 | 111 | In our example above, we chose a _k_ of 5 simply because we assume that each player position produces distinctive statistics. For example, a centre will have more rebounds and blocks, and a guard will have more assists and steals. However, this may not be the ideal number of clusters. 112 | 113 | One way for determining an optimal number of clusters is to plot the total within-cluster variation for a range of _k_ values and find the "elbow" point in the plot. This point is where the total within-cluster variation has a steep drop and forms a "visual elbow" in the plot. 114 | 115 | ```{r elbow_plot} 116 | # Use map_dbl to run many models with varying value of k (centers) 117 | tot_withinss <- map_dbl(2:30, function(k){ 118 | model <- kmeans(x = data_subset[, -(1:3)], centers = k) 119 | model$tot.withinss 120 | }) 121 | 122 | # Generate a data frame containing both k and tot_withinss 123 | elbow_df <- data.frame( 124 | k = 2:30, 125 | tot_withinss = tot_withinss 126 | ) 127 | 128 | ggplot(elbow_df, aes(x = k, y = tot_withinss)) + 129 | geom_line() + 130 | geom_point(aes(x = k, y = tot_withinss)) + 131 | scale_x_continuous(breaks = 2:30) 132 | ``` 133 | 134 | Another method for determining a suitable _k_ is the silhouette approach, which measures the within cluster distance of an observation to all other observations within its cluster and to all other observations in the closest neighbour cluster. A value close to 1 indicates that an observation is well matched to its cluster; a value of 0 indicates that the observation is on the border between two clusters; and a value of -1 indicates that the observation has a better fit in the neighbouring cluster. 135 | 136 | ```{r silhouette_analysis} 137 | # Use map_dbl to run many models with varying value of k 138 | sil_width <- map_dbl(2:30, function(k){ 139 | model <- pam(x = data_subset[, -(1:3)], k = k) 140 | model$silinfo$avg.width 141 | }) 142 | 143 | # Generate a data frame containing both k and sil_width 144 | sil_df <- data.frame( 145 | k = 2:30, 146 | sil_width = sil_width 147 | ) 148 | 149 | # Plot the relationship between k and sil_width 150 | ggplot(sil_df, aes(x = k, y = sil_width)) + 151 | geom_line() + 152 | geom_point(aes(x = k, y = sil_width)) + 153 | scale_x_continuous(breaks = 2:30) 154 | ``` 155 | 156 | The silhouette approach suggests that a _k_ of 2 is optimal. 157 | 158 | ```{r kmeans_k_2} 159 | my_kmeans_k_2 <- kmeans(data_subset[, -(1:3)], centers = 2) 160 | fviz_cluster(my_kmeans_k_2, data = data_subset[, -(1:3)]) 161 | ``` 162 | 163 | ## Extra 164 | 165 | Below I perform a Principal Component Analysis and plot the PCs. 166 | 167 | ```{r pca} 168 | my_pca <- prcomp(data_subset[, -(1:3)], center = FALSE, scale = FALSE) 169 | 170 | summary(my_pca) 171 | 172 | my_pca_df <- as.data.frame(my_pca$x) 173 | my_pca_df$pos <- data_subset$Pos 174 | my_pca_df$name <- data_subset$Player 175 | 176 | ggplot(my_pca_df, aes(x = PC1, y = PC2, colour = pos, text = name)) + 177 | geom_point() 178 | ``` 179 | 180 | If we label the points, we can clearly see that the players with more variable statistics consist of many NBA All-Stars. 181 | 182 | ```{r pca_figure, fig.width=6, fig.height=5} 183 | ggplot(my_pca_df, aes(x = PC1, y = PC2, colour = pos, label = name)) + 184 | geom_text_repel( 185 | data = my_pca_df %>% filter(PC1 > 5 | PC2 < -3.7) 186 | ) + 187 | geom_point() + 188 | theme_classic() 189 | ``` 190 | 191 | ## Further reading 192 | 193 | * https://uc-r.github.io/kmeans_clustering 194 | * https://www.datacamp.com/community/tutorials/k-means-clustering-r 195 | 196 | ## Session info 197 | 198 | Time built. 199 | 200 | ```{r time, echo=FALSE} 201 | Sys.time() 202 | ``` 203 | 204 | Session info. 205 | 206 | ```{r session_info, echo=FALSE} 207 | sessionInfo() 208 | ``` 209 | 210 | -------------------------------------------------------------------------------- /knn/README.md: -------------------------------------------------------------------------------- 1 | Introduction 2 | ------------ 3 | 4 | In pattern recognition, the [k-Nearest Neighbours 5 | algorithm](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm) 6 | (k-NN) is a non-parametric method used for classification and 7 | regression. In both cases, the input consists of the k closest training 8 | examples in the feature space. 9 | 10 | Install packages if missing and load. 11 | 12 | ``` {.r} 13 | .libPaths('/packages') 14 | my_packages <- 'FNN' 15 | 16 | for (my_package in my_packages){ 17 | if(!require(my_package, character.only = TRUE)){ 18 | install.packages(my_package, '/packages') 19 | library(my_package, character.only = TRUE) 20 | } 21 | } 22 | ``` 23 | 24 | Regression 25 | ---------- 26 | 27 | We'll use the `women` dataset to demonstrate how k-NN performs 28 | regression. The dataset contains height and weight measurements for 15 29 | American women aged between 30--39. 30 | 31 | ``` {.r} 32 | data(women) 33 | str(women) 34 | ``` 35 | 36 | ## 'data.frame': 15 obs. of 2 variables: 37 | ## $ height: num 58 59 60 61 62 63 64 65 66 67 ... 38 | ## $ weight: num 115 117 120 123 126 129 132 135 139 142 ... 39 | 40 | ``` {.r} 41 | plot(women, xlab = "Height (in)", ylab = "Weight (lb)", main = "women data: American women aged 30-39", pch = 16) 42 | ``` 43 | 44 | ![](img/unnamed-chunk-1-1.png) 45 | 46 | In the example below, we want to *predict the weight* of a female who is 47 | 60 inches tall based on data in the `women` dataset. 48 | 49 | ``` {.r} 50 | knn <- function(x, x_train, y_train, k){ 51 | d <- abs(x - x_train) 52 | s <- order(d) 53 | return(mean(y_train[s[1:k]])) 54 | } 55 | 56 | # using four neighbours 57 | knn(60, women$height, women$weight, 4) 58 | ``` 59 | 60 | ## [1] 118.75 61 | 62 | ``` {.r} 63 | # using five neighbours 64 | knn(60, women$height, women$weight, 5) 65 | ``` 66 | 67 | ## [1] 120.2 68 | 69 | - The `knn` algorithm first calculates the absolute distance of an 70 | input to a known set of data points for the same variable (height). 71 | - These distances are then sorted, with the closest data points ranked 72 | first. 73 | - The k-nearest distances of heights are used to obtain the 74 | corresponding weights 75 | - Finally, the k weights are averaged (mean) and returned 76 | 77 | Use `sapply` to predict several values. 78 | 79 | ``` {.r} 80 | sapply(c(60,70), knn, x_train = women$height, y_train = women$weight, k = 4) 81 | ``` 82 | 83 | ## [1] 118.75 152.25 84 | 85 | You can also use `knn.reg` in the `FNN` package. 86 | 87 | ``` {.r} 88 | knn.reg(women$height, 60, women$weight, 4) 89 | ``` 90 | 91 | ## Prediction: 92 | ## [1] 118.75 93 | 94 | ``` {.r} 95 | knn.reg(women$height, 60, women$weight, 5) 96 | ``` 97 | 98 | ## Prediction: 99 | ## [1] 120.2 100 | 101 | Session info 102 | ------------ 103 | 104 | Time built. 105 | 106 | ## [1] "2022-10-20 06:51:49 UTC" 107 | 108 | Session info. 109 | 110 | ## R version 4.2.1 (2022-06-23) 111 | ## Platform: x86_64-pc-linux-gnu (64-bit) 112 | ## Running under: Ubuntu 20.04.4 LTS 113 | ## 114 | ## Matrix products: default 115 | ## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 116 | ## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/liblapack.so.3 117 | ## 118 | ## locale: 119 | ## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C 120 | ## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8 121 | ## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 122 | ## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C 123 | ## [9] LC_ADDRESS=C LC_TELEPHONE=C 124 | ## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C 125 | ## 126 | ## attached base packages: 127 | ## [1] stats graphics grDevices utils datasets methods base 128 | ## 129 | ## other attached packages: 130 | ## [1] FNN_1.1.3.1 forcats_0.5.1 stringr_1.4.0 dplyr_1.0.9 131 | ## [5] purrr_0.3.4 readr_2.1.2 tidyr_1.2.0 tibble_3.1.7 132 | ## [9] ggplot2_3.3.6 tidyverse_1.3.1 133 | ## 134 | ## loaded via a namespace (and not attached): 135 | ## [1] tidyselect_1.1.2 xfun_0.31 haven_2.5.0 colorspace_2.0-3 136 | ## [5] vctrs_0.4.1 generics_0.1.3 htmltools_0.5.2 yaml_2.3.5 137 | ## [9] utf8_1.2.2 rlang_1.0.3 pillar_1.7.0 glue_1.6.2 138 | ## [13] withr_2.5.0 DBI_1.1.3 dbplyr_2.2.1 modelr_0.1.8 139 | ## [17] readxl_1.4.0 lifecycle_1.0.1 munsell_0.5.0 gtable_0.3.0 140 | ## [21] cellranger_1.1.0 rvest_1.0.2 evaluate_0.15 knitr_1.39 141 | ## [25] tzdb_0.3.0 fastmap_1.1.0 fansi_1.0.3 highr_0.9 142 | ## [29] broom_1.0.0 scales_1.2.0 backports_1.4.1 jsonlite_1.8.0 143 | ## [33] fs_1.5.2 hms_1.1.1 digest_0.6.29 stringi_1.7.6 144 | ## [37] grid_4.2.1 cli_3.3.0 tools_4.2.1 magrittr_2.0.3 145 | ## [41] crayon_1.5.1 pkgconfig_2.0.3 ellipsis_0.3.2 xml2_1.3.3 146 | ## [45] reprex_2.0.1 lubridate_1.8.0 rstudioapi_0.13 assertthat_0.2.1 147 | ## [49] rmarkdown_2.14 httr_1.4.3 R6_2.5.1 compiler_4.2.1 148 | -------------------------------------------------------------------------------- /knn/img/unnamed-chunk-1-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/knn/img/unnamed-chunk-1-1.png -------------------------------------------------------------------------------- /knn/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "k-Nearest Neighbours" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | ## Introduction 17 | 18 | In pattern recognition, the [k-Nearest Neighbours algorithm](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm) (k-NN) is a non-parametric method used for classification and regression. In both cases, the input consists of the k closest training examples in the feature space. 19 | 20 | Install packages if missing and load. 21 | 22 | ```{r load_package, message=FALSE, warning=FALSE} 23 | .libPaths('/packages') 24 | my_packages <- 'FNN' 25 | 26 | for (my_package in my_packages){ 27 | if(!require(my_package, character.only = TRUE)){ 28 | install.packages(my_package, '/packages') 29 | library(my_package, character.only = TRUE) 30 | } 31 | } 32 | ``` 33 | 34 | ## Regression 35 | 36 | We'll use the `women` dataset to demonstrate how k-NN performs regression. The dataset contains height and weight measurements for 15 American women aged between 30–39. 37 | 38 | ```{r} 39 | data(women) 40 | str(women) 41 | plot(women, xlab = "Height (in)", ylab = "Weight (lb)", main = "women data: American women aged 30-39", pch = 16) 42 | ``` 43 | 44 | In the example below, we want to *predict the weight* of a female who is 60 inches tall based on data in the `women` dataset. 45 | 46 | ```{r} 47 | knn <- function(x, x_train, y_train, k){ 48 | d <- abs(x - x_train) 49 | s <- order(d) 50 | return(mean(y_train[s[1:k]])) 51 | } 52 | 53 | # using four neighbours 54 | knn(60, women$height, women$weight, 4) 55 | 56 | # using five neighbours 57 | knn(60, women$height, women$weight, 5) 58 | ``` 59 | 60 | * The `knn` algorithm first calculates the absolute distance of an input to a known set of data points for the same variable (height). 61 | * These distances are then sorted, with the closest data points ranked first. 62 | * The k-nearest distances of heights are used to obtain the corresponding weights 63 | * Finally, the k weights are averaged (mean) and returned 64 | 65 | Use `sapply` to predict several values. 66 | 67 | ```{r} 68 | sapply(c(60,70), knn, x_train = women$height, y_train = women$weight, k = 4) 69 | ``` 70 | 71 | You can also use `knn.reg` in the `FNN` package. 72 | 73 | ```{r} 74 | knn.reg(women$height, 60, women$weight, 4) 75 | 76 | knn.reg(women$height, 60, women$weight, 5) 77 | ``` 78 | 79 | ## Session info 80 | 81 | Time built. 82 | 83 | ```{r time, echo=FALSE} 84 | Sys.time() 85 | ``` 86 | 87 | Session info. 88 | 89 | ```{r session_info, echo=FALSE} 90 | sessionInfo() 91 | ``` 92 | 93 | -------------------------------------------------------------------------------- /logit_regression/img/mpg_vs_hp-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/logit_regression/img/mpg_vs_hp-1.png -------------------------------------------------------------------------------- /logit_regression/img/unnamed-chunk-3-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/logit_regression/img/unnamed-chunk-3-1.png -------------------------------------------------------------------------------- /logit_regression/img/unnamed-chunk-5-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/logit_regression/img/unnamed-chunk-5-1.png -------------------------------------------------------------------------------- /logit_regression/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Logistic regression" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | ## Introduction 17 | 18 | Logistic regression is a regression model where the dependent variable (DV) is categorical. Example from Wikipedia: 19 | 20 | > "Logistic regression may be used to predict whether a patient has a given disease (e.g. diabetes; coronary heart disease), based on observed characteristics of the patient (age, sex, body mass index, results of various blood tests, etc.)." 21 | 22 | Install packages if missing and load. 23 | 24 | ```{r load_package, message=FALSE, warning=FALSE} 25 | .libPaths('/packages') 26 | my_packages <- c('Amelia', 'ROCR') 27 | 28 | for (my_package in my_packages){ 29 | if(!require(my_package, character.only = TRUE)){ 30 | install.packages(my_package, '/packages') 31 | library(my_package, character.only = TRUE) 32 | } 33 | } 34 | ``` 35 | 36 | ## mtcars 37 | 38 | The data `mtcars` was extracted from the 1974 Motor Trend US magazine, and comprises fuel consumption and 10 aspects of automobile design and performance for 32 automobiles (1973–74 models). 39 | 40 | For this example, we’ll predict whether a car has high or low miles per gallon (MPG) based on gross horsepower; if `mpg` is higher than the median, we'll use a 1 to indicate that it is high and a 0 otherwise. 41 | 42 | ```{r mpg_high} 43 | mtcars$mpg_high <- factor(with(mtcars, ifelse(mpg > median(mpg), 1, 0))) 44 | head(mtcars) 45 | ``` 46 | 47 | Miles/(US) gallon versus gross horsepower. 48 | 49 | ```{r mpg_vs_hp} 50 | boxplot(hp ~ mpg_high, data = mtcars) 51 | ``` 52 | 53 | 54 | Use `glm()` to fit a logistic regression model that predicts `mpg_high` using `hp`. 55 | 56 | ```{r mtcars_fit} 57 | fit <- glm(mpg_high ~ hp, data = mtcars, family = binomial) 58 | summary(fit) 59 | ``` 60 | 61 | The negative coefficient for `hp` (-0.05901) suggests that as horsepower (hp) increases, the likelihood of high MPG decreases. 62 | 63 | ## Hours of study 64 | 65 | Using the example from Wikipedia: [Probability of passing an exam versus hours of study](https://en.wikipedia.org/wiki/Logistic_regression#Example:_Probability_of_passing_an_exam_versus_hours_of_study) 66 | 67 | ```{r} 68 | d <- data.frame( 69 | hours = c(0.50,0.75,1.00,1.25,1.50,1.75,1.75,2.00,2.25,2.50,2.75,3.00,3.25,3.50,4.00,4.25,4.50,4.75,5.00,5.50), 70 | pass = factor(c(0,0,0,0,0,0,1,0,1,0,1,0,1,0,1,1,1,1,1,1)) 71 | ) 72 | 73 | model <- glm( 74 | pass ~ hours, 75 | family=binomial(link='logit'), 76 | data=d 77 | ) 78 | summary(model) 79 | ``` 80 | 81 | The output indicates that hours studying is significantly associated with the probability of passing the exam (p=0.0167, Wald test). The output also provides the coefficients for Intercept = -4.0777 and Hours = 1.5046. 82 | 83 | Probability of passing as a function of hours of study. 84 | 85 | ```{r} 86 | prob_passing <- function(hours){ 87 | 1 / (1 + exp(-(-4.0777 + 1.5046 * hours))) 88 | } 89 | 90 | prob_passing(4) 91 | ``` 92 | 93 | ## Survival on the Titanic 94 | 95 | Adapted from [How to Perform a Logistic Regression in R](https://datascienceplus.com/perform-logistic-regression-in-r/). 96 | 97 | ```{r} 98 | data <- read.csv("../data/titanic.csv.gz", na.strings = '') 99 | 100 | # missing data 101 | sapply(data, function(x) sum(is.na(x))) 102 | 103 | missmap(data) 104 | ``` 105 | 106 | Remove some features 107 | 108 | ```{r} 109 | data_subset <- select(data, -PassengerId, -Ticket, -Cabin, -Name) 110 | 111 | # remove the two cases with missing embarked data 112 | data_subset <- filter(data_subset, !is.na(Embarked)) 113 | 114 | # you can use the mean age for the missing ages 115 | data_subset$Age[is.na(data_subset$Age)] <- mean(data_subset$Age, na.rm=TRUE) 116 | 117 | # subset into training and testing sets 118 | train <- data_subset[1:800,] 119 | test <- data_subset[801:nrow(data_subset),] 120 | 121 | model <- glm(Survived ~., 122 | family=binomial(link='logit'), 123 | data=train) 124 | 125 | fitted <- predict(model, 126 | newdata=test[,-1], 127 | type='response') 128 | ``` 129 | 130 | ```{r} 131 | pr <- prediction(fitted, test$Survived, ) 132 | prf <- performance(pr, measure = "tpr", x.measure = "fpr") 133 | auc <- performance(pr, measure = "auc") 134 | plot(prf) 135 | legend(x = 0.75, y = 0.05, legend = paste("AUC = ", auc@y.values), bty = 'n') 136 | ``` 137 | 138 | ## Links 139 | 140 | * [Simple logistic regression](http://www.biostathandbook.com/simplelogistic.html) 141 | * [Multiple logistic regression](http://www.biostathandbook.com/multiplelogistic.html) 142 | 143 | ## Session info 144 | 145 | Time built. 146 | 147 | ```{r time, echo=FALSE} 148 | Sys.time() 149 | ``` 150 | 151 | Session info. 152 | 153 | ```{r session_info, echo=FALSE} 154 | sessionInfo() 155 | ``` 156 | 157 | -------------------------------------------------------------------------------- /machine_learning.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: Default 4 | SaveWorkspace: Default 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | 15 | BuildType: Makefile 16 | -------------------------------------------------------------------------------- /naive_bayes/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Naive Bayes" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | ## Introduction 17 | 18 | Naive Bayes is a machine learning approach based on [Bayes' theorem](https://en.wikipedia.org/wiki/Bayes%27_theorem), which describes the probability of an event, based on prior knowledge of conditions that might be related to the event. For example, if the risk of developing health problems is known to increase with age, Bayes' theorem allows the risk to an individual of a known age to be assessed more accurately (by conditioning it on their age) than simply assuming that the individual is typical of the population as a whole. The theorem is stated mathematically as: 19 | 20 | $$ P(A|B)=\frac{P(B|A)P(A)}{P(B)} $$ 21 | 22 | where A and B are events and P(B) is not equal to 0. 23 | 24 | * P(A|B) is a conditional probability: the likelihood of event A occurring given that B is true; this is the posterior probability 25 | * P(B|A) is also a conditional probability: the likelihood of event B occurring given that A is true 26 | * P(A) and P(B) are the probabilities of observing A and B independently of each other; they are known as marginal probabilities and are the prior probabilities 27 | 28 | The formula can be interpreted as the chance of a true positive result, divided by the chance of any positive result (true positive + false positive). 29 | 30 | A typical example used for illustrating Bayes theorem is on calculating the probability that a [drug test](https://en.wikipedia.org/wiki/Bayes%27_theorem#Drug_testing) comes up positive with a drug user. Suppose a particular drug test is correct 90% of the times at detecting drug use, i.e. 90% positive when a user has been using drugs, and is correct 80% of the times at detecting non-users. We also know that 5% of the population use this drug. Given what we know, what is the probability that a person is a drug user when the test turns up positive? Isn't this just 90% or 0.9? Not quite, since we have prior knowledge that only 5% of people are drug users and that the test can turn up positive even with non-users. If we apply Bayes' theorem and let $P(User|Positive)$ mean the probability that someone is a drug user given that they tested positive, then: 31 | 32 | $$ P(User | Positive) = \frac{P (Positive | User) P(User)}{P(Positive)} $$ 33 | 34 | The nominator is the $0.045$, since 35 | 36 | * $P(Positive|User) = 0.90$ 37 | * $P(User) = 0.05$ 38 | 39 | For working out $P(Positive)$, there are two scenarios that a test turns up positive: 40 | 41 | * $P(Positive|User) = 0.90 \times 0.05 = 0.045$ 42 | * $P(Positive|Non-user) = 0.20 \times 0.95 = 0.19$ 43 | 44 | Therefore, the probability that a person is a drug-user given a positive test result is: 45 | 46 | $$ \frac{0.045}{0.045 + 0.19} = .1914 \approx 19% $$ 47 | 48 | Classifiers based on Bayesian methods utilise training data to calculate an observed probability of each class based on feature values. An example is using Bayes' theorem to classify emails; we can use known ham and spam emails to tally up the occurrence of words to obtain prior and conditional probabilities, and use these probabilities to classify new emails. Bayesian classifiers utilise all available evidence to come up with a classification thus they are best applied to problems where the information from numerous attributes should be considered simultaneously 49 | 50 | * The "naive" part of the method refers to a assumption that all features in a dataset are equally important and independent, which is usually not true; despite this, the method performs quite well in certain applications like spam classification 51 | 52 | Install packages if missing and load. 53 | 54 | ```{r load_package, message=FALSE, warning=FALSE} 55 | .libPaths('/packages') 56 | my_packages <- 'e1071' 57 | 58 | for (my_package in my_packages){ 59 | if(!require(my_package, character.only = TRUE)){ 60 | install.packages(my_package, '/packages') 61 | library(my_package, character.only = TRUE) 62 | } 63 | } 64 | ``` 65 | 66 | ## Example 67 | 68 | Classifiers based on Bayesian methods utilise training data to calculate an observed probability of each class based on feature values. The classifier uses observed probabilities from unlabelled data to predict the most likely class. 69 | 70 | Use the `naiveBayes()` function from the `e1071` package to perform [Gaussian naive Bayes](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Gaussian_naive_Bayes). Divide the dataset in 80% training and 20% testing. 71 | 72 | ```{r naive_bayes} 73 | set.seed(1984) 74 | my_index <- sample(x = 1:nrow(iris), size = .8*(nrow(iris))) 75 | 76 | my_train <- iris[my_index, -5] 77 | my_train_label <- iris[my_index, 5] 78 | 79 | my_test <- iris[!1:nrow(iris) %in% my_index, -5] 80 | my_test_label <- iris[!1:nrow(iris) %in% my_index, 5] 81 | 82 | m <- naiveBayes(my_train, my_train_label) 83 | 84 | m 85 | ``` 86 | 87 | The values are the mean and variance for each feature stratified by class. 88 | 89 | ```{r mean_and_var, message = FALSE, warnings = FALSE} 90 | iris[my_index,] %>% 91 | group_by(Species) %>% 92 | summarise(mean = mean(Petal.Width), var = var(Petal.Width)) 93 | ``` 94 | 95 | Classify the test set and tabulate based on the real labels; only one misclassification of a virginica as a versicolor. 96 | 97 | ```{r classify} 98 | table(predict(m, my_test), my_test_label) 99 | ``` 100 | 101 | ## Further reading 102 | 103 | * [An Intuitive (and Short) Explanation of Bayes' Theorem](https://betterexplained.com/articles/an-intuitive-and-short-explanation-of-bayes-theorem/) 104 | * [Naive Bayes for Machine Learning](https://machinelearningmastery.com/naive-bayes-for-machine-learning/) 105 | 106 | ## Session info 107 | 108 | Time built. 109 | 110 | ```{r time, echo=FALSE} 111 | Sys.time() 112 | ``` 113 | 114 | Session info. 115 | 116 | ```{r session_info, echo=FALSE} 117 | sessionInfo() 118 | ``` 119 | 120 | -------------------------------------------------------------------------------- /pca/img/plot-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/pca/img/plot-1.png -------------------------------------------------------------------------------- /pca/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Principal Components Analysis" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | ## Introduction 17 | 18 | This README was generated by running from the root directory of this repository: 19 | 20 | script/rmd_to_md.sh pca/readme.Rmd 21 | 22 | ## PCA background 23 | 24 | Sample data as per PCA tutorial. 25 | 26 | ```{r x} 27 | X = c(1, 2, 4, 6, 12, 15, 25, 45, 68, 67, 65, 98) 28 | ``` 29 | 30 | Sample mean (X bar) by adding up all the numbers and then divide by how many there are. 31 | 32 | $$ \bar{X} = \frac{\sum^n_{i=1} X_i}{n}$$ 33 | 34 | ```{r x_bar} 35 | mean(X) 36 | ``` 37 | 38 | The Standard Deviation (SD) of a data set is a measure of how spread out the data is. The way to calculate it is to compute the squares of the distance from each data point to the mean of the set, add them all up, divide by $n-1$, and take the positive square root. 39 | 40 | $$ s = \sqrt\frac{\sum^n_{i=1} (X_i - \bar{X})^2}{(n - 1)} $$ 41 | 42 | ```{r x_sd} 43 | sd(X) 44 | ``` 45 | Variance is another measure of the spread of data in a data set and it is simply the standard deviation squared. 46 | 47 | $$ s^2 = \frac{\sum^n_{i=1} (X_i - \bar{X})^2}{(n - 1)} $$ 48 | 49 | Many data sets have more than one dimension and the aim of the statistical analysis of these data sets is usually to see if there is any relationship between the dimensions. Standard deviation and variance only operate on one dimension, so you can only calculate the standard deviation for each dimension of the data set _independent_ of the other dimensions. However, it is useful to have a similar measure to find out how much the dimensions vary from the mean _with respect to each other_, which is what covariance measures and is calculated _between_ two dimensions. 50 | 51 | If you calculate the covariance between one dimension and _itself_, you end up with the variance. For a 3D data set (x, y, z), you could measure the covariance between the x and y dimensions, the x and z dimensions, and the y and z dimensions. The formula for covariance is very similar to the formula for variance. 52 | 53 | 54 | $$ cov(X, Y) = \frac{\sum^n_{i=1} (X_i-\bar{X})(Y_i-\bar{Y})}{(n - 1)} $$ 55 | 56 | How does the covariance work? Imagine a 2D data set containing hours spent studying for an exam (`hours`) and the mark they received for the exam (`mark`). 57 | 58 | ```{r hours_vs_mark} 59 | my_df <- data.frame( 60 | hours = c(9, 15, 25, 14, 10, 18, 0, 16, 5, 19, 16, 20), 61 | mark = c(39, 56, 93, 61, 50, 75, 32, 85, 42, 70, 66, 80) 62 | ) 63 | 64 | cov(my_df) 65 | ``` 66 | 67 | The exact value is not as important as it's sign (i.e. positive or negative). If the value is positive, then it indicates that both dimensions _increase together_, meaning that as the numbers of hours of study increased, so did the mark. 68 | 69 | If the value is negative, such as the example below, then one dimension increases as the other decreases; the windier it gets, the lower the temperature. 70 | 71 | ```{r airquality} 72 | cov(airquality[, c('Wind', 'Temp')]) 73 | ``` 74 | 75 | If the covariance is zero (or close to zero), then it indicates that the two dimensions are independent of each other. 76 | 77 | ```{r random_cov} 78 | my_random <- data.frame( 79 | x = rnorm(20), 80 | y = rnorm(20) 81 | ) 82 | 83 | cov(my_random) 84 | ``` 85 | 86 | A useful way to represent all possible covariance values in to calculate them and store them in a matrix. 87 | 88 | ```{r cov_matrix} 89 | cov(airquality[, 1:4], use = "complete.ob") 90 | ``` 91 | 92 | You can multiply two matrices together provided that they are of compatible sizes and eigenvectors are a special case of this. 93 | 94 | The following multiplication, results in a vector that is not an integer multiple of the original vector, i.e. non-eigenvector. 95 | 96 | $$ \begin{bmatrix} 2 & 3 \\ 2 & 1 \end{bmatrix} \times \begin{bmatrix} 1 \\ 3 \end{bmatrix} = \begin{bmatrix} 11 \\ 5 \end{bmatrix} $$ 97 | 98 | This second example, results in a vector that is exactly four times the vector we began with, i.e. an eigenvector. 99 | 100 | $$ \begin{bmatrix} 2 & 3 \\ 2 & 1 \end{bmatrix} \times \begin{bmatrix} 3 \\ 2 \end{bmatrix} = \begin{bmatrix} 12 \\ 8 \end{bmatrix} = 4 \times \begin{bmatrix} 3 \\ 2 \end{bmatrix} $$ 101 | 102 | In R. 103 | 104 | ```{r matrix_mut} 105 | matrix(c(2, 3, 2, 1), byrow = TRUE, nrow = 2) %*% matrix(c(3, 2)) 106 | ``` 107 | 108 | The vector $ \begin{bmatrix} 3 \\ 2 \end{bmatrix} $ represents an arrow pointing from the origin, $ (0,0) $, to the point $ (3,2) $. The square matrix $ \begin{bmatrix} 2 & 3 \\ 2 & 1 \end{bmatrix} $ can be considered as a transformation matrix. If you multiply this matrix on the left of a vector (as per the example), the answer is another vector that is transformed from its original position. It is the nature of the transformation that the eigenvectors arise from. 109 | 110 | Now imagine a transformation matrix that, when multiplied on the left, reflected vectors in the line $y = x$ (like the second example). Then you can see that if there were a vector that lay on the line $y = x$, it's reflection is itself. This vector (and all multiples of it), would be an eigenvector of that transformation matrix. 111 | 112 | Eigenvectors can only be found for square matrices and not every square matrix has eigenvectors. If an $n \times n$ matrix does have eigenvectors, there are $n$ of them. Lastly, all eigenvectors of a matrix are perpendicular, i.e. at right angles to each other, no matter the number of dimensions. Another word for perpendicular is orthogonal and being orthogonal is important because it means that you can express the data in terms of these perpendicular eigenvectors, instead of expressing them in terms of the $x$ and $y$ axes. 113 | 114 | In addition, eigenvectors are scaled such that it has a length of 1, so that all eigenvectors have the same length. The vector $ \begin{bmatrix} 3 \\ 2 \end{bmatrix} $ has a length $ \sqrt(3^2 + 2^2) = \sqrt13 $ so we divide the original vector by this length to make it have a length of 1. 115 | 116 | ```{r standardise_eigenvector} 117 | matrix(c(3,2) / sqrt(13)) 118 | ``` 119 | 120 | 4 is the eigenvalue associated with the $ \begin{bmatrix} 3 \\ 2 \end{bmatrix} $ and the `eigen()` function can be used to find eigenvalues and eigenvectors of a square matrix. 121 | 122 | ```{r eigen} 123 | my_mat <- matrix(c(2, 3, 2, 1), byrow = TRUE, nrow = 2) 124 | eigen(my_mat) 125 | ``` 126 | 127 | ## PCA 128 | 129 | Principal Components Analysis (PCA) is a way of identifying patterns in data and expressing the data in such a way as to highlight their similarities and differences. 130 | 131 | ```{r pca_data} 132 | x <- c(2.5, 0.5, 2.2, 1.9, 3.1, 2.3, 2, 1, 1.5, 1.1) 133 | y <- c(2.4, 0.7, 2.9, 2.2, 3.0, 2.7, 1.6, 1.1, 1.6, 0.9) 134 | ``` 135 | 136 | ## Breast cancer data 137 | 138 | Using the [Breast Cancer Wisconsin (Diagnostic) Data Set](https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)). 139 | 140 | ```{r prepare_data} 141 | data <- read.table( 142 | "../data/breast_cancer_data.csv", 143 | stringsAsFactors = FALSE, 144 | sep = ',', 145 | header = TRUE 146 | ) 147 | data$class <- factor(data$class) 148 | data <- data[,-1] 149 | ``` 150 | 151 | Separate into training (80%) and testing (20%). 152 | 153 | ```{r split_data} 154 | set.seed(31) 155 | my_prob <- 0.8 156 | my_split <- as.logical( 157 | rbinom( 158 | n = nrow(data), 159 | size = 1, 160 | p = my_prob 161 | ) 162 | ) 163 | 164 | train <- data[my_split,] 165 | test <- data[!my_split,] 166 | ``` 167 | 168 | ## Results 169 | 170 | ```{r plot} 171 | ggplot(data, aes(class, ucsize)) + 172 | geom_boxplot() 173 | ``` 174 | 175 | ## Session info 176 | 177 | Time built. 178 | 179 | ```{r time, echo=FALSE} 180 | Sys.time() 181 | ``` 182 | 183 | Session info. 184 | 185 | ```{r session_info, echo=FALSE} 186 | sessionInfo() 187 | ``` 188 | 189 | -------------------------------------------------------------------------------- /proximus/README.md: -------------------------------------------------------------------------------- 1 | Introduction 2 | ------------ 3 | 4 | https://en.wikibooks.org/wiki/Data_Mining_Algorithms_In_R/Clustering/Proximus 5 | 6 | Install packages if missing and load. 7 | 8 | .libPaths('/packages') 9 | my_packages <- 'cba' 10 | 11 | for (my_package in my_packages){ 12 | if(!require(my_package, character.only = TRUE)){ 13 | install.packages(my_package, '/packages') 14 | library(my_package, character.only = TRUE) 15 | } 16 | } 17 | 18 | Results 19 | ------- 20 | 21 | https://www.rdocumentation.org/packages/cba/versions/0.2-21/topics/proximus 22 | 23 | x: a logical matrix. max.radius: the maximum number of bits a member in 24 | a row set may deviate from its dominant pattern. min.size: the minimum 25 | split size of a row set. min.retry: number of retries to split a pure 26 | rank-one approximation (translates into a resampling rate). max.iter: 27 | the maximum number of iterations for finding a local rank-one 28 | approximation. debug: optional debugging output. 29 | 30 | x <- rlbmat() 31 | pr <- proximus(x, max.radius=8, debug=TRUE) 32 | 33 | ## Non-Zero: 1078 34 | ## Sparsity: 0.26 35 | ## 0 [80,20,7] 1 > 36 | ## 1 [20,20,7] 1 * 1 37 | ## 1 [60,20,6] 1 > 38 | ## 2 [20,20,6] 1 * 2 39 | ## 2 [40,20,7] 1 > 40 | ## 3 [20,20,7] 1 * 3 41 | ## 3 [20,20,6] 1 * 4 42 | ## 2 < 43 | ## 1 < 44 | ## 0 < 45 | 46 | op <- par(mfrow=c(1,2), pty="s") 47 | lmplot(x, main="Data") 48 | box() 49 | lmplot(fitted(pr)$x, main="Approximation") 50 | box() 51 | 52 | ![](img/unnamed-chunk-1-1.png) 53 | 54 | par(op) 55 | 56 | Session info 57 | ------------ 58 | 59 | Time built. 60 | 61 | ## [1] "2022-04-10 09:20:35 UTC" 62 | 63 | Session info. 64 | 65 | ## R version 4.1.3 (2022-03-10) 66 | ## Platform: x86_64-pc-linux-gnu (64-bit) 67 | ## Running under: Ubuntu 20.04.4 LTS 68 | ## 69 | ## Matrix products: default 70 | ## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 71 | ## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/liblapack.so.3 72 | ## 73 | ## locale: 74 | ## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C 75 | ## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8 76 | ## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 77 | ## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C 78 | ## [9] LC_ADDRESS=C LC_TELEPHONE=C 79 | ## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C 80 | ## 81 | ## attached base packages: 82 | ## [1] grid stats graphics grDevices utils datasets methods 83 | ## [8] base 84 | ## 85 | ## other attached packages: 86 | ## [1] cba_0.2-21 proxy_0.4-26 forcats_0.5.1 stringr_1.4.0 87 | ## [5] dplyr_1.0.8 purrr_0.3.4 readr_2.1.2 tidyr_1.2.0 88 | ## [9] tibble_3.1.6 ggplot2_3.3.5 tidyverse_1.3.1 89 | ## 90 | ## loaded via a namespace (and not attached): 91 | ## [1] tidyselect_1.1.2 xfun_0.30 haven_2.4.3 colorspace_2.0-3 92 | ## [5] vctrs_0.4.0 generics_0.1.2 htmltools_0.5.2 yaml_2.3.5 93 | ## [9] utf8_1.2.2 rlang_1.0.2 pillar_1.7.0 glue_1.6.2 94 | ## [13] withr_2.5.0 DBI_1.1.2 dbplyr_2.1.1 modelr_0.1.8 95 | ## [17] readxl_1.4.0 lifecycle_1.0.1 munsell_0.5.0 gtable_0.3.0 96 | ## [21] cellranger_1.1.0 rvest_1.0.2 evaluate_0.15 knitr_1.38 97 | ## [25] tzdb_0.3.0 fastmap_1.1.0 fansi_1.0.3 highr_0.9 98 | ## [29] broom_0.7.12 scales_1.1.1 backports_1.4.1 jsonlite_1.8.0 99 | ## [33] fs_1.5.2 hms_1.1.1 digest_0.6.29 stringi_1.7.6 100 | ## [37] cli_3.2.0 tools_4.1.3 magrittr_2.0.3 crayon_1.5.1 101 | ## [41] pkgconfig_2.0.3 ellipsis_0.3.2 xml2_1.3.3 reprex_2.0.1 102 | ## [45] lubridate_1.8.0 rstudioapi_0.13 assertthat_0.2.1 rmarkdown_2.13 103 | ## [49] httr_1.4.2 R6_2.5.1 compiler_4.1.3 104 | -------------------------------------------------------------------------------- /proximus/img/unnamed-chunk-1-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/proximus/img/unnamed-chunk-1-1.png -------------------------------------------------------------------------------- /proximus/proximus.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Proximus" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | ## Introduction 17 | 18 | https://en.wikibooks.org/wiki/Data_Mining_Algorithms_In_R/Clustering/Proximus 19 | 20 | Install packages if missing and load. 21 | 22 | ```{r load_package, message=FALSE, warning=FALSE} 23 | .libPaths('/packages') 24 | my_packages <- 'cba' 25 | 26 | for (my_package in my_packages){ 27 | if(!require(my_package, character.only = TRUE)){ 28 | install.packages(my_package, '/packages') 29 | library(my_package, character.only = TRUE) 30 | } 31 | } 32 | ``` 33 | 34 | ## Results 35 | 36 | https://www.rdocumentation.org/packages/cba/versions/0.2-21/topics/proximus 37 | 38 | x: a logical matrix. 39 | max.radius: the maximum number of bits a member in a row set may deviate from its dominant pattern. 40 | min.size: the minimum split size of a row set. 41 | min.retry: number of retries to split a pure rank-one approximation (translates into a resampling rate). 42 | max.iter: the maximum number of iterations for finding a local rank-one approximation. 43 | debug: optional debugging output. 44 | 45 | ```{r} 46 | x <- rlbmat() 47 | pr <- proximus(x, max.radius=8, debug=TRUE) 48 | op <- par(mfrow=c(1,2), pty="s") 49 | lmplot(x, main="Data") 50 | box() 51 | lmplot(fitted(pr)$x, main="Approximation") 52 | box() 53 | par(op) 54 | ``` 55 | 56 | ## Session info 57 | 58 | Time built. 59 | 60 | ```{r time, echo=FALSE} 61 | Sys.time() 62 | ``` 63 | 64 | Session info. 65 | 66 | ```{r session_info, echo=FALSE} 67 | sessionInfo() 68 | ``` 69 | 70 | -------------------------------------------------------------------------------- /random_forest/img/tune_rf-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/random_forest/img/tune_rf-1.png -------------------------------------------------------------------------------- /random_forest/img/var_imp_plot-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/random_forest/img/var_imp_plot-1.png -------------------------------------------------------------------------------- /random_forest/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Random Forest" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | ## Introduction 17 | 18 | Install packages if missing and load. 19 | 20 | ```{r load_package, message=FALSE, warning=FALSE} 21 | .libPaths('/packages') 22 | my_packages <- c('randomForest') 23 | 24 | for (my_package in my_packages){ 25 | if(!require(my_package, character.only = TRUE)){ 26 | install.packages(my_package, '/packages') 27 | library(my_package, character.only = TRUE) 28 | } 29 | } 30 | ``` 31 | 32 | ## Breast cancer data 33 | 34 | Using the [Breast Cancer Wisconsin (Diagnostic) Data Set](https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)). 35 | 36 | ```{r prepare_data} 37 | data <- read.table( 38 | "../data/breast_cancer_data.csv", 39 | stringsAsFactors = FALSE, 40 | sep = ',', 41 | header = TRUE 42 | ) 43 | data$class <- factor(data$class) 44 | data <- data[,-1] 45 | ``` 46 | 47 | Separate into training (80%) and testing (20%). 48 | 49 | ```{r split_data} 50 | set.seed(31) 51 | my_prob <- 0.8 52 | my_split <- as.logical( 53 | rbinom( 54 | n = nrow(data), 55 | size = 1, 56 | p = my_prob 57 | ) 58 | ) 59 | 60 | train <- data[my_split,] 61 | test <- data[!my_split,] 62 | ``` 63 | 64 | ## Analysis 65 | 66 | Parameters: 67 | 68 | * `data` = an optional data frame containing the variables in the model 69 | * `importance` = calculate the importance of predictors 70 | * `do.trace` = give a more verbose output as randomForest is running 71 | * `proximity` = calculate the proximity measure among the rows 72 | 73 | ```{r train_rf} 74 | set.seed(31) 75 | r <- randomForest(class ~ ., data=train, importance=TRUE, do.trace=100, proximity = TRUE) 76 | ``` 77 | 78 | Error rate. 79 | 80 | ```{r error_rate} 81 | r 82 | ``` 83 | 84 | Find best `mtry`. 85 | 86 | ```{r tune_rf} 87 | set.seed(31) 88 | my_tuning <- tuneRF(train[, -ncol(train)], train[, ncol(train)]) 89 | my_tuning 90 | ``` 91 | 92 | Use the best `mtry` value to re-train another model. 93 | 94 | ```{r best_mtry} 95 | best_mtry <- my_tuning[order(my_tuning[, 2])[1], 1] 96 | 97 | set.seed(31) 98 | r_best_mtry <- randomForest(class ~ ., data=train, importance=TRUE, do.trace=100, proximity = TRUE, mtry = best_mtry) 99 | r_best_mtry 100 | ``` 101 | 102 | ## Predict 103 | 104 | Predict testing data. 105 | 106 | ```{r predict} 107 | prop.table( 108 | table( 109 | test$class, predict(object = r_best_mtry, newdata = test[, -ncol(test)]) 110 | ) 111 | ) 112 | ``` 113 | 114 | Error rate on testing set. 115 | 116 | ```{r test_error_rate} 117 | my_pred <- predict(object = r_best_mtry, newdata = test[, -ncol(test)]) 118 | 1 - (sum(my_pred == test[, ncol(test)]) / length(my_pred)) 119 | ``` 120 | 121 | ## Plots 122 | 123 | Variable importance. 124 | 125 | ```{r var_imp_plot} 126 | varImpPlot(r) 127 | ``` 128 | 129 | ## Random Forest object 130 | 131 | ```{r class} 132 | class(r) 133 | ``` 134 | 135 | Names. 136 | 137 | ```{r names} 138 | names(r) 139 | ``` 140 | 141 | The original call to randomForest 142 | 143 | ```{r call} 144 | r$call 145 | ``` 146 | 147 | One of regression, classification, or unsupervised 148 | 149 | ```{r type} 150 | r$type 151 | ``` 152 | 153 | The predicted values of the input data based on out-of-bag samples 154 | 155 | ```{r predicted} 156 | table(r$predicted, train$class) 157 | ``` 158 | 159 | A matrix with number of classes + 2 (for classification) or two (for regression) columns for classification: 160 | 161 | * the first two columns are the class-specific measures computed as mean decrease in accuracy 162 | * the `MeanDecreaseAccuracy` column is the mean decrease in accuracy over all classes 163 | * the `MeanDecreaseGini` is the mean decrease in Gini index 164 | 165 | ```{r importance} 166 | r$importance 167 | ``` 168 | 169 | The "standard errors" of the permutation-based importance measure. 170 | 171 | ```{r importance_sd} 172 | r$importanceSD 173 | ``` 174 | 175 | Number of trees grown. 176 | 177 | ```{r ntree} 178 | r$ntree 179 | ``` 180 | 181 | Number of predictors sampled for spliting at each node. 182 | 183 | ```{r mtry} 184 | r$mtry 185 | ``` 186 | 187 | A list that contains the entire forest. 188 | 189 | ```{r forest} 190 | r$forest[[1]] 191 | ``` 192 | 193 | Use `getTree` to obtain an individual tree. 194 | 195 | ```{r get_tree} 196 | head(getTree(r, k = 1)) 197 | ``` 198 | 199 | Vector error rates (classification only) of the prediction on the input data, the i-th element being the (OOB) error rate for all trees up to the i-th. 200 | 201 | ```{r err_rate} 202 | head(r$err.rate) 203 | ``` 204 | 205 | The confusion matrix (classification only) of the prediction (based on OOB data). 206 | 207 | ```{r confusion} 208 | r$confusion 209 | ``` 210 | 211 | A matrix with one row for each input data point and one column for each class, giving the fraction or number of (OOB) ‘votes’ from the random forest (classification only). 212 | 213 | ```{r votes} 214 | head(r$votes) 215 | ``` 216 | 217 | Number of times cases are "out-of-bag" (and thus used in computing OOB error estimate). 218 | 219 | ```{r oob_times} 220 | r$oob.times 221 | ``` 222 | 223 | If proximity=TRUE when `randomForest` is called, a matrix of proximity measures among the input (based on the frequency that pairs of data points are in the same terminal nodes). 224 | 225 | ```{r proximity} 226 | dim(r$proximity) 227 | ``` 228 | 229 | ## On importance 230 | 231 | Notes from [Stack Exchange](http://stats.stackexchange.com/questions/92419/relative-importance-of-a-set-of-predictors-in-a-random-forests-classification-in>): 232 | 233 | MeanDecreaseGini is a measure of variable importance based on the Gini impurity index used for the calculation of splits during training. A common misconception is that the variable importance metric refers to the Gini used for asserting model performance which is closely related to AUC, but this is wrong. Here is the explanation from the randomForest package written by Breiman and Cutler: 234 | 235 | > Every time a split of a node is made on variable m the gini impurity criterion for the two descendent nodes is less than the parent node. Adding up the gini decreases for each individual variable over all trees in the forest gives a fast variable importance that is often very consistent with the permutation importance measure. 236 | 237 | The Gini impurity index is defined as: 238 | 239 | ![](https://latex.codecogs.com/png.image?\large&space;\bg{white}G&space;=&space;\sum^{n_c}_{i=1}&space;p_i&space;(1&space;-&space;p_i)) 240 | 241 | where nc is the number of classes in the target variable and pi is the ratio of this class. 242 | 243 | ## Session info 244 | 245 | Time built. 246 | 247 | ```{r time, echo=FALSE} 248 | Sys.time() 249 | ``` 250 | 251 | Session info. 252 | 253 | ```{r session_info, echo=FALSE} 254 | sessionInfo() 255 | ``` 256 | -------------------------------------------------------------------------------- /ref/OUCS-2002-12.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/ref/OUCS-2002-12.pdf -------------------------------------------------------------------------------- /ref/README.md: -------------------------------------------------------------------------------- 1 | ## README 2 | 3 | A list of useful references (saved here for prosperity). 4 | 5 | * [Feature Selection, Extraction, and Construction](fdws02.pdf) by Hiroshi Motoda and Huan Liu. 6 | * Feature selection is a process that chooses a subset of `M` features from the original set of `N` features. 7 | * Feature extraction is a process that extracts a set of new features from the original features through functional mapping. 8 | * Feature construction is a process that discovers missing information about the relationships between features and augments the space of features by inferring or creating new features. 9 | * [A tutorial on Principal Components Analysis](OUCS-2002-12.pdf) by Lindsay I Smith. 10 | 11 | -------------------------------------------------------------------------------- /ref/fdws02.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/ref/fdws02.pdf -------------------------------------------------------------------------------- /script/chown.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Files built by R inside a Docker container are owned by root. 4 | # Use this script to change the permissions back to your user and group. 5 | # 6 | 7 | set -euo pipefail 8 | 9 | path=$(dirname $0) 10 | cd ${path}/.. 11 | 12 | check_depend (){ 13 | tool=$1 14 | if [[ ! -x $(command -v ${tool}) ]]; then 15 | >&2 echo Could not find ${tool} 16 | exit 1 17 | fi 18 | } 19 | 20 | dependencies=(docker) 21 | for tool in ${dependencies[@]}; do 22 | check_depend ${tool} 23 | done 24 | 25 | now(){ 26 | date '+%Y/%m/%d %H:%M:%S' 27 | } 28 | 29 | SECONDS=0 30 | 31 | >&2 printf "[ %s %s ] Start job\n" $(now) 32 | 33 | r_version=4.1.3 34 | docker_image=davetang/r_build:${r_version} 35 | USERID=$(id -u) 36 | GROUPID=$(id -g) 37 | 38 | docker run \ 39 | --rm \ 40 | -v $(pwd):$(pwd) \ 41 | -w $(pwd) \ 42 | ${docker_image} \ 43 | find . -user root -exec chown ${USERID}:${GROUPID} {} \; 44 | 45 | >&2 printf "\n[ %s %s ] Work complete\n" $(now) 46 | 47 | duration=$SECONDS 48 | >&2 echo "$(($duration / 60)) minutes and $(($duration % 60)) seconds elapsed." 49 | 50 | exit 0 51 | 52 | -------------------------------------------------------------------------------- /script/rmd_to_md.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | 5 | num_param=1 6 | 7 | usage(){ 8 | echo "Usage: $0 [outfile.md]" 9 | exit 1 10 | } 11 | 12 | if [[ $# -lt ${num_param} ]]; then 13 | usage 14 | fi 15 | 16 | infile=$1 17 | if [[ ! -e ${infile} ]]; then 18 | >&2 echo ${infile} does not exist 19 | exit 1 20 | fi 21 | 22 | outfile=README.md 23 | if [[ $# -ge 2 ]]; then 24 | outfile=$2 25 | fi 26 | 27 | check_depend (){ 28 | tool=$1 29 | if [[ ! -x $(command -v ${tool}) ]]; then 30 | >&2 echo Could not find ${tool} 31 | exit 1 32 | fi 33 | } 34 | 35 | dependencies=(docker) 36 | for tool in ${dependencies[@]}; do 37 | check_depend ${tool} 38 | done 39 | 40 | now(){ 41 | date '+%Y/%m/%d %H:%M:%S' 42 | } 43 | 44 | SECONDS=0 45 | 46 | >&2 printf "[ %s %s ] Start job\n\n" $(now) 47 | 48 | RVER=4.4.0 49 | docker_image=davetang/r_tensorflow:${RVER} 50 | package_dir=${HOME}/r_packages_${RVER} 51 | 52 | if [[ ! -d ${package_dir} ]]; then 53 | mkdir ${package_dir} 54 | fi 55 | 56 | docker run \ 57 | --rm \ 58 | -v ${package_dir}:/packages \ 59 | -v $(pwd):$(pwd) \ 60 | -w $(pwd) \ 61 | -u $(id -u):$(id -g) \ 62 | ${docker_image} \ 63 | Rscript -e ".libPaths('/packages'); rmarkdown::render('${infile}', output_file = '${outfile}')" 64 | 65 | >&2 printf "\n[ %s %s ] Work complete\n" $(now) 66 | 67 | duration=$SECONDS 68 | >&2 echo "$(($duration / 60)) minutes and $(($duration % 60)) seconds elapsed." 69 | exit 0 70 | -------------------------------------------------------------------------------- /script/run_rstudio.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | 5 | VER=4.4.0 6 | IMG=davetang/r_tensorflow:${VER} 7 | CONTAINER=rstudio_server_r_tensorflow 8 | PORT=8883 9 | RLIB=${HOME}/r_packages_${VER} 10 | 11 | if [[ ! -d ${RLIB} ]]; then 12 | mkdir ${RLIB} 13 | fi 14 | 15 | docker run \ 16 | --name ${CONTAINER} \ 17 | -d \ 18 | --rm \ 19 | -p ${PORT}:8787 \ 20 | -v ${RLIB}:/packages \ 21 | -v ${HOME}/github/:/home/rstudio/work \ 22 | -e PASSWORD=password \ 23 | -e USERID=$(id -u) \ 24 | -e GROUPID=$(id -g) \ 25 | ${IMG} 26 | 27 | >&2 echo ${CONTAINER} listening on port ${PORT} 28 | 29 | exit 0 30 | -------------------------------------------------------------------------------- /som/img/check_convergence-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/som/img/check_convergence-1.png -------------------------------------------------------------------------------- /som/img/code_plot-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/som/img/code_plot-1.png -------------------------------------------------------------------------------- /som/img/code_plot-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/som/img/code_plot-2.png -------------------------------------------------------------------------------- /som/img/heatmap_sepal_width-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/som/img/heatmap_sepal_width-1.png -------------------------------------------------------------------------------- /som/img/node_count-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/som/img/node_count-1.png -------------------------------------------------------------------------------- /som/img/species_vs_sepal_width-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/som/img/species_vs_sepal_width-1.png -------------------------------------------------------------------------------- /som/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Self-organising map" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | ## Introduction 17 | 18 | Self Organising Maps (SOMs) consist of nodes that are first initialised with random weights that match the length of the input vector. Over many iterations, each node is adjusted to more closely resemble an input vector. The algorithm is described below: 19 | 20 | 1. Select number of nodes and the type of node (e.g. square, hexagon, circle, etc.) 21 | 2. Initialise all node weight vectors randomly 22 | 3. Choose a random data point from your training data and compare it to all nodes (e.g. Euclidean distance) 23 | 4. Find the Best Matching Unit (BMU) in the map, which is the mode similar node based on the distance metric used 24 | 5. Determine the nodes within the neighbourhood of the BMU 25 | 6. Adjust weights of nodes in the BMU neighbourhood towards the chosen data point; weights are adjusted according to the distance of the node to the BMU 26 | 27 | We'll use the [kohonen](https://cran.r-project.org/web/packages/kohonen/index.html) package. 28 | 29 | ```{r load_package, message=FALSE, warning=FALSE} 30 | .libPaths('/packages') 31 | my_packages <- 'kohonen' 32 | 33 | for (my_package in my_packages){ 34 | if(!require(my_package, character.only = TRUE)){ 35 | install.packages(my_package, '/packages') 36 | library(my_package, character.only = TRUE) 37 | } 38 | } 39 | ``` 40 | 41 | ## Getting started 42 | 43 | Build a SOM using the iris dataset; we'll use 64 nodes, perform 1,000 iterations using the default learning rate. 44 | 45 | ```{r build_som} 46 | data(iris) 47 | 48 | # normalise and convert to matrix 49 | data_train_matrix <- as.matrix(scale(iris[, -5])) 50 | 51 | data_train <- list(measurement = data_train_matrix, 52 | species = iris[, 5]) 53 | 54 | som_grid <- somgrid(xdim = 8, ydim = 8, topo="hexagonal") 55 | 56 | # som_model <- som(data_train_matrix, 57 | som_model <- supersom(data_train, 58 | grid = som_grid, 59 | rlen = 1000, 60 | alpha = c(0.05, 0.01), 61 | keep.data = TRUE) 62 | 63 | names(som_model) 64 | 65 | summary(som_model) 66 | ``` 67 | 68 | Explore `som_model`. 69 | 70 | ```{r explore_som_model} 71 | str(som_model) 72 | 73 | str(som_model$grid) 74 | ``` 75 | 76 | First, we'll check for convergence. 77 | 78 | ```{r check_convergence} 79 | plot(som_model, type = "changes") 80 | ``` 81 | 82 | We can check the number of samples that are mapped to each node. (I tested different node numbers and 64 nodes gave us a good uniform distribution of samples mapped to each node.) 83 | 84 | ```{r node_count} 85 | plot(som_model, type="count", main="Node Counts") 86 | ``` 87 | 88 | The code plot is a nice visualisation of the weighted values across all nodes. 89 | 90 | ```{r code_plot} 91 | plot(som_model, type="codes") 92 | ``` 93 | 94 | The setosa species have longer sepal widths. 95 | 96 | ```{r species_vs_sepal_width} 97 | library(ggplot2) 98 | ggplot(iris, aes(x = Species, y = Sepal.Width)) + 99 | geom_violin() + 100 | theme_bw() 101 | ``` 102 | 103 | We can create a heatmap of the sepal width weights across all nodes. 104 | 105 | ```{r heatmap_sepal_width} 106 | my_var <- "Sepal.Width" 107 | 108 | plot(som_model, 109 | type = "property", 110 | property = som_model$codes$measurement[, my_var], 111 | main = my_var) 112 | ``` 113 | 114 | ## Further reading 115 | 116 | * [Tutorial](https://www.shanelynn.ie/self-organising-maps-for-customer-segmentation-using-r/) for building SOMs in R 117 | * [Tutorial on SOMs and on their implementation](http://www.ai-junkie.com/ann/som/som1.html) 118 | 119 | ## Session info 120 | 121 | Time built. 122 | 123 | ```{r time, echo=FALSE} 124 | Sys.time() 125 | ``` 126 | 127 | Session info. 128 | 129 | ```{r session_info, echo=FALSE} 130 | sessionInfo() 131 | ``` 132 | 133 | -------------------------------------------------------------------------------- /svm/README.md: -------------------------------------------------------------------------------- 1 | Introduction 2 | ------------ 3 | 4 | A support vector machine (SVM) is a supervised machine learning 5 | algorithm that can be used for classification and regression. The 6 | essence of SVM classification is broken down into four main concepts: 7 | 8 | - The separating hyperplane (a plane that can separate cases into 9 | their respective classes) 10 | - The maximum-margin hyperplane or maximum-margin linear discriminants 11 | (the hyperplane that has maximal distance from the different 12 | classes) 13 | 14 | ![Example of the maximum-margin 15 | hyperplane](img/SVM_Example_of_Hyperplanes.png) 16 | 17 | - The soft margin (allowing cases from another class to fall into the 18 | opposite class) 19 | - The kernel function (adding an additional dimension) 20 | 21 | SVMs rely on preprocessing the data to represent patterns in a high 22 | dimension using a kernel function, typically much higher than the 23 | original feature space. 24 | 25 | In essence, the kernel function is a mathematical trick that allows the 26 | SVM to perform a "two-dimensional" classification of a set of originally 27 | one-dimensional data. In general, a kernel function projects data from a 28 | low-dimensional space to a space of higher dimension. It is possible to 29 | prove that, for any given data set with consistent labels (where 30 | consistent simply means that the data set does not contain two identical 31 | objects with opposite labels) there exists a kernel function that will 32 | allow the data to be linearly separated ([Noble, Nature Biotechnology 33 | 2006](https://www.ncbi.nlm.nih.gov/pubmed/17160063)). 34 | 35 | Using a hyperplane from an SVM that uses a very high-dimensional kernel 36 | function will result in overfitting. An optimal kernel function can be 37 | selected from a fixed set of kernels in a statistically rigorous fashion 38 | by using cross-validation. Kernels also allow us to combine different 39 | data sets. 40 | 41 | Install packages if missing and load. 42 | 43 | ``` {.r} 44 | .libPaths('/packages') 45 | my_packages <- 'e1071' 46 | 47 | for (my_package in my_packages){ 48 | if(!require(my_package, character.only = TRUE)){ 49 | install.packages(my_package, '/packages') 50 | library(my_package, character.only = TRUE) 51 | } 52 | } 53 | ``` 54 | 55 | Breast cancer data 56 | ------------------ 57 | 58 | Using the [Breast Cancer Wisconsin (Diagnostic) Data 59 | Set](http://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)). 60 | 61 | ``` {.r} 62 | data <- read.table("../data/breast_cancer_data.csv", stringsAsFactors = FALSE, sep = ',', header = TRUE) 63 | ``` 64 | 65 | The class should be a factor; 2 is benign and 4 is malignant. 66 | 67 | ``` {.r} 68 | data$class <- factor(data$class) 69 | ``` 70 | 71 | Finally remove id column. 72 | 73 | ``` {.r} 74 | data <- data[,-1] 75 | ``` 76 | 77 | Separate into training (80%) and testing (20%). 78 | 79 | ``` {.r} 80 | set.seed(31) 81 | my_decider <- rbinom(n=nrow(data),size=1,p=0.8) 82 | table(my_decider) 83 | ``` 84 | 85 | ## my_decider 86 | ## 0 1 87 | ## 122 477 88 | 89 | ``` {.r} 90 | train <- data[as.logical(my_decider),] 91 | test <- data[!as.logical(my_decider),] 92 | ``` 93 | 94 | Using the `e1071` package. 95 | 96 | ``` {.r} 97 | tuned <- tune.svm(class ~ ., data = train, gamma = 10^(-6:-1), cost = 10^(-1:1)) 98 | summary(tuned) 99 | ``` 100 | 101 | ## 102 | ## Parameter tuning of 'svm': 103 | ## 104 | ## - sampling method: 10-fold cross validation 105 | ## 106 | ## - best parameters: 107 | ## gamma cost 108 | ## 0.01 1 109 | ## 110 | ## - best performance: 0.03554965 111 | ## 112 | ## - Detailed performance results: 113 | ## gamma cost error dispersion 114 | ## 1 1e-06 0.1 0.38151596 0.07135654 115 | ## 2 1e-05 0.1 0.38151596 0.07135654 116 | ## 3 1e-04 0.1 0.38151596 0.07135654 117 | ## 4 1e-03 0.1 0.36471631 0.07013290 118 | ## 5 1e-02 0.1 0.04184397 0.02770554 119 | ## 6 1e-01 0.1 0.03767730 0.02146151 120 | ## 7 1e-06 1.0 0.38151596 0.07135654 121 | ## 8 1e-05 1.0 0.38151596 0.07135654 122 | ## 9 1e-04 1.0 0.36471631 0.07013290 123 | ## 10 1e-03 1.0 0.03971631 0.02845194 124 | ## 11 1e-02 1.0 0.03554965 0.02197298 125 | ## 12 1e-01 1.0 0.03559397 0.02203474 126 | ## 13 1e-06 10.0 0.38151596 0.07135654 127 | ## 14 1e-05 10.0 0.36263298 0.06657450 128 | ## 15 1e-04 10.0 0.03971631 0.02845194 129 | ## 16 1e-03 10.0 0.03554965 0.02197298 130 | ## 17 1e-02 10.0 0.03971631 0.02670322 131 | ## 18 1e-01 10.0 0.05447695 0.02631057 132 | 133 | Train model using the best values for gamma and cost. 134 | 135 | ``` {.r} 136 | svm_model <- svm(class ~ ., data = train, kernel="radial", gamma=0.01, cost=1) 137 | summary(svm_model) 138 | ``` 139 | 140 | ## 141 | ## Call: 142 | ## svm(formula = class ~ ., data = train, kernel = "radial", gamma = 0.01, 143 | ## cost = 1) 144 | ## 145 | ## 146 | ## Parameters: 147 | ## SVM-Type: C-classification 148 | ## SVM-Kernel: radial 149 | ## cost: 1 150 | ## 151 | ## Number of Support Vectors: 72 152 | ## 153 | ## ( 36 36 ) 154 | ## 155 | ## 156 | ## Number of Classes: 2 157 | ## 158 | ## Levels: 159 | ## 2 4 160 | 161 | Predict test cases. 162 | 163 | ``` {.r} 164 | svm_predict <- predict(svm_model, test) 165 | table(svm_predict, test$class) 166 | ``` 167 | 168 | ## 169 | ## svm_predict 2 4 170 | ## 2 77 2 171 | ## 4 2 41 172 | 173 | Further reading 174 | --------------- 175 | 176 | - [Data Mining Algorithms In 177 | R/Classification/SVM](https://en.wikibooks.org/wiki/Data_Mining_Algorithms_In_R/Classification/SVM) 178 | 179 | Session info 180 | ------------ 181 | 182 | Time built. 183 | 184 | ## [1] "2022-10-20 06:51:34 UTC" 185 | 186 | Session info. 187 | 188 | ## R version 4.2.1 (2022-06-23) 189 | ## Platform: x86_64-pc-linux-gnu (64-bit) 190 | ## Running under: Ubuntu 20.04.4 LTS 191 | ## 192 | ## Matrix products: default 193 | ## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 194 | ## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/liblapack.so.3 195 | ## 196 | ## locale: 197 | ## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C 198 | ## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8 199 | ## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 200 | ## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C 201 | ## [9] LC_ADDRESS=C LC_TELEPHONE=C 202 | ## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C 203 | ## 204 | ## attached base packages: 205 | ## [1] stats graphics grDevices utils datasets methods base 206 | ## 207 | ## other attached packages: 208 | ## [1] e1071_1.7-11 forcats_0.5.1 stringr_1.4.0 dplyr_1.0.9 209 | ## [5] purrr_0.3.4 readr_2.1.2 tidyr_1.2.0 tibble_3.1.7 210 | ## [9] ggplot2_3.3.6 tidyverse_1.3.1 211 | ## 212 | ## loaded via a namespace (and not attached): 213 | ## [1] tidyselect_1.1.2 xfun_0.31 haven_2.5.0 colorspace_2.0-3 214 | ## [5] vctrs_0.4.1 generics_0.1.3 htmltools_0.5.2 yaml_2.3.5 215 | ## [9] utf8_1.2.2 rlang_1.0.3 pillar_1.7.0 glue_1.6.2 216 | ## [13] withr_2.5.0 DBI_1.1.3 dbplyr_2.2.1 modelr_0.1.8 217 | ## [17] readxl_1.4.0 lifecycle_1.0.1 munsell_0.5.0 gtable_0.3.0 218 | ## [21] cellranger_1.1.0 rvest_1.0.2 evaluate_0.15 knitr_1.39 219 | ## [25] tzdb_0.3.0 fastmap_1.1.0 class_7.3-20 fansi_1.0.3 220 | ## [29] broom_1.0.0 scales_1.2.0 backports_1.4.1 jsonlite_1.8.0 221 | ## [33] fs_1.5.2 hms_1.1.1 digest_0.6.29 stringi_1.7.6 222 | ## [37] grid_4.2.1 cli_3.3.0 tools_4.2.1 magrittr_2.0.3 223 | ## [41] proxy_0.4-27 crayon_1.5.1 pkgconfig_2.0.3 ellipsis_0.3.2 224 | ## [45] xml2_1.3.3 reprex_2.0.1 lubridate_1.8.0 rstudioapi_0.13 225 | ## [49] assertthat_0.2.1 rmarkdown_2.14 httr_1.4.3 R6_2.5.1 226 | ## [53] compiler_4.2.1 227 | -------------------------------------------------------------------------------- /svm/example.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Support Vector Machine" 3 | output: 4 | pdf_document: default 5 | html_notebook: default 6 | --- 7 | 8 | Using the [Breast Cancer Wisconsin (Diagnostic) Data Set](http://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)). 9 | 10 | ```{r} 11 | my_link <- 'http://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/breast-cancer-wisconsin.data' 12 | data <- read.table(url(my_link), stringsAsFactors = FALSE, header = FALSE, sep = ',') 13 | names(data) <- c('id','ct','ucsize','ucshape','ma','secs','bn','bc','nn','miti','class') 14 | head(data) 15 | ``` 16 | 17 | Any missing data? 18 | 19 | ```{r} 20 | dim(data)[1] * dim(data)[2] 21 | table(is.na.data.frame(data)) 22 | ``` 23 | 24 | What's the structure? 25 | 26 | ```{r} 27 | str(data) 28 | ``` 29 | 30 | Why is the bare nuclei (bn) stored as characters instead of integers? 31 | 32 | ```{r} 33 | table(data$bn) 34 | ``` 35 | 36 | Change the question marks into NA's and then into median values. 37 | 38 | ```{r} 39 | data$bn <- gsub(pattern = '\\?', replacement = NA, x = data$bn) 40 | data$bn <- as.integer(data$bn) 41 | my_median <- median(data$bn, na.rm = TRUE) 42 | data$bn[is.na(data$bn)] <- my_median 43 | str(data) 44 | ``` 45 | 46 | The class should be a factor; 2 is benign and 4 is malignant. 47 | 48 | ```{r} 49 | data$class <- factor(data$class) 50 | ``` 51 | 52 | Finally remove id the row name, which was not unique anyway. 53 | 54 | ```{r} 55 | data <- data[,-1] 56 | ``` 57 | 58 | Separate into training (80%) and testing (20%). 59 | 60 | ```{r} 61 | set.seed(31) 62 | my_decider <- rbinom(n=nrow(data),size=1,p=0.8) 63 | table(my_decider) 64 | train <- data[as.logical(my_decider),] 65 | test <- data[!as.logical(my_decider),] 66 | ``` 67 | 68 | Using the `e1071` package. 69 | 70 | ```{r} 71 | library(e1071) 72 | 73 | tuned <- tune.svm(class ~ ., data = train, gamma = 10^(-6:-1), cost = 10^(-1:1)) 74 | summary(tuned) 75 | ``` 76 | 77 | Train model using the best values for gamma and cost. 78 | 79 | ```{r} 80 | svm_model <- svm(class ~ ., data = train, kernel="radial", gamma=0.01, cost=1) 81 | summary(svm_model) 82 | ``` 83 | 84 | Predict test cases. 85 | 86 | ```{r} 87 | svm_predict <- predict(svm_model, test) 88 | table(svm_predict, test$class) 89 | ``` 90 | 91 | 92 | -------------------------------------------------------------------------------- /svm/img/SVM_Example_of_Hyperplanes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/svm/img/SVM_Example_of_Hyperplanes.png -------------------------------------------------------------------------------- /svm/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Support Vector Machine" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | ## Introduction 17 | 18 | A support vector machine (SVM) is a supervised machine learning algorithm that can be used for classification and regression. The essence of SVM classification is broken down into four main concepts: 19 | 20 | * The separating hyperplane (a plane that can separate cases into their respective classes) 21 | * The maximum-margin hyperplane or maximum-margin linear discriminants (the hyperplane that has maximal distance from the different classes) 22 | 23 | ![Example of the maximum-margin hyperplane](img/SVM_Example_of_Hyperplanes.png) 24 | 25 | * The soft margin (allowing cases from another class to fall into the opposite class) 26 | * The kernel function (adding an additional dimension) 27 | 28 | SVMs rely on preprocessing the data to represent patterns in a high dimension using a kernel function, typically much higher than the original feature space. 29 | 30 | In essence, the kernel function is a mathematical trick that allows the SVM to perform a "two-dimensional" classification of a set of originally one-dimensional data. In general, a kernel function projects data from a low-dimensional space to a space of higher dimension. It is possible to prove that, for any given data set with consistent labels (where consistent simply means that the data set does not contain two identical objects with opposite labels) there exists a kernel function that will allow the data to be linearly separated ([Noble, Nature Biotechnology 2006](https://www.ncbi.nlm.nih.gov/pubmed/17160063)). 31 | 32 | Using a hyperplane from an SVM that uses a very high-dimensional kernel function will result in overfitting. An optimal kernel function can be selected from a fixed set of kernels in a statistically rigorous fashion by using cross-validation. Kernels also allow us to combine different data sets. 33 | 34 | Install packages if missing and load. 35 | 36 | ```{r load_package, message=FALSE, warning=FALSE} 37 | .libPaths('/packages') 38 | my_packages <- 'e1071' 39 | 40 | for (my_package in my_packages){ 41 | if(!require(my_package, character.only = TRUE)){ 42 | install.packages(my_package, '/packages') 43 | library(my_package, character.only = TRUE) 44 | } 45 | } 46 | ``` 47 | 48 | ## Breast cancer data 49 | 50 | Using the [Breast Cancer Wisconsin (Diagnostic) Data Set](http://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)). 51 | 52 | ```{r} 53 | data <- read.table("../data/breast_cancer_data.csv", stringsAsFactors = FALSE, sep = ',', header = TRUE) 54 | ``` 55 | 56 | The class should be a factor; 2 is benign and 4 is malignant. 57 | 58 | ```{r} 59 | data$class <- factor(data$class) 60 | ``` 61 | 62 | Finally remove id column. 63 | 64 | ```{r} 65 | data <- data[,-1] 66 | ``` 67 | 68 | Separate into training (80%) and testing (20%). 69 | 70 | ```{r} 71 | set.seed(31) 72 | my_decider <- rbinom(n=nrow(data),size=1,p=0.8) 73 | table(my_decider) 74 | train <- data[as.logical(my_decider),] 75 | test <- data[!as.logical(my_decider),] 76 | ``` 77 | 78 | Using the `e1071` package. 79 | 80 | ```{r} 81 | tuned <- tune.svm(class ~ ., data = train, gamma = 10^(-6:-1), cost = 10^(-1:1)) 82 | summary(tuned) 83 | ``` 84 | 85 | Train model using the best values for gamma and cost. 86 | 87 | ```{r} 88 | svm_model <- svm(class ~ ., data = train, kernel="radial", gamma=0.01, cost=1) 89 | summary(svm_model) 90 | ``` 91 | 92 | Predict test cases. 93 | 94 | ```{r} 95 | svm_predict <- predict(svm_model, test) 96 | table(svm_predict, test$class) 97 | ``` 98 | 99 | ## Further reading 100 | 101 | * [Data Mining Algorithms In R/Classification/SVM](https://en.wikibooks.org/wiki/Data_Mining_Algorithms_In_R/Classification/SVM) 102 | 103 | ## Session info 104 | 105 | Time built. 106 | 107 | ```{r time, echo=FALSE} 108 | Sys.time() 109 | ``` 110 | 111 | Session info. 112 | 113 | ```{r session_info, echo=FALSE} 114 | sessionInfo() 115 | ``` 116 | 117 | -------------------------------------------------------------------------------- /tabnet/README.md: -------------------------------------------------------------------------------- 1 | Notes from [TabNet: Attentive Interpretable Tabular 2 | Learning](https://arxiv.org/abs/1908.07442). 3 | 4 | ## Abstract 5 | 6 | - We propose a novel high-performance and **interpretable** canonical 7 | deep tabular data learning architecture, TabNet. 8 | - TabNet uses **sequential attention** to choose which features to 9 | reason from at each decision step, enabling interpretability and 10 | more efficient learning as the learning capacity is used for the 11 | most salient features. 12 | - We demonstrate that TabNet outperforms other neural network and 13 | decision tree variants on a wide range of non-performance-saturated 14 | tabular datasets and yields interpretable feature attributions plus 15 | insights into the global model behavior. 16 | - Finally, for the first time to our knowledge, we demonstrate 17 | **self-supervised learning for tabular data**, significantly 18 | improving performance with unsupervised representation learning when 19 | unlabeled data is abundant. 20 | 21 | ## ChatGPT summary 22 | 23 | TabNet is a deep learning model designed specifically for tabular data, 24 | which consists of structured data organised into rows and columns, 25 | commonly found in databases and spreadsheets. TabNet combines ideas from 26 | both neural networks and decision trees to achieve state-of-the-art 27 | performance on tabular data while maintaining interpretability. The key 28 | features of TabNet include: 29 | 30 | 1. **TabNet Architecture**: 31 | - TabNet is a neural network architecture based on the transformer 32 | architecture, which is known for its success in natural language 33 | processing tasks. 34 | - It consists of a series of repeated encoder-decoder blocks, 35 | where each block performs feature transformation and 36 | attention-based feature selection. 37 | - The encoder blocks extract features from the input data, while 38 | the decoder blocks reconstruct the original features from the 39 | selected features. 40 | 2. **Attention Mechanism**: 41 | - TabNet utilises an attention mechanism to select informative 42 | features at each step of the training process. 43 | - At each layer of the network, a sparse attention mask is 44 | computed based on the feature importance scores obtained from 45 | the previous layer. 46 | - This attention mask is used to select a subset of features that 47 | are most relevant for prediction, allowing the model to focus on 48 | the most informative features while ignoring irrelevant ones. 49 | 3. **Sparse Feature Selection**: 50 | - Unlike traditional neural networks that use all input features 51 | for prediction, TabNet employs sparse feature selection to 52 | enhance interpretability and efficiency. 53 | - By selecting only a subset of features at each layer based on 54 | their importance scores, TabNet reduces the computational 55 | overhead and improves the model's ability to learn from 56 | high-dimensional tabular data. 57 | 4. **Decision-Tree-Like Properties**: 58 | - TabNet exhibits decision-tree-like properties, where each layer 59 | of the network selects a subset of features to make predictions. 60 | - This hierarchical feature selection process resembles the 61 | decision-making process of decision trees, making TabNet more 62 | interpretable and easier to understand compared to traditional 63 | neural networks. 64 | 5. **Interpretability**: 65 | - TabNet provides interpretability by allowing users to analyse 66 | the importance scores assigned to each input feature. 67 | - These importance scores indicate the contribution of each 68 | feature to the model's predictions, helping users understand 69 | which features are most influential in making decisions. 70 | 71 | ## Session info 72 | 73 | Time built. 74 | 75 | ## [1] "2024-06-20 23:19:50 UTC" 76 | 77 | Session info. 78 | 79 | ## R version 4.4.0 (2024-04-24) 80 | ## Platform: x86_64-pc-linux-gnu 81 | ## Running under: Ubuntu 22.04.4 LTS 82 | ## 83 | ## Matrix products: default 84 | ## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 85 | ## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so; LAPACK version 3.10.0 86 | ## 87 | ## locale: 88 | ## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C 89 | ## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8 90 | ## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 91 | ## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C 92 | ## [9] LC_ADDRESS=C LC_TELEPHONE=C 93 | ## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C 94 | ## 95 | ## time zone: Etc/UTC 96 | ## tzcode source: system (glibc) 97 | ## 98 | ## attached base packages: 99 | ## [1] stats graphics grDevices utils datasets methods base 100 | ## 101 | ## other attached packages: 102 | ## [1] lubridate_1.9.3 forcats_1.0.0 stringr_1.5.1 dplyr_1.1.4 103 | ## [5] purrr_1.0.2 readr_2.1.5 tidyr_1.3.1 tibble_3.2.1 104 | ## [9] ggplot2_3.5.1 tidyverse_2.0.0 105 | ## 106 | ## loaded via a namespace (and not attached): 107 | ## [1] gtable_0.3.5 compiler_4.4.0 tidyselect_1.2.1 scales_1.3.0 108 | ## [5] yaml_2.3.8 fastmap_1.1.1 R6_2.5.1 generics_0.1.3 109 | ## [9] knitr_1.46 munsell_0.5.1 pillar_1.9.0 tzdb_0.4.0 110 | ## [13] rlang_1.1.3 utf8_1.2.4 stringi_1.8.3 xfun_0.43 111 | ## [17] timechange_0.3.0 cli_3.6.2 withr_3.0.0 magrittr_2.0.3 112 | ## [21] digest_0.6.35 grid_4.4.0 hms_1.1.3 lifecycle_1.0.4 113 | ## [25] vctrs_0.6.5 evaluate_0.23 glue_1.7.0 fansi_1.0.6 114 | ## [29] colorspace_2.1-0 rmarkdown_2.27 tools_4.4.0 pkgconfig_2.0.3 115 | ## [33] htmltools_0.5.8.1 116 | -------------------------------------------------------------------------------- /tabnet/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "TabNet: Attentive Interpretable Tabular Learning" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | Notes from [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442). 17 | 18 | ## Abstract 19 | 20 | * We propose a novel high-performance and **interpretable** canonical deep tabular data learning architecture, TabNet. 21 | * TabNet uses **sequential attention** to choose which features to reason from at each decision step, enabling interpretability and more efficient learning as the learning capacity is used for the most salient features. 22 | * We demonstrate that TabNet outperforms other neural network and decision tree variants on a wide range of non-performance-saturated tabular datasets and yields interpretable feature attributions plus insights into the global model behavior. 23 | * Finally, for the first time to our knowledge, we demonstrate **self-supervised learning for tabular data**, significantly improving performance with unsupervised representation learning when unlabeled data is abundant. 24 | 25 | ## ChatGPT summary 26 | 27 | TabNet is a deep learning model designed specifically for tabular data, which consists of structured data organised into rows and columns, commonly found in databases and spreadsheets. TabNet combines ideas from both neural networks and decision trees to achieve state-of-the-art performance on tabular data while maintaining interpretability. The key features of TabNet include: 28 | 29 | 1. **TabNet Architecture**: 30 | - TabNet is a neural network architecture based on the transformer architecture, which is known for its success in natural language processing tasks. 31 | - It consists of a series of repeated encoder-decoder blocks, where each block performs feature transformation and attention-based feature selection. 32 | - The encoder blocks extract features from the input data, while the decoder blocks reconstruct the original features from the selected features. 33 | 34 | 2. **Attention Mechanism**: 35 | - TabNet utilises an attention mechanism to select informative features at each step of the training process. 36 | - At each layer of the network, a sparse attention mask is computed based on the feature importance scores obtained from the previous layer. 37 | - This attention mask is used to select a subset of features that are most relevant for prediction, allowing the model to focus on the most informative features while ignoring irrelevant ones. 38 | 39 | 3. **Sparse Feature Selection**: 40 | - Unlike traditional neural networks that use all input features for prediction, TabNet employs sparse feature selection to enhance interpretability and efficiency. 41 | - By selecting only a subset of features at each layer based on their importance scores, TabNet reduces the computational overhead and improves the model's ability to learn from high-dimensional tabular data. 42 | 43 | 4. **Decision-Tree-Like Properties**: 44 | - TabNet exhibits decision-tree-like properties, where each layer of the network selects a subset of features to make predictions. 45 | - This hierarchical feature selection process resembles the decision-making process of decision trees, making TabNet more interpretable and easier to understand compared to traditional neural networks. 46 | 47 | 5. **Interpretability**: 48 | - TabNet provides interpretability by allowing users to analyse the importance scores assigned to each input feature. 49 | - These importance scores indicate the contribution of each feature to the model's predictions, helping users understand which features are most influential in making decisions. 50 | 51 | ## Session info 52 | 53 | Time built. 54 | 55 | ```{r time, echo=FALSE} 56 | Sys.time() 57 | ``` 58 | 59 | Session info. 60 | 61 | ```{r session_info, echo=FALSE} 62 | sessionInfo() 63 | ``` 64 | -------------------------------------------------------------------------------- /template/README.md: -------------------------------------------------------------------------------- 1 | Introduction 2 | ------------ 3 | 4 | This README was generated by running from the root directory of this 5 | repository: 6 | 7 | script/rmd_to_md.sh template/template.Rmd 8 | 9 | Install packages if missing and load. 10 | 11 | ``` {.r} 12 | .libPaths('/packages') 13 | my_packages <- 'beepr' 14 | 15 | for (my_package in my_packages){ 16 | if(!require(my_package, character.only = TRUE)){ 17 | install.packages(my_package, '/packages') 18 | library(my_package, character.only = TRUE) 19 | } 20 | } 21 | ``` 22 | 23 | $\LaTeX$: 24 | 25 | $$ e = mc^2 $$ 26 | 27 | Breast cancer data 28 | ------------------ 29 | 30 | Using the [Breast Cancer Wisconsin (Diagnostic) Data 31 | Set](https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)). 32 | 33 | ``` {.r} 34 | data <- read.table( 35 | "../data/breast_cancer_data.csv", 36 | stringsAsFactors = FALSE, 37 | sep = ',', 38 | header = TRUE 39 | ) 40 | data$class <- factor(data$class) 41 | data <- data[,-1] 42 | ``` 43 | 44 | Separate into training (80%) and testing (20%). 45 | 46 | ``` {.r} 47 | set.seed(31) 48 | my_prob <- 0.8 49 | my_split <- as.logical( 50 | rbinom( 51 | n = nrow(data), 52 | size = 1, 53 | p = my_prob 54 | ) 55 | ) 56 | 57 | train <- data[my_split,] 58 | test <- data[!my_split,] 59 | ``` 60 | 61 | Results 62 | ------- 63 | 64 | ``` {.r} 65 | ggplot(data, aes(class, ucsize)) + 66 | geom_boxplot() 67 | ``` 68 | 69 | ![](img/plot-1.png) 70 | 71 | Session info 72 | ------------ 73 | 74 | Time built. 75 | 76 | ## [1] "2022-10-20 06:50:59 UTC" 77 | 78 | Session info. 79 | 80 | ## R version 4.2.1 (2022-06-23) 81 | ## Platform: x86_64-pc-linux-gnu (64-bit) 82 | ## Running under: Ubuntu 20.04.4 LTS 83 | ## 84 | ## Matrix products: default 85 | ## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 86 | ## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/liblapack.so.3 87 | ## 88 | ## locale: 89 | ## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C 90 | ## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8 91 | ## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 92 | ## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C 93 | ## [9] LC_ADDRESS=C LC_TELEPHONE=C 94 | ## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C 95 | ## 96 | ## attached base packages: 97 | ## [1] stats graphics grDevices utils datasets methods base 98 | ## 99 | ## other attached packages: 100 | ## [1] beepr_1.3 forcats_0.5.1 stringr_1.4.0 dplyr_1.0.9 101 | ## [5] purrr_0.3.4 readr_2.1.2 tidyr_1.2.0 tibble_3.1.7 102 | ## [9] ggplot2_3.3.6 tidyverse_1.3.1 103 | ## 104 | ## loaded via a namespace (and not attached): 105 | ## [1] tidyselect_1.1.2 xfun_0.31 haven_2.5.0 colorspace_2.0-3 106 | ## [5] vctrs_0.4.1 generics_0.1.3 htmltools_0.5.2 yaml_2.3.5 107 | ## [9] utf8_1.2.2 rlang_1.0.3 pillar_1.7.0 glue_1.6.2 108 | ## [13] withr_2.5.0 DBI_1.1.3 dbplyr_2.2.1 modelr_0.1.8 109 | ## [17] readxl_1.4.0 audio_0.1-10 lifecycle_1.0.1 munsell_0.5.0 110 | ## [21] gtable_0.3.0 cellranger_1.1.0 rvest_1.0.2 evaluate_0.15 111 | ## [25] labeling_0.4.2 knitr_1.39 tzdb_0.3.0 fastmap_1.1.0 112 | ## [29] fansi_1.0.3 highr_0.9 broom_1.0.0 scales_1.2.0 113 | ## [33] backports_1.4.1 jsonlite_1.8.0 farver_2.1.1 fs_1.5.2 114 | ## [37] hms_1.1.1 digest_0.6.29 stringi_1.7.6 grid_4.2.1 115 | ## [41] cli_3.3.0 tools_4.2.1 magrittr_2.0.3 crayon_1.5.1 116 | ## [45] pkgconfig_2.0.3 ellipsis_0.3.2 xml2_1.3.3 reprex_2.0.1 117 | ## [49] lubridate_1.8.0 rstudioapi_0.13 assertthat_0.2.1 rmarkdown_2.14 118 | ## [53] httr_1.4.3 R6_2.5.1 compiler_4.2.1 119 | -------------------------------------------------------------------------------- /template/img/plot-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/template/img/plot-1.png -------------------------------------------------------------------------------- /template/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Template" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | ## Introduction 17 | 18 | This README was generated by running from the root directory of this repository: 19 | 20 | script/rmd_to_md.sh template/template.Rmd 21 | 22 | Install packages if missing and load. 23 | 24 | ```{r load_package, message=FALSE, warning=FALSE} 25 | .libPaths('/packages') 26 | my_packages <- 'beepr' 27 | 28 | for (my_package in my_packages){ 29 | if(!require(my_package, character.only = TRUE)){ 30 | install.packages(my_package, '/packages') 31 | library(my_package, character.only = TRUE) 32 | } 33 | } 34 | ``` 35 | 36 | $\LaTeX$: 37 | 38 | $$ e = mc^2 $$ 39 | 40 | ## Breast cancer data 41 | 42 | Using the [Breast Cancer Wisconsin (Diagnostic) Data Set](https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)). 43 | 44 | ```{r prepare_data} 45 | data <- read.table( 46 | "../data/breast_cancer_data.csv", 47 | stringsAsFactors = FALSE, 48 | sep = ',', 49 | header = TRUE 50 | ) 51 | data$class <- factor(data$class) 52 | data <- data[,-1] 53 | ``` 54 | 55 | Separate into training (80%) and testing (20%). 56 | 57 | ```{r split_data} 58 | set.seed(31) 59 | my_prob <- 0.8 60 | my_split <- as.logical( 61 | rbinom( 62 | n = nrow(data), 63 | size = 1, 64 | p = my_prob 65 | ) 66 | ) 67 | 68 | train <- data[my_split,] 69 | test <- data[!my_split,] 70 | ``` 71 | 72 | ## Results 73 | 74 | ```{r plot} 75 | ggplot(data, aes(class, ucsize)) + 76 | geom_boxplot() 77 | ``` 78 | 79 | ## Session info 80 | 81 | Time built. 82 | 83 | ```{r time, echo=FALSE} 84 | Sys.time() 85 | ``` 86 | 87 | Session info. 88 | 89 | ```{r session_info, echo=FALSE} 90 | sessionInfo() 91 | ``` 92 | 93 | -------------------------------------------------------------------------------- /tidymodels/img/imbalance_pr_curve-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tidymodels/img/imbalance_pr_curve-1.png -------------------------------------------------------------------------------- /tidymodels/img/imbalance_roc_curve-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tidymodels/img/imbalance_roc_curve-1.png -------------------------------------------------------------------------------- /tidymodels/img/pr_curve-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tidymodels/img/pr_curve-1.png -------------------------------------------------------------------------------- /tidymodels/img/roc_curve-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tidymodels/img/roc_curve-1.png -------------------------------------------------------------------------------- /tidymodels/img/rocr_pr_curve-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tidymodels/img/rocr_pr_curve-1.png -------------------------------------------------------------------------------- /tidymodels/img/rocr_roc_curve-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tidymodels/img/rocr_roc_curve-1.png -------------------------------------------------------------------------------- /tidymodels/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Getting started with tidymodels" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | knitr::opts_chunk$set(cache = FALSE) 10 | knitr::opts_chunk$set(echo = TRUE) 11 | knitr::opts_chunk$set(fig.path = "img/") 12 | ``` 13 | 14 | ## Setup 15 | 16 | Install packages. 17 | 18 | ```{r load_packages, message=FALSE, warning=FALSE} 19 | my_packages <- c('tidyverse', 'tidymodels', 'randomForest', 'ROCR') 20 | 21 | for (my_package in my_packages){ 22 | if(!require(my_package, character.only = TRUE)){ 23 | install.packages(my_package) 24 | } 25 | library(my_package, character.only = TRUE) 26 | } 27 | 28 | theme_set(theme_bw()) 29 | ``` 30 | 31 | ## Spam 32 | 33 | Use [spam data](https://archive.ics.uci.edu/ml/machine-learning-databases/spambase/spambase.names) to train a Random Forest model to illustrate evaluation measures. 34 | 35 | ```{r load_spam} 36 | spam_data <- read.csv(file = "../data/spambase.csv") 37 | dim(spam_data) 38 | ``` 39 | 40 | Class 0 and 1 are ham (non-spam) and spam, respectively. 41 | 42 | ```{r preview_spam} 43 | spam_data$class <- factor(spam_data$class) 44 | spam_data[c(1:3, (nrow(spam_data)-2):nrow(spam_data)), (ncol(spam_data)-2):ncol(spam_data)] 45 | ``` 46 | 47 | [Split data](https://www.tidymodels.org/start/resampling/#data-split) using [rsample](https://rsample.tidymodels.org/). 48 | 49 | The `initial_split()` function takes the original data and saves the information on how to make the partitions. The `strata` argument conducts a stratified split ensuring that our training and test data sets will keep roughly the same proportion of classes. 50 | 51 | ```{r split_spam} 52 | set.seed(1984) 53 | spam_split <- initial_split(data = spam_data, prop = 0.8, strata = 'class') 54 | spam_split 55 | spam_train <- training(spam_split) 56 | spam_test <- testing(spam_split) 57 | ``` 58 | 59 | ## `parsnip` 60 | 61 | The [parsnip package](https://parsnip.tidymodels.org/index.html) provides a tidy and unified interface to a range of models. 62 | 63 | ```{r train_rf} 64 | my_mtry <- ceiling(sqrt(ncol(spam_data))) 65 | 66 | rf <- list() 67 | rand_forest(mtry = my_mtry, trees = 500) |> 68 | set_engine("randomForest") |> 69 | set_mode("classification") -> rf$model 70 | 71 | rf$model |> 72 | fit(class ~ ., data = spam_train) -> rf$fit 73 | 74 | rf$model 75 | ``` 76 | 77 | `rf` contains the model parameters and the model. 78 | 79 | ```{r str_rf} 80 | str(rf, max.level = 2) 81 | ``` 82 | 83 | ## `yardstick` 84 | 85 | The [yardstick package](https://yardstick.tidymodels.org/) provides a tidy interface to estimate how well models are performing. 86 | 87 | Example data to check how to prepare our data for use with `yardstick`. 88 | 89 | ```{r two_class_example} 90 | data(two_class_example) 91 | str(two_class_example) 92 | ``` 93 | 94 | Make predictions on the test data; `.pred_1` is the "probability" of spam. 95 | 96 | ```{r pred_spam_test} 97 | predict(rf$fit, spam_test, type = 'prob') 98 | ``` 99 | 100 | Predict and generate table in the format of `two_class_example` using a wrapper function. 101 | 102 | * `fit` - model 103 | * `test_data` - test data 104 | * `pos` - class that the model is testing for 105 | * `neg` - the other class 106 | 107 | Since the model is testing for spam, `pos` is 'spam'. 108 | 109 | ```{r predict} 110 | predict_wrapper <- function(fit, test_data, pos, neg, type = 'prob'){ 111 | predict(fit, test_data, type = type) |> 112 | mutate(truth = ifelse(as.integer(test_data$class) == 2, pos, neg)) |> 113 | mutate(truth = factor(truth, levels = c(pos, neg))) |> 114 | rename( 115 | ham = .pred_0, 116 | spam = .pred_1 117 | ) |> 118 | mutate( 119 | predicted = ifelse(spam > 0.5, pos, neg) 120 | ) |> 121 | mutate( 122 | predicted = factor(predicted, levels = c(pos, neg)) 123 | ) |> 124 | select(truth, everything()) 125 | } 126 | 127 | rf$predictions <- predict_wrapper(rf$fit, spam_test, 'spam', 'ham') 128 | rf$predictions 129 | ``` 130 | 131 | Confusion matrix. 132 | 133 | ```{r confusion_matrix} 134 | cm <- table(rf$predictions$truth, rf$predictions$predicted) 135 | cm |> 136 | prop.table() 137 | ``` 138 | 139 | Metrics. 140 | 141 | ```{r calculate_metrics} 142 | metrics(rf$predictions, truth, predicted) 143 | ``` 144 | 145 | [table_metrics](https://github.com/davetang/learning_r/blob/main/code/table_metrics.R). 146 | 147 | ```{r table_metrics} 148 | source("https://raw.githubusercontent.com/davetang/learning_r/main/code/table_metrics.R") 149 | table_metrics(cm, 'spam', 'ham', 'row', sig_fig = 7) 150 | ``` 151 | 152 | Area under the PR curve. 153 | 154 | ```{r pr_auc} 155 | pr_auc(rf$predictions, truth, spam) 156 | ``` 157 | 158 | [PR curve](https://yardstick.tidymodels.org/reference/pr_curve.html). 159 | 160 | ```{r pr_curve} 161 | pr_curve(rf$predictions, truth, spam) |> 162 | ggplot(aes(x = recall, y = precision)) + 163 | geom_path() + 164 | coord_equal() + 165 | ylim(c(0, 1)) + 166 | ggtitle('PR curve') 167 | ``` 168 | 169 | Area under the ROC curve. 170 | 171 | ```{r roc_auc} 172 | roc_auc(rf$predictions, truth, spam) 173 | ``` 174 | 175 | [ROC curve](https://yardstick.tidymodels.org/reference/roc_curve.html). 176 | 177 | ```{r roc_curve} 178 | roc_curve(rf$predictions, truth, spam) |> 179 | ggplot(aes(x = 1 - specificity, y = sensitivity)) + 180 | geom_path() + 181 | geom_abline(lty = 3) + 182 | coord_equal() + 183 | ggtitle('ROC curve') 184 | ``` 185 | 186 | ### Using ROCR 187 | 188 | Compare with [ROCR](https://cran.rstudio.com/web/packages/ROCR/vignettes/ROCR.html). 189 | 190 | Every classifier evaluation using {ROCR} starts with creating a prediction object. 191 | 192 | ```{r rocr_pred} 193 | predictions <- predict(rf$fit, spam_test, type = 'prob')$.pred_1 194 | labels <- spam_test$class 195 | pred <- prediction(predictions, labels) 196 | aucpr <- performance(pred, "aucpr") 197 | aucroc <- performance(pred, "auc") 198 | str(aucpr) 199 | str(aucroc) 200 | ``` 201 | 202 | PR curve. 203 | 204 | ```{r rocr_pr_curve} 205 | perf <- performance(pred, "prec", "rec") 206 | plot(perf, lwd= 1, main= "PR curve") 207 | ``` 208 | 209 | ROC curve. 210 | 211 | ```{r rocr_roc_curve} 212 | perf <- performance(pred, "tpr", "fpr") 213 | plot(perf, lwd= 1, main= "ROC curve") 214 | ``` 215 | 216 | ### Class imbalance 217 | 218 | Difference in area under the ROC curve and area under the precision recall curve. 219 | 220 | ```{r class_imbalance} 221 | set.seed(1984) 222 | 223 | n <- 1000 224 | p_positive <- 0.05 225 | 226 | y <- factor(rbinom(n, 1, p_positive)) 227 | x <- rnorm(n, mean = ifelse(y == 1, 2, 0), sd = 1) 228 | 229 | fit <- glm(y ~ x, family = binomial) 230 | probs <- predict(fit, type = "response") 231 | 232 | data.frame( 233 | truth = ifelse(y == 1, 'DA', 'NotDA'), 234 | NotDA = 1 - probs, 235 | DA = probs 236 | ) |> 237 | dplyr::mutate(truth = factor(truth, levels = c('DA', 'NotDA'))) -> toy_data 238 | ``` 239 | 240 | Identify most `DA` (>96%) but a lot of `NotDA` (>43%) are also predicted as `DA`, i.e., false positives. 241 | 242 | ```{r class_imbalance_table} 243 | toy_data |> 244 | dplyr::mutate( 245 | predicted = factor(ifelse(DA > 0.01, 'DA', 'NotDA'), levels = c('DA', 'NotDA')) 246 | ) |> 247 | dplyr::select(truth, predicted) |> 248 | table() |> 249 | prop.table(margin = 1) 250 | ``` 251 | 252 | Area under the ROC curve (high if we can rank positives higher than negatives) 253 | 254 | ```{r imbalance_roc_curve} 255 | roc_curve(toy_data, truth, DA) |> 256 | ggplot(aes(x = 1 - specificity, y = sensitivity)) + 257 | geom_path() + 258 | geom_abline(lty = 3) + 259 | coord_equal() + 260 | ggtitle(round(roc_auc(toy_data, truth, DA)$.estimate, 5)) + 261 | theme_minimal() 262 | ``` 263 | 264 | Area under the precision recall curve (sensitive to false positives) 265 | 266 | ```{r imbalance_pr_curve} 267 | pr_curve(toy_data, truth, DA) |> 268 | ggplot(aes(x = recall, y = precision)) + 269 | geom_path() + 270 | coord_equal() + 271 | ylim(c(0, 1)) + 272 | ggtitle(round(pr_auc(toy_data, truth, DA)$.estimate, 5)) + 273 | theme_minimal() 274 | ``` 275 | 276 | ## Session info 277 | 278 | Time built. 279 | 280 | ```{r time, echo=FALSE} 281 | Sys.time() 282 | ``` 283 | 284 | Session info. 285 | 286 | ```{r session_info, echo=FALSE} 287 | sessionInfo() 288 | ``` 289 | -------------------------------------------------------------------------------- /tidymodels/xgboost.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Getting started with tidymodels" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | knitr::opts_chunk$set(cache = FALSE) 10 | knitr::opts_chunk$set(echo = TRUE) 11 | knitr::opts_chunk$set(fig.path = "img/") 12 | ``` 13 | 14 | ## Setup 15 | 16 | Install packages. 17 | 18 | ```{r load_packages, message=FALSE, warning=FALSE} 19 | my_packages <- c('tidyverse', 'tidymodels', 'randomForest', 'ROCR') 20 | 21 | for (my_package in my_packages){ 22 | if(!require(my_package, character.only = TRUE)){ 23 | install.packages(my_package) 24 | } 25 | library(my_package, character.only = TRUE) 26 | } 27 | 28 | theme_set(theme_bw()) 29 | ``` 30 | 31 | ## Spam 32 | 33 | Use [spam data](https://archive.ics.uci.edu/ml/machine-learning-databases/spambase/spambase.names) to train a Random Forest model to illustrate evaluation measures. 34 | 35 | ```{r load_spam} 36 | spam_data <- read.csv(file = "../data/spambase.csv") 37 | dim(spam_data) 38 | ``` 39 | 40 | Class 0 and 1 are ham (non-spam) and spam, respectively. 41 | 42 | ```{r preview_spam} 43 | spam_data$class <- factor(spam_data$class) 44 | spam_data[c(1:3, (nrow(spam_data)-2):nrow(spam_data)), (ncol(spam_data)-2):ncol(spam_data)] 45 | ``` 46 | 47 | [Split data](https://www.tidymodels.org/start/resampling/#data-split) using [rsample](https://rsample.tidymodels.org/). 48 | 49 | The `initial_split()` function takes the original data and saves the information on how to make the partitions. The `strata` argument conducts a stratified split ensuring that our training and test data sets will keep roughly the same proportion of classes. 50 | 51 | ```{r split_spam} 52 | set.seed(1984) 53 | spam_split <- initial_split(data = spam_data, prop = 0.8, strata = 'class') 54 | spam_split 55 | spam_train <- training(spam_split) 56 | spam_test <- testing(spam_split) 57 | ``` 58 | 59 | ## `parsnip` 60 | 61 | The [parsnip package](https://parsnip.tidymodels.org/index.html) provides a tidy and unified interface to a range of models. 62 | 63 | ```{r train_rf} 64 | my_mtry <- ceiling(sqrt(ncol(spam_data))) 65 | 66 | rf <- list() 67 | rand_forest(mtry = my_mtry, trees = 500) |> 68 | set_engine("randomForest") |> 69 | set_mode("classification") -> rf$model 70 | 71 | rf$model |> 72 | fit(class ~ ., data = spam_train) -> rf$fit 73 | 74 | rf$model 75 | ``` 76 | 77 | `rf` contains the model parameters and the model. 78 | 79 | ```{r str_rf} 80 | str(rf, max.level = 2) 81 | ``` 82 | 83 | ## `yardstick` 84 | 85 | The [yardstick package](https://yardstick.tidymodels.org/) provides a tidy interface to estimate how well models are performing. 86 | 87 | Example data to check how to prepare our data for use with `yardstick`. 88 | 89 | ```{r two_class_example} 90 | data(two_class_example) 91 | str(two_class_example) 92 | ``` 93 | 94 | Make predictions on the test data; `.pred_1` is the "probability" of spam. 95 | 96 | ```{r pred_spam_test} 97 | predict(rf$fit, spam_test, type = 'prob') 98 | ``` 99 | 100 | Predict and generate table in the format of `two_class_example` using a wrapper function. 101 | 102 | * `fit` - model 103 | * `test_data` - test data 104 | * `pos` - class that the model is testing for 105 | * `neg` - the other class 106 | 107 | Since the model is testing for spam, `pos` is 'spam'. 108 | 109 | ```{r predict} 110 | predict_wrapper <- function(fit, test_data, pos, neg, type = 'prob'){ 111 | predict(fit, test_data, type = type) |> 112 | mutate(truth = ifelse(as.integer(test_data$class) == 2, pos, neg)) |> 113 | mutate(truth = factor(truth, levels = c(pos, neg))) |> 114 | rename( 115 | ham = .pred_0, 116 | spam = .pred_1 117 | ) |> 118 | mutate( 119 | predicted = ifelse(spam > 0.5, pos, neg) 120 | ) |> 121 | mutate( 122 | predicted = factor(predicted, levels = c(pos, neg)) 123 | ) |> 124 | select(truth, everything()) 125 | } 126 | 127 | rf$predictions <- predict_wrapper(rf$fit, spam_test, 'spam', 'ham') 128 | rf$predictions 129 | ``` 130 | 131 | Confusion matrix. 132 | 133 | ```{r confusion_matrix} 134 | cm <- table(rf$predictions$truth, rf$predictions$predicted) 135 | cm |> 136 | prop.table() 137 | ``` 138 | 139 | Metrics. 140 | 141 | ```{r calculate_metrics} 142 | metrics(rf$predictions, truth, predicted) 143 | ``` 144 | 145 | [table_metrics](https://github.com/davetang/learning_r/blob/main/code/table_metrics.R). 146 | 147 | ```{r table_metrics} 148 | source("https://raw.githubusercontent.com/davetang/learning_r/main/code/table_metrics.R") 149 | table_metrics(cm, 'spam', 'ham', 'row', sig_fig = 7) 150 | ``` 151 | 152 | Area under the PR curve. 153 | 154 | ```{r pr_auc} 155 | pr_auc(rf$predictions, truth, spam) 156 | ``` 157 | 158 | [PR curve](https://yardstick.tidymodels.org/reference/pr_curve.html). 159 | 160 | ```{r pr_curve} 161 | pr_curve(rf$predictions, truth, spam) |> 162 | ggplot(aes(x = recall, y = precision)) + 163 | geom_path() + 164 | coord_equal() + 165 | ylim(c(0, 1)) + 166 | ggtitle('PR curve') 167 | ``` 168 | 169 | Area under the ROC curve. 170 | 171 | ```{r roc_auc} 172 | roc_auc(rf$predictions, truth, spam) 173 | ``` 174 | 175 | [ROC curve](https://yardstick.tidymodels.org/reference/roc_curve.html). 176 | 177 | ```{r roc_curve} 178 | roc_curve(rf$predictions, truth, spam) |> 179 | ggplot(aes(x = 1 - specificity, y = sensitivity)) + 180 | geom_path() + 181 | geom_abline(lty = 3) + 182 | coord_equal() + 183 | ggtitle('ROC curve') 184 | ``` 185 | 186 | ### Using ROCR 187 | 188 | Compare with [ROCR](https://cran.rstudio.com/web/packages/ROCR/vignettes/ROCR.html). 189 | 190 | Every classifier evaluation using {ROCR} starts with creating a prediction object. 191 | 192 | ```{r rocr_pred} 193 | predictions <- predict(rf$fit, spam_test, type = 'prob')$.pred_1 194 | labels <- spam_test$class 195 | pred <- prediction(predictions, labels) 196 | aucpr <- performance(pred, "aucpr") 197 | aucroc <- performance(pred, "auc") 198 | str(aucpr) 199 | str(aucroc) 200 | ``` 201 | 202 | PR curve. 203 | 204 | ```{r rocr_pr_curve} 205 | perf <- performance(pred, "prec", "rec") 206 | plot(perf, lwd= 1, main= "PR curve") 207 | ``` 208 | 209 | ROC curve. 210 | 211 | ```{r rocr_roc_curve} 212 | perf <- performance(pred, "tpr", "fpr") 213 | plot(perf, lwd= 1, main= "ROC curve") 214 | ``` 215 | 216 | ## Session info 217 | 218 | Time built. 219 | 220 | ```{r time, echo=FALSE} 221 | Sys.time() 222 | ``` 223 | 224 | Session info. 225 | 226 | ```{r session_info, echo=FALSE} 227 | sessionInfo() 228 | ``` 229 | -------------------------------------------------------------------------------- /tree/img/unnamed-chunk-1-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tree/img/unnamed-chunk-1-1.png -------------------------------------------------------------------------------- /tree/img/unnamed-chunk-2-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tree/img/unnamed-chunk-2-1.png -------------------------------------------------------------------------------- /tree/img/unnamed-chunk-3-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tree/img/unnamed-chunk-3-1.png -------------------------------------------------------------------------------- /tree/img/unnamed-chunk-4-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tree/img/unnamed-chunk-4-1.png -------------------------------------------------------------------------------- /tree/img/unnamed-chunk-5-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tree/img/unnamed-chunk-5-1.png -------------------------------------------------------------------------------- /tree/img/unnamed-chunk-5-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tree/img/unnamed-chunk-5-2.png -------------------------------------------------------------------------------- /tree/img/unnamed-chunk-6-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tree/img/unnamed-chunk-6-1.png -------------------------------------------------------------------------------- /tree/img/unnamed-chunk-9-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/tree/img/unnamed-chunk-9-1.png -------------------------------------------------------------------------------- /tree/readme.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Decision trees" 3 | output: 4 | md_document: 5 | variant: markdown 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | library(tidyverse) 10 | theme_set(theme_bw()) 11 | knitr::opts_chunk$set(cache = FALSE) 12 | knitr::opts_chunk$set(echo = TRUE) 13 | knitr::opts_chunk$set(fig.path = "img/") 14 | ``` 15 | 16 | ## Introduction 17 | 18 | A [decision tree](https://en.wikipedia.org/wiki/Decision_tree) is a decision support tool that uses a tree-like graph or model of decisions and their possible consequences, including chance event outcomes, resource costs, and utility. 19 | 20 | ```{r load_package, message=FALSE, warning=FALSE} 21 | .libPaths('/packages') 22 | my_packages <- c('tree', 'rpart', 'rpart.plot') 23 | 24 | for (my_package in my_packages){ 25 | if(!require(my_package, character.only = TRUE)){ 26 | install.packages(my_package, '/packages') 27 | library(my_package, character.only = TRUE) 28 | } 29 | } 30 | ``` 31 | 32 | 33 | ```{r} 34 | tree1 <- tree(Species ~ Sepal.Width + Petal.Width, data = iris) 35 | summary(tree1) 36 | plot(tree1) 37 | text(tree1) 38 | ``` 39 | 40 | ```{r} 41 | plot(iris$Petal.Width, 42 | iris$Sepal.Width, 43 | pch=19, 44 | col=as.numeric(iris$Species)) 45 | 46 | partition.tree(tree1, label="Species", add=TRUE) 47 | legend(2.3,4.5, 48 | legend=levels(iris$Species), 49 | col=1:length(levels(iris$Species)), 50 | pch=19, 51 | bty = 'n') 52 | ``` 53 | 54 | ```{r} 55 | tree2 <- tree(Species ~ ., data = iris) 56 | summary(tree2) 57 | plot(tree2); text(tree2) 58 | ``` 59 | 60 | Each node shows: 61 | 62 | 1. The predicted class (setosa, versicolor, and virginica) 63 | 2. The numbers of each class (in the order above) 64 | 3. The percentage of all samples 65 | 66 | ```{r} 67 | rpart <- rpart(Species ~ ., data=iris, method="class") 68 | summary(rpart) 69 | rpart.plot(rpart, type = 4, extra = 101) 70 | ``` 71 | 72 | ## Titanic data 73 | 74 | ```{r} 75 | titanic <- read.csv('../data/titanic.csv.gz') 76 | str(titanic) 77 | 78 | titanic$Pclass <- factor(titanic$Pclass) 79 | boxplot(Fare ~ Pclass, data = titanic) 80 | ``` 81 | 82 | Each node shows: 83 | 84 | 1. The predicted class (0 or 1) 85 | 2. The predicted probability of survival 86 | 3. The percentage of all samples 87 | 88 | ```{r} 89 | t <- rpart(Survived ~ Sex + Fare + Age, data=titanic, method="class") 90 | rpart.plot(t) 91 | ``` 92 | 93 | For example the 0.74 indicates that 74% of females survived. 94 | 95 | ```{r} 96 | prop.table( 97 | table( 98 | titanic$Sex, titanic$Survived 99 | ), margin = 1 100 | ) 101 | ``` 102 | 103 | ## Breast cancer data 104 | 105 | ```{r} 106 | data <- read.table( 107 | "../data/breast_cancer_data.csv", 108 | stringsAsFactors = FALSE, 109 | sep = ',', 110 | header = TRUE 111 | ) 112 | data$class <- factor(data$class) 113 | data <- data[,-1] 114 | ``` 115 | 116 | Each node shows: 117 | 118 | 1. The predicted class (0 or 1) 119 | 2. The predicted probability of malignancy 120 | 3. The percentage of all samples 121 | 122 | ```{r} 123 | t <- rpart(class ~ ., data = data, method="class") 124 | rpart.plot(t) 125 | ``` 126 | 127 | ## Session info 128 | 129 | Time built. 130 | 131 | ```{r time, echo=FALSE} 132 | Sys.time() 133 | ``` 134 | 135 | Session info. 136 | 137 | ```{r session_info, echo=FALSE} 138 | sessionInfo() 139 | ``` 140 | 141 | -------------------------------------------------------------------------------- /variant/image/swissvar_cadd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/variant/image/swissvar_cadd.png -------------------------------------------------------------------------------- /variant/image/swissvar_gerp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/variant/image/swissvar_gerp.png -------------------------------------------------------------------------------- /variant/kabuki_hg18.bed: -------------------------------------------------------------------------------- 1 | chr12 47706820 47706821 c.G15195A 0 + 2 | chr12 47722237 47722238 c.C6010T 0 + 3 | chr12 47712057 47712058 c.C12697T 0 + 4 | chr12 47718917 47718918 c.C8488T 0 + 5 | chr12 47706397 47706398 c.T15618G 0 + 6 | chr12 47721524 47721525 c.C6295T 0 + 7 | -------------------------------------------------------------------------------- /variant/miller_hg18.bed: -------------------------------------------------------------------------------- 1 | chr16 70603483 70603484 G56A 0 + 2 | chr16 70606040 70606041 C403T 0 + 3 | chr16 70608442 70608443 G454A 0 + 4 | chr16 70612600 70612601 C595T 0 + 5 | chr16 70612610 70612611 G605A 0 + 6 | chr16 70612610 70612611 G605C 0 + 7 | chr16 70613785 70613786 C730T 0 + 8 | chr16 70614595 70614596 C851T 0 + 9 | chr16 70614935 70614936 C1036T 0 + 10 | chr16 70614935 70614936 C1036T 0 + 11 | chr16 70615585 70615586 A1175G 0 + 12 | -------------------------------------------------------------------------------- /variant/myvariant.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Using the myvariant package to annotate variants" 3 | author: "Dave Tang" 4 | date: "3 November 2016" 5 | output: pdf_document 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | knitr::opts_chunk$set(echo = TRUE) 10 | ``` 11 | 12 | # Install necessary packages 13 | 14 | ```{r, eval = FALSE} 15 | source("https://bioconductor.org/biocLite.R") 16 | biocLite('myvariant') 17 | biocLite('VariantAnnotation') 18 | ``` 19 | 20 | # Load libraries 21 | 22 | ```{r} 23 | library(VariantAnnotation) 24 | library(myvariant) 25 | ``` 26 | 27 | # Download example file 28 | 29 | ```{r} 30 | download.file(url = 'http://davetang.org/eg/Pfeiffer.vcf', destfile = 'Pfeiffer.vcf') 31 | my_vcf <- readVcf('Pfeiffer.vcf', genome = 'hg19') 32 | my_hgvs <- formatHgvs(my_vcf) 33 | head(my_hgvs) 34 | length(my_hgvs) 35 | ``` 36 | 37 | # Obtain annotations for your variants 38 | 39 | ```{r} 40 | my_var <- getVariants(my_hgvs) 41 | ``` 42 | 43 | # Checking out the variant annotations 44 | 45 | ```{r} 46 | class(my_var) 47 | 48 | dim(my_var) 49 | ``` 50 | 51 | ```{r} 52 | library(dplyr) 53 | my_var_tbl <- tbl_df(my_var) 54 | dim(my_var_tbl) 55 | my_var_tbl %>% select(notfound) %>% count(notfound) 56 | ``` 57 | 58 | ## Filtering cases that were not found 59 | 60 | ```{r} 61 | my_var_tbl %>% filter (is.na(notfound)) %>% select(query, starts_with('evs')) %>% dim() 62 | ``` 63 | 64 | ```{r} 65 | my_var_tbl %>% filter (is.na(notfound)) %>% select(query, starts_with('cadd')) %>% dim() 66 | ``` 67 | 68 | ## ClinVar 69 | 70 | ```{r} 71 | my_var_tbl %>% filter (is.na(notfound), !is.na(clinvar.omim)) %>% select(query, starts_with('clinvar')) 72 | ``` 73 | 74 | ## dbSNP 75 | 76 | ```{r} 77 | my_var_tbl %>% filter (is.na(notfound), dbsnp.validated == 'TRUE') %>% select(query, starts_with('dbsnp')) 78 | ``` 79 | 80 | -------------------------------------------------------------------------------- /variant/myvariant.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/variant/myvariant.pdf -------------------------------------------------------------------------------- /variant/random_forest/analysis.R: -------------------------------------------------------------------------------- 1 | setwd("~/github/machine_learning/variant/random_forest/") 2 | 3 | data <- read.csv("../DataS1/ToolScores/humvar_tool_scores.csv", stringsAsFactors = FALSE) 4 | head(data) 5 | 6 | features <- c('MutationTaster','MutationAssessor','PolyPhen2','CADD','SIFT','LRT','FatHMM.U','GERP..','PhyloP') 7 | 8 | library(dplyr) 9 | 10 | # Select columns by vector of names using dplyr 11 | # see https://gist.github.com/djhocking/62c76e63543ba9e94ebe 12 | data_subset <- select_(data, .dots=c('True.Label', features)) 13 | data_subset$True.Label <- factor(data_subset$True.Label) 14 | 15 | table(data_subset$True.Label) 16 | 17 | library(randomForest) 18 | 19 | r <- randomForest(True.Label ~ ., data=data_subset, importance=TRUE, do.trace=100, na.action=na.omit, ntree=1000) 20 | # ntree OOB 1 2 21 | # 100: 14.63% 15.32% 14.04% 22 | # 200: 14.47% 15.16% 13.88% 23 | # 300: 14.42% 15.15% 13.80% 24 | # 400: 14.35% 15.08% 13.73% 25 | # 500: 14.31% 15.01% 13.71% 26 | # 600: 14.30% 15.00% 13.70% 27 | # 700: 14.34% 15.06% 13.73% 28 | # 800: 14.37% 15.08% 13.77% 29 | # 900: 14.35% 15.04% 13.76% 30 | # 1000: 14.35% 15.06% 13.75% 31 | 32 | varImpPlot(r) 33 | 34 | library(ROCR) 35 | 36 | # use votes, which are the fraction of (OOB) votes from the random forest 37 | # in the first row, all trees voted for class 0, which is benign 38 | head(r$votes) 39 | 40 | # cases with NaN were omitted 41 | # create vector to store cases that were used 42 | my_pred_vector <- as.numeric(names(r$predicted)) 43 | 44 | # votes for 1 45 | pred <- prediction(r$votes[,2], as.numeric(data_subset$True.Label[my_pred_vector])) 46 | perf <- performance(pred,"tpr","fpr") 47 | plot(perf) 48 | 49 | # area under the curve 50 | auc <- performance(pred, measure = "auc") 51 | auc@y.values 52 | legend('bottomright', legend = paste("AUC = ", auc@y.values)) 53 | 54 | # matrix with votes and true label 55 | votes_and_truth <- cbind(r$votes, data$True.Label[my_pred_vector]) 56 | votes_and_truth <- as.data.frame(votes_and_truth) 57 | names(votes_and_truth) <- c('Negative','Positive','True') 58 | 59 | # dplyr magic 60 | filter(votes_and_truth, Positive>0.50) %>% 61 | select(True) %>% 62 | group_by(True) %>% 63 | tally(True) 64 | # A tibble: 2 x 2 65 | # True n 66 | # 67 | # 1 -1 -2169 68 | # 2 1 14783 69 | 70 | filter(votes_and_truth, Negative>0.50) %>% 71 | select(True) %>% 72 | group_by(True) %>% 73 | tally(True) 74 | # A tibble: 2 x 2 75 | # True n 76 | # 77 | # 1 -1 -12454 78 | # 2 1 2364 79 | 80 | # find absolutely misclassified examples 81 | # classified as positive but is really negative 82 | # i.e. false positive 83 | filter(votes_and_truth, Positive==1, True==-1) 84 | 85 | # need row information 86 | votes_truth_row <- votes_and_truth 87 | votes_truth_row$Row <- rownames(votes_and_truth) 88 | 89 | # see http://stackoverflow.com/questions/21618423/extract-a-dplyr-tbl-column-as-a-vector 90 | filter(votes_truth_row, Positive==1, True==-1) %>% 91 | select(Row) %>% 92 | collect %>% .[[1]] 93 | 94 | data[c(7766,14528,19722,20147,23042,28612),c('CHR','Nuc.Pos','REF.Nuc','ALT.Nuc','X.RS.ID','True.Label')] 95 | # CHR Nuc.Pos REF.Nuc ALT.Nuc X.RS.ID True.Label 96 | # 7766 14 62016431 A T rs35561533 -1 97 | # 14528 19 11348960 G A rs12609039 -1 98 | # 19722 1 246930564 G C rs7779 -1 99 | # 20147 1 44456013 G C rs35904809 -1 100 | # 23042 22 45944576 T A rs1802787 -1 101 | # 28612 4 166403424 T A rs34516004 -1 102 | 103 | # false negative 104 | filter(votes_truth_row, Negative>0.99, True==1) %>% 105 | select(Row) %>% 106 | collect %>% .[[1]] 107 | data[c(4015,6867,20187,22798,34844,36297,39089),c('CHR','Nuc.Pos','REF.Nuc','ALT.Nuc','X.RS.ID','True.Label')] 108 | # CHR Nuc.Pos REF.Nuc ALT.Nuc X.RS.ID True.Label 109 | # 4015 11 76853783 T C rs1052030 1 110 | # 6867 13 52544715 C A ? 1 111 | # 20187 1 45481018 G A ? 1 112 | # 22798 22 36661906 A G rs73885319 1 113 | # 34844 8 21976710 T C rs7014851 1 114 | # 36297 9 34649442 A G rs2070074 1 115 | # 39089 X 31496398 T C rs1800279 1 116 | 117 | # prediction 118 | # values from data_subset[10,] 119 | values <- c(0.0004286186,-0.205,0.003,10,0.83,0.0312836,0.36,-1.58,-0.462) 120 | blah <- t(data.frame(x = values)) 121 | colnames(blah) <- features 122 | rownames(blah) <- 10 123 | 124 | predict(r, blah) 125 | 126 | # read dbNSFP annotated variants 127 | 128 | negative <- read.table('negative_dbnsfp.out', header=TRUE, stringsAsFactors=FALSE, quote='', sep="\t", comment='') 129 | dim(negative) 130 | # [1] 42185 452 131 | positive <- read.table('positive_dbnsfp.out', header=TRUE, stringsAsFactors=FALSE, quote='', sep="\t", comment='') 132 | dim(positive) 133 | # [1] 37086 452 134 | 135 | # http://annovar.openbioinformatics.org/en/latest/user-guide/filter/ 136 | # There are two databases for PolyPhen2: HVAR and HDIV. They are explained below: 137 | # ljb2_pp2hvar should be used for diagnostics of Mendelian diseases, which requires distinguishing mutations with drastic effects from all the remaining human variation, including abundant mildly deleterious alleles. 138 | # ljb2_pp2hdiv should be used when evaluating rare alleles at loci potentially involved in complex phenotypes, dense mapping of regions identified by genome-wide association studies, and analysis of natural selection from sequence data. 139 | features_dbnsfp <- c('MutationTaster_converted_rankscore', 'MutationAssessor_score_rankscore', 'Polyphen2_HDIV_rankscore', 'Polyphen2_HVAR_rankscore', 'CADD_raw_rankscore', 'SIFT_converted_rankscore', 'LRT_converted_rankscore', 'FATHMM_converted_rankscore', 'fathmm.MKL_coding_rankscore', 'GERP.._RS_rankscore', 'phyloP100way_vertebrate_rankscore', 'phyloP20way_mammalian_rankscore') 140 | 141 | dbnsfp_positive <- select_(positive, .dots=features_dbnsfp) 142 | dbnsfp_positive$True.Label <- rep(1, nrow(dbnsfp_positive)) 143 | dbnsfp_negative <- select_(negative, .dots=features_dbnsfp) 144 | dbnsfp_negative$True.Label <- rep(-1, nrow(dbnsfp_negative)) 145 | dbnsfp <- rbind(dbnsfp_positive, dbnsfp_negative) 146 | 147 | dim(dbnsfp) 148 | # [1] 79271 13 149 | 150 | dbnsfp <- apply(dbnsfp, 2, function(x) as.numeric(gsub(x = x, pattern = '^\\.$', replacement = NA, perl = TRUE))) 151 | dbnsfp <- as.data.frame(dbnsfp) 152 | dbnsfp$True.Label <- factor(dbnsfp$True.Label) 153 | 154 | ################# 155 | # Below doesn't work 156 | # library(foreach) 157 | # library(doSNOW) 158 | # registerDoSNOW(makeCluster(10, type='SOCK')) 159 | # system.time(r2 <- foreach(ntree = rep(100, 10), .combine = combine, .multicombine=TRUE, .packages = "randomForest") %dopar% randomForest(True.Label ~ ., data=dbnsfp, proximity=TRUE, importance=TRUE, na.action=na.omit, ntree = ntree)) 160 | # Error in randomForest(True.Label ~ ., data = dbnsfp, proximity = TRUE, : 161 | # task 5 failed - "long vectors (argument 18) are not supported in .Fortran" 162 | # Timing stopped at: 0.14 0.064 38.082 163 | ################# 164 | 165 | r2 <- randomForest(True.Label ~ ., data=dbnsfp, importance=TRUE, do.trace=100, na.action=na.omit, ntree=1000) 166 | # ntree OOB 1 2 167 | # 100: 14.46% 13.26% 15.69% 168 | # 200: 14.31% 13.01% 15.64% 169 | # 300: 14.27% 12.89% 15.68% 170 | # 400: 14.19% 12.81% 15.59% 171 | # 500: 14.15% 12.76% 15.57% 172 | # 600: 14.18% 12.70% 15.70% 173 | # 700: 14.17% 12.71% 15.66% 174 | # 800: 14.12% 12.66% 15.61% 175 | # 900: 14.10% 12.60% 15.63% 176 | # 1000: 14.10% 12.60% 15.63% 177 | 178 | varImpPlot(r2) 179 | 180 | features_dbnsfp_expression <- c('MutationTaster_converted_rankscore', 'MutationAssessor_score_rankscore', 'Polyphen2_HDIV_rankscore', 'Polyphen2_HVAR_rankscore', 'CADD_raw_rankscore', 'SIFT_converted_rankscore', 'LRT_converted_rankscore', 'FATHMM_converted_rankscore', 'fathmm.MKL_coding_rankscore', 'GERP.._RS_rankscore', 'phyloP100way_vertebrate_rankscore', 'phyloP20way_mammalian_rankscore', grep('sample', names(positive), value=TRUE)) 181 | 182 | dbnsfp_positive_exp <- select_(positive, .dots=features_dbnsfp_expression) 183 | dbnsfp_positive_exp$True.Label <- rep(1, nrow(dbnsfp_positive_exp)) 184 | dbnsfp_negative_exp <- select_(negative, .dots=features_dbnsfp_expression) 185 | dbnsfp_negative_exp$True.Label <- rep(-1, nrow(dbnsfp_negative_exp)) 186 | dbnsfp_exp <- rbind(dbnsfp_positive_exp, dbnsfp_negative_exp) 187 | 188 | dim(dbnsfp_exp) 189 | # [1] 79271 66 190 | 191 | dbnsfp_exp <- apply(dbnsfp, 2, function(x) as.numeric(gsub(x = x, pattern = '^\\.$', replacement = NA, perl = TRUE))) 192 | dbnsfp_exp <- as.data.frame(dbnsfp_exp) 193 | dbnsfp_exp$True.Label <- factor(dbnsfp_exp$True.Label) 194 | 195 | # performs slightly worst with expression data 196 | r3 <- randomForest(True.Label ~ ., data=dbnsfp_exp, importance=TRUE, do.trace=100, na.action=na.omit, ntree=1000) 197 | # ntree OOB 1 2 198 | # 100: 17.24% 18.77% 15.69% 199 | # 200: 16.67% 17.83% 15.51% 200 | # 300: 16.54% 17.21% 15.85% 201 | # 400: 16.23% 16.41% 16.05% 202 | # 500: 16.29% 16.30% 16.28% 203 | # 600: 16.31% 16.51% 16.11% 204 | # 700: 16.27% 16.31% 16.23% 205 | # 800: 16.26% 16.31% 16.20% 206 | # 900: 16.24% 16.09% 16.40% 207 | # 1000: 16.14% 15.76% 16.53% 208 | 209 | 210 | -------------------------------------------------------------------------------- /variant/random_forest/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # clinvar.vcf.gz is a symbolic link to the latest version of the VCF file 4 | wget -c ftp://ftp.ncbi.nlm.nih.gov/pub/clinvar/vcf_GRCh37/clinvar.vcf.gz 5 | 6 | -------------------------------------------------------------------------------- /variant/random_forest/image/CFTR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/variant/random_forest/image/CFTR.png -------------------------------------------------------------------------------- /variant/random_forest/stratify.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | 3 | use strict; 4 | use warnings; 5 | 6 | my $usage = "Usage: $0 \n"; 7 | my $infile = shift or die $usage; 8 | my $code = shift or die $usage; 9 | 10 | if ($infile =~ /\.gz$/){ 11 | open(IN, '-|', "gunzip -c $infile") || die "Could not open $infile: $!\n"; 12 | } else { 13 | open(IN, '<', $infile) || die "Could not open $infile: $!\n"; 14 | } 15 | 16 | VARIANT: while(){ 17 | chomp; 18 | my $current_line = $_; 19 | if ($current_line =~ /^#/){ 20 | print "$current_line\n"; 21 | } else { 22 | if (/CLNSIG=(.*?);/){ 23 | my $clnsig = $1; 24 | # split by pipes or commas 25 | my @clnsig = split(/\||,/, $clnsig); 26 | 27 | # has only one significance code 28 | if (scalar(@clnsig) == 1){ 29 | if ($clnsig == $code){ 30 | print "$current_line\n"; 31 | next VARIANT; 32 | } 33 | } 34 | 35 | # $match if there is a match to input code 36 | my ($benign, $pathogenic, $match) = (0, 0, 0); 37 | foreach my $c (@clnsig){ 38 | if ($c == 2 || $c == 3){ 39 | $benign = 1; 40 | } 41 | if ($c == 4 || $c == 5){ 42 | $pathogenic = 1; 43 | } 44 | if ($c == $code){ 45 | $match = 1; 46 | } 47 | } 48 | 49 | # check and skip conflicting codes 50 | if ($benign == 1 && $pathogenic == 1){ 51 | next VARIANT; 52 | } 53 | 54 | if ($match == 1){ 55 | print "$current_line\n"; 56 | } 57 | } 58 | } 59 | } 60 | close(IN); 61 | 62 | exit(0); 63 | 64 | -------------------------------------------------------------------------------- /variant/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # replace with locations of the binary and chain file 4 | liftOver kabuki_hg18.bed hg18ToHg19.over.chain kabuki_hg19.bed unmapped.txt 5 | 6 | # bases are on the negative strand 7 | cat kabuki_hg19.bed | perl -nle '@a=split; if($a[3] =~ /([ACGT])\d+([ACGT])/){ $r = $1; $r =~ tr/ACGT/TGCA/; $a = $2; $a =~ tr/ACGT/TGCA/; print join("\t", @a[0..2], $r, $a) }' > kabuki_hg19.tsv 8 | 9 | # create VCF file 10 | ~/github/learning_vcf_file/script/create_vcf.pl ~/genome/hg19/hg19.fa kabuki_hg19.tsv > kabuki_hg19.vcf 11 | 12 | # dbNSFP 13 | wget -c ftp://dbnsfp:dbnsfp@dbnsfp.softgenetics.com/dbNSFPv3.2a.zip 14 | unzip dbNSFPv3.2a.zip 15 | java search_dbNSFP32a -i kabuki_hg19.vcf -o kabuki_hg19_dbnsfp.out -v hg19 16 | 17 | # repeat steps above with variants causing Miller syndrome 18 | liftOver miller_hg18.bed ~/data/ucsc/hg18ToHg19.over.chain miller_hg19.bed unmapped.txt 19 | # bases are on the positive strand 20 | cat miller_hg19.bed | perl -nle '@a=split; if($a[3] =~ /([ACGT])\d+([ACGT])/){ print join("\t", @a[0..2], $1, $2) }' > miller_hg19.tsv 21 | ~/github/learning_vcf_file/script/create_vcf.pl ~/genome/hg19/hg19.fa miller_hg19.tsv > miller_hg19.vcf 22 | java search_dbNSFP32a -i miller_hg19.vcf -o miller_hg19_dbnsfp.out -v hg19 23 | 24 | # ClinVar variants 25 | 26 | wget -c ftp://ftp.ncbi.nlm.nih.gov/pub/clinvar/vcf_GRCh37/clinvar_20160831.vcf.gz 27 | 28 | -------------------------------------------------------------------------------- /xgboost/img/arthritis_imp_plot-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/xgboost/img/arthritis_imp_plot-1.png -------------------------------------------------------------------------------- /xgboost/img/breast_cancer_feature_importance-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/xgboost/img/breast_cancer_feature_importance-1.png -------------------------------------------------------------------------------- /xgboost/img/feature_importance-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/xgboost/img/feature_importance-1.png -------------------------------------------------------------------------------- /xgboost/img/feature_importance_plot-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davetang/machine_learning/a04feb25e480078d1af390036fbe70b250687801/xgboost/img/feature_importance_plot-1.png --------------------------------------------------------------------------------