├── .Rprofile ├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── analysis_ahc.R ├── analysis_data.R ├── analysis_featurebased.R ├── analysis_gbtm.R ├── analysis_gmm.R ├── analysis_kml.R ├── analysis_llpa.R ├── analysis_mixtvem.R ├── data.R ├── demo.Rproj ├── metrics.R └── plot.R /.Rprofile: -------------------------------------------------------------------------------- 1 | suppressPackageStartupMessages({ 2 | library(data.table) 3 | library(magrittr) 4 | library(assertthat) 5 | library(ggplot2) 6 | library(scales) 7 | library(matrixStats) 8 | require(IMIFA) 9 | }) 10 | 11 | source('data.R') 12 | source('plot.R') 13 | source('metrics.R') 14 | 15 | theme_minimal(base_size = 9) %+replace% 16 | theme( 17 | plot.background = element_rect(colour = NA), 18 | plot.margin = unit(c(2,1,1,1), 'mm'), 19 | panel.background = element_rect(colour = NA), 20 | panel.spacing = unit(1, 'mm'), 21 | strip.background = element_rect(colour=NA, fill=NA), 22 | strip.text = element_text(face='plain', size=7, margin=margin()), 23 | legend.text = element_text(size=7), 24 | legend.title = element_text(size=9), 25 | legend.position = 'right', 26 | legend.spacing = unit(1, 'cm'), 27 | legend.margin = margin(0,0,0,0), 28 | legend.key.size = unit(9, 'pt'), 29 | legend.box.margin = margin(0,0,-5,0), 30 | axis.line.x = element_line(color='black', size = .1), 31 | axis.line.y = element_line(color='black', size = .1) 32 | ) %>% 33 | theme_set() 34 | 35 | bicPlotSize = c(6, 5) 36 | groupPlotSize = c(9, 5) 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rhistory 2 | .Rapp.history 3 | *.RData 4 | *.csv 5 | .Rproj.user/ 6 | .httr-oauth 7 | *_cache/ 8 | /cache/ 9 | *.utf8.md 10 | *.knit.md 11 | .Renviron 12 | /save 13 | MixTVEM.r 14 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | preferred-citation: 4 | type: article 5 | title: "Clustering of longitudinal data: A tutorial on a variety of approaches" 6 | year: 2021 7 | authors: 8 | - family-names: "Den Teuling" 9 | given-names: "Niek G. P." 10 | orcid: "https://orcid.org/0000-0003-1026-5080" 11 | - family-names: "Pauws" 12 | given-names: "Steffen C." 13 | orcid: "https://orcid.org/0000-0003-2257-9239" 14 | - family-names: "van den Heuvel" 15 | given-names: "Edwin R." 16 | orcid: "https://orcid.org/0000-0001-9157-7224" 17 | start: 1 # First page number 18 | end: 37 # Last page number 19 | version: v1 20 | status: preprint 21 | url: "https://arxiv.org/abs/2111.05469" 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Philips Labs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains the R scripts used in the analysis of the case study of the manuscript. 2 | 3 | # Useful links 4 | * MixTVEM source code used in the MixTVEM demo analysis - https://github.com/dziakj1/MixTVEM 5 | * _lcmm_ R package, used for estimating GMM and GBTM - https://cran.r-project.org/package=lcmm 6 | * _kml_ R package, used for estimating KmL - https://cran.r-project.org/package=kml 7 | * _mclust_ R package, used for estimating LLPA - https://cran.r-project.org/package=mclust 8 | * _latrend_ R package: The longitudinal clustering framework that we have created, originating from the learnings of this work - https://github.com/philips-software/latrend 9 | 10 | # Getting started 11 | 1. Either load the Rstudio project file `demo.Rproj`, or start an R session with the working directory set to the root repository directory. 12 | 2. Install required packages and dependencies 13 | ``` 14 | install.packages(c("assertthat", "data.table", "ggdendro", "ggplot2", "IMIFA", "kml", "lcmm", "lpSolve", "magrittr", "MASS", "matrixStats", "mclust", "nlme", "scales"), dependencies = TRUE) 15 | ``` 16 | 3. In case you want to run the MixTVEM analysis, you'll need to fetch `MixTVEM.R` from https://github.com/dziakj1/MixTVEM 17 | 18 | You should now be able to run any of the analysis scripts. 19 | -------------------------------------------------------------------------------- /analysis_ahc.R: -------------------------------------------------------------------------------- 1 | library(cluster) 2 | library(ggdendro) 3 | library(cluster) 4 | data = generate_osa_data() 5 | dataMat = transformToRepeatedMeasures(data) 6 | 7 | computeMeanTrajectories = function(clusters) { 8 | assert_that(is.factor(clusters), length(clusters) == uniqueN(data$Id)) 9 | data[, .(Usage=mean(Usage)), by=.(Group=clusters[as.integer(Id)], Time)] 10 | } 11 | 12 | # Estimation #### 13 | start = Sys.time() 14 | D = dist(dataMat, method='euclidean') 15 | htree = hclust(D, method='average') 16 | 17 | Ks = 1:8 18 | ahcsLabels = lapply(Ks, function(k) cutree(htree, k=k)) %>% 19 | set_names(Ks) 20 | runTime = Sys.time() - start 21 | # Solutions #### 22 | dendro = dendro_data(htree, type='rectangle') 23 | ggplot(segment(dendro)) + 24 | geom_segment(aes(x=x, y=y, xend=xend, yend=yend), size=.001) + 25 | scale_x_continuous(breaks=pretty_breaks(4)) + 26 | coord_flip() + 27 | labs(x='Patient', y='Height') + 28 | theme(panel.grid.major.y=element_blank(), 29 | panel.grid.minor.y=element_blank()) 30 | # ggsave('save/ahc_tree.pdf', width=15, height=3, units='cm') 31 | 32 | # Compute silhouette widths 33 | sils = lapply(ahcsLabels[as.character(setdiff(Ks, 1))], silhouette, dist=D) %>% 34 | lapply(summary) %>% 35 | sapply('[[', 'avg.width') 36 | 37 | # Compute residual sum of squares 38 | rsss = sapply(ahcsLabels, function(clusters) { 39 | clusters = factor(clusters, labels=LETTERS[1:uniqueN(clusters)]) 40 | dtPatClus = data.table(Id=levels(data$Id), Group=clusters) 41 | dtTraj = computeMeanTrajectories(clusters) 42 | dtFitted = data[dtPatClus, on='Id'] %>% 43 | merge(dtTraj, by.x=c('i.Group', 'Time'), by.y=c('Group', 'Time')) 44 | dtFitted[, sum((Usage.x - Usage.y)^2)] 45 | }) 46 | 47 | plotMetric(c(NA, sils), Ks, 'Silhouette width') 48 | # ggsave('save/ahc_bic.pdf', width=bicPlotSize[1], height=bicPlotSize[2], units='cm') 49 | 50 | plotMetric(rsss, Ks, 'RSS') 51 | 52 | # Assess result #### 53 | k = 4 54 | clusters = ahcsLabels[[as.character(k)]] %>% 55 | factor(labels=LETTERS[1:k]) 56 | groupProps = table(clusters) %>% prop.table() %>% as.numeric() 57 | 58 | computeMeanTrajectories(clusters) %>% plotGroupTrajectories(groupProps) 59 | # ggsave('save/ahc_groups.pdf', width=groupPlotSize[1], height=groupPlotSize[2], units='cm') 60 | -------------------------------------------------------------------------------- /analysis_data.R: -------------------------------------------------------------------------------- 1 | data = generate_osa_data() 2 | plotTrajectories(data) 3 | # ggsave('save/trajectories.pdf', width=10.5, height=6, units='cm') 4 | 5 | plotGroupTrajectories(data) + 6 | guides(shape=FALSE) 7 | # ggsave('save/groups.pdf', width=6, height=6, units='cm') 8 | 9 | plotGroupTrajectories(data) 10 | 11 | ggplot(data, aes(x=Usage)) + 12 | geom_histogram(aes(y=..count../sum(..count..)), breaks=seq(0, 12, by=.5), color='black', fill='gray') + 13 | scale_x_continuous(breaks=seq(0, 12, by=2)) + 14 | scale_y_continuous(breaks=pretty_breaks(8), labels=percent) + 15 | labs(x='Hours of Use', y='Observations') 16 | 17 | 18 | # View individual trajectories 19 | plotTrajectory(data) 20 | -------------------------------------------------------------------------------- /analysis_featurebased.R: -------------------------------------------------------------------------------- 1 | library(ggdendro) 2 | library(MASS) 3 | library(cluster) 4 | data = generate_osa_data() %>% 5 | .[, NormTime := (Time - min(Time)) / (max(Time) - min(Time))] 6 | 7 | computeMeanTrajectories = function(clusters) { 8 | assert_that(is.factor(clusters), length(clusters) == uniqueN(data$Id)) 9 | data[, .(Usage=mean(Usage)), by=.(Group=clusters[as.integer(Id)], Time)] 10 | } 11 | 12 | # Compute patient representation #### 13 | { 14 | start = Sys.time() 15 | patientFeatures = data[, { 16 | mod = lm(Usage ~ poly(NormTime, 2), data=.SD) 17 | as.list(coef(mod)) %>% c(logN=log(sum(Usage > 0)), sigma=sigma(mod)) 18 | }, keyby=Id] 19 | 20 | X = as.matrix(patientFeatures[, -c('Id', 'sigma')]) %>% 21 | set_rownames(patientFeatures$Id) %>% 22 | scale() 23 | 24 | # Hierarchical clustering 25 | D = dist(X, method='euclidean') 26 | 27 | htree = hclust(D, method='average') 28 | 29 | Ks = 1:8 30 | ahcs = lapply(Ks, function(k) pam(D, k=k)) %>% 31 | set_names(Ks) 32 | sils = lapply(ahcs, function(ahc) ahc$silinfo$avg.width) %>% 33 | sapply(function(x) ifelse(is.null(x), NA, x)) 34 | runTime = Sys.time() - start 35 | } 36 | max(sils, na.rm=TRUE) 37 | which.max(sils) 38 | 39 | # Solutions #### 40 | dendro = dendro_data(htree, type='rectangle') 41 | 42 | ggplot(segment(dendro)) + 43 | geom_segment(aes(x=x, y=y, xend=xend, yend=yend)) + 44 | scale_x_continuous(breaks=pretty_breaks(10)) + 45 | coord_flip() + 46 | labs(x='Trajectories\n\n', y='Height') + 47 | theme(panel.grid.major=element_blank(), 48 | panel.grid.minor=element_blank()) 49 | 50 | plotMetric(sils, Ks, 'Silhouette width') 51 | # ggsave('save/featurebased_bic.pdf', width=bicPlotSize[1], height=bicPlotSize[2], units='cm') 52 | 53 | # Assess the best solution #### 54 | k = 7 55 | bestAhc = ahcs[[as.character(k)]] 56 | clusters = factor(bestAhc$clustering, labels=LETTERS[1:k]) 57 | groupProps = table(clusters) %>% prop.table() %>% as.numeric() 58 | 59 | computeMeanTrajectories(clusters) %>% plotGroupTrajectories(groupProps) 60 | # ggsave('save/featurebased_groups.pdf', width=groupPlotSize[1], height=groupPlotSize[2], units='cm') 61 | -------------------------------------------------------------------------------- /analysis_gbtm.R: -------------------------------------------------------------------------------- 1 | library(lcmm) 2 | library(splines) 3 | data = generate_osa_data() %>% 4 | .[, Id := as.integer(Id)] %>% 5 | .[, NormTime := (Time - min(Time)) / (max(Time) - min(Time))] 6 | 7 | makeGbtmCall = function(k) { 8 | substitute( 9 | hlme(fixed=Usage ~ NormTime + I(NormTime^2), 10 | mixture=~NormTime + I(NormTime^2), 11 | random=~-1, 12 | subject='Id', ng=k, data=data), 13 | env=list(k=k) 14 | ) 15 | } 16 | 17 | computeHlmeTrajectories = function(model) { 18 | times = sort(unique(data$Time)) 19 | normTimes = sort(unique(data$NormTime)) 20 | predictY(model, newdata=data.frame(NormTime=normTimes))$pred %>% 21 | data.table(Time=times) %>% 22 | melt(id.vars='Time', value.name='Usage', variable.name='Group') %>% 23 | .[, Group := factor(Group, levels=paste0('Ypred_class', 1:model$ng), labels=LETTERS[1:model$ng])] %>% 24 | .[] 25 | } 26 | 27 | # Single-group analysis #### 28 | mod00 = hlme(fixed=Usage ~ 1, random=~1, subject='Id', ng=1, data=data) 29 | summary(mod00) 30 | residuals(mod00) %T>% qqnorm %>% qqline 31 | 32 | mod11 = hlme(fixed=Usage ~ NormTime, random=~NormTime, subject='Id', ng=1, data=data) 33 | summary(mod11) 34 | residuals(mod11) %T>% qqnorm %>% qqline 35 | 36 | mod22 = hlme(fixed=Usage ~ poly(NormTime, 2, raw=TRUE), random=~poly(NormTime, 2, raw=TRUE), subject='Id', ng=1, data=data) 37 | summary(mod22) 38 | residuals(mod22) %T>% qqnorm %>% qqline 39 | 40 | mod33 = hlme(fixed=Usage ~ poly(NormTime, 3, raw=TRUE), random=~poly(NormTime, 3, raw=TRUE), subject='Id', ng=1, data=data) 41 | summary(mod33) 42 | residuals(mod33) %T>% qqnorm %>% qqline 43 | 44 | modbs = hlme(fixed=Usage ~ bs(NormTime), random=~bs(NormTime), subject='Id', ng=1, data=data) 45 | summary(modbs) 46 | residuals(modbs) %T>% qqnorm %>% qqline 47 | 48 | # Estimation #### 49 | gbtms = list() 50 | # gbtms = readRDS('save/gbtm.rds') 51 | gbtms[['1']] = hlme(fixed=Usage ~ NormTime + I(NormTime^2), random=~-1, subject='Id', ng=1, data=data) 52 | 53 | fitGbtm = function(k) { 54 | start = Sys.time() 55 | model = do.call(gridsearch, list(m=makeGbtmCall(k), rep=20, maxiter=1, minit=gbtms[['1']])) 56 | model$runTime = Sys.time() - start 57 | return(model) 58 | } 59 | 60 | gbtms[['2']] = fitGbtm(2) 61 | gbtms[['3']] = fitGbtm(3) 62 | gbtms[['4']] = fitGbtm(4) 63 | gbtms[['5']] = fitGbtm(5) 64 | gbtms[['6']] = fitGbtm(6) 65 | gbtms[['7']] = fitGbtm(7) 66 | gbtms[['8']] = fitGbtm(8) 67 | # saveRDS(gbtms, file='save/gbtm.rds') 68 | 69 | # Solutions #### 70 | plotMetric(sapply(gbtms, '[[', 'BIC'), as.integer(names(gbtms)), 'BIC') 71 | # ggsave('save/gbtm_bic.pdf', width=bicPlotSize[1], height=bicPlotSize[2], units='cm') 72 | 73 | # Assess the best solution #### 74 | k = 4 75 | bestGbtm = gbtms[[as.character(k)]] 76 | 77 | pp = bestGbtm$pprob[paste0('prob', 1:k)] %>% 78 | as.matrix() %>% 79 | set_colnames(LETTERS[1:k]) 80 | groupProps = colMeans(pp) %T>% print() 81 | 82 | computeHlmeTrajectories(bestGbtm) %>% plotGroupTrajectories(groupProps) 83 | # ggsave('save/gbtm_groups.pdf', width=groupPlotSize[1], height=groupPlotSize[2], units='cm') 84 | 85 | appa(pp) 86 | relativeEntropy(pp) 87 | -------------------------------------------------------------------------------- /analysis_gmm.R: -------------------------------------------------------------------------------- 1 | library(lcmm) 2 | data = generate_osa_data() %>% 3 | .[, Id := as.integer(Id)] %>% 4 | .[, NormTime := (Time - min(Time)) / (max(Time) - min(Time))] 5 | 6 | makeGmmCall = function(k) { 7 | substitute( 8 | hlme(fixed=Usage ~ NormTime + I(NormTime^2), 9 | mixture=~NormTime + I(NormTime^2), 10 | random=~NormTime, 11 | idiag=TRUE, 12 | nwg=TRUE, 13 | subject='Id', ng=k, data=data), 14 | env=list(k=k) 15 | ) 16 | } 17 | 18 | computeHlmeTrajectories = function(model) { 19 | times = sort(unique(data$Time)) 20 | normTimes = sort(unique(data$NormTime)) 21 | predictY(model, newdata=data.frame(NormTime=normTimes))$pred %>% 22 | data.table(Time=times) %>% 23 | melt(id.vars='Time', value.name='Usage', variable.name='Group') %>% 24 | .[, Group := factor(Group, levels=paste0('Ypred_class', 1:model$ng), labels=LETTERS[1:model$ng])] %>% 25 | .[] 26 | } 27 | 28 | # Estimation #### 29 | gmms = list() 30 | # gmms = readRDS('save/gmm.rds') 31 | gmms[['1']] = hlme(fixed=Usage ~ NormTime + I(NormTime^2), random=~NormTime, subject='Id', idiag=TRUE, ng=1, data=data) 32 | 33 | fitGmm = function(k) { 34 | set.seed(1) 35 | start = Sys.time() 36 | model = do.call(gridsearch, list(m=makeGmmCall(k), rep=20, maxiter=1, minit=gmms[['1']])) 37 | model$runTime = Sys.time() - start 38 | return(model) 39 | } 40 | 41 | gmms[['2']] = fitGmm(2) 42 | gmms[['3']] = fitGmm(3) 43 | gmms[['4']] = fitGmm(4) 44 | gmms[['5']] = fitGmm(5) 45 | gmms[['6']] = fitGmm(6) 46 | gmms[['7']] = fitGmm(7) 47 | gmms[['8']] = fitGmm(8) #nwg=TRUE takes 60s per iter as opposed to 7s 48 | # saveRDS(gmms, file='save/gmm.rds') 49 | 50 | # Solutions #### 51 | plotMetric(sapply(gmms, '[[', 'BIC'), as.integer(names(gmms)), 'BIC') 52 | # ggsave('save/gmm_bic.pdf', width=bicPlotSize[1], height=bicPlotSize[2], units='cm') 53 | 54 | # Assess the best solution #### 55 | k = 7 56 | bestGmm = gmms[[as.character(k)]] 57 | 58 | pp = bestGmm$pprob[paste0('prob', 1:k)] %>% 59 | as.matrix() %>% 60 | set_colnames(LETTERS[1:k]) 61 | groupProps = colMeans(pp) %T>% print() 62 | 63 | computeHlmeTrajectories(bestGmm) %>% plotGroupTrajectories(groupProps) 64 | # ggsave('save/gmm_groups.pdf', width=groupPlotSize[1], height=groupPlotSize[2], units='cm') 65 | 66 | appa(pp) 67 | relativeEntropy(pp) 68 | -------------------------------------------------------------------------------- /analysis_kml.R: -------------------------------------------------------------------------------- 1 | library(kml) 2 | data = generate_osa_data() 3 | 4 | dataMat = transformToRepeatedMeasures(data) 5 | cld = clusterLongData(dataMat, idAll=rownames(dataMat), time=sort(unique(data$Time)), varNames='Usage') 6 | par = parALGO(startingCond='kmeans++') 7 | 8 | computeKmlTrajectories = function(k) { 9 | calculTrajMeanC(cld@traj, clust=getClusters(cld, bestK, asInteger=TRUE)) %>% 10 | set_rownames(LETTERS[seq_len(nrow(.))]) %>% 11 | melt(varnames=c('Group', 'Time'), value.name='Usage') %>% 12 | as.data.table() %>% 13 | .[, Time := cld@time[Time]] 14 | } 15 | 16 | # Evaluate for all clusters #### 17 | # cld = readRDS('save/kml.rds') 18 | start = Sys.time(); kml(cld, nbClusters=1:8, nbRedrawing=20, parAlgo=par); runTime = Sys.time() - start 19 | # saveRDS(cld, file='save/kml.rds') 20 | 21 | # Select best solution #### 22 | evalGroups = mapply(slot, list(cld), paste0('c', 1:26)) %>% 23 | lengths() %>% 24 | {which(. > 0)} 25 | kmls = mapply(slot, list(cld), paste0('c', evalGroups), SIMPLIFY=FALSE) %>% 26 | lapply(first) %>% 27 | set_names(evalGroups) 28 | 29 | bics = -1 * sapply(kmls, function(part) part@criterionValues['BIC']) 30 | plotMetric(bics, evalGroups, 'BIC')+ expand_limits(y=6e4) 31 | # ggsave('save/kml_bic.pdf', width=bicPlotSize[1], height=bicPlotSize[2], units='cm') 32 | 33 | # Assess the best solution #### 34 | bestK = 7 35 | bestKml = kmls[[as.character(bestK)]] 36 | print(bestKml) 37 | 38 | pp = bestKml@postProba %>% 39 | set_colnames(levels(bestKml@clusters)) 40 | groupProps =colMeans(pp) %T>% print() 41 | 42 | computeKmlTrajectories(bestK) %>% plotGroupTrajectories(groupProps) 43 | # ggsave('save/kml_groups.pdf', width=groupPlotSize[1], height=groupPlotSize[2], units='cm') 44 | 45 | relativeEntropy(pp) 46 | confusionMatrix(pp) %>% round(4) 47 | -------------------------------------------------------------------------------- /analysis_llpa.R: -------------------------------------------------------------------------------- 1 | library(mclust) 2 | data = generate_osa_data() 3 | dataMat = transformToRepeatedMeasures(data) 4 | 5 | fitLlpa = function(k) { 6 | cat('Fitting for k =', k, '\n') 7 | start = Sys.time() 8 | bestModel <<- NULL 9 | for(n in 1:10) { 10 | cat(sprintf('fit %d ', n)) 11 | model <<- Mclust(dataMat, G=k, modelNames='VVI', 12 | prior=priorControl(functionName='defaultPrior'), 13 | verbose=FALSE) 14 | 15 | if(!is.null(model)) { 16 | if(is.null(bestModel) || model$bic > bestModel$bic) { 17 | bestModel <<- model 18 | } 19 | } else { 20 | cat('conv error') 21 | } 22 | cat('\n') 23 | } 24 | if(is.null(bestModel)) { 25 | bestModel = list(bic=NA) 26 | } 27 | bestModel$runTime = Sys.time() - start 28 | return(bestModel) 29 | } 30 | 31 | computeLlpaTrajectories = function(model) { 32 | times = sort(unique(data$Time)) 33 | model$parameters$mean %>% 34 | set_rownames(NULL) %>% 35 | set_colnames(LETTERS[seq_len(model$G)]) %>% 36 | melt(varnames=c('Time', 'Group'), value.name='Usage') %>% 37 | as.data.table() %>% 38 | .[, Time := times[Time]] %>% 39 | .[] 40 | } 41 | 42 | # Estimation #### 43 | llpas = list() 44 | # llpas = readRDS('save/llpa.rds') 45 | llpas[['1']] = fitLlpa(1) 46 | llpas[['2']] = fitLlpa(2) 47 | llpas[['3']] = fitLlpa(3) 48 | llpas[['4']] = fitLlpa(4) 49 | llpas[['5']] = fitLlpa(5) 50 | llpas[['6']] = fitLlpa(6) 51 | llpas[['7']] = fitLlpa(7) 52 | llpas[['8']] = fitLlpa(8) 53 | 54 | # saveRDS(llpas, file='save/llpa.rds') 55 | 56 | # Solutions #### 57 | plotMetric(-1 * sapply(llpas, '[[', 'bic'), as.integer(names(llpas)), 'BIC') 58 | # ggsave('save/llpa_bic.pdf', width=bicPlotSize[1], height=bicPlotSize[2], units='cm') 59 | 60 | # Assess the best solution #### 61 | k = 5 62 | bestLlpa = llpas[[as.character(k)]] 63 | 64 | pp = bestLlpa$z %>% 65 | set_colnames(LETTERS[1:k]) 66 | groupProps = colMeans(pp) %T>% print() 67 | 68 | computeLlpaTrajectories(bestLlpa) %>% plotGroupTrajectories(groupProps) 69 | # ggsave('save/llpa_groups.pdf', width=groupPlotSize[1], height=groupPlotSize[2], units='cm') 70 | 71 | 72 | appa(pp) 73 | relativeEntropy(pp) 74 | -------------------------------------------------------------------------------- /analysis_mixtvem.R: -------------------------------------------------------------------------------- 1 | source('MixTVEM.R') 2 | data = generate_osa_data() %>% 3 | .[, NormTime := (Time - min(Time)) / (max(Time) - min(Time))] 4 | 5 | fitMixTvem = function(k) { 6 | start = Sys.time() 7 | if(k == 1) { 8 | # needed to prevent consistent estimation errors for the single-group solution 9 | y = ifelse(data$Usage == 0, rnorm(nrow(data), sd=.1), data$Usage) 10 | } else { 11 | y = data$Usage 12 | } 13 | model = TVEMMixNormal(dep=y, 14 | id=data$Id, 15 | numInteriorKnots=6, 16 | deg=3, 17 | doPlot=FALSE, 18 | numClasses=k, 19 | numStarts=20, 20 | gridSize=365, 21 | maxVarianceRatio=NA, 22 | convergenceCriterion=1e-4, 23 | getSEs=FALSE, 24 | tcov=rep(1, nrow(data)), 25 | time=data$NormTime) 26 | model$runTime = Sys.time() - start 27 | return(model) 28 | } 29 | 30 | computeMixTvemTrajectories = function(model) { 31 | k = ncol(model$fittedValues) 32 | times = model$timeGrid * (max(data$Time) - min(data$Time)) + min(data$Time) 33 | data.table(model$betaByGrid[[1]], Time=times) %>% 34 | setnames(paste0('V', 1:k), LETTERS[1:k]) %>% 35 | melt(id.vars='Time', value.name='Usage', variable.name='Group') 36 | } 37 | 38 | # mixtvems = readRDS(file='save/mixtvems.rds') 39 | mixtvems = list() 40 | mixtvems[['1']] = fitMixTvem(1) 41 | mixtvems[['2']] = fitMixTvem(2) 42 | mixtvems[['3']] = fitMixTvem(3) 43 | mixtvems[['4']] = fitMixTvem(4) 44 | mixtvems[['5']] = fitMixTvem(5) 45 | mixtvems[['6']] = fitMixTvem(6) 46 | mixtvems[['7']] = fitMixTvem(7) 47 | mixtvems[['8']] = fitMixTvem(8) 48 | # saveRDS(mixtvems, file='save/mixtvems.rds') 49 | 50 | # Solutions #### 51 | plotMetric(sapply(mixtvems, function(m) m$bestFit$bic), as.integer(names(mixtvems)), 'BIC') 52 | # ggsave('save/mixtvem_bic.pdf', width=bicPlotSize[1], height=bicPlotSize[2], units='cm') 53 | 54 | # Assess the best solution #### 55 | k = 4 56 | bestMixTvem = mixtvems[[as.character(k)]] 57 | 58 | pp = bestMixTvem$bestFit$postProbsBySub 59 | groupProps = colMeans(pp) %T>% print() 60 | 61 | computeMixTvemTrajectories(bestMixTvem) %>% plotGroupTrajectories(groupProps) 62 | # ggsave('save/mixtvem_groups.pdf', width=groupPlotSize[1], height=groupPlotSize[2], units='cm') 63 | 64 | relativeEntropy(pp) 65 | confusionMatrix(pp) %>% {round(. * 100, .1)} 66 | -------------------------------------------------------------------------------- /data.R: -------------------------------------------------------------------------------- 1 | #' @description Generates data from the group definitions provided by M. Aloia et al. (2008). 2 | #' @param patients Number of patients 3 | #' @param times Times at which to generate the data points 4 | #' @param nAggr Number of measurements per bin 5 | #' @param props Group proportions 6 | #' @param N Average number of measurements 7 | #' @param int Level 8 | #' @param slope Slope 9 | #' @param var Measurement variance 10 | #' @param r Autocorrelation 11 | #' @param missing Whether to simulate measurements being missing (including drop-out) 12 | #' @param seed The seed for the PRNG 13 | #' @references 14 | #' Mark S. Aloia, Matthew S. Goodwin, Wayne F. Velicer, J. Todd Arnedt, Molly Zimmerman, Jaime Skrekas, Sarah Harris, Richard P. Millman, 15 | #' Time Series Analysis of Treatment Adherence Patterns in Individuals with Obstructive Sleep Apnea, Annals of Behavioral Medicine, 16 | #' Volume 36, Issue 1, August 2008, Pages 44–53, https://doi.org/10.1007/s12160-008-9052-9 17 | generate_osa_data = function(patients=500, 18 | times=seq(1, 365, by=1), 19 | nAggr=14, 20 | props=c(GU=.24, SI=.13, SD=.14, VU=.17, OA=.08, ED=.13, NU=.11), 21 | # N=c(354, 344, 280, 299, 106, 55, 10), 22 | # sd.N=c(31, 49, 68, 55, 64, 30, 4), 23 | dropoutTimes=c(Inf, Inf, Inf, Inf, Inf, 80, 20), 24 | sd.dropoutTimes=c(0, 0, 0, 0, 0, 30, 10), 25 | attemptProbs=c(354, 344, 280, 299, 106, 55, 14) / pmin(dropoutTimes, 365), # based on median 26 | intercepts=c(GU=6.6, SI=5.8-1, SD=6.1, VU=4.9-.5, OA=3.2, ED=4.0, NU=2.5), 27 | sd.intercepts=c(GU=.81, SI=1.6-.1, SD=.95, VU=1.3, OA=1.6, ED=1.6, NU=1.4) * .667, 28 | slopes=c(GU=0, SI=.0058*3, SD=-.0038*5, VU=.0004*24, OA=-.003, ED=-.0014, NU=-.015), 29 | sd.slopes=c(GU=.0016, SI=.0031/2, SD=.0027/2, VU=.0032*0, OA=.0091, ED=.01, NU=.01), 30 | quads=c(GU=0, SI=-.00003, SD=.00003, VU=-.00003, OA=0, ED=-.0001, NU=-.0001), 31 | sd.quads=c(GU=0, SI=0, SD=0, VU=0, OA=0, ED=0, NU=0), 32 | vars=c(GU=2.0, SI=3.6, SD=3.2, VU=3.4, OA=3.6, ED=5.0, NU=3.0), 33 | sd.vars=c(.82, 1.3, .85, 1.2, 1.8, 2.6, 1.7), 34 | autocors=c(GU=.056, SI=.11, SD=.073, VU=.048, OA=.006, ED=-.044, NU=-.31), 35 | groupNames=c('Good users', 'Slow improvers', 'Slow decliners', 'Variable users', 'Occasional attempters', 'Early drop-outs', 'Non-users'), 36 | missing=FALSE, 37 | seed=1) { 38 | set.seed(seed) 39 | groupCounts = floor(patients * props) 40 | incrIdx = order((patients * props) %% 1) %>% 41 | head(patients - sum(groupCounts)) 42 | groupCounts[incrIdx] = groupCounts[incrIdx] + 1 # increment the groups that were closest to receiving another patient 43 | assert_that(sum(groupCounts) == patients) 44 | groupNames = factor(groupNames, levels=groupNames) 45 | 46 | # generate patient coefficients 47 | groupCoefs = data.table(Group=groupNames, Patients=groupCounts, 48 | TDrop=dropoutTimes, Sd.TDrop=sd.dropoutTimes, 49 | AProb=attemptProbs, 50 | Int=intercepts, Sd.Int=sd.intercepts, 51 | Slope=slopes, Sd.Slope=sd.slopes, 52 | Quad=quads, Sd.Quad=sd.quads, 53 | Var=vars, Sd.Var=sd.vars, 54 | R=autocors) 55 | assert_that(nrow(groupCoefs) == 7) 56 | patCoefs = groupCoefs[, .(TDrop=rnorm(Patients, TDrop, Sd.TDrop) %>% round %>% pmax(7), 57 | AProb=AProb, 58 | Intercept=rnorm(Patients, Int, Sd.Int) %>% pmax(0), 59 | Slope=rnorm(Patients, Slope, Sd.Slope), 60 | Quad=rnorm(Patients, Quad, Sd.Quad), 61 | Variance=rnorm(Patients, Var, Sd.Var) %>% pmax(.75), 62 | R=rep(R, Patients)), by=Group] 63 | 64 | # generate patient measurements 65 | genTs = function(N, Intercept, Slope, Quad, Variance, R, AProb, TDrop, ...) { 66 | y = as.numeric(Intercept + times * Slope + times^2 * Quad + arima.sim(list(ar=R), n=length(times), sd=sqrt(Variance))) 67 | 68 | skipMask = !rbinom(length(times), size=1, prob=AProb) 69 | y[skipMask] = 0 70 | 71 | if(missing) { 72 | obsMask = times <= TDrop 73 | list(Time=times[obsMask], Usage=pmax(y[obsMask], 0)) 74 | } else { 75 | y[times > TDrop] = 0 76 | list(Time=times, Usage=pmax(y, 0)) 77 | } 78 | } 79 | 80 | patNames = paste0('P', 1:patients) 81 | alldata = patCoefs[, do.call(genTs, .SD), by=.(Group, Id=factor(patNames, levels=patNames))] %>% 82 | setkey(Id, Time) 83 | 84 | # generate group trajectories 85 | groupTrajs = groupCoefs[, .(Time=times, 86 | Usage=(pmax(Int + times * Slope + times^2 * Quad, 0) * AProb) %>% ifelse(times > TDrop, 0, .) 87 | ), by=Group] 88 | 89 | setattr(alldata, 'patCoefs', patCoefs) 90 | setattr(alldata, 'groupCoefs', groupCoefs) 91 | setattr(alldata, 'groupTrajs', groupTrajs) 92 | 93 | # Extra step: downsampling 94 | if(is.finite(nAggr)) { 95 | alldata = transformToAverage(alldata, binSize=nAggr) 96 | } 97 | return(alldata[]) 98 | } 99 | 100 | export_data = function(file, ..., seed = 1) { 101 | data = generate_osa_data(..., seed = seed) 102 | write.csv(data, file = file, quote = FALSE, row.names = FALSE) 103 | } 104 | 105 | transformToAverage = function(data, binSize=14) { 106 | bins = seq(min(data$Time), max(data$Time), by=binSize) 107 | nBins = length(bins) - 1 108 | bindata = data[, .(Usage=mean(Usage)), keyby=.(Group, Id, Bin=findInterval(Time, bins, all.inside=TRUE))] %>% 109 | .[, Time := bins[Bin]] 110 | 111 | groupTrajs = attr(data, 'groupTrajs') 112 | bingroupTrajs = groupTrajs[, .(Usage=mean(Usage)), keyby=.(Group, Bin=findInterval(Time, bins, all.inside=TRUE))] %>% 113 | .[, Time := bins[Bin]] 114 | 115 | setattr(bindata, 'groupTrajs', bingroupTrajs) 116 | return(bindata[]) 117 | } 118 | 119 | 120 | 121 | transformToRepeatedMeasures = function(data) { 122 | assert_that(is.data.frame(data), has_name(data, c('Id', 'Time', 'Usage'))) 123 | dtWide = dcast(data, Id ~ Time, value.var='Usage') 124 | 125 | dataMat = as.matrix(dtWide[, -'Id']) 126 | assert_that(nrow(dataMat) == uniqueN(data$Id), ncol(dataMat) == uniqueN(data$Time)) 127 | rownames(dataMat) = dtWide$Id 128 | colnames(dataMat) = names(dtWide)[-1] 129 | return(dataMat) 130 | } 131 | 132 | -------------------------------------------------------------------------------- /demo.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: No 4 | SaveWorkspace: No 5 | AlwaysSaveHistory: No 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 4 10 | Encoding: UTF-8 11 | 12 | RnwWeave: knitr 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | -------------------------------------------------------------------------------- /metrics.R: -------------------------------------------------------------------------------- 1 | # Average posterior probability of assignments (APPA) 2 | appa = function(pp) { 3 | rowMaxs(pp) %>% mean() 4 | } 5 | 6 | entropy = function(pp) { 7 | assert_that(is.matrix(pp), min(pp) >= 0, max(pp) <= 1) 8 | pp = pmax(pp, .Machine$double.xmin) 9 | -sum(rowSums(pp * log(pp))) 10 | } 11 | 12 | relativeEntropy = function(pp) { 13 | N = nrow(pp) 14 | K = ncol(pp) 15 | 1 - entropy(pp) / (N * log(K)) 16 | } 17 | 18 | confusionMatrix = function(pp) { 19 | IMIFA::post_conf_mat(pp) 20 | } 21 | -------------------------------------------------------------------------------- /plot.R: -------------------------------------------------------------------------------- 1 | plotTrajectory = function(data, id=sample(levels(data$Id), size=1)) { 2 | p = ggplot(data[Id == id], aes(x=Time, y=Usage, group=Id)) + 3 | geom_line() + 4 | scale_x_continuous(breaks=seq(1, 361, by=30)) + 5 | scale_y_continuous(breaks=pretty_breaks(10)) + 6 | expand_limits(x=range(data$Time)) + 7 | coord_cartesian(ylim=c(0, 12)) + 8 | labs(x='Day', y='Hours of use', title=sprintf('Id %s of group "%s"', id, first(data[Id == id, Group]))) 9 | 10 | if(uniqueN(data$Time) < 30) { 11 | p = p + geom_point() 12 | } 13 | return(p) 14 | } 15 | 16 | 17 | plotTrajectories = function(data) { 18 | props = data[, .(N=uniqueN(Id)), keyby=Group]$N / uniqueN(data$Id) 19 | plotdata = copy(data) %>% 20 | .[, GroupLabel := factor(Group, levels=levels(Group), labels=sprintf('%s (%s%%)', gsub(' ', ' ', levels(Group)), round(props * 100)))] 21 | 22 | ggplot(plotdata, aes(x=Time, y=Usage, group=Id)) + 23 | geom_line(size=.0001, alpha=.25) + 24 | scale_x_continuous(breaks=seq(1, 361, by=90)) + 25 | scale_y_continuous(breaks=seq(0, 12, by=2)) + 26 | expand_limits(x=range(data$Time)) + 27 | coord_cartesian(ylim=c(0, 10)) + 28 | labs(x='Day', y='Hours of use') + 29 | facet_wrap(~GroupLabel) 30 | } 31 | 32 | 33 | plotGroupTrajectories = function(data, props) { 34 | if(has_attr(data, 'groupTrajs')) { 35 | groupTrajs = attr(data, 'groupTrajs') %>% copy() 36 | if(missing(props)) { 37 | props = data[, .(N=uniqueN(Id)), keyby=Group]$N / uniqueN(data$Id) 38 | } 39 | } else { 40 | groupTrajs = copy(data) 41 | } 42 | 43 | if(!missing(props)) { 44 | groupTrajs[, GroupLabel := factor(Group, levels=levels(Group), labels=sprintf('%s (%s%%)', gsub(' ', '\n', levels(Group)), round(props * 100)))] 45 | } else { 46 | groupTrajs[, GroupLabel := Group] 47 | } 48 | 49 | # Point positions 50 | times = sort(unique(groupTrajs$Time)) 51 | nPoints = min(length(times), 10) 52 | pointIdx = seq(1, length(times), length.out=nPoints) %>% round 53 | pointData = groupTrajs[Time %in% times[pointIdx]] 54 | 55 | # Compute label positions 56 | library(lpSolve) 57 | groupTrajs[, Id := Group] 58 | groupMat = transformToRepeatedMeasures(groupTrajs) 59 | k = nlevels(groupTrajs$Group) 60 | labelIdx = seq(1, length(times), length.out=k+2)[2:(k+1)] %>% round 61 | m = apply(groupMat[, labelIdx], 2, function(x) vapply(1:k, function(i) min(abs(x[i] - x[-i])), FUN.VALUE=0)) 62 | lp = lp.assign(m, direction='max') 63 | groupOrder = apply(lp$solution, 2, which.max) 64 | 65 | textData = groupTrajs[data.frame(Group=levels(groupTrajs$Group)[groupOrder], 66 | Time=times[labelIdx]), on=c('Group', 'Time')] 67 | 68 | # Plot 69 | ggplot(groupTrajs, aes(x=Time, y=Usage, group=Group)) + 70 | geom_line(size=.1) + 71 | geom_point(data=pointData, aes(x=Time, y=Usage, shape=GroupLabel), size=2) + 72 | geom_label(data=textData, aes(x=Time, y=Usage, label=Group), 73 | color='gray20', 74 | alpha=.9, 75 | size=ifelse(nchar(levels(groupTrajs$Group)) == 1, 2.5, 1.5)) + 76 | scale_x_continuous(breaks=seq(1, 361, by=60)) + 77 | scale_y_continuous(breaks=seq(0, 12, by=2)) + 78 | scale_shape_manual(name='Cluster', values=1:k) + 79 | expand_limits(x=range(data$Time), y=c(0, 7)) + 80 | labs(x='Day', y='Hours of use') 81 | } 82 | 83 | 84 | plotMetric = function(values, numGroups, name='Metric') { 85 | assert_that(length(values) == length(numGroups)) 86 | 87 | data.frame(NumGroups=numGroups, Metric=values) %>% 88 | ggplot(aes(NumGroups, Metric)) + 89 | geom_line(size=.5) + 90 | geom_point(size=1) + 91 | scale_x_continuous(breaks=seq(1, max(numGroups)), minor_breaks=NULL) + 92 | scale_y_continuous(breaks=pretty_breaks(5), labels=comma) + 93 | labs(x='Number of clusters', y=name) 94 | } 95 | --------------------------------------------------------------------------------