├── scripts └── spinme.R ├── .gitignore ├── overview ├── calibration-plots.R ├── bigdata-mtry-nsplit.R ├── cohort-tables.R ├── all-models.R ├── cohort-tables-full.R ├── performance-differences.R ├── missing-values-risk.R ├── explore-dataset.R ├── variable-importances.R └── variable-effects.R ├── cox-ph ├── rapsomaniki-cox-values-from-paper.csv ├── caliber-scale.R ├── cox-discretised-imputed.R ├── caliber-replicate-with-imputation.R ├── cox-discretised.R └── cox-discrete-varsellogrank.R ├── random-forest ├── rf-classification.R ├── rfsrc-cv.R ├── rf-imputed.R ├── rf-age.R ├── rf-varselmiss.R └── rf-varsellogrank.R ├── lib ├── shared.R ├── rfsrc-cv-mtry-nsplit-logical.R ├── rfsrc-cv-nsplit-bootstrap.R └── all-cv-bootstrap.R ├── age-only └── age-only.R └── README.md /scripts/spinme.R: -------------------------------------------------------------------------------- 1 | require(knitr) 2 | spin(commandArgs(trailingOnly=TRUE)[1]) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # History files 2 | .Rhistory 3 | .Rapp.history 4 | 5 | # Example code in package build process 6 | *-Ex.R 7 | 8 | # RStudio files 9 | .Rproj.user/ 10 | 11 | # produced vignettes 12 | vignettes/*.html 13 | vignettes/*.pdf 14 | 15 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 16 | .httr-oauth 17 | -------------------------------------------------------------------------------- /overview/calibration-plots.R: -------------------------------------------------------------------------------- 1 | source('../lib/handymedical.R', chdir = TRUE) 2 | requirePlus('ggplot2', 'cowplot') 3 | 4 | cox.calibration <- 5 | read.csv('../../output/cox-bigdata-varsellogrank-01-calibration-table.csv') 6 | rf.calibration <- 7 | read.csv('../../output/rfsrc-cv-nsplit-try3-calibration-table.csv') 8 | 9 | cox.calibration.plot <- calibrationPlot(cox.calibration, max.points = 2000) 10 | rf.calibration.plot <- calibrationPlot(rf.calibration, max.points = 2000) 11 | 12 | plot_grid( 13 | cox.calibration.plot, rf.calibration.plot, 14 | labels = c("C", ""), 15 | align = "v", ncol = 2 16 | ) 17 | 18 | ggsave( 19 | '../../output/models-calibration-eg.pdf', 20 | width = 16, 21 | height = 8, 22 | units = 'cm', 23 | useDingbats = FALSE 24 | ) -------------------------------------------------------------------------------- /overview/bigdata-mtry-nsplit.R: -------------------------------------------------------------------------------- 1 | source('../lib/handy.R') 2 | requirePlus('ggplot2', 'cowplot') 3 | 4 | # Read in two cross-validation data files 5 | cv.performance <- 6 | read.csv('../../output/rf-bigdata-try7-ALL-cv-largemtry-calibration.csv') 7 | cv.performance <- 8 | rbind( 9 | cv.performance, 10 | read.csv('../../output/rf-bigdata-try7-ALL-cv-smallmtry-calibration.csv') 11 | ) 12 | 13 | # Read in overall model performance 14 | models.performance <- readTablePlus('../../output/models-performance.tsv') 15 | 16 | 17 | 18 | cv.performance.avg <- 19 | aggregate( 20 | c.index.val ~ n.splits + m.try, 21 | data = cv.performance, 22 | mean 23 | ) 24 | 25 | ggplot(cv.performance.avg, aes(x = n.splits, y = c.index.val, colour = factor(m.try), group = m.try)) + 26 | geom_line() + 27 | geom_point(data = cv.performance) + 28 | geom_hline( 29 | yintercept = 30 | models.performance$c.index[models.performance$model == 'rf-varselrf'], 31 | colour = 'grey' 32 | ) + 33 | geom_hline( 34 | yintercept = 35 | models.performance$c.index[models.performance$model == 'rf-varselmiss'], 36 | colour = 'grey' 37 | ) + 38 | coord_cartesian( 39 | ylim = c(0.65, 0.8) 40 | ) 41 | -------------------------------------------------------------------------------- /overview/cohort-tables.R: -------------------------------------------------------------------------------- 1 | data.filename <- '../../data/cohort-sanitised.csv' 2 | 3 | source('../lib/shared.R') 4 | 5 | # Load the data and convert to data frame to make column-selecting code in 6 | # prepData simpler 7 | COHORT.full <- fread(data.filename) 8 | 9 | print(nrow(COHORT.full)) 10 | 11 | # Remove the patients we shouldn't include 12 | COHORT.full <- 13 | COHORT.full[ 14 | # remove negative times to death 15 | COHORT.full$time_death > 0 & 16 | # remove patients who should be excluded 17 | !COHORT.full$exclude 18 | , 19 | ] 20 | 21 | # Total study population 22 | print(nrow(COHORT.full)) 23 | 24 | # Age, 5, 50, 95, %missing 25 | print(quantile(COHORT.full$age, c(0.5, 0.05, 0.95))) 26 | 27 | # Gender 28 | print(table(COHORT.full$gender)) 29 | print(table(COHORT.full$gender))/nrow(COHORT.full)*100 30 | 31 | # Deprivation, 5, 50, 95, %missing 32 | print(quantile(COHORT.full$imd_score, c(0.5, 0.05, 0.95), na.rm = TRUE)) 33 | print(percentMissing(COHORT.full$imd_score)) 34 | 35 | # Smoking, by category, %missing 36 | print(table(COHORT.full$smokstatus))/nrow(COHORT.full)*100 37 | print(percentMissing(COHORT.full$smokstatus)) 38 | 39 | # Diabetes, yes/no 40 | print( 41 | ( sum(COHORT.full$diabetes == 'Diabetes unspecified type') + 42 | sum(COHORT.full$diabetes == 'Type 1 diabetes') + 43 | sum(COHORT.full$diabetes == 'Type 2 diabetes')) /nrow(COHORT.full)*100 44 | ) 45 | 46 | # Follow-up, 5, 50, 95 47 | print(quantile(COHORT.full$endpoint_death_date, c(0.5, 0.05, 0.95)))/365.25 48 | 49 | # Death vs censored, % 50 | print(table(COHORT.full$endpoint_death)) 51 | print(table(COHORT.full$endpoint_death)) /nrow(COHORT.full)*100 52 | -------------------------------------------------------------------------------- /cox-ph/rapsomaniki-cox-values-from-paper.csv: -------------------------------------------------------------------------------- 1 | quantity,quantity.level,unit,long_name,their_value,their_lower,their_upper 2 | age,age,years,Age if man,1.065,1.063,1.067 3 | age:gender,age:genderWomen,years,Age if woman,1.081,1.078,1.083 4 | gender,genderWomen,female,Gender,0.204,0.162,0.257 5 | most_deprived,most_deprived,bottom quintile,IMD,1.151,1.111,1.192 6 | most_deprived_missing,most_deprived_missingTRUE,missing,IMD,,, 7 | diagnosis,diagnosisCHD,,CHD,1.024,0.982,1.067 8 | diagnosis,diagnosisUA,,UA,1.021,0.97,1.075 9 | diagnosis,diagnosisSTEMI,,STEMI,1.083,1.006,1.166 10 | diagnosis,diagnosisNSTEMI,,NSTEMI,1.298,1.238,1.36 11 | pci_6mo,pci_6moTRUE,,PCI,0.651,0.605,0.699 12 | cabg_6mo,cabg_6moTRUE,,CABG,0.516,0.469,0.566 13 | hx_mi,hx_miTRUE,,Previous MI,1.136,1.095,1.179 14 | long_nitrate,long_nitrateTRUE,,Nitrates,1.152,1.118,1.188 15 | smokstatus,smokstatusEx,ex,Smoking,1.11,1.065,1.157 16 | smokstatus,smokstatusCurrent,current,Smoking,1.315,1.245,1.389 17 | smokstatus_missing,smokstatus_missingTRUE,missing,Smoking,,, 18 | hypertension,hypertensionTRUE,,Hypertension,0.965,0.929,1.001 19 | diabetes_logical,diabetes_logicalTRUE,,Diabetes,1.203,1.16,1.248 20 | total_chol_6mo,total_chol_6mo,,Total cholesterol,1.012,0.983,1.042 21 | total_chol_6mo_missing,total_chol_6mo_missingTRUE,missing,Total cholesterol,,, 22 | hdl_6mo,hdl_6mo,mmol/L,HDL,1.006,0.987,1.025 23 | hdl_6mo_missing,hdl_6mo_missingTRUE,missing,HDL,,, 24 | heart_failure,heart_failureTRUE,,Heart failure,1.543,1.495,1.593 25 | pad,padTRUE,,PAD,1.286,1.234,1.34 26 | hx_af,hx_afTRUE,,Atrial fibrillation,1.28,1.236,1.326 27 | hx_stroke,hx_strokeTRUE,,Prior stroke,1.329,1.277,1.382 28 | hx_renal,hx_renalTRUE,,Chronic renal disease,1.116,1.058,1.178 29 | hx_copd,hx_copdTRUE,,COPD,1.15,1.114,1.187 30 | hx_cancer,hx_cancerTRUE,,Cancer,1.377,1.324,1.432 31 | hx_liver,hx_liverTRUE,,Chronic liver disease,1.631,1.443,1.842 32 | hx_depression,hx_depressionTRUE,,Depression,1.179,1.135,1.225 33 | hx_anxiety,hx_anxietyTRUE,,Anxiety,1.172,1.116,1.231 34 | pulse_6mo,pulse_6mo,beats/min,Heart rate,1.098,1.084,1.112 35 | pulse_6mo_missing,pulse_6mo_missingTRUE,missing,Heart rate,,, 36 | crea_6mo,crea_6mo,mmol/L,Creatinine,1.065,1.051,1.08 37 | crea_6mo_missing,crea_6mo_missingTRUE,missing,Creatinine,,, 38 | total_wbc_6mo,total_wbc_6mo,10^9/L,White cell count,1.12,1.106,1.135 39 | total_wbc_6mo_missing,total_wbc_6mo_missingTRUE,missing,White cell count,,, 40 | haemoglobin_6mo,haemoglobin_6mo,g/dL,Haemoglobin,0.758,0.724,0.794 41 | haemoglobin_6mo_missing,haemoglobin_6mo_missingTRUE,missing,Haemoglobin,,, 42 | -------------------------------------------------------------------------------- /overview/all-models.R: -------------------------------------------------------------------------------- 1 | models.include <- 2 | c( 3 | 'age', 'cox', 'cox disc', 'cox imp', 'cox imp disc', 'rfsrc', 'rfsrc imp', 4 | 'rf-logrank', 'cox-logrank disc', 'cox-elnet disc' 5 | ) 6 | 7 | source('../lib/handy.R') 8 | requirePlus('ggplot2', 'cowplot') 9 | 10 | models.performance.all <- readTablePlus('../../output/models-performance-manual.tsv') 11 | 12 | models.performance.all$x.labels <- 13 | paste0( 14 | models.performance.all$model, 15 | ifelse(models.performance.all$imputation, ' imp', ''), 16 | ifelse(models.performance.all$discretised, ' disc', '') 17 | ) 18 | 19 | # Currently different scripts quote either the area under the curve or a pre 20 | # one-minused the area under the curve...so standardise that 21 | big.calibration.scores <- models.performance.all$calibration.score > 0.5 22 | models.performance.all[big.calibration.scores, c('calibration.score', 'calibration.score.lower', 'calibration.score.upper')] <- 23 | 1 - models.performance.all[big.calibration.scores, c('calibration.score', 'calibration.score.lower', 'calibration.score.upper')] 24 | 25 | models.performance <- data.frame() 26 | for(model in models.include) { 27 | models.performance <- 28 | rbind( 29 | models.performance, 30 | models.performance.all[models.performance.all$x.labels == model, ] 31 | ) 32 | } 33 | # Turn this into a factor with defined levels so ggplot respects the order above 34 | models.performance$x.labels <- 35 | factor(models.performance$x.labels, levels = models.include) 36 | 37 | plot.c.index <- 38 | ggplot(models.performance, aes(x = x.labels, y = c.index)) + 39 | geom_bar(stat='identity', aes(fill = model)) + 40 | geom_errorbar( 41 | aes(ymin = c.index.lower, ymax = c.index.upper), width = 0.1 42 | ) + 43 | coord_cartesian( 44 | ylim = c(0.75, 0.81) 45 | ) + 46 | theme(legend.position = "none") 47 | 48 | plot.calibration <- 49 | ggplot(models.performance, aes(x = x.labels, y = 1 - calibration.score)) + 50 | geom_bar(stat='identity', aes(fill = model)) + 51 | geom_errorbar( 52 | aes( 53 | ymin = 1 - calibration.score.lower, ymax = 1 - calibration.score.upper 54 | ), 55 | width = 0.1 56 | ) + 57 | coord_cartesian( 58 | ylim = c(0.8, 1.0) 59 | ) + 60 | theme(legend.position = "none") 61 | 62 | plot_grid( 63 | plot.c.index, plot.calibration, 64 | labels = c("A", "B"), 65 | align = "v", ncol = 1 66 | ) 67 | 68 | ggsave( 69 | '../../output/all-models-performance.pdf', 70 | width = 16, 71 | height = 10, 72 | units = 'cm', 73 | useDingbats = FALSE 74 | ) 75 | -------------------------------------------------------------------------------- /overview/cohort-tables-full.R: -------------------------------------------------------------------------------- 1 | data.filename <- '../../data/cohort-sanitised.csv' 2 | require(data.table) 3 | COHORT <- fread(data.filename) 4 | 5 | percentMissing <- function(x, sf = 3) { 6 | round(sum(is.na(x))/length(x), digits = sf)*100 7 | } 8 | 9 | # Remove the patients we shouldn't include 10 | COHORT <- 11 | COHORT[ 12 | # remove negative times to death 13 | COHORT$time_death > 0 & 14 | # remove patients who should be excluded 15 | !COHORT$exclude 16 | , 17 | ] 18 | 19 | # Age, 5, 50, 95, %missing 20 | print(quantile(COHORT$age, c(0.5, 0.025, 0.975))) 21 | 22 | # Gender 23 | print(table(COHORT$gender)) 24 | print(table(COHORT$gender)/nrow(COHORT)*100) 25 | 26 | # Deprivation, 5, 50, 95, %missing 27 | print(quantile(COHORT$imd_score, c(0.5, 0.025, 0.975), na.rm = TRUE)) 28 | print(percentMissing(COHORT$imd_score)) 29 | 30 | # SCAD subtype 31 | print(table(COHORT$diagnosis)/nrow(COHORT)*100) 32 | 33 | # PCI 34 | print(sum(COHORT$pci_6mo)/nrow(COHORT)*100) 35 | 36 | # CABG 37 | print(sum(COHORT$cabg_6mo)/nrow(COHORT)*100) 38 | 39 | # previous/recurrent MI 40 | print(sum(COHORT$hx_mi)/nrow(COHORT)*100) 41 | 42 | # nitrates (listed as 1 and NA not T and F) 43 | print(sum(COHORT$long_nitrate, na.rm = TRUE)/nrow(COHORT)*100) 44 | 45 | # Smoking, by category, %missing 46 | print(table(COHORT$smokstatus)/nrow(COHORT)*100) 47 | print(percentMissing(COHORT$smokstatus)) 48 | 49 | # Hypertension 50 | print(sum(COHORT$hypertension)/nrow(COHORT)*100) 51 | 52 | # Diabetes, yes/no 53 | print( 54 | (sum(COHORT$diabetes == 'Diabetes unspecified type') + 55 | sum(COHORT$diabetes == 'Type 1 diabetes') + 56 | sum(COHORT$diabetes == 'Type 2 diabetes')) /nrow(COHORT)*100 57 | ) 58 | 59 | # Total cholesterol 60 | print(quantile(COHORT$total_chol_6mo, c(0.5, 0.025, 0.975), na.rm = TRUE)) 61 | print(percentMissing(COHORT$total_chol_6mo)) 62 | 63 | # HDL 64 | print(quantile(COHORT$hdl_6mo, c(0.5, 0.025, 0.975), na.rm = TRUE)) 65 | print(percentMissing(COHORT$hdl_6mo)) 66 | 67 | # Heart failure 68 | print(sum(COHORT$heart_failure)/nrow(COHORT)*100) 69 | 70 | # Peripheral arterial disease 71 | print(sum(COHORT$pad)/nrow(COHORT)*100) 72 | 73 | # Atrial fibrillation 74 | print(sum(COHORT$hx_af)/nrow(COHORT)*100) 75 | 76 | # Stroke 77 | print(sum(COHORT$hx_stroke)/nrow(COHORT)*100) 78 | 79 | # Chronic kidney disease 80 | print(sum(COHORT$hx_renal)/nrow(COHORT)*100) 81 | 82 | # COPD 83 | print(sum(COHORT$hx_copd)/nrow(COHORT)*100) 84 | 85 | # Cancer 86 | print(sum(COHORT$hx_cancer)/nrow(COHORT)*100) 87 | 88 | # Chronic liver disease 89 | print(sum(COHORT$hx_liver)/nrow(COHORT)*100) 90 | 91 | # Depression 92 | print(sum(COHORT$hx_depression)/nrow(COHORT)*100) 93 | 94 | # Anxiety 95 | print(sum(COHORT$hx_anxiety)/nrow(COHORT)*100) 96 | 97 | # Heart rate 98 | print(quantile(COHORT$pulse_6mo, c(0.5, 0.025, 0.975), na.rm = TRUE)) 99 | print(percentMissing(COHORT$pulse_6mo)) 100 | 101 | # Creatinine 102 | print(quantile(COHORT$crea_6mo, c(0.5, 0.025, 0.975), na.rm = TRUE)) 103 | print(percentMissing(COHORT$crea_6mo)) 104 | 105 | # WCC 106 | print(quantile(COHORT$total_wbc_6mo, c(0.5, 0.025, 0.975), na.rm = TRUE)) 107 | print(percentMissing(COHORT$total_wbc_6mo)) 108 | 109 | # Haemoglobin 110 | print(quantile(COHORT$haemoglobin_6mo, c(0.5, 0.025, 0.975), na.rm = TRUE)) 111 | print(percentMissing(COHORT$haemoglobin_6mo)) 112 | 113 | # Follow-up, 5, 50, 95 114 | print(quantile(COHORT$endpoint_death_date, c(0.5, 0.025, 0.975)))/365.25 115 | 116 | # Death vs censored, % 117 | print(table(COHORT$endpoint_death)) /nrow(COHORT)*100 118 | -------------------------------------------------------------------------------- /overview/performance-differences.R: -------------------------------------------------------------------------------- 1 | source('../lib/handymedical.R', chdir = TRUE) 2 | 3 | bootstrap.base <- '../../output' 4 | 5 | bootstrap.files <- 6 | c( 7 | cox.miss = 'caliber-replicate-with-missing-survreg-6-linear-age-surv-boot.rds', 8 | cox.disc = 'all-cv-survreg-boot-try5-surv-model.rds', 9 | cox.imp = 'caliber-replicate-imputed-survreg-4-surv-boot-imp.rds', 10 | rf = 'rfsrc-cv-nsplit-try3-boot-all.csv', 11 | rf.imp = 'rf-imputed-try1-boot.rds', 12 | rfbig = 'rf-bigdata-varsellogrank-02-boot-all.csv', 13 | coxbig = 'cox-bigdata-varsellogrank-01-boot-all.csv' 14 | ) 15 | 16 | # Helper functions 17 | 18 | # Turn a boot object into a data frame 19 | bootstrap2Df <- function(x) { 20 | df <- data.frame(x$t) 21 | names(df) <- names(x$t0) 22 | df 23 | } 24 | 25 | # Make sure calibration scores are bigger = better 26 | calibrationFix <- function(x) { 27 | if(mean(x) < 0.5) { 28 | x <- 1 - x 29 | } 30 | x 31 | } 32 | 33 | n <- length(bootstrap.files) 34 | 35 | bootstraps <- list() 36 | 37 | for(i in 1:n) { 38 | if(fileExt(bootstrap.files[i]) == 'rds'){ 39 | bootstraps[[i]] <- readRDS(file.path(bootstrap.base, bootstrap.files[i])) 40 | 41 | if(class(bootstraps[[i]]) == 'list') { 42 | # If it's a list, then it's from an imputed dataset with separate bootstraps 43 | # Turn each of these into a data frame and then combine them together. 44 | # (data.frame is needed because rbindlist returns a data.table) 45 | bootstraps[[i]] <- 46 | data.frame(rbindlist(lapply(bootstraps[[i]], bootstrap2Df))) 47 | } else { 48 | bootstraps[[i]] <- bootstrap2Df(bootstraps[[i]] ) 49 | } 50 | } else{ 51 | bootstraps[[i]] <- read.csv(file.path(bootstrap.base, bootstrap.files[i])) 52 | } 53 | } 54 | 55 | x1x2 <- combn(1:n, 2) 56 | x1 <- x1x2[1,] 57 | x2 <- x1x2[2,] 58 | 59 | 60 | bootstrap.differences <- data.frame() 61 | for(i in 1:length(x1)) { 62 | # C-index 63 | col.1.c.index <- 64 | which(names(bootstraps[[x1[i]]]) %in% c('c.test', 'c.index')) 65 | col.2.c.index <- 66 | which(names(bootstraps[[x2[i]]]) %in% c('c.test', 'c.index')) 67 | boot.diff <- 68 | bootstrapDiff( 69 | bootstraps[[x1[i]]][, col.1.c.index], 70 | bootstraps[[x2[i]]][, col.2.c.index] 71 | ) 72 | 73 | bootstrap.differences <- 74 | rbind( 75 | bootstrap.differences, 76 | data.frame( 77 | model.1 = names(bootstrap.files)[x1[i]], 78 | model.2 = names(bootstrap.files)[x2[i]], 79 | var = 'c.index', 80 | diff = boot.diff['val'], 81 | lower = boot.diff['lower'], 82 | upper = boot.diff['upper'] 83 | ) 84 | ) 85 | 86 | # Calibration score 87 | col.1.calib <- 88 | which(names(bootstraps[[x1[i]]]) == 'calibration.score') 89 | col.2.calib <- 90 | which(names(bootstraps[[x2[i]]]) == 'calibration.score') 91 | boot.diff <- 92 | bootstrapDiff( 93 | calibrationFix(bootstraps[[x1[i]]][, col.1.calib]), 94 | calibrationFix(bootstraps[[x2[i]]][, col.2.calib]) 95 | ) 96 | 97 | bootstrap.differences <- 98 | rbind( 99 | bootstrap.differences, 100 | data.frame( 101 | model.1 = names(bootstrap.files)[x1[i]], 102 | model.2 = names(bootstrap.files)[x2[i]], 103 | var = 'calibration.score', 104 | diff = boot.diff['val'], 105 | lower = boot.diff['lower'], 106 | upper = boot.diff['upper'] 107 | ) 108 | ) 109 | } 110 | 111 | # Remove nonsense row names 112 | rownames(bootstrap.differences) <- NULL 113 | 114 | print(cbind(bootstrap.differences[, c('model.1', 'model.2', 'var')], round(bootstrap.differences[, 4:6], 3))) 115 | 116 | write.csv(bootstrap.differences, '../../output/bootstrap-differences.csv') -------------------------------------------------------------------------------- /random-forest/rf-classification.R: -------------------------------------------------------------------------------- 1 | #+ knitr_setup, include = FALSE 2 | 3 | # Whether to cache the intensive code sections. Set to FALSE to recalculate 4 | # everything afresh. 5 | cacheoption <- FALSE 6 | # Disable lazy caching globally, because it fails for large objects, and all the 7 | # objects we wish to cache are large... 8 | opts_chunk$set(cache.lazy = FALSE) 9 | 10 | #' # Cross-validating discretisation of input variables in a survival model 11 | #' 12 | #' In difference to previous attempts at cross-validation, this uses between 10 13 | #' and 20 bins, not between 2 and 20, in an attempt to avoid throwing away data. 14 | 15 | # The first part of the filename for any output 16 | output.filename.base <- '../../output/rfsrc-classification-try1' 17 | 18 | risk.time <- 5 19 | 20 | n.data <- NA 21 | split.rule <- 'logrank' 22 | n.trees <- 2000 23 | n.threads <- 19 24 | 25 | continuous.vars <- 26 | c( 27 | 'age', 'total_chol_6mo', 'hdl_6mo', 'pulse_6mo', 'crea_6mo', 28 | 'total_wbc_6mo', 'haemoglobin_6mo' 29 | ) 30 | 31 | untransformed.vars <- c('anonpatid', 'surv_time', 'imd_score', 'exclude') 32 | 33 | # If surv.vars is defined as a character vector here, the model only uses those 34 | # variables specified, eg c('age') would build a model purely based on age. If 35 | # not specified (ie commented out), it will use the defaults. 36 | # surv.predict <- c('age') 37 | 38 | #' ## Fit the model 39 | #' 40 | #' Now, let's fit the model, but without cross-validating the number of factor 41 | #' levels! The issue is that, if we're allowing factor levels to be grouped into 42 | #' two branches arbitrarily, there are 2^n - 2 combinations, which rapidly 43 | #' becomes a huge number. Thus, cross-validating, especially with large numbers 44 | #' of factor levels, is very impractical. 45 | #' 46 | #' We'll also leave age as a pure number: we know that it's both a very 47 | #' significant variable, and also that it makes sense to treat it as though it's 48 | #' ordered, because risk should increase monotonically with it. 49 | #' 50 | #+ rf_discretised, cache=cacheoption 51 | 52 | source('../lib/shared.R') 53 | 54 | # Load the data and convert to data frame to make column-selecting code in 55 | # prepData simpler 56 | COHORT.full <- data.frame(fread(data.filename)) 57 | 58 | # If n.data was specified... 59 | if(!is.na(n.data)){ 60 | # Take a subset n.data in size 61 | COHORT.use <- sample.df(COHORT.full, n.data) 62 | rm(COHORT.full) 63 | } else { 64 | # Use all the data 65 | COHORT.use <- COHORT.full 66 | rm(COHORT.full) 67 | } 68 | 69 | # Process settings: don't touch anything!! 70 | process.settings <- 71 | list( 72 | var = c(untransformed.vars, continuous.vars), 73 | method = rep(NA, length(untransformed.vars) + length(continuous.vars)), 74 | settings = rep(NA, length(untransformed.vars) + length(continuous.vars)) 75 | ) 76 | 77 | COHORT.prep <- 78 | prepData( 79 | COHORT.use, 80 | cols.keep, 81 | process.settings, 82 | surv.time, surv.event, 83 | surv.event.yes, 84 | extra.fun = caliberExtraPrep 85 | ) 86 | n.data <- nrow(COHORT.prep) 87 | 88 | # Define indices of test set 89 | test.set <- sample(1:n.data, (1/3)*n.data) 90 | 91 | # Create column for whether or not the patient had an event before risk.time 92 | COHORT.prep$event <- NA 93 | # Event before risk.time 94 | COHORT.prep$event[ 95 | COHORT.prep$surv_event & COHORT.prep$surv_time <= risk.time 96 | ] <- TRUE 97 | # Event after, whether censorship or not, means no event by risk.time 98 | COHORT.prep$event[COHORT.prep$surv_time > risk.time] <- FALSE 99 | # Otherwise, censored before risk.time, let's remove the row 100 | COHORT.prep <- COHORT.prep[!is.na(COHORT.prep$event), ] 101 | 102 | surv.model.fit <- 103 | rfsrc( 104 | as.formula( 105 | paste('event ~', paste(surv.predict, collapse = '+')) 106 | ), 107 | COHORT.prep[-test.set,], # Training set 108 | ntree = n.trees, 109 | splitrule = 'gini', 110 | n.threads = n.threads, 111 | na.action = 'na.impute', 112 | nimpute = 3 113 | ) -------------------------------------------------------------------------------- /lib/shared.R: -------------------------------------------------------------------------------- 1 | #' # prep-data.R 2 | #' We start by preparing the data for reproducible comparisons... 3 | 4 | # Set the random seed for reproducibility 5 | random.seed <- 35498L 6 | 7 | # Specify the data file containing the patient cohort 8 | data.filename <- '../../data/cohort-sanitised.csv' 9 | 10 | # Specify file to write performance characteristics to 11 | performance.file <- '../../output/models-performance.tsv' 12 | 13 | # The fraction of the data to use as the test set (1 - this will be used as the 14 | # training set) 15 | test.fraction <- 1/3 16 | 17 | # If surv.predict wasn't already specified, use the defaults... 18 | if(!exists('surv.predict')) { 19 | # Column names of variables to use for predictions 20 | surv.predict <- c( 21 | 'age', 'gender', 'diagnosis', 'pci_6mo', 'cabg_6mo', 22 | 'hx_mi', 'long_nitrate', 'smokstatus', 'hypertension', 'diabetes', 23 | 'total_chol_6mo', 'hdl_6mo', 'heart_failure', 'pad', 'hx_af', 'hx_stroke', 24 | 'hx_renal', 'hx_copd', 'hx_cancer', 'hx_liver', 'hx_depression', 25 | 'hx_anxiety', 'pulse_6mo', 'crea_6mo', 'total_wbc_6mo','haemoglobin_6mo', 26 | 'most_deprived' 27 | ) 28 | } 29 | 30 | cols.keep <- c(surv.predict, 'exclude', 'imd_score') 31 | 32 | exclude.vars <- c('hx_mi') 33 | surv.predict <- surv.predict[!(surv.predict %in% exclude.vars)] 34 | 35 | # Check to see if endpoint exists to avoid error 36 | if(!exists('endpoint')) { 37 | # Default is all-cause mortality 38 | endpoint <- 'death' 39 | } 40 | 41 | # If we're looking at MI... 42 | if(endpoint == 'mi') { 43 | surv.time <- 'time_coronary' 44 | surv.event <- 'endpoint_coronary' 45 | surv.event.yes <- c('Nonfatal MI', 'Coronary death') 46 | 47 | # If dealing with death in an imputed dataset... 48 | } else if(endpoint == 'death.imputed') { 49 | surv.time <- 'time_death' 50 | surv.event <- 'endpoint_death' 51 | surv.event.yes <- 1 # Coded as 1s and 0s for imputation 52 | 53 | # Default is all-cause mortality... 54 | } else { 55 | # Name of column giving time for use in survival object 56 | surv.time <- 'time_death' 57 | # Name of event column for survival object 58 | surv.event <- 'endpoint_death' # Cannot be 'surv_event' or will break later! 59 | # Value of surv.event column if an event is recorded 60 | surv.event.yes <- 'Death' 61 | } 62 | 63 | 64 | 65 | # Quantile boundaries for discretisation 66 | discretise.quantiles <- c(0, 0.1, 0.25, 0.5, 0.75, 0.9, 1) 67 | 68 | # Columns to discretise in specific ways, or not discretise at all. Those not 69 | # listed here will be discretised by quantile with the default quantiles listed 70 | # above. 71 | discretise.settings <- 72 | list( 73 | var = c('anonpatid', 'surv_time', 'imd_score', 'exclude'), 74 | method = c(NA, NA, NA, NA), 75 | settings = list(NA, NA, NA, NA) 76 | ) 77 | 78 | ################################################################################ 79 | ### END USER VARIABLES ######################################################### 80 | ################################################################################ 81 | 82 | if(!is.na(random.seed)) { 83 | set.seed(random.seed) 84 | } 85 | 86 | source('../lib/handymedical.R', chdir = TRUE) 87 | 88 | # Define a function of extra non-general prep to be done on this dataset 89 | caliberExclude <- function(df) { 90 | df <- 91 | df[ 92 | # remove negative times to death 93 | df$surv_time > 0 & 94 | # remove patients who should be excluded 95 | !df$exclude 96 | , 97 | ] 98 | # Remove the exclude column, which we don't need any more 99 | df$exclude <- NULL 100 | 101 | df 102 | } 103 | 104 | caliberExtraPrep <- function(df) { 105 | df <- caliberExclude(df) 106 | 107 | # Create most_deprived, as defined in the paper: the bottom 20% 108 | df$most_deprived <- df$imd_score > quantile(df$imd_score, 0.8, na.rm = TRUE) 109 | df$most_deprived <- factorNAfix(factor(df$most_deprived), NAval = 'missing') 110 | # Remove the imd_score, to avoid confusion later 111 | df$imd_score <- NULL 112 | 113 | df 114 | } -------------------------------------------------------------------------------- /age-only/age-only.R: -------------------------------------------------------------------------------- 1 | #+ knitr_setup, include = FALSE 2 | 3 | # Whether to cache the intensive code sections. Set to FALSE to recalculate 4 | # everything afresh. 5 | cacheoption <- FALSE 6 | # Disable lazy caching globally, because it fails for large objects, and all the 7 | # objects we wish to cache are large... 8 | opts_chunk$set(cache.lazy = FALSE) 9 | 10 | #' # Performance of an age-only model 11 | #' 12 | #+ setup, message=FALSE 13 | 14 | bootstraps <- 200 15 | 16 | data.filename <- '../../data/cohort-sanitised.csv' 17 | n.threads <- 16 18 | 19 | source('../lib/shared.R') 20 | require(rms) 21 | 22 | COHORT.full <- data.frame(fread(data.filename)) 23 | 24 | COHORT.use <- subset(COHORT.full, !exclude) 25 | 26 | n.data <- nrow(COHORT.use) 27 | 28 | # Define indices of test set 29 | test.set <- testSetIndices(COHORT.use, random.seed = 78361) 30 | 31 | COHORT.use <- prepSurvCol(COHORT.use, surv.time, surv.event,surv.event.yes) 32 | 33 | #' ## Concordance index 34 | #' 35 | #' The c-index can be calculated taking age itself as a risk score. Since it's 36 | #' purely rank-based, it doesn't matter that age is nonlinearly related to true 37 | #' risk. 38 | #' 39 | #' We bootstrap this based purely on the test set to make the bootstrap 40 | #' variability commensurate with other models tested on the test set. You can 41 | #' imagine that this model was 'trained' on the training set, even though it's 42 | #' so simple that we actually did no such thing... 43 | #' 44 | #+ c_index 45 | 46 | # Define a trivial function used for bootstrapping which simply returns the 47 | # c-index of a 'model' based purely on age. 48 | 49 | ageOnly <- function(df, indices, df.test) { 50 | # Create a Kaplan-Meier curve from the bootstrap sample 51 | km.by.age <- 52 | survfit( 53 | Surv(surv_time, surv_event) ~ age, 54 | data = df[indices, ], 55 | conf.type = "log-log" 56 | ) 57 | 58 | # Return the calibration score 59 | c( 60 | calibration.score = 61 | calibrationScore( 62 | calibrationTable(km.by.age, df.test) 63 | )$area 64 | ) 65 | } 66 | 67 | # C-index is by definition non-variable because we do not bootstrap or otherwise 68 | # permute the test set, and there's no 'training' phase 69 | age.c.index <- 70 | as.numeric( 71 | survConcordance( 72 | Surv(surv_time, surv_event) ~ age, 73 | COHORT.use[test.set,] 74 | )$concordance 75 | ) 76 | 77 | # Bootstrap to establish variability of calibration 78 | age.only.boot <- 79 | boot( 80 | data = COHORT.use[-test.set,], 81 | statistic = ageOnly, 82 | R = bootstraps, 83 | parallel = 'multicore', 84 | ncpus = n.threads, 85 | df.test = COHORT.use[test.set,] 86 | ) 87 | 88 | age.only.boot.stats <- bootStats(age.only.boot, uncertainty = '95ci') 89 | 90 | #' C-index is 91 | #' **`r round(age.c.index, 3)`** 92 | #' on the held-out test set (not that it really matters, the model isn't 93 | #' 'trained' as such for the discrimination test...it's just oldest patient dies 94 | #' first). 95 | #' 96 | #' 97 | #' ## Calibration 98 | #' 99 | #' 100 | #+ calibration 101 | 102 | km.by.age <- 103 | survfit( 104 | Surv(surv_time, surv_event) ~ age, 105 | data = COHORT.use[-test.set,], 106 | conf.type = "log-log" 107 | ) 108 | 109 | calibration.table <- calibrationTable(km.by.age, COHORT.use[test.set, ]) 110 | 111 | print(calibrationScore(calibration.table)) 112 | 113 | calibrationPlot(calibration.table) 114 | 115 | #' Calibration score is 116 | #' **`r round(age.only.boot.stats['calibration.score', 'val'], 3)` 117 | #' (`r round(age.only.boot.stats['calibration.score', 'lower'], 3)` - 118 | #' `r round(age.only.boot.stats['calibration.score', 'upper'], 3)`)** 119 | #' on the held-out test set. 120 | 121 | varsToTable( 122 | data.frame( 123 | model = 'age', 124 | imputation = FALSE, 125 | discretised = FALSE, 126 | c.index = age.c.index, 127 | c.index.lower = NA, 128 | c.index.upper = NA, 129 | calibration.score = age.only.boot.stats['calibration.score', 'val'], 130 | calibration.score.lower = age.only.boot.stats['calibration.score', 'lower'], 131 | calibration.score.upper = age.only.boot.stats['calibration.score', 'upper'] 132 | ), 133 | performance.file, 134 | index.cols = c('model', 'imputation', 'discretised') 135 | ) 136 | -------------------------------------------------------------------------------- /cox-ph/caliber-scale.R: -------------------------------------------------------------------------------- 1 | ageSpline <- function(x) { 2 | max((x-51)/10.289,0)^3 + 3 | (69-51) * (max((x-84)/10.289,0)^3) - 4 | ((84-51) * (max(((x-69))/10.289,0))^3)/(84-69) 5 | } 6 | 7 | caliberScaleUnits <- function(x, quantity) { 8 | if(quantity == 'age') { 9 | ## Spline function 10 | x <- sapply(x, ageSpline) 11 | } else if(quantity == 'total_chol_6mo') { 12 | ## Total cholesterol, per 1 mmol/L increase 13 | x <- x - 5 14 | } else if(quantity == 'hdl_6mo') { 15 | ## HDL, per 0.5 mmol/L increase 16 | x <- (x - 1.5) / 0.5 17 | } else if(quantity == 'pulse_6mo') { 18 | ## Heart rate, per 10 b.p.m increase 19 | x <- (x - 70) / 10 20 | } else if(quantity == 'crea_6mo') { 21 | ## Creatinine, per 30 μmol/L increase 22 | x <- (x - 60) / 30 23 | } else if(quantity == 'total_wbc_6mo') { 24 | ## White cell count, per 1.5 109/L increase 25 | x <- (x - 7.5) / 1.5 26 | } else if(quantity == 'haemoglobin_6mo') { 27 | ## Haemoglobin, per 1.5 g/dL increase 28 | x <- (x - 13.5) / 1.5 29 | } 30 | 31 | # Return transformed values 32 | x 33 | } 34 | 35 | caliberScale <- function(df, surv.time, surv.event) { 36 | # Return a data frame with all quantities normalised/scaled/standardised to 37 | # ranges etc used in Rapsomaniki et al. 2014 38 | 39 | # imd_score is sometimes turned into a factor, which causes an error with the 40 | # quantile function. Since it's just integers, make it numeric. 41 | df$imd_score <- as.numeric(df$imd_score) 42 | 43 | data.frame( 44 | ## Time to event 45 | surv_time = df[, surv.time], 46 | ## Death/censorship 47 | surv_event = df[, surv.event] %in% surv.event.yes, 48 | ## Rescaled age 49 | age = sapply(df$age, ageSpline), 50 | ## Gender 51 | gender = df$gender, 52 | ## Most deprived quintile, yes vs. no 53 | most_deprived = 54 | df$imd_score > quantile(df$imd_score, 0.8, na.rm = TRUE), 55 | ### SCAD diagnosis and severity ############################################ 56 | ## Other CHD / unstable angina / NSTEMI / STEMI vs. stable angina 57 | diagnosis = factorChooseFirst(factor(df$diagnosis), 'SA'), 58 | ## PCI in last 6 months, yes vs. no 59 | pci_6mo = df$pci_6mo, 60 | ## CABG in last 6 months, yes vs. no 61 | cabg_6mo = df$cabg_6mo, 62 | ## Previous/recurrent MI, yes vs. no 63 | hx_mi = df$hx_mi, 64 | ## Use of nitrates, yes vs. no 65 | long_nitrate = df$long_nitrate, 66 | ### CVD risk factors ####################################################### 67 | ## Ex-smoker vs. never / Current smoker vs. never 68 | smokstatus = factorChooseFirst(factor(df$smokstatus), 'Non'), 69 | ## Hypertension, present vs. absent 70 | hypertension = df$hypertension, 71 | ## Diabetes mellitus, present vs. absent 72 | diabetes_logical = df$diabetes != 'No diabetes', 73 | ## Total cholesterol, per 1 mmol/L increase 74 | total_chol_6mo = caliberScaleUnits(df$total_chol_6mo, 'total_chol_6mo'), 75 | ## HDL, per 0.5 mmol/L increase 76 | hdl_6mo = caliberScaleUnits(df$hdl_6mo, 'hdl_6mo'), 77 | ### CVD co-morbidities ##################################################### 78 | ## Heart failure, present vs. absent 79 | heart_failure = df$heart_failure, 80 | ## Peripheral arterial disease, present vs. absent 81 | pad = df$pad, 82 | ## Atrial fibrillation, present vs. absent 83 | hx_af = df$hx_af, 84 | ## Stroke, present vs. absent 85 | hx_stroke = df$hx_stroke, 86 | ### Non-CVD comorbidities ################################################## 87 | ## Chronic kidney disease, present vs. absent 88 | hx_renal = df$hx_renal, 89 | ## Chronic obstructive pulmonary disease, present vs. absent 90 | hx_copd = df$hx_copd, 91 | ## Cancer, present vs. absent 92 | hx_cancer = df$hx_cancer, 93 | ## Chronic liver disease, present vs. absent 94 | hx_liver = df$hx_liver, 95 | ### Psychosocial characteristics ########################################### 96 | ## Depression at diagnosis, present vs. absent 97 | hx_depression = df$hx_depression, 98 | ## Anxiety at diagnosis, present vs. absent 99 | hx_anxiety = df$hx_anxiety, 100 | ### Biomarkers ############################################################# 101 | ## Heart rate, per 10 b.p.m increase 102 | pulse_6mo = caliberScaleUnits(df$pulse_6mo, 'pulse_6mo'), 103 | ## Creatinine, per 30 μmol/L increase 104 | crea_6mo = caliberScaleUnits(df$crea_6mo, 'crea_6mo'), 105 | ## White cell count, per 1.5 109/L increase 106 | total_wbc_6mo = caliberScaleUnits(df$total_wbc_6mo, 'total_wbc_6mo'), 107 | ## Haemoglobin, per 1.5 g/dL increase 108 | haemoglobin_6mo = caliberScaleUnits(df$haemoglobin_6mo, 'haemoglobin_6mo') 109 | ) 110 | } 111 | -------------------------------------------------------------------------------- /random-forest/rfsrc-cv.R: -------------------------------------------------------------------------------- 1 | #+ knitr_setup, include = FALSE 2 | 3 | # Whether to cache the intensive code sections. Set to FALSE to recalculate 4 | # everything afresh. 5 | cacheoption <- FALSE 6 | # Disable lazy caching globally, because it fails for large objects, and all the 7 | # objects we wish to cache are large... 8 | opts_chunk$set(cache.lazy = FALSE) 9 | 10 | #' # Cross-validating discretisation of input variables in a survival model 11 | #' 12 | #' In difference to previous attempts at cross-validation, this uses between 10 13 | #' and 20 bins, not between 2 and 20, in an attempt to avoid throwing away data. 14 | 15 | output.filename.base <- '../../output/rfsrc-cv-nsplit-try3' 16 | data.filename <- '../../data/cohort-sanitised.csv' 17 | 18 | # If surv.vars is defined as a character vector here, the model only uses those 19 | # variables specified, eg c('age') would build a model purely based on age. If 20 | # not specified (ie commented out), it will use the defaults. 21 | # surv.predict <- c('age') 22 | 23 | #' ## Do the cross-validation 24 | #' 25 | #' The steps for this are common regardless of model type, so run the script to 26 | #' get a cross-validated model to further analyse... 27 | #+ rf_discretised_cv, cache=cacheoption 28 | 29 | source('../lib/rfsrc-cv-nsplit-bootstrap.R', chdir = TRUE) 30 | 31 | # Save the resulting 'final' model 32 | saveRDS(surv.model.fit, paste0(output.filename.base, '-final-model.rds')) 33 | 34 | #' # Results 35 | #' 36 | #' 37 | #' ## Performance 38 | #' 39 | #' ### Discrimination 40 | #' 41 | #' C-indices are **`r round(surv.model.fit.coeffs['c.train', 'val'], 3)` +/- 42 | #' `r round(surv.model.fit.coeffs['c.train', 'err'], 3)`** on the training set and 43 | #' **`r round(surv.model.fit.coeffs['c.test', 'val'], 3)` +/- 44 | #' `r round(surv.model.fit.coeffs['c.test', 'err'], 3)`** on the held-out test set. 45 | #' 46 | #' ### Calibration 47 | #' 48 | #' Does the model predict realistic probabilities of an event? 49 | #' 50 | #+ calibration_plot 51 | 52 | calibration.table <- 53 | calibrationTable( 54 | # Standard calibration options 55 | surv.model.fit, COHORT.prep[test.set,], 56 | # Always need to specify NA imputation for rfsrc 57 | na.action = 'na.impute' 58 | ) 59 | 60 | calibration.score <- calibrationScore(calibration.table) 61 | 62 | calibrationPlot(calibration.table, show.censored = TRUE, max.points = 10000) 63 | 64 | # Save the calibration data for later plotting 65 | write.csv( 66 | calibration.table, paste0(output.filename.base, '-calibration-table.csv') 67 | ) 68 | 69 | #' The area between the calibration curve and the diagonal is 70 | #' **`r round(calibration.score['area'], 3)`** +/- 71 | #' **`r round(calibration.score['se'], 3)`**. 72 | #' 73 | #' ## Model fit 74 | #' 75 | #+ resulting_fit 76 | 77 | print(surv.model.fit) 78 | 79 | #' ## Variable importance 80 | 81 | # First, load data from Cox modelling for comparison 82 | cox.var.imp <- read.csv(comparison.filename) 83 | 84 | # Then, get the variable importance from the model just fitted 85 | var.imp <- 86 | data.frame( 87 | var.imp = importance(surv.model.fit)/max(importance(surv.model.fit)) 88 | ) 89 | var.imp$quantity <- rownames(var.imp) 90 | 91 | var.imp <- merge(var.imp, cox.var.imp) 92 | 93 | # Save the results as a CSV 94 | write.csv(var.imp, paste0(output.filename, '-var-imp.csv')) 95 | 96 | #' ## Variable importance vs Cox model replication variable importance 97 | 98 | print( 99 | ggplot(var.imp, aes(x = our_range, y = var.imp)) + 100 | geom_point() + 101 | geom_text_repel(aes(label = quantity)) + 102 | # Log both...old coefficients for linearity, importance to shrink range! 103 | scale_x_log10() + 104 | scale_y_log10() 105 | ) 106 | 107 | print(cor(var.imp[, c('var.imp', 'our_range')], method = 'spearman')) 108 | 109 | #' ## Variable effects 110 | #' 111 | #+ variable_effects 112 | 113 | risk.by.variables <- data.frame() 114 | 115 | for(variable in continuous.vars) { 116 | # Create a partial effect table for this variable 117 | risk.by.variable <- 118 | partialEffectTable( 119 | surv.model.fit, COHORT.prep[-test.set,], variable, na.action = 'na.impute' 120 | ) 121 | # Slight kludge...rename the column which above is given the variable name to 122 | # just val, to allow rbinding 123 | names(risk.by.variable)[2] <- 'val' 124 | # Append a column with the variable's name so we can distinguish this in 125 | # a long data frame 126 | risk.by.variable$var <- variable 127 | # Append it to our ongoing big data frame 128 | risk.by.variables <- rbind(risk.by.variables, risk.by.variable) 129 | # Save the risks as we go 130 | write.csv(risk.by.variables, paste0(output.filename.base, '-var-effects.csv')) 131 | 132 | # Get the mean of the normalised risk for every value of the variable 133 | risk.aggregated <- 134 | aggregate( 135 | as.formula(paste0('risk.normalised ~ ', variable)), 136 | risk.by.variable, mean 137 | ) 138 | 139 | # work out the limits on the x-axis by taking the 1st and 99th percentiles 140 | x.axis.limits <- 141 | quantile(COHORT.full[, variable], c(0.01, 0.99), na.rm = TRUE) 142 | 143 | print( 144 | ggplot(risk.by.variable, aes_string(x = variable, y = 'risk.normalised')) + 145 | geom_line(alpha=0.01, aes(group = id)) + 146 | geom_line(data = risk.aggregated, colour = 'blue') + 147 | coord_cartesian(xlim = c(x.axis.limits)) 148 | ) 149 | } -------------------------------------------------------------------------------- /overview/missing-values-risk.R: -------------------------------------------------------------------------------- 1 | source('../lib/handymedical.R', chdir = TRUE) 2 | 3 | requirePlus('survminer', 'cowplot') 4 | 5 | models.base <- '../../output' 6 | cox.missing.filename <- 'caliber-replicate-with-missing-model-survreg-bootstrap-1.rds' 7 | cox.missing.riskdists.filename <- 'caliber-replicate-with-missing-survreg-3-risk-violins.csv' 8 | cox.missing.riskcats.filename <- 'caliber-replicate-with-missing-survreg-3-risk-cats.csv' 9 | 10 | cox.imp.filename <- 'caliber-replicate-imputed-survreg-4-surv-boot-imp.rds' 11 | 12 | cox.disc.filename <- 'all-cv-survreg-boot-try5-surv-boot.csv' 13 | 14 | data.filename <- '../../data/cohort-sanitised.csv' 15 | 16 | survPlots <- function(...) { 17 | surv.fits <- list(...) 18 | 19 | df <- data.frame() 20 | for(i in 1:length(surv.fits)) { 21 | df <- 22 | rbind( 23 | df, 24 | data.frame( 25 | variable = as.character(surv.fits[[i]]$call)[2], 26 | stratum = '1', 27 | time = surv.fits[[i]][1]$time, 28 | surv = surv.fits[[i]][1]$surv, 29 | lower = surv.fits[[i]][1]$lower, 30 | upper = surv.fits[[i]][1]$upper 31 | ), 32 | data.frame( 33 | variable = as.character(surv.fits[[i]]$call)[2], 34 | stratum = '2', 35 | time = surv.fits[[i]][2]$time, 36 | surv = surv.fits[[i]][2]$surv, 37 | lower = surv.fits[[i]][2]$lower, 38 | upper = surv.fits[[i]][2]$upper 39 | ) 40 | ) 41 | } 42 | 43 | ggplot( 44 | df, 45 | aes(x = time, y = surv, ymin = lower, ymax = upper) 46 | ) + 47 | geom_line(aes(colour = stratum)) + 48 | geom_ribbon(aes(fill = stratum), alpha = 0.4) + 49 | facet_grid(variable ~ .) 50 | } 51 | 52 | cox.missing.boot <- readRDS(file.path(models.base, cox.missing.filename)) 53 | fit.risks.miss <- bootStats(cox.missing.boot, '95ci') 54 | fit.risks.miss$var <- rownames(fit.risks.miss) 55 | 56 | cox.imp.boot <- readRDS(file.path(models.base, cox.imp.filename)) 57 | fit.risks.imp <- bootMIStats(cox.imp.boot, '95ci') 58 | fit.risks.imp$var <- rownames(fit.risks.imp) 59 | 60 | cox.disc.boot <- read.csv(file.path(models.base, cox.disc.filename)) 61 | fit.risks.disc <- bootStatsDf(cox.disc.boot) 62 | fit.risks.disc$var <- rownames(fit.risks.disc) 63 | 64 | # Create two data frames, one for missing vs imputed and one for discrete 65 | fit.risks.imp.vs.miss <- merge(fit.risks.imp, fit.risks.miss, by = c('var')) 66 | fit.risks.imp.vs.miss$model <- 'miss' 67 | # This is a slight cheat... The discrete model here was done by my home-made 68 | # bootstrap function, whose variable/level names have been preprocessed with 69 | # make.names for slightly annoying internal reasons. Luckily, all the logical 70 | # variables escape unscathed from this, and they're the only ones we can compare 71 | # anyway! 72 | fit.risks.imp.vs.disc <- merge(fit.risks.imp, fit.risks.disc, by = c('var')) 73 | fit.risks.imp.vs.disc$model <- 'disc' 74 | 75 | fit.risks <- rbind(fit.risks.imp.vs.miss, fit.risks.imp.vs.disc) 76 | 77 | # Lose a few irrelevant variables 78 | fit.risks <- 79 | subset( 80 | fit.risks, !(var %in% c('(Intercept)', 'c.index', 'calibration.score')) 81 | ) 82 | 83 | # Make them 84 | 85 | disc.vs.cont.risks.plot <- 86 | ggplot( 87 | fit.risks, 88 | aes( 89 | x = val.x, xmin = lower.x, xmax = upper.x, 90 | y = val.y, ymin = lower.y, ymax = upper.y, 91 | label = var, colour = model 92 | ) 93 | ) + 94 | geom_point() + 95 | geom_errorbar() + 96 | geom_errorbarh() + 97 | geom_text() + 98 | # Remove the legend, massively compresses the plot horizontally 99 | theme(legend.position="none") 100 | 101 | 102 | 103 | risk.dist.by.var <- 104 | read.csv(file.path(models.base, cox.missing.riskdists.filename)) 105 | risk.cats <- read.csv(file.path(models.base, cox.missing.riskcats.filename)) 106 | 107 | 108 | # Plot the results 109 | risk.violins.plot <- 110 | ggplot() + 111 | # First, and therefore at the bottom, draw the reference line at risk = 1 112 | geom_hline(yintercept = 1) + 113 | # Then, on top of that, draw the violin plot of the risk from the data 114 | geom_violin(data = risk.dist.by.var, aes(x = quantity, y = risk)) + 115 | geom_pointrange( 116 | data = risk.cats, 117 | aes(x = quantity, y = our_value, ymin = our_lower, 118 | ymax = our_upper), 119 | 120 | position = position_jitter(width = 0.1) 121 | ) + 122 | geom_text( 123 | data = risk.cats, 124 | aes( 125 | x = quantity, 126 | y = our_value, 127 | label = quantity.level 128 | ) 129 | ) + 130 | scale_y_continuous(breaks = c(0.75, 1.0, 1.25, 1.5)) 131 | 132 | 133 | # Kaplan-Meier survival curves for a few example variables being missing 134 | COHORT <- fread(data.filename) 135 | COHORT <- subset(COHORT, !exclude & time_death > 0) 136 | 137 | # Calculate the curves 138 | km.hdl <- 139 | survfit( 140 | Surv(time_death, endpoint_death == 'Death') ~ is.na(hdl_6mo), 141 | data = COHORT 142 | ) 143 | km.total_chol <- 144 | survfit( 145 | Surv(time_death, endpoint_death == 'Death') ~ is.na(total_chol_6mo), 146 | data = COHORT 147 | ) 148 | km.crea <- 149 | survfit( 150 | Surv(time_death, endpoint_death == 'Death') ~ is.na(crea_6mo), 151 | data = COHORT 152 | ) 153 | 154 | km.missingness <- 155 | survPlots(km.hdl, km.total_chol, km.crea) + 156 | theme( 157 | legend.position = "none", 158 | # Remove grey labels on facets 159 | strip.background = element_blank(), 160 | strip.text = element_blank() 161 | ) 162 | 163 | theme_set(theme_cowplot(font_size = 10)) 164 | plot_grid( 165 | disc.vs.cont.risks.plot, risk.violins.plot, 166 | km.missingness, 167 | labels = c("A", "B", "C"), 168 | align = "v", nrow = 1 169 | ) 170 | ggsave( 171 | '../../output/missing-values-risk.pdf', 172 | width = 16, 173 | height = 5, 174 | units = 'cm', 175 | useDingbats = FALSE 176 | ) 177 | -------------------------------------------------------------------------------- /lib/rfsrc-cv-mtry-nsplit-logical.R: -------------------------------------------------------------------------------- 1 | bootstraps <- 3 2 | split.rule <- 'logrank' 3 | n.threads <- 20 4 | 5 | # Cross-validation variables 6 | ns.splits <- c(0, 5, 10, 15, 20, 30) 7 | ms.try <- c(50, 100, 200, 300, 400) 8 | n.trees.cv <- 500 9 | n.imputations <- 3 10 | cv.n.folds <- 3 11 | n.trees.final <- 2000 12 | n.data <- NA # This is of full dataset...further rows may be excluded in prep 13 | 14 | calibration.filename <- paste0(output.filename.base, '-calibration.csv') 15 | 16 | # If we've not already done a calibration, then do one 17 | if(!file.exists(calibration.filename)) { 18 | # Create an empty data frame to aggregate stats per fold 19 | cv.performance <- data.frame() 20 | 21 | # Items to cross-validate over 22 | cv.vars <- expand.grid(ns.splits, ms.try) 23 | names(cv.vars) <- c('n.splits', 'm.try') 24 | 25 | COHORT.cv <- COHORT.bigdata[-test.set, ] 26 | 27 | # Run crossvalidations. No need to parallelise because rfsrc is parallelised 28 | for(i in 1:nrow(cv.vars)) { 29 | cat( 30 | 'Calibration', i, '...\n' 31 | ) 32 | 33 | # Get folds for cross-validation 34 | cv.folds <- cvFolds(nrow(COHORT.cv), cv.n.folds) 35 | 36 | cv.fold.performance <- data.frame() 37 | 38 | for(j in 1:cv.n.folds) { 39 | time.start <- handyTimer() 40 | # Fit model to the training set 41 | surv.model.fit <- 42 | survivalFit( 43 | surv.predict, 44 | COHORT.cv[-cv.folds[[j]],], 45 | model.type = 'rfsrc', 46 | n.trees = n.trees.cv, 47 | split.rule = split.rule, 48 | n.threads = n.threads, 49 | nsplit = cv.vars$n.splits[i], 50 | nimpute = n.imputations, 51 | na.action = 'na.impute', 52 | mtry = cv.vars$m.try[i] 53 | ) 54 | time.learn <- handyTimer(time.start) 55 | 56 | time.start <- handyTimer() 57 | # Get C-index on validation set 58 | c.index.val <- 59 | cIndex( 60 | surv.model.fit, COHORT.cv[cv.folds[[j]],], 61 | na.action = 'na.impute' 62 | ) 63 | time.c.index <- handyTimer(time.start) 64 | 65 | time.start <- handyTimer() 66 | # Get C-index on validation set 67 | calibration.score <- 68 | calibrationScore( 69 | calibrationTable( 70 | surv.model.fit, COHORT.cv[cv.folds[[j]],], na.action = 'na.impute' 71 | ) 72 | ) 73 | time.calibration <- handyTimer(time.start) 74 | 75 | # Append the stats we've obtained from this fold 76 | cv.fold.performance <- 77 | rbind( 78 | cv.fold.performance, 79 | data.frame( 80 | calibration = i, 81 | cv.fold = j, 82 | n.splits = cv.vars$n.splits[i], 83 | m.try = cv.vars$m.try[i], 84 | c.index.val, 85 | time.learn, 86 | time.c.index, 87 | time.calibration 88 | ) 89 | ) 90 | 91 | } # End cross-validation loop (j) 92 | 93 | 94 | # rbind the performance by fold 95 | cv.performance <- 96 | rbind( 97 | cv.performance, 98 | cv.fold.performance 99 | ) 100 | 101 | # Save output at the end of each loop 102 | write.csv(cv.performance, calibration.filename) 103 | 104 | } # End calibration loop (i) 105 | 106 | 107 | 108 | } else { # If we did previously calibrate, load it 109 | cv.performance <- read.csv(calibration.filename) 110 | } 111 | 112 | # Find the best calibration... 113 | # First, average performance across cross-validation folds 114 | cv.performance.average <- 115 | aggregate( 116 | c.index.val ~ calibration, 117 | data = cv.performance, 118 | mean 119 | ) 120 | # Find the highest value 121 | best.calibration <- 122 | cv.performance.average$calibration[ 123 | which.max(cv.performance.average$c.index.val) 124 | ] 125 | # And finally, find the first row of that calibration to get the n.bins values 126 | best.calibration.row1 <- 127 | min(which(cv.performance$calibration == best.calibration)) 128 | 129 | #' ## Fit the final model 130 | #' 131 | #' This may take some time, so we'll cache it if possible... 132 | 133 | #+ fit_final_model 134 | 135 | surv.model.fit <- 136 | survivalFit( 137 | surv.predict, 138 | COHORT.bigdata[-test.set,], 139 | model.type = 'rfsrc', 140 | n.trees = n.trees.final, 141 | split.rule = split.rule, 142 | n.threads = n.threads, 143 | nimpute = n.imputations, 144 | nsplit = cv.performance[best.calibration.row1, 'n.splits'], 145 | mtry = cv.performance[best.calibration.row1, 'm.try'], 146 | na.action = 'na.impute', 147 | importance = 'permute' 148 | ) 149 | 150 | # Save the fit object 151 | saveRDS( 152 | surv.model.fit, 153 | paste0(output.filename.base, '-surv-model.rds') 154 | ) 155 | 156 | surv.model.fit.boot <- 157 | survivalBootstrap( 158 | surv.predict, 159 | COHORT.bigdata[-test.set,], # Training set 160 | COHORT.bigdata[test.set,], # Test set 161 | model.type = 'rfsrc', 162 | n.trees = n.trees.final, 163 | split.rule = split.rule, 164 | n.threads = n.threads, 165 | nimpute = n.imputations, 166 | nsplit = cv.performance[best.calibration.row1, 'n.splits'], 167 | mtry = cv.performance[best.calibration.row1, 'm.try'], 168 | na.action = 'na.impute', 169 | bootstraps = bootstraps 170 | ) 171 | 172 | # Save the fit object 173 | saveRDS( 174 | surv.model.fit.boot, 175 | paste0(output.filename.base, '-surv-model-bootstraps.rds') 176 | ) 177 | 178 | # Get C-indices for training and test sets 179 | surv.model.fit.coeffs <- bootStats(surv.model.fit.boot, uncertainty = '95ci') 180 | 181 | # Save them to the all-models comparison table 182 | varsToTable( 183 | data.frame( 184 | model = 'rfbigdata', 185 | imputation = FALSE, 186 | discretised = FALSE, 187 | c.index = surv.model.fit.coeffs['c.test', 'val'], 188 | c.index.lower = surv.model.fit.coeffs['c.test', 'lower'], 189 | c.index.upper = surv.model.fit.coeffs['c.test', 'upper'], 190 | calibration.score = surv.model.fit.coeffs['calibration.score', 'val'], 191 | calibration.score.lower = 192 | surv.model.fit.coeffs['calibration.score', 'lower'], 193 | calibration.score.upper = 194 | surv.model.fit.coeffs['calibration.score', 'upper'] 195 | ), 196 | performance.file, 197 | index.cols = c('model', 'imputation', 'discretised') 198 | ) -------------------------------------------------------------------------------- /overview/explore-dataset.R: -------------------------------------------------------------------------------- 1 | #+ knitr_setup, include = FALSE 2 | 3 | # Whether to cache the intensive code sections. Set to FALSE to recalculate 4 | # everything afresh. 5 | cacheoption <- FALSE 6 | # Disable lazy caching globally, because it fails for large objects, and all the 7 | # objects we wish to cache are large... 8 | opts_chunk$set(cache.lazy = FALSE) 9 | 10 | #' # Simple data investigations 11 | #' 12 | #+ setup, message=FALSE 13 | 14 | data.filename <- '../../data/cohort-sanitised.csv' 15 | 16 | source('../lib/shared.R') 17 | requirePlus('rms') 18 | 19 | COHORT.full <- data.frame(fread(data.filename)) 20 | 21 | COHORT.use <- subset(COHORT.full, !exclude) 22 | 23 | #' ## Missing data 24 | #' 25 | #' How much is there, and where is it concentrated? 26 | #' 27 | #+ missing_data_plot 28 | 29 | interesting.vars <- 30 | c( 31 | 'age', 'gender', 'diagnosis', 'pci_6mo', 'cabg_6mo', 32 | 'hx_mi', 'long_nitrate', 'smokstatus', 'hypertension', 'diabetes', 33 | 'total_chol_6mo', 'hdl_6mo', 'heart_failure', 'pad', 'hx_af', 'hx_stroke', 34 | 'hx_renal', 'hx_copd', 'hx_cancer', 'hx_liver', 'hx_depression', 35 | 'hx_anxiety', 'pulse_6mo', 'crea_6mo', 'total_wbc_6mo','haemoglobin_6mo', 36 | 'imd_score' 37 | ) 38 | 39 | missingness <- 40 | unlist(lapply(COHORT.use[, interesting.vars], function(x){sum(is.na(x))})) 41 | 42 | missingness <- data.frame(var = names(missingness), n.missing = missingness) 43 | 44 | missingness$pc.missing <- missingness$n.missing / nrow(COHORT.use) 45 | 46 | ggplot(subset(missingness, n.missing > 0), aes(x = var, y = pc.missing)) + 47 | geom_bar(stat = 'identity') + 48 | ggtitle('% missingness by variable') 49 | 50 | #' Are any variables commonly found to be jointly missing? 51 | #' 52 | #+ missing_jointly_plot 53 | 54 | COHORT.missing <- 55 | data.frame( 56 | lapply(COHORT.use[, interesting.vars], function(x){is.na(x)}) 57 | ) 58 | 59 | COHORT.missing.cor <- data.frame( 60 | t(combn(1:ncol(COHORT.missing), 2)), 61 | var1 = NA, var2 = NA, joint.missing = NA 62 | ) 63 | 64 | for(i in 1:nrow(COHORT.missing.cor)) { 65 | var1 <- sort(names(COHORT.missing))[COHORT.missing.cor[i, 'X1']] 66 | var2 <- sort(names(COHORT.missing))[COHORT.missing.cor[i, 'X2']] 67 | COHORT.missing.cor[i, c('var1', 'var2')] <- c(var1, var2) 68 | if(any(COHORT.missing[, var1]) & any(COHORT.missing[, var2])) { 69 | COHORT.missing.cor[i, 'joint.missing'] <- 70 | sum(!(COHORT.missing[, var1]) & !(COHORT.missing[, var2])) / 71 | sum(!(COHORT.missing[, var1]) | !(COHORT.missing[, var2])) 72 | } 73 | } 74 | 75 | ggplot(subset(COHORT.missing.cor, !is.na(joint.missing)), aes(x = var1, y = var2, fill = joint.missing)) + 76 | geom_tile() 77 | 78 | #' Some tests are usually ordered together. Are they missing together? 79 | #' 80 | #+ missing_venn 81 | 82 | ggplot( 83 | data.frame( 84 | x = 1, 85 | category = factor(c('HDL', 'both', 'total cholesterol', 'neither'), 86 | levels = c('HDL', 'both', 'total cholesterol', 'neither')), 87 | val = c( 88 | sum(!COHORT.missing$hdl_6mo) - sum(!COHORT.missing$hdl_6mo & !COHORT.missing$total_chol_6mo), 89 | sum(!COHORT.missing$hdl_6mo & !COHORT.missing$total_chol_6mo), 90 | sum(!COHORT.missing$total_chol_6mo) - sum(!COHORT.missing$hdl_6mo & !COHORT.missing$total_chol_6mo), 91 | sum(COHORT.missing$hdl_6mo & COHORT.missing$total_chol_6mo) 92 | ) 93 | ), 94 | aes(x = x, y = val, fill = category) 95 | ) + 96 | geom_bar(stat='identity') + 97 | scale_fill_manual(values=c("#cc0000", "#990099", "#0000cc", '#cccccc')) 98 | 99 | ggplot( 100 | data.frame( 101 | x = 1, 102 | category = factor(c('haemoglobin', 'both', 'WBC', 'neither'), 103 | levels = c('haemoglobin', 'both', 'WBC', 'neither')), 104 | val = c( 105 | sum(!COHORT.missing$haemoglobin_6mo) - sum(!COHORT.missing$haemoglobin_6mo & !COHORT.missing$total_wbc_6mo), 106 | sum(!COHORT.missing$haemoglobin_6mo & !COHORT.missing$total_wbc_6mo), 107 | sum(!COHORT.missing$total_wbc_6mo) - sum(!COHORT.missing$haemoglobin_6mo & !COHORT.missing$total_wbc_6mo), 108 | sum(COHORT.missing$hdl_6mo & COHORT.missing$total_wbc_6mo) 109 | ) 110 | ), 111 | aes(x = x, y = val, fill = category) 112 | ) + 113 | geom_bar(stat='identity') + 114 | scale_fill_manual(values=c("#cc0000", "#990099", "#0000cc", '#cccccc')) 115 | 116 | 117 | #' ### Survival and missingness 118 | #' 119 | #' Let's plot survival curves for a few types of data by missingness... 120 | #' 121 | #+ missingness_survival 122 | 123 | COHORT.use$surv <- with(COHORT.use, Surv(time_death, endpoint_death == 'Death')) 124 | 125 | plotSurvSummary <- function(df, var) { 126 | df$var_summary <- factorNAfix(binByQuantile(df[, var], c(0,0.1,0.9,1))) 127 | surv.curve <- npsurv(surv ~ var_summary, data = df) 128 | print(survplot(surv.curve, ylab = var)) 129 | } 130 | 131 | plotSurvSummary(COHORT.use, 'total_wbc_6mo') 132 | 133 | surv.curve <- npsurv(surv ~ is.na(crea_6mo), data = COHORT.use) 134 | survplot(surv.curve) 135 | 136 | surv.curve <- npsurv(surv ~ is.na(haemoglobin_6mo), data = COHORT.use) 137 | survplot(surv.curve) 138 | 139 | surv.curve <- npsurv(surv ~ is.na(hdl_6mo), data = COHORT.use) 140 | survplot(surv.curve) 141 | 142 | surv.curve <- npsurv(surv ~ is.na(pulse_6mo), data = COHORT.use) 143 | survplot(surv.curve) 144 | 145 | surv.curve <- npsurv(surv ~ is.na(smokstatus), data = COHORT.use) 146 | survplot(surv.curve) 147 | 148 | surv.curve <- npsurv(surv ~ is.na(total_chol_6mo), data = COHORT.use) 149 | survplot(surv.curve) 150 | 151 | surv.curve <- npsurv(surv ~ is.na(total_wbc_6mo), data = COHORT.use) 152 | survplot(surv.curve) 153 | 154 | #' Variables where it's safer to be missing data: 155 | #' * Creatinine 156 | #' * 157 | #' 158 | #' Variables where it's safer to have data: 159 | #' * 160 | 161 | #' ## Data distributions 162 | #' 163 | #' How are the data distributed? 164 | #' 165 | #+ data_distributions 166 | 167 | ggplot(COHORT.full) + 168 | geom_histogram(aes(x = crea_6mo, fill = crea_6mo %% 10 != 0), binwidth = 1) + 169 | xlim(0, 300) 170 | 171 | #' Creatinine levels are fairly smoothly distributed. The highlighted bins 172 | #' indicate numerical values divisible by 10, and there seems to be no 173 | #' particular bias. The small cluster of extremely low values could be 174 | #' misrecorded somehow. 175 | 176 | ggplot(COHORT.full) + 177 | geom_histogram(aes(x = pulse_6mo, fill = pulse_6mo %% 4 != 0), binwidth = 1) + 178 | xlim(0, 150) 179 | 180 | #' Heart rate data have high missingness, and those few values we do have are 181 | #' very heavily biased towards multiples of 4. This is likely because heart rate 182 | #' is commonly measured for 15 seconds and then multiplied up to give a result 183 | #' in beats per minute! There is also a bias towards round numbers, with large 184 | #' peaks at 60, 80, 100 and 120... -------------------------------------------------------------------------------- /random-forest/rf-imputed.R: -------------------------------------------------------------------------------- 1 | #+ knitr_setup, include = FALSE 2 | 3 | # Whether to cache the intensive code sections. Set to FALSE to recalculate 4 | # everything afresh. 5 | cacheoption <- FALSE 6 | # Disable lazy caching globally, because it fails for large objects, and all the 7 | # objects we wish to cache are large... 8 | opts_chunk$set(cache.lazy = FALSE) 9 | 10 | #' # Random survival forests on pre-imputed data 11 | #' 12 | #' The imputation algorithm used for the Cox modelling 'cheats' in a couple of 13 | #' different ways. Firstly, the imputation model is fitted on the whole dataset, 14 | #' rather than training a model on a training set and then using it on the test 15 | #' set. Secondly, causality is violated in that future measurements and even the 16 | #' outcome variable (ie whether a patient died) is included in the model. The 17 | #' rationale for this is that you want to have the best and most complete 18 | #' dataset possible, and if doctors are going to go on and use this in clinical 19 | #' practice they will simply take the additional required measurements when 20 | #' calculating a patient's risk score, rather than trying to do the modelling on 21 | #' incomplete data. 22 | #' 23 | #' However, this could allow the imputation to 'pass back' useful data to the 24 | #' Cox model, and thus artificially inflate its performance statistics. 25 | #' 26 | #' It is non-trivial to work around this, not least because the imputation 27 | #' package we used does not expose the model to allow training on a subset of 28 | #' the data. A quick and dirty method to check whether this may be an issue, 29 | #' therefore, is to train a random forest model on the imputed dataset, and see 30 | #' if it can outperform the Cox model. So, let's try it... 31 | 32 | #+ user_variables, message=FALSE 33 | 34 | imputed.data.filename <- '../../data/COHORT_complete.rds' 35 | endpoint <- 'death.imputed' 36 | 37 | n.trees <- 500 38 | n.split <- 10 39 | n.threads <- 40 40 | 41 | bootstraps <- 50 42 | 43 | output.filename.base <- '../../output/rf-imputed-try1' 44 | boot.filename <- paste0(output.filename.base, '-boot.rds') 45 | 46 | # What to do with missing data 47 | continuous.vars <- 48 | c( 49 | 'age', 'total_chol_6mo', 'hdl_6mo', 'pulse_6mo', 'crea_6mo', 50 | 'total_wbc_6mo', 'haemoglobin_6mo' 51 | ) 52 | untransformed.vars <- c('anonpatid', 'surv_time', 'imd_score', 'exclude') 53 | 54 | source('../lib/shared.R') 55 | require(ggrepel) 56 | 57 | #' ## Read imputed data 58 | 59 | # Load the data from its RDS file 60 | imputed.data <- readRDS(imputed.data.filename) 61 | 62 | # Remove rows with death time of 0 to avoid fitting errors, and get the survival 63 | # columns ready 64 | for(i in 1:length(imputed.data)) { 65 | imputed.data[[i]] <- imputed.data[[i]][imputed.data[[i]][, surv.time] > 0, ] 66 | # Put in a fake exclude column for the next function (exclusions are already 67 | # excluded in the imputed dataset) 68 | imputed.data[[i]]$exclude <- FALSE 69 | imputed.data[[i]]$imd_score <- as.numeric(imputed.data[[i]]$imd_score) 70 | imputed.data[[i]] <- 71 | caliberExtraPrep( 72 | prepSurvCol( 73 | imputed.data[[i]], 'time_death', 'endpoint_death', 1 74 | ) 75 | ) 76 | } 77 | 78 | # Define test set 79 | test.set <- testSetIndices(imputed.data[[1]], random.seed = 78361) 80 | 81 | 82 | #' ## Fit random forest model 83 | 84 | # Do a quick and dirty fit on a single imputed dataset, to draw calibration 85 | # curve from 86 | time.start <- handyTimer() 87 | surv.model.fit <- 88 | survivalFit( 89 | surv.predict, 90 | imputed.data[[1]][-test.set,], 91 | model.type = 'rfsrc', 92 | n.trees = n.trees, 93 | split.rule = split.rule, 94 | n.threads = n.threads, 95 | nsplit = n.split 96 | ) 97 | time.fit <- handyTimer(time.start) 98 | 99 | fit.rf.boot <- list() 100 | 101 | # Perform bootstrap fitting for every multiply imputed dataset 102 | time.start <- handyTimer() 103 | for(i in 1:length(imputed.data)) { 104 | fit.rf.boot[[i]] <- 105 | survivalBootstrap( 106 | surv.predict, 107 | imputed.data[[i]][-test.set, ], 108 | imputed.data[[i]][test.set, ], 109 | model.type = 'rfsrc', 110 | bootstraps = bootstraps, 111 | n.threads = n.threads 112 | ) 113 | } 114 | time.boot <- handyTimer(time.start) 115 | 116 | # Save the fits, because it might've taken a while! 117 | saveRDS(fit.rf.boot, boot.filename) 118 | 119 | #' Model fitted in `r round(time.fit)` seconds, and `r bootstraps` fits 120 | #' performed on `r length(imputed.data)` imputed datasets in 121 | #' `r round(time.boot)` seconds. 122 | 123 | # Unpackage the uncertainties from the bootstrapped data 124 | fit.rf.boot.ests <- bootMIStats(fit.rf.boot) 125 | 126 | # Save bootstrapped performance values 127 | varsToTable( 128 | data.frame( 129 | model = 'rfsrc', 130 | imputation = TRUE, 131 | discretised = FALSE, 132 | c.index = fit.rf.boot.ests['c.test', 'val'], 133 | c.index.lower = fit.rf.boot.ests['c.test', 'lower'], 134 | c.index.upper = fit.rf.boot.ests['c.test', 'upper'], 135 | calibration.score = fit.rf.boot.ests['calibration.score', 'val'], 136 | calibration.score.lower = fit.rf.boot.ests['calibration.score', 'lower'], 137 | calibration.score.upper = fit.rf.boot.ests['calibration.score', 'upper'] 138 | ), 139 | performance.file, 140 | index.cols = c('model', 'imputation', 'discretised') 141 | ) 142 | 143 | #' ## Performance 144 | #' 145 | #' Having fitted the Cox model, how did we do? The c-indices were calculated as 146 | #' part of the bootstrapping, so we just need to take a look at those... 147 | #' 148 | #' C-indices are **`r round(fit.rf.boot.ests['c.train', 'val'], 3)` 149 | #' (`r round(fit.rf.boot.ests['c.train', 'lower'], 3)` - 150 | #' `r round(fit.rf.boot.ests['c.train', 'upper'], 3)`)** on the training set and 151 | #' **`r round(fit.rf.boot.ests['c.test', 'val'], 3)` 152 | #' (`r round(fit.rf.boot.ests['c.test', 'lower'], 3)` - 153 | #' `r round(fit.rf.boot.ests['c.test', 'upper'], 3)`)** on the test set. 154 | #' 155 | #' 156 | #' ### Calibration 157 | #' 158 | #' The bootstrapped calibration score is 159 | #' **`r round(fit.rf.boot.ests['calibration.score', 'val'], 3)` 160 | #' (`r round(fit.rf.boot.ests['calibration.score', 'lower'], 3)` - 161 | #' `r round(fit.rf.boot.ests['calibration.score', 'upper'], 3)`)**. 162 | #' 163 | #' Let's draw a representative curve from the unbootstrapped fit... (It would be 164 | #' better to draw all the curves from the bootstrap fit to get an idea of 165 | #' variability, but I've not implemented this yet.) 166 | #' 167 | #+ calibration_plot 168 | 169 | calibration.table <- 170 | calibrationTable(surv.model.fit, imputed.data[[i]][test.set, ]) 171 | 172 | calibration.score <- calibrationScore(calibration.table) 173 | 174 | calibrationPlot(calibration.table) 175 | 176 | #' The area between the calibration curve and the diagonal is 177 | #' **`r round(calibration.score[['area']], 3)`** +/- 178 | #' **`r round(calibration.score[['se']], 3)`**. -------------------------------------------------------------------------------- /cox-ph/cox-discretised-imputed.R: -------------------------------------------------------------------------------- 1 | #+ knitr_setup, include = FALSE 2 | 3 | # Whether to cache the intensive code sections. Set to FALSE to recalculate 4 | # everything afresh. 5 | cacheoption <- FALSE 6 | # Disable lazy caching globally, because it fails for large objects, and all the 7 | # objects we wish to cache are large... 8 | opts_chunk$set(cache.lazy = FALSE) 9 | 10 | #' # Discrete Cox model with imputed data 11 | #' 12 | #' The discrete Cox model performs very similarly to the normal Cox model, even 13 | #' without performing imputation first. Let's try it with imputation, and see 14 | #' if the performance is boosted. 15 | #' 16 | #' We're going to use the same parameters for discretising as were found for the 17 | #' data with missing values. On the one hand, this is slightly unfair because 18 | #' these values might be suboptimal now that the data are imputed but, on the 19 | #' other, it does allow for direct comparison. If we cross-validated again, we 20 | #' would be very likely to find different parameters (there are far more than 21 | #' can plausibly be tried), which may lead performance to be better or worse 22 | #' entirely at random. 23 | 24 | # Calibration from cross-validation performed on data before imputation 25 | calibration.filename <- '../../output/survreg-crossvalidation-try4.csv' 26 | # The first part of the filename for any output 27 | output.filename.base <- '../../output/survreg-discrete-imputed-try1' 28 | imputed.data.filename <- '../../data/COHORT_complete.rds' 29 | boot.filename <- paste0(output.filename.base, '-boot.rds') 30 | 31 | n.threads <- 20 32 | bootstraps <- 50 33 | 34 | model.type <- 'survreg' 35 | 36 | #' ## Discretise data 37 | #' 38 | #' First, let's load the results from the calibrations and find the parameters 39 | #' for discretisation. 40 | 41 | source('../lib/shared.R') 42 | 43 | continuous.vars <- 44 | c( 45 | 'age', 'total_chol_6mo', 'hdl_6mo', 'pulse_6mo', 'crea_6mo', 46 | 'total_wbc_6mo', 'haemoglobin_6mo' 47 | ) 48 | 49 | # read file containing calibrations 50 | cv.performance <- read.csv(calibration.filename) 51 | 52 | # Find the best calibration... 53 | # First, average performance across cross-validation folds 54 | cv.performance.average <- 55 | aggregate( 56 | c.index.val ~ calibration, 57 | data = cv.performance, 58 | mean 59 | ) 60 | # Find the highest value 61 | best.calibration <- 62 | cv.performance.average$calibration[ 63 | which.max(cv.performance.average$c.index.val) 64 | ] 65 | # And finally, find the first row of that calibration to get the n.bins values 66 | best.calibration.row1 <- 67 | min(which(cv.performance$calibration == best.calibration)) 68 | 69 | # Get its parameters 70 | n.bins <- 71 | t( 72 | cv.performance[best.calibration.row1, continuous.vars] 73 | ) 74 | 75 | 76 | # Reset process settings with the base setings 77 | process.settings <- 78 | list( 79 | var = c('anonpatid', 'time_death', 'imd_score', 'exclude'), 80 | method = c(NA, NA, NA, NA), 81 | settings = list(NA, NA, NA, NA) 82 | ) 83 | for(j in 1:length(continuous.vars)) { 84 | process.settings$var <- c(process.settings$var, continuous.vars[j]) 85 | process.settings$method <- c(process.settings$method, 'binByQuantile') 86 | process.settings$settings <- 87 | c( 88 | process.settings$settings, 89 | list( 90 | seq( 91 | # Quantiles are obviously between 0 and 1 92 | 0, 1, 93 | # Choose a random number of bins (and for n bins, you need n + 1 breaks) 94 | length.out = n.bins[j] 95 | ) 96 | ) 97 | ) 98 | } 99 | 100 | #' ## Load imputed data 101 | #' 102 | #' Load the data and prepare it with the settings above 103 | 104 | # Load the data from its RDS file 105 | imputed.data <- readRDS(imputed.data.filename) 106 | 107 | # Remove rows with death time of 0 to avoid fitting errors, and get the survival 108 | # columns ready 109 | for(i in 1:length(imputed.data)) { 110 | imputed.data[[i]] <- imputed.data[[i]][imputed.data[[i]][, surv.time] > 0, ] 111 | # Put in a fake exclude column for the next function (exclusions are already 112 | # excluded in the imputed dataset) 113 | imputed.data[[i]]$exclude <- FALSE 114 | imputed.data[[i]]$imd_score <- as.numeric(imputed.data[[i]]$imd_score) 115 | imputed.data[[i]] <- 116 | prepData( 117 | imputed.data[[i]], 118 | cols.keep, 119 | process.settings, 120 | 'time_death', 'endpoint_death', 1, 121 | extra.fun = caliberExtraPrep 122 | ) 123 | } 124 | 125 | # Define test set 126 | test.set <- testSetIndices(imputed.data[[1]], random.seed = 78361) 127 | 128 | # Do a quick and dirty fit on a single imputed dataset, to draw calibration 129 | # curve from 130 | time.start <- handyTimer() 131 | surv.model.fit <- 132 | survivalFit( 133 | surv.predict, 134 | COHORT.optimised[-test.set,], # Training set 135 | model.type = model.type, 136 | n.threads = n.threads 137 | ) 138 | time.fit <- handyTimer(time.start) 139 | 140 | # Perform bootstrap fitting for every multiply imputed dataset 141 | surv.fit.boot <- list() 142 | time.start <- handyTimer() 143 | for(i in 1:length(imputed.data)) { 144 | surv.fit.boot[[i]] <- 145 | survivalBootstrap( 146 | surv.predict, 147 | imputed.data[[i]][-test.set, ], 148 | imputed.data[[i]][test.set, ], 149 | model.type = model.type, 150 | bootstraps = bootstraps, 151 | n.threads = n.threads 152 | ) 153 | } 154 | time.boot <- handyTimer(time.start) 155 | 156 | # Save the fits, because it might've taken a while! 157 | saveRDS(surv.fit.boot, boot.filename) 158 | 159 | #' Model fitted in `r round(time.fit)` seconds, and `r bootstraps` fits 160 | #' performed on `r length(imputed.data)` imputed datasets in 161 | #' `r round(time.boot)` seconds. 162 | 163 | # Unpackage the uncertainties from the bootstrapped data 164 | surv.fit.boot.ests <- bootMIStats(surv.fit.boot) 165 | 166 | # Save bootstrapped performance values 167 | varsToTable( 168 | data.frame( 169 | model = 'cox', 170 | imputation = TRUE, 171 | discretised = TRUE, 172 | c.index = surv.fit.boot.ests['c.test', 'val'], 173 | c.index.lower = surv.fit.boot.ests['c.test', 'lower'], 174 | c.index.upper = surv.fit.boot.ests['c.test', 'upper'], 175 | calibration.score = surv.fit.boot.ests['calibration.score', 'val'], 176 | calibration.score.lower = surv.fit.boot.ests['calibration.score', 'lower'], 177 | calibration.score.upper = surv.fit.boot.ests['calibration.score', 'upper'] 178 | ), 179 | performance.file, 180 | index.cols = c('model', 'imputation', 'discretised') 181 | ) 182 | 183 | 184 | #' # Results 185 | #' 186 | #' ## Performance 187 | #' 188 | #' Having fitted the Cox model, how did we do? The c-indices were calculated as 189 | #' part of the bootstrapping, so we just need to take a look at those... 190 | #' 191 | #' C-indices are **`r round(surv.fit.boot.ests['c.train', 'val'], 3)` 192 | #' (`r round(surv.fit.boot.ests['c.train', 'lower'], 3)` - 193 | #' `r round(surv.fit.boot.ests['c.train', 'upper'], 3)`)** on the training set and 194 | #' **`r round(surv.fit.boot.ests['c.test', 'val'], 3)` 195 | #' (`r round(surv.fit.boot.ests['c.test', 'lower'], 3)` - 196 | #' `r round(surv.fit.boot.ests['c.test', 'upper'], 3)`)** on the test set. 197 | #' 198 | #' 199 | #' ### Calibration 200 | #' 201 | #' The bootstrapped calibration score is 202 | #' **`r round(surv.fit.boot.ests['calibration.score', 'val'], 3)` 203 | #' (`r round(surv.fit.boot.ests['calibration.score', 'lower'], 3)` - 204 | #' `r round(surv.fit.boot.ests['calibration.score', 'upper'], 3)`)**. 205 | #' 206 | #' Let's draw a representative curve from the unbootstrapped fit... (It would be 207 | #' better to draw all the curves from the bootstrap fit to get an idea of 208 | #' variability, but I've not implemented this yet.) 209 | #' 210 | #+ calibration_plot 211 | 212 | calibration.table <- 213 | calibrationTable(surv.model.fit, imputed.data[[1]][test.set, ]) 214 | 215 | calibration.score <- calibrationScore(calibration.table) 216 | 217 | calibrationPlot(calibration.table) 218 | 219 | #' The area between the calibration curve and the diagonal is 220 | #' **`r round(calibration.score[['area']], 3)`** +/- 221 | #' **`r round(calibration.score[['se']], 3)`**. -------------------------------------------------------------------------------- /lib/rfsrc-cv-nsplit-bootstrap.R: -------------------------------------------------------------------------------- 1 | bootstraps <- 20 2 | split.rule <- 'logrank' 3 | n.threads <- 16 4 | 5 | # Cross-validation variables 6 | ns.splits <- 0:20 7 | ns.trees <- c(500, 1000, 2000) 8 | ns.imputations <- 1:3 9 | cv.n.folds <- 3 10 | n.data <- NA # This is of full dataset...further rows may be excluded in prep 11 | 12 | calibration.filename <- paste0(output.filename.base, '-calibration.csv') 13 | 14 | continuous.vars <- 15 | c( 16 | 'age', 'total_chol_6mo', 'hdl_6mo', 'pulse_6mo', 'crea_6mo', 17 | 'total_wbc_6mo', 'haemoglobin_6mo' 18 | ) 19 | 20 | untransformed.vars <- c('anonpatid', 'surv_time', 'imd_score', 'exclude') 21 | 22 | source('../lib/shared.R') 23 | require(ggrepel) 24 | 25 | # Load the data and convert to data frame to make column-selecting code in 26 | # prepData simpler 27 | COHORT.full <- data.frame(fread(data.filename)) 28 | 29 | # If n.data was specified... 30 | if(!is.na(n.data)){ 31 | # Take a subset n.data in size 32 | COHORT.use <- sample.df(COHORT.full, n.data) 33 | rm(COHORT.full) 34 | } else { 35 | # Use all the data 36 | COHORT.use <- COHORT.full 37 | rm(COHORT.full) 38 | } 39 | 40 | # We now need a quick null preparation of the data to get its length (some rows 41 | # may be excluded during preparation) 42 | COHORT.prep <- 43 | prepData( 44 | COHORT.use, 45 | cols.keep, discretise.settings, surv.time, surv.event, 46 | surv.event.yes, extra.fun = caliberExtraPrep, n.keep = n.data 47 | ) 48 | n.data <- nrow(COHORT.prep) 49 | 50 | # Define indices of test set 51 | test.set <- testSetIndices(COHORT.prep, random.seed = 78361) 52 | 53 | # Process settings: don't touch anything!! 54 | process.settings <- 55 | list( 56 | var = c(untransformed.vars, continuous.vars), 57 | method = rep(NA, length(untransformed.vars) + length(continuous.vars)), 58 | settings = rep(NA, length(untransformed.vars) + length(continuous.vars)) 59 | ) 60 | 61 | # If we've not already done a calibration, then do one 62 | if(!file.exists(calibration.filename)) { 63 | # Create an empty data frame to aggregate stats per fold 64 | cv.performance <- data.frame() 65 | 66 | # Items to cross-validate over 67 | cv.vars <- expand.grid(ns.splits, ns.trees, ns.imputations) 68 | names(cv.vars) <- c('n.splits', 'n.trees', 'n.imputations') 69 | 70 | # prep the data (since we're not cross-validating on data prep this can be 71 | # done before the loop) 72 | 73 | # Prep the data 74 | COHORT.cv <- 75 | prepData( 76 | # Data for cross-validation excludes test set 77 | COHORT.use[-test.set, ], 78 | cols.keep, 79 | process.settings, 80 | surv.time, surv.event, 81 | surv.event.yes, 82 | extra.fun = caliberExtraPrep 83 | ) 84 | 85 | # Finally, add missing flag columns, but leave the missing data intact because 86 | # rfsrc can do on-the-fly imputation 87 | COHORT.cv <- prepCoxMissing(COHORT.cv, missingReplace = NA) 88 | 89 | # Add on those column names we just created 90 | surv.predict <- 91 | c(surv.predict, names(COHORT.cv)[grepl('_missing', names(COHORT.cv))]) 92 | 93 | # Run crossvalidations. No need to parallelise because rfsrc is parallelised 94 | for(i in 1:nrow(cv.vars)) { 95 | cat( 96 | 'Calibration', i, '...\n' 97 | ) 98 | 99 | # Get folds for cross-validation 100 | cv.folds <- cvFolds(nrow(COHORT.cv), cv.n.folds) 101 | 102 | cv.fold.performance <- data.frame() 103 | 104 | for(j in 1:cv.n.folds) { 105 | time.start <- handyTimer() 106 | # Fit model to the training set 107 | surv.model.fit <- 108 | survivalFit( 109 | surv.predict, 110 | COHORT.cv[-cv.folds[[j]],], 111 | model.type = 'rfsrc', 112 | n.trees = cv.vars$n.trees[i], 113 | split.rule = split.rule, 114 | n.threads = n.threads, 115 | nsplit = cv.vars$n.splits[i], 116 | nimpute = cv.vars$n.imputations[i], 117 | na.action = 'na.impute' 118 | ) 119 | time.learn <- handyTimer(time.start) 120 | 121 | time.start <- handyTimer() 122 | # Get C-indices for training and validation sets 123 | c.index.train <- 124 | cIndex( 125 | surv.model.fit, COHORT.cv[-cv.folds[[j]],], 126 | na.action = 'na.impute' 127 | ) 128 | c.index.val <- 129 | cIndex( 130 | surv.model.fit, COHORT.cv[cv.folds[[j]],], 131 | na.action = 'na.impute' 132 | ) 133 | time.predict <- handyTimer(time.start) 134 | 135 | # Append the stats we've obtained from this fold 136 | cv.fold.performance <- 137 | rbind( 138 | cv.fold.performance, 139 | data.frame( 140 | calibration = i, 141 | cv.fold = j, 142 | n.trees = cv.vars$n.trees[i], 143 | n.splits = cv.vars$n.splits[i], 144 | n.imputations = cv.vars$n.imputations[i], 145 | c.index.train, 146 | c.index.val, 147 | time.learn, 148 | time.predict 149 | ) 150 | ) 151 | 152 | } # End cross-validation loop (j) 153 | 154 | 155 | # rbind the performance by fold 156 | cv.performance <- 157 | rbind( 158 | cv.performance, 159 | cv.fold.performance 160 | ) 161 | 162 | # Save output at the end of each loop 163 | write.csv(cv.performance, calibration.filename) 164 | 165 | } # End calibration loop (i) 166 | 167 | 168 | 169 | } else { # If we did previously calibrate, load it 170 | cv.performance <- read.csv(calibration.filename) 171 | } 172 | 173 | 174 | 175 | # Find the best calibration... 176 | # First, average performance across cross-validation folds 177 | cv.performance.average <- 178 | aggregate( 179 | c.index.val ~ calibration, 180 | data = cv.performance, 181 | mean 182 | ) 183 | # Find the highest value 184 | best.calibration <- 185 | cv.performance.average$calibration[ 186 | which.max(cv.performance.average$c.index.val) 187 | ] 188 | # And finally, find the first row of that calibration to get the n.bins values 189 | best.calibration.row1 <- 190 | min(which(cv.performance$calibration == best.calibration)) 191 | 192 | # Prep the data to fit and test with 193 | COHORT.prep <- 194 | prepData( 195 | # Data for cross-validation excludes test set 196 | COHORT.use, 197 | cols.keep, 198 | process.settings, 199 | surv.time, surv.event, 200 | surv.event.yes, 201 | extra.fun = caliberExtraPrep 202 | ) 203 | 204 | # Finally, add missing flag columns, but leave the missing data intact because 205 | # rfsrc can do on-the-fly imputation 206 | COHORT.prep <- prepCoxMissing(COHORT.prep, missingReplace = NA) 207 | 208 | # Add on those column names we just created 209 | surv.predict <- 210 | c(surv.predict, names(COHORT.prep)[grepl('_missing', names(COHORT.prep))]) 211 | 212 | #' ## Fit the final model 213 | #' 214 | #' This may take some time, so we'll cache it if possible... 215 | 216 | #+ fit_final_model 217 | 218 | surv.model.fit <- 219 | survivalFit( 220 | surv.predict, 221 | COHORT.prep[-test.set,], 222 | model.type = 'rfsrc', 223 | n.trees = cv.performance[best.calibration.row1, 'n.trees'], 224 | split.rule = split.rule, 225 | n.threads = n.threads, 226 | nsplit = cv.performance[best.calibration.row1, 'n.splits'], 227 | nimpute = cv.performance[best.calibration.row1, 'n.imputations'], 228 | na.action = 'na.impute' 229 | ) 230 | 231 | surv.model.params.boot <- 232 | survivalFitBoot( 233 | surv.predict, 234 | COHORT.prep[-test.set,], # Training set 235 | COHORT.prep[test.set,], # Test set 236 | model.type = 'rfsrc', 237 | n.threads = n.threads, 238 | bootstraps = bootstraps, 239 | n.trees = cv.performance[best.calibration.row1, 'n.trees'], 240 | split.rule = split.rule, 241 | nsplit = cv.performance[best.calibration.row1, 'n.splits'], 242 | nimpute = cv.performance[best.calibration.row1, 'n.imputations'], 243 | na.action = 'na.impute', 244 | filename = paste0(output.filename.base, '-boot-all.csv') 245 | ) 246 | 247 | # Get C-indices for training and test sets 248 | surv.model.fit.coeffs <- bootStatsDf(surv.model.params.boot) 249 | 250 | # Save them to the all-models comparison table 251 | varsToTable( 252 | data.frame( 253 | model = 'rfsrc', 254 | imputation = FALSE, 255 | discretised = FALSE, 256 | c.index = surv.model.fit.coeffs['c.index', 'val'], 257 | c.index.lower = surv.model.fit.coeffs['c.index', 'lower'], 258 | c.index.upper = surv.model.fit.coeffs['c.index', 'upper'], 259 | calibration.score = surv.model.fit.coeffs['calibration.score', 'val'], 260 | calibration.score.lower = 261 | surv.model.fit.coeffs['calibration.score', 'lower'], 262 | calibration.score.upper = 263 | surv.model.fit.coeffs['calibration.score', 'upper'] 264 | ), 265 | performance.file, 266 | index.cols = c('model', 'imputation', 'discretised') 267 | ) 268 | 269 | write.csv( 270 | surv.model.fit.coeffs, 271 | paste0(output.filename.base, '-boot-summary.csv') 272 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Machine learning on electronic health records 2 | 3 | Repository of code developed for *Machine learning models in electronic health records can outperform conventional survival models for predicting patient mortality in coronary artery disease*. 4 | 5 | ## Introduction 6 | 7 | The R scripts provided in this repository were developed to perform survival modelling on 100,000 patients’ electronic health records. 8 | 9 | ## Usage 10 | 11 | To use these scripts, download them into a subfolder ``code`` of a folder which also contains subfolders ``data`` (which should contain input data) and ``output`` (where output files are placed by the scripts). 12 | 13 | Most scripts [can use ``spin`` to generate reports](http://deanattali.com/2015/03/24/knitrs-best-hidden-gem-spin/), and the simplest way to do this is to execute them using the ``spinme.R`` script provided. For example, to run the discretised Cox model script and generate a report, navigate to the ``code/cox-ph`` folder of your project, and run 14 | 15 | ``` 16 | Rscript ../scripts/spinme.R cox-discretised.R 17 | ``` 18 | 19 | This will generate a report ``cox-discretised.html`` in the working directory, as well as additional files in the ``output`` folder. 20 | 21 | ## Project inventory 22 | 23 | This is a brief description of the important files in this project, broken down by folder. 24 | 25 | ### age-only 26 | 27 | #### age-only.R 28 | 29 | This calculates the C-index and calibration score for a model where age is the only variable. For the C-index, it is assumed that the older patient will die first, and for calibration the Kaplan–Meier estimator for patients of a given age is assumed to be the risk estimate for all patients of that age. 30 | 31 | ### cox-ph 32 | 33 | Various Cox proportional hazards models. Those prefixed ``caliber-replicate-`` are based on the Cox model developed in [Rapsomaniki et al. 2014](https://academic.oup.com/eurheartj/article-lookup/doi/10.1093/eurheartj/eht533) (DOI: [10.1093/eurheartj/eht533](https://dx.doi.org/10.1093/eurheartj/eht533)). 34 | 35 | #### caliber-replicate-with-imputation.R 36 | 37 | This model is as close to identical to Rapsomaniki et al. as possible, using a five-fold multiply-imputed dataset, with continuous variables scaled as in that paper. 38 | 39 | #### caliber-replicate-with-missing.R 40 | 41 | This model uses the same scaling as Rapsomaniki et al., but is conducted on a single dataset with missing values represented by missing indicator variables rather than imputed. 42 | 43 | #### caliber-scale.R 44 | 45 | Functions to scale data for Cox modelling. 46 | 47 | #### cox-discrete-elasticnet.R 48 | 49 | Discrete elastic net Cox model for the data-driven modelling, which cross-validates to find the optimal _α_ and then bootstraps to establish distributions for the other fitted parameters. 50 | 51 | #### cox-discrete-varsellogrank.R 52 | 53 | Discrete Cox model for the data-driven modelling, which cross-validates over number of variables used, drawing from a list ranked by univariate logrank tests. 54 | 55 | #### cox-discretised.R 56 | 57 | This model uses the expert-selected dataset with discretised versions of continuous variables to allow missing values to be incorporated, and cross-validates to determine the discretisation scheme. 58 | 59 | #### cox-discretised-imputed.R 60 | 61 | This model uses the imputed version of the expert-selected dataset with discretised versions of continuous variables, following the same method as above. 62 | 63 | #### rapsomaniki-cox-values-from-paper.csv 64 | 65 | Values for Cox coefficients transcribed from Rapsomaniki et al., used to check for consistency between that model and these. 66 | 67 | ### lib 68 | 69 | Shared libraries for functions and common routines. 70 | 71 | #### all-cv-bootstrap.R 72 | 73 | Script to cross-validate discretisation schemes, followed by bootstrapping the selected optimal model. Works for both Cox modelling and random forests with either ``randomForestSRC`` or ``ranger``. Discretised random forests were not used in the final analysis, as there was no appreciable performance gain. 74 | 75 | #### handy.R 76 | 77 | Shortcuts and wrappers, from [Andrew’s](https://github.com/ajsteele/) handy [handy.R](https://github.com/ajsteele/handy.R) script. 78 | 79 | #### handymedical.R 80 | 81 | Useful functions and wrappers for preparing and manipulating data, and making use of Cox models and random forests with either ``randomForestSRC`` or ``ranger``, including bootstrapping, as transparent and consistent as possible. These functions are hopefully of general use for other survival modelling projects; dataset-specific functions are defined in ``shared.R``. 82 | 83 | #### rfsrc-cv-mtry-nsplit-logical.R 84 | 85 | Script to cross-validate ``randomForestSRC`` survival forests using the large dataset, optimising the ``mtry`` and ``nsplit`` hyperparameters. 86 | 87 | #### rfsrc-cv-nsplit-bootstrap.R 88 | 89 | Script to cross-validate ``randomForestSRC`` survival forests using the expert-selected dataset, optimising ``nsplit``. 90 | 91 | #### shared.R 92 | 93 | This script is run at the start of most model scripts, and defines a random seed, plus variables and functions which will be useful. The data-parsing functions here are specific to the scheme of this particular dataset and so were excluded from ``handymedical.R``. 94 | 95 | ### overview 96 | 97 | Various scripts for exploring the dataset and retrieving and plotting results for publication. 98 | 99 | #### all-models.R 100 | 101 | Produces a graph of the C-index and calibration score from all models. The basis of Fig. 1 in the paper. 102 | 103 | #### bigdata-mtry-nsplit.R 104 | 105 | Plots a line graph showing C-index performance of random forests depending on ``mtry`` and ``nsplit`` in the large dataset. 106 | 107 | #### calibration-plots.R 108 | 109 | Plots two example calibration curves to show how the calibration score is calculated. The basis of Fig. 2 in the paper. 110 | 111 | #### cohort-tables.R 112 | 113 | Prints a number of summary statistics used for Table 2 in the paper. 114 | 115 | #### explore-dataset.R 116 | 117 | A number of quick exploratory graphs and comparisons to explore the expert-selected dataset, with a particular focus on degrees and distribution of missing data. 118 | 119 | #### missing-values-risk.R 120 | 121 | Compares coefficients for Cox models. First, continuous imputed vs continuous with missing indicators and discrete; second, ranges of continuous values’ associated risks with those associated with a value being missing; finally, survival curves for patients with a particular value missing vs present. The basis of Fig. 3 in the paper. 122 | 123 | #### performance-differences.R 124 | 125 | Pairwise differences with uncertainty in C-index and calibration between all models tested, ascertained by finding the distribution of differences between bootstrap replicates for each model-pair. 126 | 127 | #### variable-effects.R 128 | 129 | Plots of variable effects for continuous and discrete Cox models, and random forests. The basis of Fig. 4 in the paper. 130 | 131 | #### variable-importances.R 132 | 133 | Plots permutation variable importances calculated for the final data-driven models, post variable selection. The basis of Fig. 5 in the paper. 134 | 135 | ### random-forest 136 | 137 | #### rf-age.R 138 | 139 | Building a random forest with fewer variables (including just age) to experiment with predictive power. 140 | 141 | #### rf-classification.R 142 | 143 | Classification forest for death at 5 years, in an attempt to improve calibration score of the resulting model. 144 | 145 | #### rf-imputed.R 146 | 147 | Random forest on the imputed dataset as an empirical test of whether imputation provides an advantage. 148 | 149 | #### rfsrc-cv.R 150 | 151 | Random forest on the expert-selected dataset, which uses ``rfsrc-cv-nsplit-bootstrap.R`` from ``lib`` (see above) to fit its forest. 152 | 153 | #### rf-varsellogrank.R 154 | 155 | Random forest for the data-driven modelling, which cross-validates over number of variables used, drawing from a list ranked by univariate logrank tests. 156 | 157 | #### rf-varselmiss.R 158 | 159 | Random forest for the data-driven modelling, which cross-validates over number of variables used, drawing from a list ranked by decreasing missingness. 160 | 161 | #### rf-varselrf-eqv.R 162 | 163 | Random forest for the data-driven modelling, which cross-validates over number of variables used, drawing from a list ranked by the variable importance of a large random forest fitted to all the data. Modelled after [varSelRF](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-7-3). 164 | 165 | ### scripts 166 | 167 | This folder is for a few miscellaneous short scripts. 168 | 169 | #### export-bigdata.R 170 | 171 | This script was used to export anonymised data for the data-driven modelling with ~600 variables. Its file paths are not correct because they are localised for the secure environment where the raw data are stored. 172 | 173 | #### spinme.R 174 | 175 | This wrapper makes it easy to spin a script into an HTML report from the command line (see the example command at the top of this readme). 176 | 177 | ## Notes 178 | 179 | This repository has been tidied up so that only scripts relevant to the final publication are preserved. Various initial and exploratory analysis scripts have been removed for clarity. If for any reason these are of interest, they are present in commit 08934808c497a0f094c71a731cb9cb2564e4cc0f, the final commit before the tidy-up began. -------------------------------------------------------------------------------- /lib/all-cv-bootstrap.R: -------------------------------------------------------------------------------- 1 | # All model types are bootstrapped this many times 2 | bootstraps <- 200 3 | # n.trees is (obviously) only relevant for random forests 4 | n.trees <- 500 5 | # The following two variables are only relevant if the model.type is 'ranger' 6 | split.rule <- 'logrank' 7 | n.threads <- 10 8 | 9 | # Cross-validation variables 10 | input.n.bins <- 10:20 11 | cv.n.folds <- 3 12 | n.calibrations <- 1000 13 | n.data <- NA # This is of full dataset...further rows may be excluded in prep 14 | 15 | continuous.vars <- 16 | c( 17 | 'age', 'total_chol_6mo', 'hdl_6mo', 'pulse_6mo', 'crea_6mo', 18 | 'total_wbc_6mo', 'haemoglobin_6mo' 19 | ) 20 | 21 | source('shared.R') 22 | require(ggrepel) 23 | 24 | # Load the data and convert to data frame to make column-selecting code in 25 | # prepData simpler 26 | COHORT.full <- data.frame(fread(data.filename)) 27 | 28 | # If n.data was specified... 29 | if(!is.na(n.data)){ 30 | # Take a subset n.data in size 31 | COHORT.use <- sample.df(COHORT.full, n.data) 32 | rm(COHORT.full) 33 | } else { 34 | # Use all the data 35 | COHORT.use <- COHORT.full 36 | rm(COHORT.full) 37 | } 38 | 39 | # We now need a quick null preparation of the data to get its length (some rows 40 | # may be excluded during preparation) 41 | COHORT.prep <- 42 | prepData( 43 | COHORT.use, 44 | cols.keep, discretise.settings, surv.time, surv.event, 45 | surv.event.yes, extra.fun = caliberExtraPrep, n.keep = n.data 46 | ) 47 | n.data <- nrow(COHORT.prep) 48 | 49 | # Define indices of test set 50 | test.set <- testSetIndices(COHORT.prep, random.seed = 78361) 51 | 52 | # If we've not already done a calibration, then do one 53 | if(!file.exists(calibration.filename)) { 54 | # Create an empty data frame to aggregate stats per fold 55 | cv.performance <- data.frame() 56 | 57 | # We can parallelise this bit with foreach, so set that up 58 | initParallel(n.threads) 59 | 60 | # Run crossvalidations in parallel 61 | cv.performance <- 62 | foreach(i = 1:n.calibrations, .combine = 'rbind') %dopar% { 63 | cat( 64 | 'Calibration', i, '...\n' 65 | ) 66 | 67 | # Reset process settings with the base setings 68 | process.settings <- 69 | list( 70 | var = c('anonpatid', 'time_death', 'imd_score', 'exclude'), 71 | method = c(NA, NA, NA, NA), 72 | settings = list(NA, NA, NA, NA) 73 | ) 74 | # Generate some random numbers of bins (and for n bins, you need n + 1 breaks) 75 | n.bins <- sample(input.n.bins, length(continuous.vars), replace = TRUE) + 1 76 | names(n.bins) <- continuous.vars 77 | # Go through each variable setting it to bin by quantile with a random number of bins 78 | for(j in 1:length(continuous.vars)) { 79 | process.settings$var <- c(process.settings$var, continuous.vars[j]) 80 | process.settings$method <- c(process.settings$method, 'binByQuantile') 81 | process.settings$settings <- 82 | c( 83 | process.settings$settings, 84 | list( 85 | seq( 86 | # Quantiles are obviously between 0 and 1 87 | 0, 1, 88 | # Choose a random number of bins (and for n bins, you need n + 1 breaks) 89 | length.out = n.bins[j] 90 | ) 91 | ) 92 | ) 93 | } 94 | 95 | # prep the data given the variables provided 96 | COHORT.cv <- 97 | prepData( 98 | # Data for cross-validation excludes test set 99 | COHORT.use[-test.set, ], 100 | cols.keep, 101 | process.settings, 102 | surv.time, surv.event, 103 | surv.event.yes, 104 | extra.fun = caliberExtraPrep 105 | ) 106 | 107 | # Get folds for cross-validation 108 | cv.folds <- cvFolds(nrow(COHORT.cv), cv.n.folds) 109 | 110 | cv.fold.performance <- data.frame() 111 | 112 | for(j in 1:cv.n.folds) { 113 | time.start <- handyTimer() 114 | # Fit model to the training set 115 | surv.model.fit <- 116 | survivalFit( 117 | surv.predict, 118 | COHORT.cv[-cv.folds[[j]],], 119 | model.type = model.type, 120 | n.trees = n.trees, 121 | split.rule = split.rule 122 | # n.threads not used because this is run in parallel 123 | ) 124 | time.learn <- handyTimer(time.start) 125 | 126 | time.start <- handyTimer() 127 | # Get C-indices for training and validation sets 128 | c.index.train <- 129 | cIndex( 130 | surv.model.fit, COHORT.cv[-cv.folds[[j]],], model.type = model.type 131 | ) 132 | c.index.val <- 133 | cIndex( 134 | surv.model.fit, COHORT.cv[cv.folds[[j]],], model.type = model.type 135 | ) 136 | time.predict <- handyTimer(time.start) 137 | 138 | # Append the stats we've obtained from this fold 139 | cv.fold.performance <- 140 | rbind( 141 | cv.fold.performance, 142 | data.frame( 143 | calibration = i, 144 | cv.fold = j, 145 | as.list(n.bins), 146 | c.index.train, 147 | c.index.val, 148 | time.learn, 149 | time.predict 150 | ) 151 | ) 152 | 153 | } # End cross-validation loop (j) 154 | 155 | # rbind the performance by fold 156 | cv.fold.performance 157 | } # End calibration loop (i) 158 | 159 | # Save output at end of calibration 160 | write.csv(cv.performance, calibration.filename) 161 | 162 | } else { # If we did previously calibrate, load it 163 | cv.performance <- read.csv(calibration.filename) 164 | } 165 | 166 | # Find the best calibration... 167 | # First, average performance across cross-validation folds 168 | cv.performance.average <- 169 | aggregate( 170 | c.index.val ~ calibration, 171 | data = cv.performance, 172 | mean 173 | ) 174 | # Find the highest value 175 | best.calibration <- 176 | cv.performance.average$calibration[ 177 | which.max(cv.performance.average$c.index.val) 178 | ] 179 | # And finally, find the first row of that calibration to get the n.bins values 180 | best.calibration.row1 <- 181 | min(which(cv.performance$calibration == best.calibration)) 182 | 183 | # Get its parameters 184 | n.bins <- 185 | t( 186 | cv.performance[best.calibration.row1, continuous.vars] 187 | ) 188 | 189 | # Prepare the data with those settings... 190 | 191 | # Reset process settings with the base setings 192 | process.settings <- 193 | list( 194 | var = c('anonpatid', 'time_death', 'imd_score', 'exclude'), 195 | method = c(NA, NA, NA, NA), 196 | settings = list(NA, NA, NA, NA) 197 | ) 198 | for(j in 1:length(continuous.vars)) { 199 | process.settings$var <- c(process.settings$var, continuous.vars[j]) 200 | process.settings$method <- c(process.settings$method, 'binByQuantile') 201 | process.settings$settings <- 202 | c( 203 | process.settings$settings, 204 | list( 205 | seq( 206 | # Quantiles are obviously between 0 and 1 207 | 0, 1, 208 | # Choose a random number of bins (and for n bins, you need n + 1 breaks) 209 | length.out = n.bins[j] 210 | ) 211 | ) 212 | ) 213 | } 214 | 215 | # prep the data given the variables provided 216 | COHORT.optimised <- 217 | prepData( 218 | # Data for cross-validation excludes test set 219 | COHORT.use, 220 | cols.keep, 221 | process.settings, 222 | surv.time, surv.event, 223 | surv.event.yes, 224 | extra.fun = caliberExtraPrep 225 | ) 226 | 227 | #' ## Fit the final model 228 | #' 229 | #' This may take some time, so we'll cache it if possible... 230 | 231 | #+ fit_final_model 232 | 233 | # Fit to whole training set 234 | surv.model.fit <- 235 | survivalFit( 236 | surv.predict, 237 | COHORT.optimised[-test.set,], # Training set 238 | model.type = model.type, 239 | n.trees = n.trees, 240 | split.rule = split.rule, 241 | n.threads = n.threads 242 | ) 243 | 244 | cl <- initParallel(n.threads, backend = 'doParallel') 245 | 246 | surv.model.params.boot <- 247 | foreach( 248 | i = 1:bootstraps, 249 | .combine = rbind, 250 | .packages = c('survival'), 251 | .verbose = TRUE 252 | ) %dopar% { 253 | 254 | # Bootstrap-sampled training set 255 | COHORT.boot <- 256 | sample.df( 257 | COHORT.optimised[-test.set,], 258 | nrow(COHORT.optimised[-test.set,]), 259 | replace = TRUE 260 | ) 261 | 262 | surv.model.fit.i <- 263 | survivalFit( 264 | surv.predict, 265 | COHORT.boot, 266 | model.type = model.type, 267 | n.trees = n.trees, 268 | split.rule = split.rule, 269 | # 1 thread, because we're parallelising the bootstrapping 270 | n.threads = 1 271 | ) 272 | 273 | # Work out other quantities of interest 274 | #var.imp.vector <- bootstrapVarImp(surv.model.fit.i, COHORT.boot) 275 | c.index <- cIndex(surv.model.fit.i, COHORT.optimised[test.set, ]) 276 | calibration.score <- 277 | calibrationScoreWrapper(surv.model.fit.i, COHORT.optimised[test.set, ]) 278 | 279 | data.frame( 280 | i, 281 | t(coef(surv.model.fit.i)), 282 | #t(var.imp.vector), 283 | c.index, 284 | calibration.score 285 | ) 286 | } 287 | 288 | # Save the fit object 289 | write.csv(surv.model.params.boot, paste0(output.filename.base, '-surv-boot.csv')) 290 | 291 | # Tidy up by removing the cluster 292 | stopCluster(cl) 293 | 294 | surv.model.fit.coeffs <- bootStatsDf(surv.model.params.boot) 295 | -------------------------------------------------------------------------------- /overview/variable-importances.R: -------------------------------------------------------------------------------- 1 | source('../lib/handymedical.R', chdir = TRUE) 2 | requirePlus('cowplot') 3 | 4 | 5 | # Load the variable importances from the Cox model 6 | cox.miss <- 7 | readRDS('../../output/caliber-replicate-with-missing-survreg-6-linear-age-surv-boot.rds') 8 | cox.miss.vars <- bootStats(cox.miss, uncertainty = '95ci') 9 | cox.miss.vars$var <- rownames(cox.miss.vars) 10 | cox.miss.vars <- subset(cox.miss.vars, startsWith(var, 'vimp.c.index')) 11 | cox.miss.vars$var <- 12 | substring(cox.miss.vars$var, nchar('vimp.c.index.') + 1) 13 | 14 | # Load the variable importances from the random forest 15 | rf.boot <- read.csv('../../output/rfsrc-cv-nsplit-try3-boot-all.csv') 16 | rf.boot.vars <- bootStatsDf(rf.boot) 17 | rf.boot.vars$var <- rownames(rf.boot.vars) 18 | rf.boot.vars <- subset(rf.boot.vars, startsWith(var, 'vimp.c.index')) 19 | rf.boot.vars$var <- 20 | substring(rf.boot.vars$var, nchar('vimp.c.index.') + 1) 21 | 22 | var.imp.compare <- merge(cox.miss.vars, rf.boot.vars, by = c('var')) 23 | 24 | 25 | # Plot a scatterplot of them 26 | rf.vs.cox <- 27 | ggplot( 28 | var.imp.compare, 29 | aes( 30 | x = val.x, xmin = lower.x, xmax = upper.x, 31 | y = val.y, ymin = lower.y, ymax = upper.y 32 | ) 33 | ) + 34 | geom_point() + 35 | geom_errorbar() + 36 | geom_errorbarh() + 37 | coord_cartesian(xlim = c(0, 0.03), ylim = c(0, 0.03)) 38 | 39 | print('Spearman correlation coefficient of variable importances:') 40 | print(cor(var.imp.compare$val.x, var.imp.compare$val.y, method = 'spearman')) 41 | 42 | # Load the variable importances from the big data model 43 | cox.bigdata <- read.csv('../../output/cox-bigdata-varsellogrank-01-boot-all.csv') 44 | cox.bigdata.vars <- bootStatsDf(cox.bigdata) 45 | cox.bigdata.vars$var <- rownames(cox.bigdata.vars) 46 | cox.bigdata.vars <- subset(cox.bigdata.vars, startsWith(var, 'vimp.c.index')) 47 | cox.bigdata.vars$var <- 48 | substring(cox.bigdata.vars$var, nchar('vimp.c.index.') + 1) 49 | 50 | cox.bigdata.vars <- 51 | cox.bigdata.vars[order(cox.bigdata.vars$val, decreasing = TRUE)[1:20], ] 52 | 53 | cox.bigdata.vars <- 54 | cox.bigdata.vars[order(cox.bigdata.vars$val, decreasing = FALSE), ] 55 | 56 | cox.bigdata.vars$description <- lookUpDescriptions(cox.bigdata.vars$var) 57 | 58 | cat('c(', paste0("'", as.character(cox.bigdata.vars$description), "',"), ')', sep = '\n') 59 | 60 | cox.bigdata.vars$description.manual <- 61 | factorOrderedLevels( 62 | c( 63 | 'ALT', 64 | 'PVD', 65 | 'Hb', 66 | 'Dementia', 67 | 'Albumin', 68 | 'Cardiac glycosides', 69 | 'LV failure', 70 | 'Home visit', 71 | 'Oestrogens/HRT', 72 | 'Chest pain', 73 | 'Na', 74 | 'WCC', 75 | 'ALP', 76 | 'Lymphocyte count', 77 | 'Diabetes', 78 | 'BMI ', 79 | 'Weight', 80 | 'Loop diuretics', 81 | 'Smoking status', 82 | 'Age' 83 | ) 84 | ) 85 | 86 | # Plot a bar graph of them 87 | cox.bigdata.plot <- 88 | ggplot( 89 | cox.bigdata.vars, 90 | aes(x = description.manual, y = val, ymin = lower, ymax = upper) 91 | ) + 92 | geom_bar(stat = 'identity') + 93 | geom_errorbar(width = 0.25) + 94 | coord_flip() + 95 | theme(axis.title.y = element_blank(), axis.text.y = element_text(size = 10)) + 96 | ylim(0, 0.17) 97 | 98 | # Random forest big data 99 | rf.bigdata <- read.csv('../../output/rf-bigdata-varsellogrank-02-boot-all.csv') 100 | rf.bigdata.vars <- bootStatsDf(rf.bigdata) 101 | rf.bigdata.vars$var <- rownames(rf.bigdata.vars) 102 | rf.bigdata.vars <- subset(rf.bigdata.vars, startsWith(var, 'vimp.c.index')) 103 | rf.bigdata.vars$var <- 104 | substring(rf.bigdata.vars$var, nchar('vimp.c.index.') + 1) 105 | 106 | rf.bigdata.vars <- 107 | rf.bigdata.vars[order(rf.bigdata.vars$val, decreasing = TRUE)[1:20], ] 108 | 109 | rf.bigdata.vars <- 110 | rf.bigdata.vars[order(rf.bigdata.vars$val, decreasing = FALSE), ] 111 | 112 | cat('c(', paste0("'", as.character(rf.bigdata.vars$description), "',"), ')', sep = '\n') 113 | 114 | rf.bigdata.vars$description <- lookUpDescriptions(rf.bigdata.vars$var) 115 | 116 | rf.bigdata.vars$description.manual <- 117 | factorOrderedLevels( 118 | c( 119 | 'Fit note', 120 | 'Stimulant laxatives', 121 | 'Urea', 122 | 'Hypertension', 123 | 'Cardiac glycosides', 124 | 'Beta2 agonists', 125 | 'Telephone encounter', 126 | 'Feet examination', 127 | 'Creatinine', 128 | 'Screening', 129 | 'Osmotic laxatives', 130 | 'ACE inhibitors', 131 | 'Beta blockers', 132 | 'Home visit', 133 | 'Analgesics', 134 | 'Blood pressure', 135 | 'Chest pain', 136 | 'Loop diuretics', 137 | 'Smoking status', 138 | 'Age' 139 | ) 140 | ) 141 | 142 | # Plot a bar graph of them 143 | rf.bigdata.plot <- 144 | ggplot( 145 | rf.bigdata.vars, 146 | aes(x = description.manual, y = val, ymin = lower, ymax = upper) 147 | ) + 148 | geom_bar(stat = 'identity') + 149 | geom_errorbar(width = 0.25) + 150 | coord_flip() + 151 | theme(axis.title.y = element_blank(), axis.text.y = element_text(size = 10)) + 152 | ylim(0, 0.17) 153 | 154 | 155 | 156 | # Elastic net Cox model 157 | elastic.bigdata <- read.csv('../../output/cox-discrete-elasticnet-01-boot-all.csv') 158 | elastic.bigdata.vars <- bootStatsDf(elastic.bigdata) 159 | elastic.bigdata.vars$var <- rownames(elastic.bigdata.vars) 160 | elastic.bigdata.vars <- subset(elastic.bigdata.vars, startsWith(var, 'vimp.c.index')) 161 | elastic.bigdata.vars$var <- 162 | substring(elastic.bigdata.vars$var, nchar('vimp.c.index.') + 1) 163 | 164 | elastic.bigdata.vars <- 165 | elastic.bigdata.vars[order(elastic.bigdata.vars$val, decreasing = TRUE)[1:20], ] 166 | 167 | elastic.bigdata.vars <- 168 | elastic.bigdata.vars[order(elastic.bigdata.vars$val, decreasing = FALSE), ] 169 | 170 | cat('c(', paste0("'", as.character(elastic.bigdata.vars$description), "',"), ')', sep = '\n') 171 | 172 | elastic.bigdata.vars$description <- factorOrderedLevels(lookUpDescriptions(elastic.bigdata.vars$var)) 173 | 174 | elastic.bigdata.vars$description.manual <- 175 | factorOrderedLevels( 176 | c( 177 | 'Biguanides', 178 | 'CKD', 179 | 'Osmotic laxatives', 180 | 'MCV', 181 | 'IMD score', 182 | 'Dementia', 183 | 'Home visit', 184 | 'Sulphonylureas', 185 | 'Insulin', 186 | 'LV failure', 187 | 'Cardiac glycosides', 188 | 'Telephone encounter', 189 | 'Records held date', 190 | 'Chest pain', 191 | 'Diabetes', 192 | 'Smoking status', 193 | 'Loop diuretics', 194 | 'Gender', 195 | 'Type 2 diabetes', 196 | 'Age' 197 | ) 198 | ) 199 | 200 | # Plot a bar graph of them 201 | elastic.bigdata.plot <- 202 | ggplot( 203 | elastic.bigdata.vars, 204 | aes(x = description.manual, y = val, ymin = lower, ymax = upper) 205 | ) + 206 | geom_bar(stat = 'identity') + 207 | geom_errorbar(width = 0.25) + 208 | coord_flip() + 209 | theme(axis.title.y = element_blank(), axis.text.y = element_text(size = 10)) + 210 | ylim(0, 0.17) 211 | 212 | 213 | # Combine for output 214 | plot_grid( 215 | rf.bigdata.plot, cox.bigdata.plot, elastic.bigdata.plot, 216 | labels = c('A', 'B', 'C'), 217 | ncol = 3 218 | ) 219 | 220 | ggsave( 221 | '../../output/variable-importances.pdf', 222 | width = 24, 223 | height = 8, 224 | units = 'cm', 225 | useDingbats = FALSE 226 | ) 227 | 228 | 229 | # Print a helpful list of overlapping predictors 230 | print('RF vs Cox') 231 | print( 232 | paste( 233 | rf.bigdata.vars$description.manual[ 234 | rf.bigdata.vars$description.manual %in% 235 | cox.bigdata.vars$description.manual 236 | ], collapse = ', ') 237 | ) 238 | print('Cox vs elastic net') 239 | print( 240 | paste( 241 | elastic.bigdata.vars$description.manual[ 242 | elastic.bigdata.vars$description.manual %in% 243 | cox.bigdata.vars$description.manual 244 | ], collapse = ', ') 245 | ) 246 | print('RF vs elastic net not Cox') 247 | print( 248 | paste( 249 | elastic.bigdata.vars$description.manual[ 250 | elastic.bigdata.vars$description.manual %in% 251 | rf.bigdata.vars$description.manual & !( 252 | elastic.bigdata.vars$description.manual %in% 253 | cox.bigdata.vars$description.manual) 254 | ], collapse = ', ') 255 | ) 256 | print('Cox vs elastic net not RF') 257 | print( 258 | paste( 259 | elastic.bigdata.vars$description.manual[ 260 | elastic.bigdata.vars$description.manual %in% 261 | cox.bigdata.vars$description.manual & !( 262 | elastic.bigdata.vars$description.manual %in% 263 | rf.bigdata.vars$description.manual) 264 | ], collapse = ', ') 265 | ) 266 | # Should be none or the graph will be challenging to draw! 267 | 268 | # Aborted idea to draw a rank-change chart which is too messy to be useful... 269 | elastic.bigdata.vars$model <- 'enet' 270 | rf.bigdata.vars$model <- 'rf' 271 | cox.bigdata.vars$model <- 'cox' 272 | 273 | all.models <- rbind(cox.bigdata.vars, rf.bigdata.vars, elastic.bigdata.vars) 274 | 275 | all.models$model <- factor(all.models$model, levels = c('rf', 'cox', 'enet')) 276 | 277 | ggplot( 278 | all.models, 279 | aes( 280 | x = model, y = log(val), ymin = log(lower), ymax = log(upper), 281 | label = description.manual, group = description.manual 282 | )) + 283 | geom_text(position = position_dodge(width = 0.7)) + 284 | geom_errorbar(width = 0.1, position = position_dodge(width = 0.7)) + 285 | geom_point(position = position_dodge(width = 0.7)) 286 | 287 | -------------------------------------------------------------------------------- /random-forest/rf-age.R: -------------------------------------------------------------------------------- 1 | #+ knitr_setup, include = FALSE 2 | 3 | # Whether to cache the intensive code sections. Set to FALSE to recalculate 4 | # everything afresh. 5 | cacheoption <- TRUE 6 | # Disable lazy caching globally, because it fails for large objects, and all the 7 | # objects we wish to cache are large... 8 | opts_chunk$set(cache.lazy = FALSE) 9 | 10 | #' # Modelling with age 11 | #' 12 | #' It seems that, no matter what I do, the C-index of a model, random forest or 13 | #' otherwise, is about 0.78. I decided to try to some simpler models, intially 14 | #' based purely on age which is clearly the biggest factor in this dataset. 15 | #' (And, I assume, most datasets where a reasonably broad range of ages is 16 | #' present.) 17 | #' 18 | #' The sanity check works: giving the model more data does indeed result in a 19 | #' better fit. However, on top of that, I was surprised by just how good the 20 | #' performance can be when age alone is considered! 21 | 22 | #+ user_variables, message=FALSE 23 | 24 | data.filename <- '../../data/cohort-sanitised.csv' 25 | 26 | n.trees <- 500 27 | 28 | continuous.vars <- 29 | c( 30 | 'age', 'total_chol_6mo', 'hdl_6mo', 'pulse_6mo', 'crea_6mo', 31 | 'total_wbc_6mo', 'haemoglobin_6mo' 32 | ) 33 | untransformed.vars <- c('anonpatid', 'time_death', 'imd_score', 'exclude') 34 | 35 | source('../lib/shared.R') 36 | require(ggrepel) 37 | 38 | #' ## Load and prepare data 39 | #+ load_and_prepare_data 40 | 41 | # Load the data and convert to data frame to make column-selecting code in 42 | # prepData simpler 43 | COHORT.full <- data.frame(fread(data.filename)) 44 | 45 | # Define process settings; nothing for those to not transform, and missingToBig 46 | # for the continuous ones... 47 | process.settings <- 48 | list( 49 | var = c(untransformed.vars, continuous.vars), 50 | method = 51 | c( 52 | rep(NA, length(untransformed.vars)), 53 | rep('missingToBig', length(continuous.vars)) 54 | ), 55 | settings = rep(NA, length(untransformed.vars) + length(continuous.vars)) 56 | ) 57 | 58 | COHORT.prep <- 59 | prepData( 60 | # Data for cross-validation excludes test set 61 | COHORT.full, 62 | cols.keep, 63 | process.settings, 64 | surv.time, surv.event, 65 | surv.event.yes, 66 | extra.fun = caliberExtraPrep 67 | ) 68 | n.data <- nrow(COHORT.prep) 69 | 70 | # Define indices of test set 71 | test.set <- sample(1:n.data, (1/3)*n.data) 72 | 73 | #' ## Models 74 | #' 75 | #' ### The normal model 76 | #' 77 | #' All the variables, as in the vector `surv.predict`. 78 | #+ normal_model, cache=cacheoption 79 | 80 | # Fit random forest 81 | surv.model.fit <- 82 | survivalFit( 83 | surv.predict, 84 | COHORT.prep[-test.set,], 85 | model.type = 'rfsrc', 86 | n.trees = n.trees, 87 | split.rule = 'logrank', 88 | n.threads = 7, 89 | nsplit = 20 90 | ) 91 | 92 | print(surv.model.fit) 93 | 94 | # Get C-index 95 | c.index.test <- 96 | cIndex(surv.model.fit, COHORT.prep[test.set, ], model.type = 'ranger') 97 | 98 | #' The C-index on the held-out test set is **`r round(c.index.test, 3)`**. 99 | 100 | 101 | #' ### Just age 102 | #' 103 | #' What if all we had to go on was age? 104 | #+ just_age_model, cache=cacheoption 105 | 106 | # Fit random forest 107 | surv.model.fit <- 108 | survivalFit( 109 | c('age'), 110 | COHORT.prep[-test.set,], 111 | model.type = 'ranger', 112 | n.trees = n.trees, 113 | split.rule = 'logrank', 114 | n.threads = 8, 115 | respect.unordered.factors = 'partition' 116 | ) 117 | 118 | print(surv.model.fit) 119 | 120 | # Get C-index 121 | c.index.test <- 122 | cIndex(surv.model.fit, COHORT.prep[test.set, ], model.type = 'ranger') 123 | 124 | #' The C-index on the held-out test set is **`r round(c.index.test, 3)`**. 125 | 126 | #' ### No model, literally just age 127 | #' 128 | #' What if we constructed the C-index based purely on patients' ages? 129 | #+ just_age_cindex 130 | 131 | c.index.age <- 132 | as.numeric( 133 | survConcordance( 134 | Surv(time_death, surv_event) ~ age, 135 | COHORT.prep 136 | )$concordance 137 | ) 138 | 139 | #' The C-index on the whole dataset based purely on age is 140 | #' **`r round(c.index.test, 3)`**. That's most of our predictive accuracy right 141 | #' there! Reassuringly, it's also equal to the value predicted by the random 142 | #' forest model based purely on age... 143 | 144 | #' ### Age and gender 145 | #' 146 | #' OK, age and gender. 147 | #+ age_gender_model, cache=cacheoption 148 | 149 | # Fit random forest 150 | surv.model.fit <- 151 | survivalFit( 152 | c('age', 'gender'), 153 | COHORT.prep[-test.set,], 154 | model.type = 'ranger', 155 | n.trees = n.trees, 156 | split.rule = 'logrank', 157 | n.threads = 8, 158 | respect.unordered.factors = 'partition' 159 | ) 160 | 161 | print(surv.model.fit) 162 | 163 | # Get C-index 164 | c.index.test <- 165 | cIndex(surv.model.fit, COHORT.prep[test.set, ], model.type = 'ranger') 166 | 167 | #' The C-index on the held-out test set is **`r round(c.index.test, 3)`**. 168 | 169 | 170 | #' ### Age, gender and history of liver disease 171 | #' 172 | #' Let's add a third variable. In the replication of the Cox model with missing 173 | #' data included, liver disease was the most predictive factor after age, so 174 | #' it's a reasonable next variable to add. 175 | 176 | #+ age_gender_liver_model, cache=cacheoption 177 | 178 | # Fit random forest 179 | surv.model.fit <- 180 | survivalFit( 181 | c('age', 'gender', 'hx_liver'), 182 | COHORT.prep[-test.set,], 183 | model.type = 'ranger', 184 | n.trees = n.trees, 185 | split.rule = 'logrank', 186 | n.threads = 8, 187 | respect.unordered.factors = 'partition' 188 | ) 189 | 190 | print(surv.model.fit) 191 | 192 | # Get C-index 193 | c.index.test <- 194 | cIndex(surv.model.fit, COHORT.prep[test.set, ], model.type = 'ranger') 195 | 196 | #' The C-index on the held-out test set is **`r round(c.index.test, 3)`**. 197 | 198 | #' ### Age, gender and heart failure 199 | #' 200 | #' A different third variable: heart failure, the second most important variable 201 | #' (after age) from random forest modelling. 202 | 203 | #+ age_gender_hf_model, cache=cacheoption 204 | 205 | # Fit random forest 206 | surv.model.fit <- 207 | survivalFit( 208 | c('age', 'gender', 'heart_failure'), 209 | COHORT.prep[-test.set,], 210 | model.type = 'ranger', 211 | n.trees = n.trees, 212 | split.rule = 'logrank', 213 | n.threads = 8, 214 | respect.unordered.factors = 'partition' 215 | ) 216 | 217 | print(surv.model.fit) 218 | 219 | # Get C-index 220 | c.index.test <- 221 | cIndex(surv.model.fit, COHORT.prep[test.set, ], model.type = 'ranger') 222 | 223 | #' The C-index on the held-out test set is **`r round(c.index.test, 3)`**. 224 | 225 | 226 | #' ### Just gender 227 | #' 228 | #' Just gender, as a sanity check. 229 | #+ just_gender_model, cache=cacheoption 230 | 231 | # Fit random forest 232 | surv.model.fit <- 233 | survivalFit( 234 | c('gender'), 235 | COHORT.prep[-test.set,], 236 | model.type = 'ranger', 237 | n.trees = n.trees, 238 | split.rule = 'logrank', 239 | n.threads = 8, 240 | respect.unordered.factors = 'partition' 241 | ) 242 | 243 | print(surv.model.fit) 244 | 245 | # Get C-index 246 | c.index.test <- 247 | cIndex(surv.model.fit, COHORT.prep[test.set, ], model.type = 'ranger') 248 | 249 | #' The C-index on the held-out test set is **`r round(c.index.test, 3)`**. 250 | 251 | #' ### Everything except age 252 | #' 253 | #' How do we do if we use all the variables _except_ age? 254 | #+ no_age_model, cache=cacheoption 255 | 256 | # Fit random forest 257 | surv.model.fit <- 258 | survivalFit( 259 | surv.predict[surv.predict != 'age'], 260 | COHORT.prep[-test.set,], 261 | model.type = 'ranger', 262 | n.trees = n.trees, 263 | split.rule = 'logrank', 264 | n.threads = 8, 265 | respect.unordered.factors = 'partition' 266 | ) 267 | 268 | print(surv.model.fit) 269 | 270 | # Get C-index 271 | c.index.test.all.not.age <- 272 | cIndex(surv.model.fit, COHORT.prep[test.set, ], model.type = 'ranger') 273 | 274 | #' The C-index on the held-out test set is 275 | #' **`r round(c.index.test.all.not.age, 3)`**. 276 | 277 | 278 | #' ### Predicting age 279 | #' 280 | #' So why does the model which doesn't include age do so well? Clearly the other 281 | #' variables allow you to predict age with reasonable accuracy... So let's try 282 | #' just that as a final test. 283 | #+ predict_age, cache=cacheoption 284 | 285 | options(rf.cores = n.threads) 286 | 287 | age.model <- 288 | rfsrc( 289 | formula( 290 | paste0( 291 | # Predicting just the age 292 | 'age ~ ', 293 | # Predictor variables then make up the other side 294 | paste(surv.predict[!(surv.predict %in% c('age', 'most_deprived'))], collapse = '+') 295 | ) 296 | ), 297 | COHORT.use[-test.set, ], 298 | ntree = n.trees, 299 | splitrule = 'mse', 300 | na.action = 'na.impute', 301 | nimpute = 3 302 | ) 303 | 304 | age.predictions <- predict(age.model, COHORT.prep[test.set, ]) 305 | 306 | age.cor <- cor(age.predictions$predictions, COHORT.prep[test.set, 'age']) 307 | 308 | to.plot <- 309 | data.frame( 310 | age = COHORT.prep[test.set, 'age'], 311 | predicted = age.predictions$predictions 312 | ) 313 | 314 | ggplot(sample.df(to.plot, 10000), aes(x = age, y = predicted)) + 315 | geom_point(alpha = 0.2) 316 | 317 | #' It doesn't look that great, but there is some correlation... 318 | #' r^2 = `r age.cor^2` which is unremarkable, but OK. A more relevant measure 319 | #' would be the pure-age C-index, ie if I gave you a pair of patients, how 320 | #' often could you predict who was older? 321 | 322 | c.index.age.on.age <- 323 | 1 - as.numeric( 324 | survConcordance( 325 | age ~ predicted, 326 | to.plot 327 | )$concordance 328 | ) 329 | 330 | #' This comes out as **`r round(c.index.age.on.age, 3)`**, as compared to 331 | #' **`r round(c.index.test.all.not.age, 3)`**, which was the C-index for the 332 | #' survival model based on all other variables. 333 | #' 334 | #' This makes sense. The maths of C-indices is a bit tricky (how much they add 335 | #' to one-another depends in large part of how correlated the variables are, as 336 | #' well as probably being somewhat non-linear anyway), 337 | #' but clearly some significant fraction of the all-except-age model's 338 | #' predictive power comes from its ability to infer age from the remaining 339 | #' variables, and then use _that_ (implicitly) to predict time to death. 340 | #' 341 | #' The mechanism for this could be that people are likely to 342 | #' get more diseases as they get older, so if you have a, b _and_ c you're 343 | #' likely to be older. Second-order predictivity may occur if particular 344 | #' combinations of disorders or test results are common in certain age groups. 345 | #' 346 | #' ## Conclusion 347 | #' 348 | #' So, in conclusion, the sanity check has worked: giving the 349 | #' random forest model more data to work with improves its performance. Age 350 | #' alone is less predictive, adding gender makes it slightly more predictive, 351 | #' and so on. 352 | #' 353 | #' Further, though, a huge amount of the model's predictivity arises from just 354 | #' a patien's age. Not only is that alone a good predictor, but the 355 | #' reasonable performance on the model of all factors except age is explained in 356 | #' part by those factors' ability to act as a proxy for age. -------------------------------------------------------------------------------- /cox-ph/caliber-replicate-with-imputation.R: -------------------------------------------------------------------------------- 1 | #+ knitr_setup, include = FALSE 2 | 3 | # Whether to cache the intensive code sections. Set to FALSE to recalculate 4 | # everything afresh. 5 | cacheoption <- FALSE 6 | # Disable lazy caching globally, because it fails for large objects, and all the 7 | # objects we wish to cache are large... 8 | opts_chunk$set(cache.lazy = FALSE) 9 | 10 | #' # Replicating Rapsomaniki _et al._ 2014 11 | #' 12 | #' ## User variables 13 | #' 14 | #' First, define some variables... 15 | 16 | #+ define_vars 17 | 18 | imputed.data.filename <- '../../data/COHORT_complete.rds' 19 | n.data <- NA # This is of full dataset...further rows may be excluded in prep 20 | endpoint <- 'death.imputed' 21 | 22 | old.coefficients.filename <- 'rapsomaniki-cox-values-from-paper.csv' 23 | 24 | output.filename.base <- '../../output/caliber-replicate-imputed-survreg-4' 25 | 26 | 27 | cox.var.imp.perm.filename <- 28 | '../../output/caliber-replicate-imputed-survreg-bootstrap-var-imp-perm-1.csv' 29 | cox.var.imp.perm.missing.filename <- 30 | '../../output/caliber-replicate-with-missing-survreg-bootstrap-var-imp-perm-1.csv' 31 | 32 | 33 | bootstraps <- 100 34 | n.threads <- 10 35 | 36 | #' ## Setup 37 | 38 | #+ setup, message=FALSE 39 | 40 | source('../lib/shared.R') 41 | require(xtable) 42 | require(ggrepel) 43 | 44 | # Load the data and convert to data frame to make column-selecting code in 45 | # prepData simpler 46 | imputed.data <- readRDS(imputed.data.filename) 47 | 48 | # Remove rows with death time of 0 to avoid fitting errors 49 | for(i in 1:length(imputed.data)) { 50 | imputed.data[[i]] <- imputed.data[[i]][imputed.data[[i]][, surv.time] > 0, ] 51 | } 52 | 53 | # Define n.data based on the imputed data, which has already been preprocessed 54 | n.data <- nrow(imputed.data[[1]]) 55 | # Define indices of test set 56 | test.set <- testSetIndices(imputed.data[[1]], random.seed = 78361) 57 | 58 | #' OK, we've now got **`r n.data`** patients, split into a training set of 59 | #' `r n.data - length(test.set)` and a test set of `r length(test.set)`. 60 | #' 61 | #' 62 | #' ## Transform variables 63 | #' 64 | #' The model uses variables which have been standardised in various ways, so 65 | #' let's go through and transform our input variables in the same way... 66 | 67 | source('caliber-scale.R') 68 | # Remove the age splining 69 | ageSpline <- identity 70 | 71 | for(i in 1:length(imputed.data)) { 72 | imputed.data[[i]] <- caliberScale(imputed.data[[i]], surv.time, surv.event) 73 | } 74 | 75 | #' ## Survival fitting 76 | #' 77 | #' Fit a Cox model to the preprocessed data. The paper uses a Cox model with an 78 | #' exponential baseline hazard, as here. The standard errors were calculated 79 | #' with 200 bootstrap samples, which we're also doing here. 80 | 81 | #+ fit_cox_model, cache=cacheoption 82 | 83 | surv.formula <- 84 | Surv(surv_time, surv_event) ~ 85 | ### Sociodemographic characteristics ####################################### 86 | ## Age in men, per year 87 | ## Age in women, per year 88 | ## Women vs. men 89 | # ie include interaction between age and gender! 90 | age*gender + 91 | ## Most deprived quintile, yes vs. no 92 | most_deprived + 93 | ### SCAD diagnosis and severity ############################################ 94 | ## Other CHD vs. stable angina 95 | ## Unstable angina vs. stable angina 96 | ## NSTEMI vs. stable angina 97 | ## STEMI vs. stable angina 98 | diagnosis + 99 | #diagnosis_missing + 100 | ## PCI in last 6 months, yes vs. no 101 | pci_6mo + 102 | ## CABG in last 6 months, yes vs. no 103 | cabg_6mo + 104 | ## Previous/recurrent MI, yes vs. no 105 | hx_mi + 106 | ## Use of nitrates, yes vs. no 107 | long_nitrate + 108 | ### CVD risk factors ####################################################### 109 | ## Ex-smoker / current smoker / missing data vs. never 110 | smokstatus + 111 | ## Hypertension, present vs. absent 112 | hypertension + 113 | ## Diabetes mellitus, present vs. absent 114 | diabetes_logical + 115 | ## Total cholesterol, per 1 mmol/L increase 116 | total_chol_6mo + 117 | ## HDL, per 0.5 mmol/L increase 118 | hdl_6mo + 119 | ### CVD co-morbidities ##################################################### 120 | ## Heart failure, present vs. absent 121 | heart_failure + 122 | ## Peripheral arterial disease, present vs. absent 123 | pad + 124 | ## Atrial fibrillation, present vs. absent 125 | hx_af + 126 | ## Stroke, present vs. absent 127 | hx_stroke + 128 | ### Non-CVD comorbidities ################################################## 129 | ## Chronic kidney disease, present vs. absent 130 | hx_renal + 131 | ## Chronic obstructive pulmonary disease, present vs. absent 132 | hx_copd + 133 | ## Cancer, present vs. absent 134 | hx_cancer + 135 | ## Chronic liver disease, present vs. absent 136 | hx_liver + 137 | ### Psychosocial characteristics ########################################### 138 | ## Depression at diagnosis, present vs. absent 139 | hx_depression + 140 | ## Anxiety at diagnosis, present vs. absent 141 | hx_anxiety + 142 | ### Biomarkers ############################################################# 143 | ## Heart rate, per 10 b.p.m increase 144 | pulse_6mo + 145 | ## Creatinine, per 30 μmol/L increase 146 | crea_6mo + 147 | ## White cell count, per 1.5 109/L increase 148 | total_wbc_6mo + 149 | ## Haemoglobin, per 1.5 g/dL increase 150 | haemoglobin_6mo 151 | 152 | # Do a quick and dirty fit on a single imputed dataset, to draw calibration 153 | # curve from 154 | fit.exp <- survreg( 155 | formula = surv.formula, 156 | data = imputed.data[[1]][-test.set, ], 157 | dist = "exponential" 158 | ) 159 | 160 | fit.exp.boot <- list() 161 | 162 | # Perform bootstrap fitting for every multiply imputed dataset 163 | for(i in 1:length(imputed.data)) { 164 | fit.exp.boot[[i]] <- 165 | boot( 166 | formula = surv.formula, 167 | data = imputed.data[[i]][-test.set, ], 168 | statistic = bootstrapFitSurvreg, 169 | R = bootstraps, 170 | parallel = 'multicore', 171 | ncpus = n.threads, 172 | test.data = imputed.data[[i]][test.set, ] 173 | ) 174 | } 175 | 176 | # Save the fits, because it might've taken a while! 177 | saveRDS(fit.exp.boot, paste0(output.filename.base, '-surv-boot-imp.rds')) 178 | 179 | # Unpackage the uncertainties from the bootstrapped data 180 | fit.exp.boot.ests <- bootMIStats(fit.exp.boot) 181 | 182 | # Save bootstrapped performance values 183 | varsToTable( 184 | data.frame( 185 | model = 'cox', 186 | imputation = TRUE, 187 | discretised = FALSE, 188 | c.index = fit.exp.boot.ests['c.index', 'val'], 189 | c.index.lower = fit.exp.boot.ests['c.index', 'lower'], 190 | c.index.upper = fit.exp.boot.ests['c.index', 'upper'], 191 | calibration.score = fit.exp.boot.ests['calibration.score', 'val'], 192 | calibration.score.lower = fit.exp.boot.ests['calibration.score', 'lower'], 193 | calibration.score.upper = fit.exp.boot.ests['calibration.score', 'upper'] 194 | ), 195 | performance.file, 196 | index.cols = c('model', 'imputation', 'discretised') 197 | ) 198 | 199 | #' ## Performance 200 | #' 201 | #' Having fitted the Cox model, how did we do? The c-indices were calculated as 202 | #' part of the bootstrapping, so we just need to take a look at those... 203 | #' 204 | #' C-indices are **`r round(fit.exp.boot.ests['c.train', 'val'], 3)` 205 | #' (`r round(fit.exp.boot.ests['c.train', 'lower'], 3)` - 206 | #' `r round(fit.exp.boot.ests['c.train', 'upper'], 3)`)** on the training set and 207 | #' **`r round(fit.exp.boot.ests['c.test', 'val'], 3)` 208 | #' (`r round(fit.exp.boot.ests['c.test', 'lower'], 3)` - 209 | #' `r round(fit.exp.boot.ests['c.test', 'upper'], 3)`)** on the test set. 210 | #' Not too bad! 211 | #' 212 | #' 213 | #' ### Calibration 214 | #' 215 | #' The bootstrapped calibration score is 216 | #' **`r round(fit.exp.boot.ests['calibration.score', 'val'], 3)` 217 | #' (`r round(fit.exp.boot.ests['calibration.score', 'lower'], 3)` - 218 | #' `r round(fit.exp.boot.ests['calibration.score', 'upper'], 3)`)**. 219 | #' 220 | #' Let's draw a representative curve from the unbootstrapped fit... (It would be 221 | #' better to draw all the curves from the bootstrap fit to get an idea of 222 | #' variability, but I've not implemented this yet.) 223 | #' 224 | #+ calibration_plot 225 | 226 | calibration.table <- 227 | calibrationTable(fit.exp, imputed.data[[i]][test.set, ]) 228 | 229 | calibration.score <- calibrationScore(calibration.table) 230 | 231 | calibrationPlot(calibration.table) 232 | 233 | #' The area between the calibration curve and the diagonal is 234 | #' **`r round(calibration.score, 3)`**. 235 | #' 236 | #' ## Coefficients 237 | #' 238 | #' As well as getting comparable C-indices, it's also worth checking to see how 239 | #' the risk coefficients calculated compare to those found in the original 240 | #' paper. Let's compare... 241 | 242 | # Load CSV of values from paper 243 | old.coefficients <- read.csv(old.coefficients.filename) 244 | 245 | # Get coefficients from this fit 246 | new.coefficients <- 247 | bootMIStats(fit.exp.boot, uncertainty = '95ci', transform = negExp) 248 | names(new.coefficients) <- c('our_value', 'our_lower', 'our_upper') 249 | new.coefficients$quantity.level <- rownames(new.coefficients) 250 | 251 | # Create a data frame comparing them 252 | compare.coefficients <- merge(old.coefficients, new.coefficients) 253 | 254 | # Kludge because age:genderWomen is the pure interaction term, not the risk for 255 | # a woman per unit of advancing spline-transformed age 256 | compare.coefficients[ 257 | compare.coefficients$quantity.level == 'age:genderWomen', 'our_value' 258 | ] <- 259 | compare.coefficients[ 260 | compare.coefficients$quantity.level == 'age:genderWomen', 'our_value' 261 | ] * 262 | compare.coefficients[ 263 | compare.coefficients$quantity.level == 'age', 'our_value' 264 | ] 265 | 266 | # Save CSV of results 267 | write.csv(compare.coefficients, output.filename) 268 | 269 | # Plot a graph by which to judge success 270 | ggplot(compare.coefficients, aes(x = their_value, y = our_value)) + 271 | geom_abline(intercept = 0, slope = 1) + 272 | geom_hline(yintercept = 1, colour = 'grey') + 273 | geom_vline(xintercept = 1, colour = 'grey') + 274 | geom_point() + 275 | geom_errorbar(aes(ymin = our_lower, ymax = our_upper)) + 276 | geom_errorbarh(aes(xmin = their_lower, xmax = their_upper)) + 277 | geom_text_repel(aes(label = long_name)) + 278 | theme_classic(base_size = 8) 279 | 280 | #+ coefficients_table, results='asis' 281 | 282 | print( 283 | xtable( 284 | data.frame( 285 | variable = 286 | paste( 287 | compare.coefficients$long_name, compare.coefficients$unit, sep=', ' 288 | ), 289 | compare.coefficients[c('our_value', 'their_value')] 290 | ), 291 | digits = c(0,0,3,3) 292 | ), 293 | type = 'html', 294 | include.rownames = FALSE 295 | ) 296 | 297 | #' ### Variable importance 298 | #' 299 | #' Let's compare the variable importance from this method with accounting for 300 | #' missing values explicitly. Slight kludge as it's only using one imputed 301 | #' dataset and a fit based on another, but should give some idea. 302 | #' 303 | #+ cox_variable_importance 304 | 305 | cox.var.imp.perm <- 306 | generalVarImp( 307 | fit.exp, imputed.data[[2]][test.set, ], model.type = 'survreg' 308 | ) 309 | 310 | write.csv(cox.var.imp.perm, cox.var.imp.perm.filename, row.names = FALSE) 311 | 312 | cox.var.imp.perm.missing <- read.csv(cox.var.imp.perm.missing.filename) 313 | 314 | cox.var.imp.comparison <- 315 | merge( 316 | cox.var.imp.perm, 317 | cox.var.imp.perm.missing, 318 | by = 'var', 319 | suffixes = c('', '.missing') 320 | ) 321 | 322 | ggplot(cox.var.imp.comparison, aes(x = var.imp.missing, y = var.imp)) + 323 | geom_point() + 324 | scale_x_log10() + 325 | scale_y_log10() 326 | 327 | #' There's a good correlation! 328 | -------------------------------------------------------------------------------- /cox-ph/cox-discretised.R: -------------------------------------------------------------------------------- 1 | #+ knitr_setup, include = FALSE 2 | 3 | # Whether to cache the intensive code sections. Set to FALSE to recalculate 4 | # everything afresh. 5 | cacheoption <- FALSE 6 | # Disable lazy caching globally, because it fails for large objects, and all the 7 | # objects we wish to cache are large... 8 | opts_chunk$set(cache.lazy = FALSE) 9 | 10 | #' # Cross-validating discretisation of input variables in a survival model 11 | #' 12 | #' 13 | 14 | calibration.filename <- '../../output/survreg-crossvalidation-try5.csv' 15 | caliber.missing.coefficients.filename <- 16 | '../../output/caliber-replicate-with-missing-survreg-bootstrap-coeffs-1.csv' 17 | comparison.filename <- 18 | '../../output/caliber-replicate-with-missing-var-imp-try2.csv' 19 | # The first part of the filename for any output 20 | output.filename.base <- '../../output/all-cv-survreg-boot-try5' 21 | 22 | 23 | # What kind of model to fit to...currently 'cph' (Cox model), 'ranger' or 24 | # 'rfsrc' (two implementations of random survival forests) 25 | model.type <- 'survreg' 26 | 27 | # If surv.vars is defined as a character vector here, the model only uses those 28 | # variables specified, eg c('age') would build a model purely based on age. If 29 | # not specified (ie commented out), it will use the defaults. 30 | # surv.predict <- c('age') 31 | 32 | #' ## Do the cross-validation 33 | #' 34 | #' The steps for this are common regardless of model type, so run the script to 35 | #' get a cross-validated model to further analyse... 36 | #+ cox_discretised_cv, cache=cacheoption 37 | 38 | source('../lib/all-cv-bootstrap.R', chdir = TRUE) 39 | 40 | #' # Results 41 | #' 42 | #' ## Performance 43 | #' 44 | #' ### C-index 45 | #' 46 | #' C-index is **`r round(surv.model.fit.coeffs['c.index', 'val'], 3)` 47 | #' (`r round(surv.model.fit.coeffs['c.index', 'lower'], 3)` - 48 | #' `r round(surv.model.fit.coeffs['c.index', 'upper'], 3)`)** on the held-out 49 | #' test set. 50 | #' 51 | #' 52 | #' ### Calibration 53 | #' 54 | #' The bootstrapped calibration score is 55 | #' **`r round(surv.model.fit.coeffs['calibration.score', 'val'], 3)` 56 | #' (`r round(surv.model.fit.coeffs['calibration.score', 'lower'], 3)` - 57 | #' `r round(surv.model.fit.coeffs['calibration.score', 'upper'], 3)`)**. 58 | #' 59 | #' Let's draw a representative curve from the unbootstrapped fit... (It would be 60 | #' better to draw all the curves from the bootstrap fit to get an idea of 61 | #' variability, but I've not implemented this yet.) 62 | #' 63 | #+ calibration_plot 64 | 65 | calibration.table <- 66 | calibrationTable(surv.model.fit, COHORT.optimised[test.set, ]) 67 | 68 | calibration.score <- calibrationScore(calibration.table) 69 | 70 | calibrationPlot(calibration.table) 71 | 72 | #' 73 | #' ## Model fit 74 | #' 75 | #+ resulting_fit 76 | 77 | print(surv.model.fit) 78 | 79 | #' ## Cox coefficients 80 | #' 81 | #+ cox_coefficients_plot 82 | 83 | 84 | # Save bootstrapped performance values 85 | varsToTable( 86 | data.frame( 87 | model = 'cox', 88 | imputation = FALSE, 89 | discretised = TRUE, 90 | c.index = surv.model.fit.coeffs['c.index', 'val'], 91 | c.index.lower = surv.model.fit.coeffs['c.index', 'lower'], 92 | c.index.upper = surv.model.fit.coeffs['c.index', 'upper'], 93 | calibration.score = surv.model.fit.coeffs['calibration.score', 'val'], 94 | calibration.score.lower = 95 | surv.model.fit.coeffs['calibration.score', 'lower'], 96 | calibration.score.upper = 97 | surv.model.fit.coeffs['calibration.score', 'upper'] 98 | ), 99 | performance.file, 100 | index.cols = c('model', 'imputation', 'discretised') 101 | ) 102 | 103 | # Unpackage the uncertainties again, this time transformed because survreg 104 | # returns negative values 105 | surv.boot.ests <- bootStatsDf(surv.model.params.boot, transform = `-`) 106 | 107 | #' First, plot the factors and logicals as a scatter plot to compare with the 108 | #' continuous Cox model... 109 | 110 | # Pull coefficients from model with missing data 111 | caliber.missing.coeffs <- read.csv(caliber.missing.coefficients.filename) 112 | 113 | # Rename surv.boot.ests ready for merging 114 | names(surv.boot.ests) <- 115 | c('cox_discrete_value', 'cox_discrete_lower', 'cox_discrete_upper') 116 | surv.boot.ests$quantity.level <- rownames(surv.boot.ests) 117 | # Convert variablemissing to variable_missingTRUE for compatibility 118 | vars.with.missing <- endsWith(surv.boot.ests$quantity.level, 'missing') 119 | surv.boot.ests$quantity.level[vars.with.missing] <- 120 | paste0( 121 | substr( 122 | surv.boot.ests$quantity.level[vars.with.missing], 123 | 1, 124 | nchar(surv.boot.ests$quantity.level[vars.with.missing]) - nchar('missing') 125 | ), 126 | '_missingTRUE' 127 | ) 128 | 129 | # Create a data frame comparing them 130 | compare.coefficients <- merge(caliber.missing.coeffs, surv.boot.ests) 131 | 132 | ggplot( 133 | compare.coefficients, 134 | aes(x = our_value, y = cox_discrete_value, colour = unit == 'missing') 135 | ) + 136 | geom_abline(intercept = 0, slope = 1) + 137 | geom_hline(yintercept = 1, colour = 'grey') + 138 | geom_vline(xintercept = 1, colour = 'grey') + 139 | geom_point() + 140 | geom_errorbar(aes(ymin = cox_discrete_lower, ymax = cox_discrete_upper)) + 141 | geom_errorbarh(aes(xmin = our_lower, xmax = our_upper)) + 142 | geom_text_repel(aes(label = long_name)) + 143 | theme_classic(base_size = 8) 144 | 145 | # Unpack variable and level names 146 | cph.coeffs <- cphCoeffs( 147 | bootStats(surv.model.fit.boot, uncertainty = '95ci', transform = `-`), 148 | COHORT.optimised, surv.predict, model.type = 'boot.survreg' 149 | ) 150 | 151 | # We'll need the CALIBER scaling functions for plotting 152 | source('../cox-ph/caliber-scale.R') 153 | 154 | # set up list to store the plots 155 | cox.discrete.plots <- list() 156 | # Add dummy columns for x-position of missing values 157 | cph.coeffs$missing.x.pos.cont <- NA 158 | cph.coeffs$missing.x.pos.disc <- NA 159 | 160 | for(variable in unique(cph.coeffs$var)) { 161 | # If it's a continuous variable, get the real centres of the bins 162 | if(variable %in% process.settings$var) { 163 | process.i <- which(variable == process.settings$var) 164 | 165 | if(process.settings$method[[process.i]] == 'binByQuantile') { 166 | 167 | variable.quantiles <- 168 | getQuantiles( 169 | COHORT.use[, variable], 170 | process.settings$settings[[process.i]] 171 | ) 172 | # For those rows which relate to this variable, and whose level isn't 173 | # missing, put in the appropriate quantile boundaries for plotting 174 | cph.coeffs$bin.min[cph.coeffs$var == variable & 175 | cph.coeffs$level != 'missing'] <- 176 | variable.quantiles[1:(length(variable.quantiles) - 1)] 177 | cph.coeffs$bin.max[cph.coeffs$var == variable & 178 | cph.coeffs$level != 'missing'] <- 179 | variable.quantiles[2:length(variable.quantiles)] 180 | # Make the final bin the 99th percentile 181 | cph.coeffs$bin.max[cph.coeffs$var == variable & 182 | cph.coeffs$level != 'missing'][ 183 | length(variable.quantiles) - 1] <- 184 | quantile(COHORT.use[, variable], 0.99, na.rm = TRUE) 185 | 186 | # Add a fake data point at the highest value to finish the graph 187 | cph.coeffs <- 188 | rbind( 189 | cph.coeffs, 190 | cph.coeffs[cph.coeffs$var == variable & 191 | cph.coeffs$level != 'missing', ][ 192 | length(variable.quantiles) - 1, ] 193 | ) 194 | # Change it so that bin.min is bin.max from the old one 195 | cph.coeffs$bin.min[nrow(cph.coeffs)] <- 196 | cph.coeffs$bin.max[cph.coeffs$var == variable & 197 | cph.coeffs$level != 'missing'][ 198 | length(variable.quantiles) - 1] 199 | 200 | # Work out data range by taking the 1st and 99th percentiles 201 | # Use the max to provide a max value for the final bin 202 | # Also use for x-axis limits, unless there are missing values to 203 | # accommodate on the right-hand edge. 204 | x.data.range <- 205 | quantile(COHORT.use[, variable], c(0.01, 0.99), na.rm = TRUE) 206 | x.axis.limits <- x.data.range 207 | 208 | 209 | # Finally, we need to scale this such that the baseline value is equal 210 | # to the value for the equivalent place in the Cox model, to make the 211 | # risks comparable... 212 | 213 | # First, we need to find the average value of this variable in the lowest 214 | # bin (which is always the baseline here) 215 | baseline.bin <- variable.quantiles[1:2] 216 | baseline.bin.avg <- 217 | mean( 218 | # Take only those values of the variable which are in the range 219 | COHORT.use[ 220 | inRange(COHORT.use[, variable], baseline.bin, na.false = TRUE), 221 | variable 222 | ] 223 | ) 224 | # Then, scale it with the caliber scaling 225 | baseline.bin.val <- 226 | caliberScaleUnits(baseline.bin.avg, variable) * 227 | caliber.missing.coeffs$our_value[ 228 | caliber.missing.coeffs$quantity == variable 229 | ] 230 | 231 | # And now, add all the discretised values to that value to make them 232 | # comparable... 233 | cph.coeffs[cph.coeffs$var == variable, c('val', 'lower', 'upper')] <- 234 | cph.coeffs[cph.coeffs$var == variable, c('val', 'lower', 'upper')] - 235 | baseline.bin.val 236 | 237 | # Now, plot this variable as a stepped line plot using those quantile 238 | # boundaries 239 | cox.discrete.plot <- 240 | ggplot( 241 | subset(cph.coeffs, var == variable), 242 | aes(x = bin.min, y = val) 243 | ) + 244 | geom_step() + 245 | geom_step(aes(y = lower), colour = 'grey') + 246 | geom_step(aes(y = upper), colour = 'grey') + 247 | ggtitle(variable) 248 | 249 | # If there's a missing value risk, add it 250 | if(any(cph.coeffs$var == variable & cph.coeffs$level == 'missing')) { 251 | # Expand the x-axis to squeeze the missing values in 252 | x.axis.limits[2] <- 253 | x.axis.limits[2] + diff(x.data.range) * missing.padding 254 | # Put this missing value a third of the way into the missing area 255 | cph.coeffs$missing.x.pos.disc[ 256 | cph.coeffs$var == variable & 257 | cph.coeffs$level == 'missing'] <- 258 | x.axis.limits[2] + diff(x.data.range) * missing.padding / 3 259 | 260 | # Add the point to the graph (we'll set axis limits later) 261 | cox.discrete.plot <- 262 | cox.discrete.plot + 263 | geom_pointrange( 264 | data = cph.coeffs[cph.coeffs$var == variable & 265 | cph.coeffs$level == 'missing', ], 266 | aes( 267 | x = missing.x.pos.disc, 268 | y = val, ymin = lower, 269 | ymax = upper 270 | ), 271 | colour = 'red' 272 | ) 273 | } 274 | 275 | # Now, let's add the line from the continuous Cox model. We only need two 276 | # points because the lines are straight! 277 | continuous.cox <- 278 | data.frame( 279 | var.x.values = x.data.range 280 | ) 281 | # Scale the x-values 282 | continuous.cox$var.x.scaled <- 283 | caliberScaleUnits(continuous.cox$var.x.values, variable) 284 | # Use the risks to calculate risk per x for central estimate and errors 285 | continuous.cox$y <- 286 | -caliber.missing.coeffs$our_value[ 287 | caliber.missing.coeffs$quantity == variable 288 | ] * continuous.cox$var.x.scaled 289 | continuous.cox$upper <- 290 | -caliber.missing.coeffs$our_upper[ 291 | caliber.missing.coeffs$quantity == variable 292 | ] * continuous.cox$var.x.scaled 293 | continuous.cox$lower <- 294 | -caliber.missing.coeffs$our_lower[ 295 | caliber.missing.coeffs$quantity == variable 296 | ] * continuous.cox$var.x.scaled 297 | 298 | cox.discrete.plot <- 299 | cox.discrete.plot + 300 | geom_line( 301 | data = continuous.cox, 302 | aes(x = var.x.values, y = y), 303 | colour = 'blue' 304 | ) + 305 | geom_line( 306 | data = continuous.cox, 307 | aes(x = var.x.values, y = upper), 308 | colour = 'lightblue' 309 | ) + 310 | geom_line( 311 | data = continuous.cox, 312 | aes(x = var.x.values, y = lower), 313 | colour = 'lightblue' 314 | ) 315 | 316 | # If there is one, add missing value risk from the continuous model 317 | if(any(caliber.missing.coeffs$quantity == paste0(variable, '_missing') & 318 | caliber.missing.coeffs$unit == 'missing')) { 319 | # Expand the x-axis to squeeze the missing values in 320 | x.axis.limits[2] <- 321 | x.axis.limits[2] + diff(x.data.range) * missing.padding 322 | # Put this missing value 2/3rds of the way into the missing area 323 | cph.coeffs$missing.x.pos.cont[ 324 | cph.coeffs$var == variable & 325 | cph.coeffs$level == 'missing'] <- 326 | x.axis.limits[2] + diff(x.data.range) * missing.padding / 3 327 | x.axis.limits[2] + 2 * diff(x.data.range) * missing.padding / 3 328 | 329 | cox.discrete.plot <- 330 | cox.discrete.plot + 331 | geom_pointrange( 332 | data = cph.coeffs[ 333 | cph.coeffs$var == variable & 334 | cph.coeffs$level == 'missing', 335 | ], 336 | aes( 337 | x = missing.x.pos.cont, 338 | y = val, ymin = lower, ymax = upper 339 | ), 340 | colour = 'blue' 341 | ) 342 | } 343 | 344 | # Finally, set the x-axis limits; will just be the data range, or data 345 | # range plus a bit if there are missing values to squeeze in 346 | cox.discrete.plot <- 347 | cox.discrete.plot + 348 | coord_cartesian(xlim = x.axis.limits) 349 | 350 | cox.discrete.plots[[variable]] <- cox.discrete.plot 351 | } 352 | } 353 | } 354 | 355 | print(cox.discrete.plots) 356 | -------------------------------------------------------------------------------- /overview/variable-effects.R: -------------------------------------------------------------------------------- 1 | cox.disc.filename <- '../../output/all-cv-survreg-boot-try5-surv-model.rds' 2 | caliber.missing.coefficients.filename <- 3 | '../../output/caliber-replicate-with-missing-survreg-6-linear-age-coeffs-3.csv' 4 | rf.filename <- '../../output/rfsrc-cv-nsplit-try3-var-effects.csv' 5 | 6 | source('../lib/shared.R') 7 | requirePlus('cowplot') 8 | 9 | # Amount of padding at the right-hand side to make space for missing values 10 | missing.padding <- 0.05 11 | 12 | continuous.vars <- 13 | c( 14 | 'age', 'total_chol_6mo', 'hdl_6mo', 'pulse_6mo', 'crea_6mo', 15 | 'total_wbc_6mo', 'haemoglobin_6mo' 16 | ) 17 | 18 | # Load in the discretised Cox model for plotting 19 | surv.model.fit.boot <- readRDS(cox.disc.filename) 20 | 21 | # Pull coefficients from model with missing data 22 | caliber.missing.coeffs <- read.csv(caliber.missing.coefficients.filename) 23 | # Log them to get them on the same scale as discrete model 24 | caliber.missing.coeffs$our_value <- -log(caliber.missing.coeffs$our_value) 25 | caliber.missing.coeffs$our_lower <- -log(caliber.missing.coeffs$our_lower) 26 | caliber.missing.coeffs$our_upper <- -log(caliber.missing.coeffs$our_upper) 27 | 28 | # Load the data 29 | COHORT.use <- data.frame(fread(data.filename)) 30 | 31 | # Open the calibration to find the best binning scheme 32 | calibration.filename <- '../../output/survreg-crossvalidation-try5.csv' 33 | cv.performance <- read.csv(calibration.filename) 34 | 35 | # Find the best calibration... 36 | # First, average performance across cross-validation folds 37 | cv.performance.average <- 38 | aggregate( 39 | c.index.val ~ calibration, 40 | data = cv.performance, 41 | mean 42 | ) 43 | # Find the highest value 44 | best.calibration <- 45 | cv.performance.average$calibration[ 46 | which.max(cv.performance.average$c.index.val) 47 | ] 48 | # And finally, find the first row of that calibration to get the n.bins values 49 | best.calibration.row1 <- 50 | min(which(cv.performance$calibration == best.calibration)) 51 | 52 | # Get its parameters 53 | n.bins <- 54 | t( 55 | cv.performance[best.calibration.row1, continuous.vars] 56 | ) 57 | 58 | # Prepare the data with those settings... 59 | 60 | # Reset process settings with the base setings 61 | process.settings <- 62 | list( 63 | var = c('anonpatid', 'time_death', 'imd_score', 'exclude'), 64 | method = c(NA, NA, NA, NA), 65 | settings = list(NA, NA, NA, NA) 66 | ) 67 | for(j in 1:length(continuous.vars)) { 68 | process.settings$var <- c(process.settings$var, continuous.vars[j]) 69 | process.settings$method <- c(process.settings$method, 'binByQuantile') 70 | process.settings$settings <- 71 | c( 72 | process.settings$settings, 73 | list( 74 | seq( 75 | # Quantiles are obviously between 0 and 1 76 | 0, 1, 77 | # Choose a random number of bins (and for n bins, you need n + 1 breaks) 78 | length.out = n.bins[j] 79 | ) 80 | ) 81 | ) 82 | } 83 | 84 | # prep the data given the variables provided 85 | COHORT.optimised <- 86 | prepData( 87 | # Data for cross-validation excludes test set 88 | COHORT.use, 89 | cols.keep, 90 | process.settings, 91 | surv.time, surv.event, 92 | surv.event.yes, 93 | extra.fun = caliberExtraPrep 94 | ) 95 | 96 | # Unpack variable and level names 97 | cph.coeffs <- cphCoeffs( 98 | bootStats(surv.model.fit.boot, uncertainty = '95ci', transform = `-`), 99 | COHORT.optimised, surv.predict, model.type = 'boot.survreg' 100 | ) 101 | 102 | # We'll need the CALIBER scaling functions for plotting 103 | source('../cox-ph/caliber-scale.R') 104 | 105 | # set up list to store the plots 106 | cox.discrete.plots <- list() 107 | # Add dummy columns for x-position of missing values 108 | caliber.missing.coeffs$missing.x.pos.cont <- NA 109 | cph.coeffs$missing.x.pos.disc <- NA 110 | 111 | for(variable in unique(cph.coeffs$var)) { 112 | # If it's a continuous variable, get the real centres of the bins 113 | if(variable %in% process.settings$var) { 114 | process.i <- which(variable == process.settings$var) 115 | 116 | if(process.settings$method[[process.i]] == 'binByQuantile') { 117 | 118 | variable.quantiles <- 119 | getQuantiles( 120 | COHORT.use[, variable], 121 | process.settings$settings[[process.i]] 122 | ) 123 | # For those rows which relate to this variable, and whose level isn't 124 | # missing, put in the appropriate quantile boundaries for plotting 125 | cph.coeffs$bin.min[cph.coeffs$var == variable & 126 | cph.coeffs$level != 'missing'] <- 127 | variable.quantiles[1:(length(variable.quantiles) - 1)] 128 | cph.coeffs$bin.max[cph.coeffs$var == variable & 129 | cph.coeffs$level != 'missing'] <- 130 | variable.quantiles[2:length(variable.quantiles)] 131 | # Make the final bin the 99th percentile 132 | cph.coeffs$bin.max[cph.coeffs$var == variable & 133 | cph.coeffs$level != 'missing'][ 134 | length(variable.quantiles) - 1] <- 135 | quantile(COHORT.use[, variable], 0.99, na.rm = TRUE) 136 | 137 | # Add a fake data point at the highest value to finish the graph 138 | cph.coeffs <- 139 | rbind( 140 | cph.coeffs, 141 | cph.coeffs[cph.coeffs$var == variable & 142 | cph.coeffs$level != 'missing', ][ 143 | length(variable.quantiles) - 1, ] 144 | ) 145 | # Change it so that bin.min is bin.max from the old one 146 | cph.coeffs$bin.min[nrow(cph.coeffs)] <- 147 | cph.coeffs$bin.max[cph.coeffs$var == variable & 148 | cph.coeffs$level != 'missing'][ 149 | length(variable.quantiles) - 1] 150 | 151 | # Work out data range by taking the 1st and 99th percentiles 152 | # Use the max to provide a max value for the final bin 153 | # Also use for x-axis limits, unless there are missing values to 154 | # accommodate on the right-hand edge. 155 | x.data.range <- 156 | quantile(COHORT.use[, variable], c(0.01, 0.99), na.rm = TRUE) 157 | x.axis.limits <- x.data.range 158 | 159 | 160 | # Finally, we need to scale this such that the baseline value is equal 161 | # to the value for the equivalent place in the Cox model, to make the 162 | # risks comparable... 163 | 164 | # First, we need to find the average value of this variable in the lowest 165 | # bin (which is always the baseline here) 166 | baseline.bin <- variable.quantiles[1:2] 167 | baseline.bin.avg <- 168 | mean( 169 | # Take only those values of the variable which are in the range 170 | COHORT.use[ 171 | inRange(COHORT.use[, variable], baseline.bin, na.false = TRUE), 172 | variable 173 | ] 174 | ) 175 | # Then, scale it with the caliber scaling 176 | baseline.bin.val <- 177 | caliberScaleUnits(baseline.bin.avg, variable) * 178 | caliber.missing.coeffs$our_value[ 179 | caliber.missing.coeffs$quantity == variable 180 | ] 181 | 182 | # And now, add all the discretised values to that value to make them 183 | # comparable... 184 | cph.coeffs[cph.coeffs$var == variable, c('val', 'lower', 'upper')] <- 185 | cph.coeffs[cph.coeffs$var == variable, c('val', 'lower', 'upper')] - 186 | baseline.bin.val 187 | 188 | # Now, plot this variable as a stepped line plot using those quantile 189 | # boundaries 190 | cox.discrete.plot <- 191 | ggplot( 192 | subset(cph.coeffs, var == variable), 193 | aes(x = bin.min, y = val) 194 | ) + 195 | geom_step() + 196 | geom_step(aes(y = lower), colour = 'grey') + 197 | geom_step(aes(y = upper), colour = 'grey') + 198 | labs(x = variable, y = 'Bx') 199 | 200 | # If there's a missing value risk, add it 201 | if(any(cph.coeffs$var == variable & cph.coeffs$level == 'missing')) { 202 | # Expand the x-axis to squeeze the missing values in 203 | x.axis.limits[2] <- 204 | x.axis.limits[2] + diff(x.data.range) * missing.padding 205 | # Put this missing value a third of the way into the missing area 206 | cph.coeffs$missing.x.pos.disc[ 207 | cph.coeffs$var == variable & 208 | cph.coeffs$level == 'missing'] <- 209 | x.axis.limits[2] + diff(x.data.range) * missing.padding / 3 210 | 211 | # Add the point to the graph (we'll set axis limits later) 212 | cox.discrete.plot <- 213 | cox.discrete.plot + 214 | geom_pointrange( 215 | data = cph.coeffs[cph.coeffs$var == variable & 216 | cph.coeffs$level == 'missing', ], 217 | aes( 218 | x = missing.x.pos.disc, 219 | y = val, ymin = lower, 220 | ymax = upper 221 | ), 222 | colour = 'red' 223 | ) 224 | } 225 | 226 | # Now, let's add the line from the continuous Cox model. We only need two 227 | # points because the lines are straight! 228 | continuous.cox <- 229 | data.frame( 230 | var.x.values = x.data.range 231 | ) 232 | # Scale the x-values 233 | continuous.cox$var.x.scaled <- 234 | caliberScaleUnits(continuous.cox$var.x.values, variable) 235 | # Use the risks to calculate risk per x for central estimate and errors 236 | continuous.cox$y <- 237 | -caliber.missing.coeffs$our_value[ 238 | caliber.missing.coeffs$quantity == variable 239 | ] * continuous.cox$var.x.scaled 240 | continuous.cox$upper <- 241 | -caliber.missing.coeffs$our_upper[ 242 | caliber.missing.coeffs$quantity == variable 243 | ] * continuous.cox$var.x.scaled 244 | continuous.cox$lower <- 245 | -caliber.missing.coeffs$our_lower[ 246 | caliber.missing.coeffs$quantity == variable 247 | ] * continuous.cox$var.x.scaled 248 | 249 | cox.discrete.plot <- 250 | cox.discrete.plot + 251 | geom_line( 252 | data = continuous.cox, 253 | aes(x = var.x.values, y = y), 254 | colour = 'blue' 255 | ) + 256 | geom_line( 257 | data = continuous.cox, 258 | aes(x = var.x.values, y = upper), 259 | colour = 'lightblue' 260 | ) + 261 | geom_line( 262 | data = continuous.cox, 263 | aes(x = var.x.values, y = lower), 264 | colour = 'lightblue' 265 | ) 266 | 267 | # If there is one, add missing value risk from the continuous model 268 | if(any(caliber.missing.coeffs$quantity == paste0(variable, '_missing') & 269 | caliber.missing.coeffs$unit == 'missing')) { 270 | 271 | # Put this missing value 2/3rds of the way into the missing area 272 | caliber.missing.coeffs$missing.x.pos.cont[ 273 | caliber.missing.coeffs$quantity == paste0(variable, '_missing') & 274 | caliber.missing.coeffs$unit == 'missing'] <- 275 | x.axis.limits[2] + 2 * diff(x.data.range) * missing.padding / 3 276 | 277 | cox.discrete.plot <- 278 | cox.discrete.plot + 279 | geom_pointrange( 280 | data = caliber.missing.coeffs[ 281 | caliber.missing.coeffs$quantity == paste0(variable, '_missing') & 282 | caliber.missing.coeffs$unit == 'missing', 283 | ], 284 | aes( 285 | x = missing.x.pos.cont, 286 | y = our_value, ymin = our_lower, ymax = our_upper 287 | ), 288 | colour = 'blue' 289 | ) 290 | } 291 | 292 | # Finally, set the x-axis limits; will just be the data range, or data 293 | # range plus a bit if there are missing values to squeeze in 294 | cox.discrete.plot <- 295 | cox.discrete.plot + 296 | coord_cartesian(xlim = x.axis.limits) + 297 | theme(axis.title.y = element_blank()) + 298 | theme(plot.margin = unit(c(0.2, 0.1, 0.2, 0.1), "cm")) 299 | 300 | cox.discrete.plots[[variable]] <- cox.discrete.plot 301 | } 302 | } 303 | } 304 | 305 | # Load the random forest variable effects file 306 | risk.by.variables <- read.csv(rf.filename) 307 | rf.vareff.plots <- list() 308 | 309 | for(variable in unique(risk.by.variables$var)) { 310 | # Get the mean of the normalised risk for every value of the variable 311 | risk.aggregated <- 312 | aggregate( 313 | as.formula(paste0('risk.normalised ~ val')), 314 | subset(risk.by.variables, var == variable), median 315 | ) 316 | 317 | # work out the limits on the axes by taking the 1st and 99th percentiles 318 | x.axis.limits <- 319 | quantile(COHORT.use[, variable], c(0.01, 0.99), na.rm = TRUE) 320 | y.axis.limits <- 321 | quantile(subset(risk.by.variables, var == variable)$risk.normalised, c(0.05, 0.95), na.rm = TRUE) 322 | 323 | # If there's a missing value risk in the graph above, expand the axes so they 324 | # match 325 | if(any(cph.coeffs$var == variable & cph.coeffs$level == 'missing')) { 326 | x.axis.limits[2] <- 327 | x.axis.limits[2] + diff(x.data.range) * missing.padding 328 | } 329 | 330 | rf.vareff.plots[[variable]] <- 331 | ggplot( 332 | subset(risk.by.variables, var == variable), 333 | aes(x = val, y = log(risk.normalised)) 334 | ) + 335 | geom_line(alpha=0.003, aes(group = id)) + 336 | geom_line(data = risk.aggregated, colour = 'blue') + 337 | coord_cartesian(xlim = x.axis.limits, ylim = log(y.axis.limits)) + 338 | labs(x = variable) + 339 | theme( 340 | plot.margin = unit(c(0.2, 0.1, 0.2, 0.1), "cm"), 341 | axis.title.y = element_blank() 342 | ) 343 | } 344 | 345 | 346 | plot_grid( 347 | cox.discrete.plots[['age']], 348 | cox.discrete.plots[['haemoglobin_6mo']], 349 | cox.discrete.plots[['total_wbc_6mo']], 350 | cox.discrete.plots[['crea_6mo']], 351 | rf.vareff.plots[['age']], 352 | rf.vareff.plots[['haemoglobin_6mo']], 353 | rf.vareff.plots[['total_wbc_6mo']], 354 | rf.vareff.plots[['crea_6mo']], 355 | labels = c('A', rep('', 3), 'B', rep('', 3)), 356 | align = "v", ncol = 4 357 | ) 358 | 359 | ggsave( 360 | '../../output/variable-effects.pdf', 361 | width = 16, 362 | height = 10, 363 | units = 'cm', 364 | useDingbats = FALSE 365 | ) 366 | -------------------------------------------------------------------------------- /random-forest/rf-varselmiss.R: -------------------------------------------------------------------------------- 1 | #+ knitr_setup, include = FALSE 2 | 3 | # Whether to cache the intensive code sections. Set to FALSE to recalculate 4 | # everything afresh. 5 | cacheoption <- TRUE 6 | # Disable lazy caching globally, because it fails for large objects, and all the 7 | # objects we wish to cache are large... 8 | opts_chunk$set(cache.lazy = FALSE) 9 | 10 | #' # Variable selection in data-driven health records 11 | #' 12 | #' Having extracted around 600 variables which occur most frequently in patient 13 | #' records, let's try to narrow these down using the methodology of varSelRf. 14 | #' 15 | #' ## User variables 16 | #' 17 | #+ user_variables 18 | 19 | output.filename.base <- '../../output/rf-bigdata-try11-varselmiss' 20 | 21 | nsplit <- 20 22 | n.trees.cv <- 500 23 | n.trees.final <- 500 24 | split.rule <- 'logrank' 25 | n.imputations <- 3 26 | cv.n.folds <- 3 27 | vars.drop.frac <- 0.2 # Fraction of variables to drop at each iteration 28 | bootstraps <- 200 29 | 30 | n.data <- NA # This is after any variables being excluded in prep 31 | 32 | n.threads <- 40 33 | 34 | #' ## Data set-up 35 | #' 36 | #+ data_setup 37 | 38 | data.filename.big <- '../../data/cohort-datadriven-02.csv' 39 | 40 | surv.predict.old <- c('age', 'smokstatus', 'imd_score', 'gender') 41 | untransformed.vars <- c('time_death', 'endpoint_death', 'exclude') 42 | 43 | source('../lib/shared.R') 44 | require(xtable) 45 | 46 | # Define these after shared.R or they will be overwritten! 47 | exclude.vars <- 48 | c( 49 | # Entity type 4 is smoking status, which we already have 50 | "clinical.values.4_data1", "clinical.values.4_data5", 51 | "clinical.values.4_data6", 52 | # Entity 13 data2 is the patient's weight centile, and not a single one is 53 | # entered, but they come out as 0 so the algorithm, looking for NAs, thinks 54 | # it's a useful column 55 | "clinical.values.13_data2", 56 | # Entities 148 and 149 are to do with death certification. I'm not sure how 57 | # it made it into the dataset, but since all the datapoints in this are 58 | # looking back in time, they're all NA. This causes rfsrc to fail. 59 | "clinical.values.148_data1", "clinical.values.148_data2", 60 | "clinical.values.148_data3", "clinical.values.148_data4", 61 | "clinical.values.148_data5", 62 | "clinical.values.149_data1", "clinical.values.149_data2" 63 | ) 64 | 65 | COHORT <- fread(data.filename.big) 66 | 67 | bigdata.prefixes <- 68 | c( 69 | 'hes.icd.', 70 | 'hes.opcs.', 71 | 'tests.enttype.', 72 | 'clinical.history.', 73 | 'clinical.values.', 74 | 'bnf.' 75 | ) 76 | 77 | bigdata.columns <- 78 | colnames(COHORT)[ 79 | which( 80 | # Does is start with one of the data column names? 81 | startsWithAny(names(COHORT), bigdata.prefixes) & 82 | # And it's not one of the columns we want to exclude? 83 | !(colnames(COHORT) %in% exclude.vars) 84 | ) 85 | ] 86 | 87 | COHORT.bigdata <- 88 | COHORT[, c( 89 | untransformed.vars, surv.predict.old, bigdata.columns 90 | ), 91 | with = FALSE 92 | ] 93 | 94 | # Get the missingness before we start removing missing values 95 | missingness <- sort(sapply(COHORT.bigdata, percentMissing)) 96 | # Remove values for the 'untransformed.vars' above, which are the survival 97 | # values plus exclude column 98 | missingness <- missingness[!(names(missingness) %in% untransformed.vars)] 99 | 100 | # Deal appropriately with missing data 101 | # Most of the variables are number of days since the first record of that type 102 | time.based.vars <- 103 | names(COHORT.bigdata)[ 104 | startsWithAny( 105 | names(COHORT.bigdata), 106 | c('hes.icd.', 'hes.opcs.', 'clinical.history.') 107 | ) 108 | ] 109 | # We're dealing with this as a logical, so we want non-NA values to be TRUE, 110 | # is there is something in the history 111 | for (j in time.based.vars) { 112 | set(COHORT.bigdata, j = j, value = !is.na(COHORT.bigdata[[j]])) 113 | } 114 | 115 | # Again, taking this as a logical, set any non-NA value to TRUE. 116 | prescriptions.vars <- names(COHORT.bigdata)[startsWith(names(COHORT.bigdata), 'bnf.')] 117 | for (j in prescriptions.vars) { 118 | set(COHORT.bigdata, j = j, value = !is.na(COHORT.bigdata[[j]])) 119 | } 120 | 121 | # This leaves tests and clinical.values, which are test results and should be 122 | # imputed. 123 | 124 | # Manually fix clinical values items... 125 | # 126 | # "clinical.values.1_data1" "clinical.values.1_data2" 127 | # These are just blood pressure values...fine to impute 128 | # 129 | # "clinical.values.13_data1" "clinical.values.13_data3" 130 | # These are weight and BMI...also fine to impute 131 | # 132 | # Entity 5 is alcohol consumption status, 1 = Yes, 2 = No, 3 = Ex, so should be 133 | # a factor, and NA can be a factor level 134 | COHORT.bigdata$clinical.values.5_data1 <- 135 | factorNAfix(factor(COHORT.bigdata$clinical.values.5_data1), NAval = 'missing') 136 | 137 | # Both gender and smokstatus are factors...fix that 138 | COHORT.bigdata$gender <- factor(COHORT.bigdata$gender) 139 | COHORT.bigdata$smokstatus <- 140 | factorNAfix(factor(COHORT.bigdata$smokstatus), NAval = 'missing') 141 | 142 | # Exclude invalid patients 143 | COHORT.bigdata <- COHORT.bigdata[!COHORT.bigdata$exclude] 144 | COHORT.bigdata$exclude <- NULL 145 | 146 | COHORT.bigdata <- 147 | prepSurvCol(data.frame(COHORT.bigdata), 'time_death', 'endpoint_death', 'Death') 148 | 149 | # If n.data was specified, trim the data table down to size 150 | if(!is.na(n.data)) { 151 | COHORT.bigdata <- sample.df(COHORT.bigdata, n.data) 152 | } 153 | 154 | # Define test set 155 | test.set <- testSetIndices(COHORT.bigdata, random.seed = 78361) 156 | 157 | # Start by predicting survival with all the variables provided 158 | surv.predict <- c(surv.predict.old, bigdata.columns) 159 | 160 | # Set up a csv file to store calibration data, or retrieve previous data 161 | calibration.filename <- paste0(output.filename.base, '-varselcalibration.csv') 162 | 163 | #' ## Run random forest calibration 164 | #' 165 | #' If there's not already a calibration file, we run the rfVarSel methodology: 166 | #' 1. Fit a big forest to the whole dataset to obtain variable importances. 167 | #' 2. Cross-validate as number of most important variables kept is reduced. 168 | #' 169 | #' (If there is already a calibration file, just load the previous work.) 170 | #' 171 | #+ rf_var_sel_calibration 172 | 173 | # If we've not already done a calibration, then do one 174 | if(!file.exists(calibration.filename)) { 175 | # Create an empty data frame to aggregate stats per fold 176 | cv.performance <- data.frame() 177 | 178 | # Cross-validate over number of variables to try 179 | cv.vars <- getVarNums(length(missingness)) 180 | 181 | COHORT.cv <- COHORT.bigdata[-test.set, ] 182 | 183 | # Run crossvalidations. No need to parallelise because rfsrc is parallelised 184 | for(i in 1:length(cv.vars)) { 185 | # Get the subset of most important variables to use 186 | surv.predict.partial <- names(missingness)[1:cv.vars[i]] 187 | 188 | # Get folds for cross-validation 189 | cv.folds <- cvFolds(nrow(COHORT.cv), cv.n.folds) 190 | 191 | cv.fold.performance <- data.frame() 192 | 193 | for(j in 1:cv.n.folds) { 194 | time.start <- handyTimer() 195 | # Fit model to the training set 196 | surv.model.fit <- 197 | survivalFit( 198 | surv.predict.partial, 199 | COHORT.cv[-cv.folds[[j]],], 200 | model.type = 'rfsrc', 201 | n.trees = n.trees.cv, 202 | split.rule = split.rule, 203 | n.threads = n.threads, 204 | nsplit = nsplit, 205 | nimpute = n.imputations, 206 | na.action = 'na.impute' 207 | ) 208 | time.learn <- handyTimer(time.start) 209 | 210 | time.start <- handyTimer() 211 | # Get C-index on validation set 212 | c.index.val <- 213 | cIndex( 214 | surv.model.fit, COHORT.cv[cv.folds[[j]],], 215 | na.action = 'na.impute' 216 | ) 217 | time.c.index <- handyTimer(time.start) 218 | 219 | time.start <- handyTimer() 220 | # Get calibration score validation set 221 | calibration.score <- 222 | calibrationScore( 223 | calibrationTable( 224 | surv.model.fit, COHORT.cv[cv.folds[[j]],], na.action = 'na.impute' 225 | ) 226 | ) 227 | time.calibration <- handyTimer(time.start) 228 | 229 | # Append the stats we've obtained from this fold 230 | cv.fold.performance <- 231 | rbind( 232 | cv.fold.performance, 233 | data.frame( 234 | calibration = i, 235 | cv.fold = j, 236 | n.vars = cv.vars[i], 237 | c.index.val, 238 | calibration.score, 239 | time.learn, 240 | time.c.index, 241 | time.calibration 242 | ) 243 | ) 244 | 245 | } # End cross-validation loop (j) 246 | 247 | 248 | # rbind the performance by fold 249 | cv.performance <- 250 | rbind( 251 | cv.performance, 252 | cv.fold.performance 253 | ) 254 | 255 | # Save output at the end of each loop 256 | write.csv(cv.performance, calibration.filename) 257 | 258 | } # End calibration loop (i) 259 | } else { 260 | cv.performance <- read.csv(calibration.filename) 261 | } 262 | 263 | #' ## Find the best model from the calibrations 264 | #' 265 | #' ### Plot model performance 266 | #' 267 | #+ model_performance 268 | 269 | # Find the best calibration... 270 | # First, average performance across cross-validation folds 271 | cv.performance.average <- 272 | aggregate( 273 | c.index.val ~ n.vars, 274 | data = cv.performance, 275 | mean 276 | ) 277 | 278 | cv.calibration.average <- 279 | aggregate( 280 | area ~ n.vars, 281 | data = cv.performance, 282 | mean 283 | ) 284 | 285 | ggplot(cv.performance.average, aes(x = n.vars, y = c.index.val)) + 286 | geom_line() + 287 | geom_point(data = cv.performance) + 288 | ggtitle(label = 'C-index by n.vars') 289 | 290 | ggplot(cv.calibration.average, aes(x = n.vars, y = area)) + 291 | geom_line() + 292 | geom_point(data = cv.performance) + 293 | ggtitle(label = 'Calibration performance by n.vars') 294 | 295 | # Find the highest value 296 | n.vars <- 297 | cv.performance.average$n.vars[ 298 | which.max(cv.performance.average$c.index.val) 299 | ] 300 | 301 | # Fit a full model with the variables provided 302 | surv.predict.partial <- names(missingness)[1:n.vars] 303 | 304 | #' ## Best model 305 | #' 306 | #' The best model contained `r n.vars` variables. Let's see what those were... 307 | #' 308 | #+ variables_used 309 | 310 | vars.df <- 311 | data.frame( 312 | vars = surv.predict.partial 313 | ) 314 | 315 | vars.df$descriptions <- lookUpDescriptions(surv.predict.partial) 316 | 317 | vars.df$missingness <- missingness[1:n.vars] 318 | 319 | #+ variables_table, results='asis' 320 | 321 | print( 322 | xtable(vars.df), 323 | type = 'html', 324 | include.rownames = FALSE 325 | ) 326 | 327 | #' ## Perform the final fit 328 | #' 329 | #' Having found the best number of variables by cross-validation, let's perform 330 | #' the final fit with the full training set and `r n.trees.final` trees. 331 | #' 332 | #+ final_fit 333 | 334 | time.start <- handyTimer() 335 | surv.model.fit.final <- 336 | survivalFit( 337 | surv.predict.partial, 338 | COHORT.bigdata[-test.set,], 339 | model.type = 'rfsrc', 340 | n.trees = n.trees.final, 341 | split.rule = split.rule, 342 | n.threads = n.threads, 343 | nimpute = 3, 344 | nsplit = nsplit, 345 | na.action = 'na.impute' 346 | ) 347 | time.fit.final <- handyTimer(time.start) 348 | 349 | saveRDS(surv.model.fit.final, paste0(output.filename.base, '-finalmodel.rds')) 350 | 351 | #' Final model of `r n.trees.final` trees fitted in `r round(time.fit.final)` 352 | #' seconds! 353 | #' 354 | #' Also bootstrap this final fitting stage. A fully proper bootstrap would 355 | #' iterate over the whole model-building process including variable selection, 356 | #' but that would be prohibitive in terms of computational time. 357 | #' 358 | #+ bootstrap_final 359 | 360 | surv.model.fit.boot <- 361 | survivalBootstrap( 362 | surv.predict.partial, 363 | COHORT.bigdata[-test.set,], # Training set 364 | COHORT.bigdata[test.set,], # Test set 365 | model.type = 'rfsrc', 366 | n.trees = n.trees.final, 367 | split.rule = split.rule, 368 | n.threads = n.threads, 369 | nimpute = 3, 370 | nsplit = nsplit, 371 | na.action = 'na.impute', 372 | bootstraps = bootstraps 373 | ) 374 | 375 | # Get coefficients and variable importances from bootstrap fits 376 | surv.model.fit.coeffs <- bootStats(surv.model.fit.boot, uncertainty = '95ci') 377 | 378 | #' ## Performance 379 | #' 380 | #' ### C-index 381 | #' 382 | #' C-indices are **`r round(surv.model.fit.coeffs['c.train', 'val'], 3)` 383 | #' (`r round(surv.model.fit.coeffs['c.train', 'lower'], 3)` - 384 | #' `r round(surv.model.fit.coeffs['c.train', 'upper'], 3)`)** 385 | #' on the training set and 386 | #' **`r round(surv.model.fit.coeffs['c.test', 'val'], 3)` 387 | #' (`r round(surv.model.fit.coeffs['c.test', 'lower'], 3)` - 388 | #' `r round(surv.model.fit.coeffs['c.test', 'upper'], 3)`)** on the test set. 389 | #' 390 | #' 391 | #' ### Calibration 392 | #' 393 | #' The bootstrapped calibration score is 394 | #' **`r round(surv.model.fit.coeffs['calibration.score', 'val'], 3)` 395 | #' (`r round(surv.model.fit.coeffs['calibration.score', 'lower'], 3)` - 396 | #' `r round(surv.model.fit.coeffs['calibration.score', 'upper'], 3)`)**. 397 | #' 398 | #' Let's draw a representative curve from the unbootstrapped fit... (It would be 399 | #' better to draw all the curves from the bootstrap fit to get an idea of 400 | #' variability, but I've not implemented this yet.) 401 | #' 402 | #+ calibration_plot 403 | 404 | calibration.table <- 405 | calibrationTable( 406 | # Standard calibration options 407 | surv.model.fit.final, COHORT.bigdata[test.set,], 408 | # Always need to specify NA imputation for rfsrc 409 | na.action = 'na.impute' 410 | ) 411 | 412 | calibration.score <- calibrationScore(calibration.table) 413 | 414 | calibrationPlot(calibration.table) 415 | 416 | #' The area between the calibration curve and the diagonal is 417 | #' **`r round(calibration.score[['area']], 3)`** +/- 418 | #' **`r round(calibration.score[['se']], 3)`**. 419 | #' 420 | #+ save_results 421 | 422 | # Save performance results 423 | varsToTable( 424 | data.frame( 425 | model = 'rf-varselmiss', 426 | imputation = FALSE, 427 | discretised = FALSE, 428 | c.index = surv.model.fit.coeffs['c.train', 'val'], 429 | c.index.lower = surv.model.fit.coeffs['c.train', 'lower'], 430 | c.index.upper = surv.model.fit.coeffs['c.train', 'upper'], 431 | calibration.score = surv.model.fit.coeffs['calibration.score', 'val'], 432 | calibration.score.lower = 433 | surv.model.fit.coeffs['calibration.score', 'lower'], 434 | calibration.score.upper = 435 | surv.model.fit.coeffs['calibration.score', 'upper'] 436 | ), 437 | performance.file, 438 | index.cols = c('model', 'imputation', 'discretised') 439 | ) -------------------------------------------------------------------------------- /random-forest/rf-varsellogrank.R: -------------------------------------------------------------------------------- 1 | #+ knitr_setup, include = FALSE 2 | 3 | # Whether to cache the intensive code sections. Set to FALSE to recalculate 4 | # everything afresh. 5 | cacheoption <- TRUE 6 | # Disable lazy caching globally, because it fails for large objects, and all the 7 | # objects we wish to cache are large... 8 | opts_chunk$set(cache.lazy = FALSE) 9 | 10 | #' # Variable selection in data-driven health records 11 | #' 12 | #' Having extracted around 600 variables which occur most frequently in patient 13 | #' records, let's try to narrow these down using a methodology based on varSelRf 14 | #' combined with survival modelling. We'll find the predictability of variables 15 | #' as defined by the p-value of a logrank test on survival curves of different 16 | #' categories within that variable, and then iteratively throw out unimportant 17 | #' variables, cross-validating for optimum performance. 18 | #' 19 | #' ## User variables 20 | #' 21 | #+ user_variables 22 | 23 | output.filename.base <- '../../output/rf-bigdata-varsellogrank-02' 24 | 25 | nsplit <- 20 26 | n.trees.cv <- 500 27 | n.trees.final <- 500 28 | split.rule <- 'logrank' 29 | n.imputations <- 3 30 | cv.n.folds <- 3 31 | vars.drop.frac <- 0.2 # Fraction of variables to drop at each iteration 32 | bootstraps <- 20 33 | 34 | n.data <- NA # This is after any variables being excluded in prep 35 | 36 | n.threads <- 40 37 | 38 | #' ## Data set-up 39 | #' 40 | #+ data_setup 41 | 42 | data.filename.big <- '../../data/cohort-datadriven-02.csv' 43 | 44 | surv.predict.old <- c('age', 'smokstatus', 'imd_score', 'gender') 45 | untransformed.vars <- c('time_death', 'endpoint_death', 'exclude') 46 | 47 | source('../lib/shared.R') 48 | require(xtable) 49 | 50 | # Define these after shared.R or they will be overwritten! 51 | exclude.vars <- 52 | c( 53 | # Entity type 4 is smoking status, which we already have 54 | "clinical.values.4_data1", "clinical.values.4_data5", 55 | "clinical.values.4_data6", 56 | # Entity 13 data2 is the patient's weight centile, and not a single one is 57 | # entered, but they come out as 0 so the algorithm, looking for NAs, thinks 58 | # it's a useful column 59 | "clinical.values.13_data2", 60 | # Entities 148 and 149 are to do with death certification. I'm not sure how 61 | # it made it into the dataset, but since all the datapoints in this are 62 | # looking back in time, they're all NA. This causes rfsrc to fail. 63 | "clinical.values.148_data1", "clinical.values.148_data2", 64 | "clinical.values.148_data3", "clinical.values.148_data4", 65 | "clinical.values.148_data5", 66 | "clinical.values.149_data1", "clinical.values.149_data2" 67 | ) 68 | 69 | COHORT <- fread(data.filename.big) 70 | 71 | bigdata.prefixes <- 72 | c( 73 | 'hes.icd.', 74 | 'hes.opcs.', 75 | 'tests.enttype.', 76 | 'clinical.history.', 77 | 'clinical.values.', 78 | 'bnf.' 79 | ) 80 | 81 | bigdata.columns <- 82 | colnames(COHORT)[ 83 | which( 84 | # Does is start with one of the data column names? 85 | startsWithAny(names(COHORT), bigdata.prefixes) & 86 | # And it's not one of the columns we want to exclude? 87 | !(colnames(COHORT) %in% exclude.vars) 88 | ) 89 | ] 90 | 91 | COHORT.bigdata <- 92 | COHORT[, c( 93 | untransformed.vars, surv.predict.old, bigdata.columns 94 | ), 95 | with = FALSE 96 | ] 97 | 98 | # Get the missingness before we start removing missing values 99 | missingness <- sort(sapply(COHORT.bigdata, percentMissing)) 100 | # Remove values for the 'untransformed.vars' above, which are the survival 101 | # values plus exclude column 102 | missingness <- missingness[!(names(missingness) %in% untransformed.vars)] 103 | 104 | # Deal appropriately with missing data 105 | # Most of the variables are number of days since the first record of that type 106 | time.based.vars <- 107 | names(COHORT.bigdata)[ 108 | startsWithAny( 109 | names(COHORT.bigdata), 110 | c('hes.icd.', 'hes.opcs.', 'clinical.history.') 111 | ) 112 | ] 113 | # We're dealing with this as a logical, so we want non-NA values to be TRUE, 114 | # is there is something in the history 115 | for (j in time.based.vars) { 116 | set(COHORT.bigdata, j = j, value = !is.na(COHORT.bigdata[[j]])) 117 | } 118 | 119 | # Again, taking this as a logical, set any non-NA value to TRUE. 120 | prescriptions.vars <- names(COHORT.bigdata)[startsWith(names(COHORT.bigdata), 'bnf.')] 121 | for (j in prescriptions.vars) { 122 | set(COHORT.bigdata, j = j, value = !is.na(COHORT.bigdata[[j]])) 123 | } 124 | 125 | # This leaves tests and clinical.values, which are test results and should be 126 | # imputed. 127 | 128 | # Manually fix clinical values items... 129 | # 130 | # "clinical.values.1_data1" "clinical.values.1_data2" 131 | # These are just blood pressure values...fine to impute 132 | # 133 | # "clinical.values.13_data1" "clinical.values.13_data3" 134 | # These are weight and BMI...also fine to impute 135 | # 136 | # Entity 5 is alcohol consumption status, 1 = Yes, 2 = No, 3 = Ex, so should be 137 | # a factor, and NA can be a factor level 138 | COHORT.bigdata$clinical.values.5_data1 <- 139 | factorNAfix(factor(COHORT.bigdata$clinical.values.5_data1), NAval = 'missing') 140 | 141 | # Both gender and smokstatus are factors...fix that 142 | COHORT.bigdata$gender <- factor(COHORT.bigdata$gender) 143 | COHORT.bigdata$smokstatus <- 144 | factorNAfix(factor(COHORT.bigdata$smokstatus), NAval = 'missing') 145 | 146 | # Exclude invalid patients 147 | COHORT.bigdata <- COHORT.bigdata[!COHORT.bigdata$exclude] 148 | COHORT.bigdata$exclude <- NULL 149 | 150 | COHORT.bigdata <- 151 | prepSurvCol(data.frame(COHORT.bigdata), 'time_death', 'endpoint_death', 'Death') 152 | 153 | # If n.data was specified, trim the data table down to size 154 | if(!is.na(n.data)) { 155 | COHORT.bigdata <- sample.df(COHORT.bigdata, n.data) 156 | } 157 | 158 | # Define test set 159 | test.set <- testSetIndices(COHORT.bigdata, random.seed = 78361) 160 | 161 | # Start by predicting survival with all the variables provided 162 | surv.predict <- c(surv.predict.old, bigdata.columns) 163 | 164 | # Set up a csv file to store calibration data, or retrieve previous data 165 | calibration.filename <- paste0(output.filename.base, '-varselcalibration.csv') 166 | 167 | varLogrankTest <- function(df, var) { 168 | # If there's only one category, this is a single-valued variable so you can't 169 | # do a logrank test on different values of it... 170 | if(length(unique(NArm(df[, var]))) == 1) { 171 | return(NA) 172 | } 173 | 174 | # If it's a logical, make an extra column for consistency of later code 175 | if(class(df[, var]) == 'logical') { 176 | df$groups <- factor(ifelse(df[, var], 'A', 'B')) 177 | # If it's numeric, split it into four quartiles 178 | } else if(class(df[, var]) == 'numeric') { 179 | # First, discard all rows where the value is missing 180 | df <- df[!is.na(df[, var]), ] 181 | # Then, assign quartiles 182 | df$groups <- 183 | factor( 184 | findInterval( 185 | df[, var], 186 | quantile(df[, var], probs=c(0, 0.25, .5, .75, 1)) 187 | ) 188 | ) 189 | 190 | } else { 191 | # Otherwise, it's a factor, so leave it as-is 192 | df$groups <- df[, var] 193 | } 194 | 195 | # Perform a logrank test on the data 196 | lr.test <- 197 | survdiff( 198 | as.formula(paste0('Surv(surv_time, surv_event) ~ groups')), 199 | df 200 | ) 201 | # Return the p-value of the logrank test 202 | pchisq(lr.test$chisq, length(lr.test$n)-1, lower.tail = FALSE) 203 | } 204 | 205 | # Don't use the output variables in our list 206 | vars.to.check <- 207 | names(COHORT.bigdata)[!(names(COHORT.bigdata) %in% c('surv_time', 'surv_event'))] 208 | 209 | var.logrank.p <- 210 | sapply( 211 | X = vars.to.check, FUN = varLogrankTest, 212 | df = COHORT.bigdata[-test.set, ] 213 | ) 214 | 215 | # Sort them, in ascending order because small p-values indicate differing 216 | # survival curves 217 | var.logrank.p <- sort(var.logrank.p) 218 | 219 | #' ## Run random forest calibration 220 | #' 221 | #' If there's not already a calibration file, we run the rfVarSel methodology: 222 | #' 1. Fit a big forest to the whole dataset to obtain variable importances. 223 | #' 2. Cross-validate as number of most important variables kept is reduced. 224 | #' 225 | #' (If there is already a calibration file, just load the previous work.) 226 | #' 227 | #+ rf_var_sel_calibration 228 | 229 | # If we've not already done a calibration, then do one 230 | if(!file.exists(calibration.filename)) { 231 | # Create an empty data frame to aggregate stats per fold 232 | cv.performance <- data.frame() 233 | 234 | # Cross-validate over number of variables to try 235 | cv.vars <- 236 | getVarNums( 237 | length(var.logrank.p), 238 | # no point going lower than the point at which all the p-values are 0, 239 | # because the order is alphabetical and therefore meaningless below this! 240 | min = sum(var.logrank.p == 0) 241 | ) 242 | 243 | COHORT.cv <- COHORT.bigdata[-test.set, ] 244 | 245 | # Run crossvalidations. No need to parallelise because rfsrc is parallelised 246 | for(i in 1:length(cv.vars)) { 247 | # Get the subset of most important variables to use 248 | surv.predict.partial <- names(var.logrank.p)[1:cv.vars[i]] 249 | 250 | # Get folds for cross-validation 251 | cv.folds <- cvFolds(nrow(COHORT.cv), cv.n.folds) 252 | 253 | cv.fold.performance <- data.frame() 254 | 255 | for(j in 1:cv.n.folds) { 256 | time.start <- handyTimer() 257 | # Fit model to the training set 258 | surv.model.fit <- 259 | survivalFit( 260 | surv.predict.partial, 261 | COHORT.cv[-cv.folds[[j]],], 262 | model.type = 'rfsrc', 263 | n.trees = n.trees.cv, 264 | split.rule = split.rule, 265 | n.threads = n.threads, 266 | nsplit = nsplit, 267 | nimpute = n.imputations, 268 | na.action = 'na.impute' 269 | ) 270 | time.learn <- handyTimer(time.start) 271 | 272 | time.start <- handyTimer() 273 | # Get C-index on validation set 274 | c.index.val <- 275 | cIndex( 276 | surv.model.fit, COHORT.cv[cv.folds[[j]],], 277 | na.action = 'na.impute' 278 | ) 279 | time.c.index <- handyTimer(time.start) 280 | 281 | time.start <- handyTimer() 282 | # Get calibration score validation set 283 | calibration.score <- 284 | calibrationScore( 285 | calibrationTable( 286 | surv.model.fit, COHORT.cv[cv.folds[[j]],], na.action = 'na.impute' 287 | ) 288 | ) 289 | time.calibration <- handyTimer(time.start) 290 | 291 | # Append the stats we've obtained from this fold 292 | cv.fold.performance <- 293 | rbind( 294 | cv.fold.performance, 295 | data.frame( 296 | calibration = i, 297 | cv.fold = j, 298 | n.vars = cv.vars[i], 299 | c.index.val, 300 | calibration.score, 301 | time.learn, 302 | time.c.index, 303 | time.calibration 304 | ) 305 | ) 306 | 307 | } # End cross-validation loop (j) 308 | 309 | 310 | # rbind the performance by fold 311 | cv.performance <- 312 | rbind( 313 | cv.performance, 314 | cv.fold.performance 315 | ) 316 | 317 | # Save output at the end of each loop 318 | write.csv(cv.performance, calibration.filename) 319 | 320 | } # End calibration loop (i) 321 | } else { 322 | cv.performance <- read.csv(calibration.filename) 323 | } 324 | 325 | #' ## Find the best model from the calibrations 326 | #' 327 | #' ### Plot model performance 328 | #' 329 | #+ model_performance 330 | 331 | # Find the best calibration... 332 | # First, average performance across cross-validation folds 333 | cv.performance.average <- 334 | aggregate( 335 | c.index.val ~ n.vars, 336 | data = cv.performance, 337 | mean 338 | ) 339 | 340 | cv.calibration.average <- 341 | aggregate( 342 | area ~ n.vars, 343 | data = cv.performance, 344 | mean 345 | ) 346 | 347 | ggplot(cv.performance.average, aes(x = n.vars, y = c.index.val)) + 348 | geom_line() + 349 | geom_point(data = cv.performance) + 350 | ggtitle(label = 'C-index by n.vars') 351 | 352 | ggplot(cv.calibration.average, aes(x = n.vars, y = area)) + 353 | geom_line() + 354 | geom_point(data = cv.performance) + 355 | ggtitle(label = 'Calibration performance by n.vars') 356 | 357 | # Find the highest value 358 | n.vars <- 359 | cv.performance.average$n.vars[ 360 | which.max(cv.performance.average$c.index.val) 361 | ] 362 | 363 | # Fit a full model with the variables provided 364 | surv.predict.partial <- names(var.logrank.p)[1:n.vars] 365 | 366 | #' ## Best model 367 | #' 368 | #' The best model contained `r n.vars` variables. Let's see what those were... 369 | #' 370 | #+ variables_used 371 | 372 | vars.df <- 373 | data.frame( 374 | vars = surv.predict.partial 375 | ) 376 | 377 | vars.df$descriptions <- lookUpDescriptions(surv.predict.partial) 378 | 379 | vars.df$missingness <- missingness[surv.predict.partial] 380 | 381 | #+ variables_table, results='asis' 382 | 383 | print( 384 | xtable(vars.df), 385 | type = 'html', 386 | include.rownames = FALSE 387 | ) 388 | 389 | #' ## Perform the final fit 390 | #' 391 | #' Having found the best number of variables by cross-validation, let's perform 392 | #' the final fit with the full training set and `r n.trees.final` trees. 393 | #' 394 | #+ final_fit 395 | 396 | time.start <- handyTimer() 397 | surv.model.fit.final <- 398 | survivalFit( 399 | surv.predict.partial, 400 | COHORT.bigdata[-test.set,], 401 | model.type = 'rfsrc', 402 | n.trees = n.trees.final, 403 | split.rule = split.rule, 404 | n.threads = n.threads, 405 | nimpute = 3, 406 | nsplit = nsplit, 407 | na.action = 'na.impute' 408 | ) 409 | time.fit.final <- handyTimer(time.start) 410 | 411 | saveRDS(surv.model.fit.final, paste0(output.filename.base, '-finalmodel.rds')) 412 | 413 | #' Final model of `r n.trees.final` trees fitted in `r round(time.fit.final)` 414 | #' seconds! 415 | #' 416 | #' Also bootstrap this final fitting stage. A fully proper bootstrap would 417 | #' iterate over the whole model-building process including variable selection, 418 | #' but that would be prohibitive in terms of computational time. 419 | #' 420 | #+ bootstrap_final 421 | 422 | time.start <- handyTimer() 423 | surv.model.params.boot <- 424 | survivalFitBoot( 425 | surv.predict.partial, 426 | COHORT.bigdata[-test.set,], # Training set 427 | COHORT.bigdata[test.set,], # Test set 428 | model.type = 'rfsrc', 429 | n.threads = n.threads, 430 | bootstraps = bootstraps, 431 | filename = paste0(output.filename.base, '-boot-all.csv'), 432 | na.action = 'na.impute' 433 | ) 434 | time.fit.boot <- handyTimer(time.start) 435 | 436 | #' `r bootstraps` bootstraps completed in `r round(time.fit.boot)` seconds! 437 | 438 | # Get coefficients and variable importances from bootstrap fits 439 | surv.model.fit.coeffs <- bootStatsDf(surv.model.params.boot) 440 | 441 | # Save performance results 442 | varsToTable( 443 | data.frame( 444 | model = 'rf-varsellr', 445 | imputation = FALSE, 446 | discretised = FALSE, 447 | c.index = surv.model.fit.coeffs['c.index', 'val'], 448 | c.index.lower = surv.model.fit.coeffs['c.index', 'lower'], 449 | c.index.upper = surv.model.fit.coeffs['c.index', 'upper'], 450 | calibration.score = surv.model.fit.coeffs['calibration.score', 'val'], 451 | calibration.score.lower = 452 | surv.model.fit.coeffs['calibration.score', 'lower'], 453 | calibration.score.upper = 454 | surv.model.fit.coeffs['calibration.score', 'upper'] 455 | ), 456 | performance.file, 457 | index.cols = c('model', 'imputation', 'discretised') 458 | ) 459 | 460 | write.csv( 461 | surv.model.fit.coeffs, 462 | paste0(output.filename.base, '-boot-summary.csv') 463 | ) 464 | 465 | #' ## Performance 466 | #' 467 | #' ### C-index 468 | #' 469 | #' C-indices are **`r round(surv.model.fit.coeffs['c.train', 'val'], 3)` 470 | #' (`r round(surv.model.fit.coeffs['c.train', 'lower'], 3)` - 471 | #' `r round(surv.model.fit.coeffs['c.train', 'upper'], 3)`)** 472 | #' on the training set and 473 | #' **`r round(surv.model.fit.coeffs['c.test', 'val'], 3)` 474 | #' (`r round(surv.model.fit.coeffs['c.test', 'lower'], 3)` - 475 | #' `r round(surv.model.fit.coeffs['c.test', 'upper'], 3)`)** on the test set. 476 | #' 477 | #' 478 | #' ### Calibration 479 | #' 480 | #' The bootstrapped calibration score is 481 | #' **`r round(surv.model.fit.coeffs['calibration.score', 'val'], 3)` 482 | #' (`r round(surv.model.fit.coeffs['calibration.score', 'lower'], 3)` - 483 | #' `r round(surv.model.fit.coeffs['calibration.score', 'upper'], 3)`)**. 484 | #' 485 | #' Let's draw a representative curve from the unbootstrapped fit... (It would be 486 | #' better to draw all the curves from the bootstrap fit to get an idea of 487 | #' variability, but I've not implemented this yet.) 488 | #' 489 | #+ calibration_plot 490 | 491 | calibration.table <- 492 | calibrationTable( 493 | # Standard calibration options 494 | surv.model.fit.final, COHORT.bigdata[test.set,], 495 | # Always need to specify NA imputation for rfsrc 496 | na.action = 'na.impute' 497 | ) 498 | 499 | calibration.score <- calibrationScore(calibration.table) 500 | 501 | calibrationPlot(calibration.table) 502 | 503 | #' The area between the calibration curve and the diagonal is 504 | #' **`r round(calibration.score, 3)`** -------------------------------------------------------------------------------- /cox-ph/cox-discrete-varsellogrank.R: -------------------------------------------------------------------------------- 1 | #+ knitr_setup, include = FALSE 2 | 3 | # Whether to cache the intensive code sections. Set to FALSE to recalculate 4 | # everything afresh. 5 | cacheoption <- TRUE 6 | # Disable lazy caching globally, because it fails for large objects, and all the 7 | # objects we wish to cache are large... 8 | opts_chunk$set(cache.lazy = FALSE) 9 | 10 | #' # Variable selection in data-driven health records with discretised 11 | #' # Cox models 12 | #' 13 | #' Having extracted around 600 variables which occur most frequently in patient 14 | #' records, let's try to narrow these down using a methodology based on varSelRf 15 | #' combined with survival modelling. We'll find the predictability of variables 16 | #' as defined by the p-value of a logrank test on survival curves of different 17 | #' categories within that variable, and then iteratively throw out unimportant 18 | #' variables, cross-validating for optimum performance. 19 | #' 20 | #' ## User variables 21 | #' 22 | #+ user_variables 23 | 24 | output.filename.base <- '../../output/cox-bigdata-varsellogrank-01' 25 | 26 | cv.n.folds <- 3 27 | vars.drop.frac <- 0.2 # Fraction of variables to drop at each iteration 28 | bootstraps <- 100 29 | 30 | n.data <- NA # This is after any variables being excluded in prep 31 | 32 | n.threads <- 20 33 | 34 | #' ## Data set-up 35 | #' 36 | #+ data_setup 37 | 38 | data.filename.big <- '../../data/cohort-datadriven-02.csv' 39 | 40 | surv.predict.old <- c('age', 'smokstatus', 'imd_score', 'gender') 41 | untransformed.vars <- c('time_death', 'endpoint_death', 'exclude') 42 | 43 | source('../lib/shared.R') 44 | require(xtable) 45 | 46 | # Define these after shared.R or they will be overwritten! 47 | exclude.vars <- 48 | c( 49 | # Entity type 4 is smoking status, which we already have 50 | "clinical.values.4_data1", "clinical.values.4_data5", 51 | "clinical.values.4_data6", 52 | # Entity 13 data2 is the patient's weight centile, and not a single one is 53 | # entered, but they come out as 0 so the algorithm, looking for NAs, thinks 54 | # it's a useful column 55 | "clinical.values.13_data2", 56 | # Entities 148 and 149 are to do with death certification. I'm not sure how 57 | # it made it into the dataset, but since all the datapoints in this are 58 | # looking back in time, they're all NA. This causes rfsrc to fail. 59 | "clinical.values.148_data1", "clinical.values.148_data2", 60 | "clinical.values.148_data3", "clinical.values.148_data4", 61 | "clinical.values.148_data5", 62 | "clinical.values.149_data1", "clinical.values.149_data2" 63 | ) 64 | 65 | COHORT <- fread(data.filename.big) 66 | 67 | bigdata.prefixes <- 68 | c( 69 | 'hes.icd.', 70 | 'hes.opcs.', 71 | 'tests.enttype.', 72 | 'clinical.history.', 73 | 'clinical.values.', 74 | 'bnf.' 75 | ) 76 | 77 | bigdata.columns <- 78 | colnames(COHORT)[ 79 | which( 80 | # Does is start with one of the data column names? 81 | startsWithAny(names(COHORT), bigdata.prefixes) & 82 | # And it's not one of the columns we want to exclude? 83 | !(colnames(COHORT) %in% exclude.vars) 84 | ) 85 | ] 86 | 87 | COHORT.bigdata <- 88 | COHORT[, c( 89 | untransformed.vars, surv.predict.old, bigdata.columns 90 | ), 91 | with = FALSE 92 | ] 93 | 94 | # Get the missingness before we start removing missing values 95 | missingness <- sort(sapply(COHORT.bigdata, percentMissing)) 96 | # Remove values for the 'untransformed.vars' above, which are the survival 97 | # values plus exclude column 98 | missingness <- missingness[!(names(missingness) %in% untransformed.vars)] 99 | 100 | # Deal appropriately with missing data 101 | # Most of the variables are number of days since the first record of that type 102 | time.based.vars <- 103 | names(COHORT.bigdata)[ 104 | startsWithAny( 105 | names(COHORT.bigdata), 106 | c('hes.icd.', 'hes.opcs.', 'clinical.history.') 107 | ) 108 | ] 109 | # We're dealing with this as a logical, so we want non-NA values to be TRUE, 110 | # is there is something in the history 111 | for (j in time.based.vars) { 112 | set(COHORT.bigdata, j = j, value = !is.na(COHORT.bigdata[[j]])) 113 | } 114 | 115 | # Again, taking this as a logical, set any non-NA value to TRUE. 116 | prescriptions.vars <- names(COHORT.bigdata)[startsWith(names(COHORT.bigdata), 'bnf.')] 117 | for (j in prescriptions.vars) { 118 | set(COHORT.bigdata, j = j, value = !is.na(COHORT.bigdata[[j]])) 119 | } 120 | 121 | # This leaves tests and clinical.values, which are test results and should be 122 | # imputed. 123 | 124 | # Manually fix clinical values items... 125 | # 126 | # "clinical.values.1_data1" "clinical.values.1_data2" 127 | # These are just blood pressure values...fine to impute 128 | # 129 | # "clinical.values.13_data1" "clinical.values.13_data3" 130 | # These are weight and BMI...also fine to impute 131 | # 132 | # Entity 5 is alcohol consumption status, 1 = Yes, 2 = No, 3 = Ex, so should be 133 | # a factor, and NA can be a factor level 134 | COHORT.bigdata$clinical.values.5_data1 <- 135 | factorNAfix(factor(COHORT.bigdata$clinical.values.5_data1), NAval = 'missing') 136 | 137 | # Both gender and smokstatus are factors...fix that 138 | COHORT.bigdata$gender <- factor(COHORT.bigdata$gender) 139 | COHORT.bigdata$smokstatus <- 140 | factorNAfix(factor(COHORT.bigdata$smokstatus), NAval = 'missing') 141 | 142 | # Exclude invalid patients 143 | COHORT.bigdata <- COHORT.bigdata[!COHORT.bigdata$exclude] 144 | COHORT.bigdata$exclude <- NULL 145 | 146 | # Remove negative survival times 147 | COHORT.bigdata <- subset(COHORT.bigdata, time_death > 0) 148 | 149 | # Define test set 150 | test.set <- testSetIndices(COHORT.bigdata, random.seed = 78361) 151 | 152 | # If n.data was specified, trim the data table down to size 153 | if(!is.na(n.data)) { 154 | COHORT.bigdata <- sample.df(COHORT.bigdata, n.data) 155 | } 156 | 157 | # Create an appropraite survival column 158 | COHORT.bigdata <- 159 | prepSurvCol( 160 | data.frame(COHORT.bigdata), 'time_death', 'endpoint_death', 'Death' 161 | ) 162 | 163 | # Start by predicting survival with all the variables provided 164 | surv.predict <- c(surv.predict.old, bigdata.columns) 165 | 166 | # Set up a csv file to store calibration data, or retrieve previous data 167 | calibration.filename <- paste0(output.filename.base, '-varselcalibration.csv') 168 | 169 | varLogrankTest <- function(df, var) { 170 | # If there's only one category, this is a single-valued variable so you can't 171 | # do a logrank test on different values of it... 172 | if(length(unique(NArm(df[, var]))) == 1) { 173 | return(NA) 174 | } 175 | 176 | # If it's a logical, make an extra column for consistency of later code 177 | if(class(df[, var]) == 'logical') { 178 | df$groups <- factor(ifelse(df[, var], 'A', 'B')) 179 | # If it's numeric, split it into four quartiles 180 | } else if(class(df[, var]) == 'numeric') { 181 | # First, discard all rows where the value is missing 182 | df <- df[!is.na(df[, var]), ] 183 | # Then, assign quartiles 184 | df$groups <- 185 | factor( 186 | findInterval( 187 | df[, var], 188 | quantile(df[, var], probs=c(0, 0.25, .5, .75, 1)) 189 | ) 190 | ) 191 | 192 | } else { 193 | # Otherwise, it's a factor, so leave it as-is 194 | df$groups <- df[, var] 195 | } 196 | 197 | # Perform a logrank test on the data 198 | lr.test <- 199 | survdiff( 200 | as.formula(paste0('Surv(surv_time, surv_event) ~ groups')), 201 | df 202 | ) 203 | # Return the p-value of the logrank test 204 | pchisq(lr.test$chisq, length(lr.test$n)-1, lower.tail = FALSE) 205 | } 206 | 207 | # Don't use the output variables in our list 208 | vars.to.check <- 209 | names(COHORT.bigdata)[!(names(COHORT.bigdata) %in% c('surv_time', 'surv_event'))] 210 | 211 | var.logrank.p <- 212 | sapply( 213 | X = vars.to.check, FUN = varLogrankTest, 214 | df = COHORT.bigdata[-test.set, ] 215 | ) 216 | 217 | # Sort them, in ascending order because small p-values indicate differing 218 | # survival curves 219 | var.logrank.p <- sort(var.logrank.p, na.last = TRUE) 220 | 221 | # Create process settings 222 | 223 | # Variables to leave alone, including those whose logrank p-value is NA because 224 | # that means there is only one value in the column and so it can't be discretised 225 | # properly anyway 226 | vars.noprocess <- c('surv_time', 'surv_event', names(var.logrank.p)[is.na(var.logrank.p)]) 227 | process.settings <- 228 | list( 229 | var = vars.noprocess, 230 | method = rep(NA, length(vars.noprocess)), 231 | settings = rep(list(NA), length(vars.noprocess)) 232 | ) 233 | # Find continuous variables which will need discretising 234 | continuous.vars <- names(COHORT.bigdata)[sapply(COHORT.bigdata, class) %in% c('integer', 'numeric')] 235 | # Remove those variables already explicitly excluded, mainly for those whose 236 | # logrank score was NA 237 | continuous.vars <- continuous.vars[!(continuous.vars %in% process.settings$var)] 238 | process.settings$var <- c(process.settings$var, continuous.vars) 239 | process.settings$method <- 240 | c(process.settings$method, 241 | rep('binByQuantile', length(continuous.vars)) 242 | ) 243 | process.settings$settings <- 244 | c( 245 | process.settings$settings, 246 | rep( 247 | list( 248 | seq( 249 | # Quantiles are obviously between 0 and 1 250 | 0, 1, 251 | # All have the same number of bins 252 | length.out = 10 253 | ) 254 | ), 255 | length(continuous.vars) 256 | ) 257 | ) 258 | 259 | COHORT.prep <- 260 | prepData( 261 | # Data for cross-validation excludes test set 262 | COHORT.bigdata, 263 | names(COHORT.bigdata), 264 | process.settings, 265 | 'surv_time', 'surv_event', 266 | TRUE 267 | ) 268 | 269 | # Kludge...remove surv_time.1 and rename surv_event.1 270 | COHORT.prep$surv_time.1 <- NULL 271 | names(COHORT.prep)[names(COHORT.prep) == 'surv_event.1'] <- 'surv_event' 272 | 273 | #' ## Run variable selection 274 | #' 275 | #' If there's not already a calibration file, we run our variable selection 276 | #' algorithm: 277 | #' 1. Perform logrank tests on survival curves of subsets of the data to find 278 | #' those variables which seemingly have the largest effect on survival. 279 | #' 2. Cross-validate as number of most important variables kept is reduced. 280 | #' 281 | #' (If there is already a calibration file, just load the previous work.) 282 | #' 283 | #+ cox_var_sel_calibration 284 | 285 | # If we've not already done a calibration, then do one 286 | if(!file.exists(calibration.filename)) { 287 | # Create an empty data frame to aggregate stats per fold 288 | cv.performance <- data.frame() 289 | 290 | # Cross-validate over number of variables to try 291 | cv.vars <- 292 | getVarNums( 293 | length(var.logrank.p), 294 | # no point going lower than the point at which all the p-values are 0, 295 | # because the order is alphabetical and therefore meaningless below this! 296 | min = sum(var.logrank.p == 0, na.rm = TRUE) 297 | ) 298 | 299 | COHORT.cv <- COHORT.prep[-test.set, ] 300 | 301 | # Run crossvalidations. No need to parallelise because rfsrc is parallelised 302 | for(i in 1:length(cv.vars)) { 303 | # Get the subset of most important variables to use 304 | surv.predict.partial <- names(var.logrank.p)[1:cv.vars[i]] 305 | 306 | # Get folds for cross-validation 307 | cv.folds <- cvFolds(nrow(COHORT.cv), cv.n.folds) 308 | 309 | cv.fold.performance <- data.frame() 310 | 311 | for(j in 1:cv.n.folds) { 312 | time.start <- handyTimer() 313 | # Fit model to the training set 314 | surv.model.fit <- 315 | survivalFit( 316 | surv.predict.partial, 317 | COHORT.cv[-cv.folds[[j]],], 318 | model.type = 'survreg', 319 | n.threads = n.threads 320 | ) 321 | time.learn <- handyTimer(time.start) 322 | 323 | time.start <- handyTimer() 324 | # Get C-index on validation set 325 | c.index.val <- 326 | cIndex( 327 | surv.model.fit, COHORT.cv[cv.folds[[j]],] 328 | ) 329 | time.c.index <- handyTimer(time.start) 330 | 331 | time.start <- handyTimer() 332 | # Get calibration score validation set 333 | calibration.score <- 334 | calibrationScore( 335 | calibrationTable( 336 | surv.model.fit, COHORT.cv[cv.folds[[j]],] 337 | ) 338 | ) 339 | time.calibration <- handyTimer(time.start) 340 | 341 | # Append the stats we've obtained from this fold 342 | cv.fold.performance <- 343 | rbind( 344 | cv.fold.performance, 345 | data.frame( 346 | calibration = i, 347 | cv.fold = j, 348 | n.vars = cv.vars[i], 349 | c.index.val, 350 | calibration.score, 351 | time.learn, 352 | time.c.index, 353 | time.calibration 354 | ) 355 | ) 356 | 357 | } # End cross-validation loop (j) 358 | 359 | 360 | # rbind the performance by fold 361 | cv.performance <- 362 | rbind( 363 | cv.performance, 364 | cv.fold.performance 365 | ) 366 | 367 | # Save output at the end of each loop 368 | write.csv(cv.performance, calibration.filename) 369 | 370 | } # End calibration loop (i) 371 | } else { 372 | cv.performance <- read.csv(calibration.filename) 373 | } 374 | 375 | #' ## Find the best model from the calibrations 376 | #' 377 | #' ### Plot model performance 378 | #' 379 | #+ model_performance 380 | 381 | # Find the best calibration... 382 | # First, average performance across cross-validation folds 383 | cv.performance.average <- 384 | aggregate( 385 | c.index.val ~ n.vars, 386 | data = cv.performance, 387 | mean 388 | ) 389 | 390 | cv.calibration.average <- 391 | aggregate( 392 | area ~ n.vars, 393 | data = cv.performance, 394 | mean 395 | ) 396 | 397 | ggplot(cv.performance.average, aes(x = n.vars, y = c.index.val)) + 398 | geom_line() + 399 | geom_point(data = cv.performance) + 400 | ggtitle(label = 'C-index by n.vars') 401 | 402 | ggplot(cv.calibration.average, aes(x = n.vars, y = area)) + 403 | geom_line() + 404 | geom_point(data = cv.performance) + 405 | ggtitle(label = 'Calibration performance by n.vars') 406 | 407 | # Find the highest value 408 | n.vars <- 409 | cv.performance.average$n.vars[ 410 | which.max(cv.performance.average$c.index.val) 411 | ] 412 | 413 | # Fit a full model with the variables provided 414 | surv.predict.partial <- names(var.logrank.p)[1:n.vars] 415 | 416 | #' ## Best model 417 | #' 418 | #' The best model contained `r n.vars` variables. Let's see what those were... 419 | #' 420 | #+ variables_used 421 | 422 | vars.df <- 423 | data.frame( 424 | vars = surv.predict.partial 425 | ) 426 | 427 | vars.df$descriptions <- lookUpDescriptions(surv.predict.partial) 428 | 429 | vars.df$missingness <- missingness[surv.predict.partial] 430 | 431 | #+ variables_table, results='asis' 432 | 433 | print( 434 | xtable(vars.df), 435 | type = 'html', 436 | include.rownames = FALSE 437 | ) 438 | 439 | #' ## Perform the final fit 440 | #' 441 | #' Having found the best number of variables by cross-validation, let's perform 442 | #' the final fit with the full training set. 443 | #' 444 | #+ final_fit 445 | 446 | time.start <- handyTimer() 447 | surv.model.fit.final <- 448 | survivalFit( 449 | surv.predict.partial, 450 | COHORT.prep[-test.set,], 451 | model.type = 'survreg' 452 | ) 453 | time.fit.final <- handyTimer(time.start) 454 | 455 | saveRDS(surv.model.fit.final, paste0(output.filename.base, '-finalmodel.rds')) 456 | 457 | #' Final model of fitted in `r round(time.fit.final)` seconds! 458 | #' 459 | #' Also bootstrap this final fitting stage. A fully proper bootstrap would 460 | #' iterate over the whole model-building process including variable selection, 461 | #' but that would be prohibitive in terms of computational time. 462 | #' 463 | #+ bootstrap_final 464 | 465 | time.start <- handyTimer() 466 | surv.model.params.boot <- 467 | survivalFitBoot( 468 | surv.predict.partial, 469 | COHORT.prep[-test.set,], # Training set 470 | COHORT.prep[test.set,], # Test set 471 | model.type = 'survreg', 472 | bootstraps = bootstraps, 473 | n.threads = n.threads, 474 | filename = paste0(output.filename.base, '-boot-all.csv') 475 | ) 476 | time.boot.final <- handyTimer(time.start) 477 | 478 | #' `r bootstraps` bootstrap fits completed in `r time.boot.final` seconds! 479 | 480 | # Get coefficients and variable importances from bootstrap fits 481 | surv.model.fit.coeffs <- bootStatsDf(surv.model.params.boot) 482 | 483 | # Save performance results 484 | varsToTable( 485 | data.frame( 486 | model = 'cox-logrank', 487 | imputation = FALSE, 488 | discretised = TRUE, 489 | c.index = surv.model.fit.coeffs['c.index', 'val'], 490 | c.index.lower = surv.model.fit.coeffs['c.index', 'lower'], 491 | c.index.upper = surv.model.fit.coeffs['c.index', 'upper'], 492 | calibration.score = surv.model.fit.coeffs['calibration.score', 'val'], 493 | calibration.score.lower = 494 | surv.model.fit.coeffs['calibration.score', 'lower'], 495 | calibration.score.upper = 496 | surv.model.fit.coeffs['calibration.score', 'upper'] 497 | ), 498 | performance.file, 499 | index.cols = c('model', 'imputation', 'discretised') 500 | ) 501 | 502 | #' ## Performance 503 | #' 504 | #' ### C-index 505 | #' 506 | #' C-index is **`r round(surv.model.fit.coeffs['c.index', 'val'], 3)` 507 | #' (`r round(surv.model.fit.coeffs['c.index', 'lower'], 3)` - 508 | #' `r round(surv.model.fit.coeffs['c.index', 'upper'], 3)`)** 509 | #' on the held-out test set. 510 | #' 511 | #' 512 | #' ### Calibration 513 | #' 514 | #' The bootstrapped calibration score is 515 | #' **`r round(surv.model.fit.coeffs['calibration.score', 'val'], 3)` 516 | #' (`r round(surv.model.fit.coeffs['calibration.score', 'lower'], 3)` - 517 | #' `r round(surv.model.fit.coeffs['calibration.score', 'upper'], 3)`)**. 518 | #' 519 | #' Let's draw a representative curve from the unbootstrapped fit... (It would be 520 | #' better to draw all the curves from the bootstrap fit to get an idea of 521 | #' variability, but I've not implemented this yet.) 522 | #' 523 | #+ calibration_plot 524 | 525 | calibration.table <- 526 | calibrationTable(surv.model.fit.final, COHORT.prep[test.set,]) 527 | 528 | calibration.score <- calibrationScore(calibration.table) 529 | 530 | calibrationPlot(calibration.table, show.censored = TRUE) 531 | 532 | # Save the calibration table for plotting later 533 | write.csv( 534 | calibration.table, 535 | paste0(output.filename.base, '-calibration-table.csv') 536 | ) --------------------------------------------------------------------------------