├── examples ├── diamonds_col2skip.txt ├── diamonds.RData ├── diamonds_types.txt ├── predict_csv.conf ├── predict_rdata.conf ├── data_csv.conf ├── model.conf └── data_rdata.conf ├── conf ├── EXPORT.conf ├── PREDICT.conf ├── PROFILES.conf ├── PARDEP.conf ├── MODEL.conf └── DATA.conf ├── bin ├── trainModel.bat ├── trainModel.sh ├── runModel.sh ├── exportModel.sh └── runModelSQL.sh ├── src ├── winsorize.R ├── rfGraphics.R ├── rfCV.R ├── logger.R ├── rfExportSQL_main.R ├── rfPrintProfiles_main.R ├── rfCV_main.R ├── rfTrain.R ├── rfPreproc.R ├── rfTrain_main.R ├── rfR2HTML.R ├── rfAPI.R ├── rfExport.R ├── rfPredict_main.R ├── rfRulesIO.R └── rfPardep_main.R ├── doc ├── EXPORT_CONF.md ├── PARDEP_CONF.md ├── DATA_CONF.md ├── MODEL_CONF.md └── REgo.html ├── README.md └── LICENSE /examples/diamonds_col2skip.txt: -------------------------------------------------------------------------------- 1 | x 2 | y 3 | z 4 | -------------------------------------------------------------------------------- /examples/diamonds.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intuit/rego/HEAD/examples/diamonds.RData -------------------------------------------------------------------------------- /examples/diamonds_types.txt: -------------------------------------------------------------------------------- 1 | carat,1 2 | cut,2 3 | color,2 4 | clarity,2 5 | depth,1 6 | table,1 7 | price,1 8 | x,1 9 | y,1 10 | z,1 11 | -------------------------------------------------------------------------------- /conf/EXPORT.conf: -------------------------------------------------------------------------------- 1 | param value 2 | do.dedup 1 3 | expand.lcl.mode 2 4 | out.type score 5 | sql.dialect HiveQL 6 | #levels.fname 7 | max.sqle.length 500 8 | log.level kLogLevelINFO 9 | out.fname rules_forSQL.txt 10 | -------------------------------------------------------------------------------- /conf/PREDICT.conf: -------------------------------------------------------------------------------- 1 | param value 2 | ## Data source 3 | data.source.type csv 4 | ## ...CSV spec 5 | csv.path d:/data/csv_400k/ 6 | csv.fname test_01.csv 7 | ## Column specs 8 | col.y HAS_PAYMENTS 9 | col.id ROWNUM 10 | ## Output 11 | out.path d:/model/re/400k/type.both_size.3/test 12 | out.fname id_y_yHat_test.csv 13 | rf.working.dir d:/tmp/RuleFit 14 | ## Other 15 | log.level kLogLevelINFO 16 | -------------------------------------------------------------------------------- /examples/predict_csv.conf: -------------------------------------------------------------------------------- 1 | param value 2 | ## Data source 3 | data.source.type csv 4 | ## ...CSV spec 5 | csv.path . 6 | csv.fname diamonds.csv 7 | ## Column specs 8 | ## ... if y is known: 9 | col.y price 10 | # col.id ROWNUM 11 | ## Output 12 | out.path /tmp/REgo/Diamonds_wd/predict 13 | out.fname diamonds_predict.csv 14 | ## Woking directory 15 | rf.working.dir /tmp/REgo/Diamonds_wd 16 | ## Other 17 | log.level kLogLevelINFO 18 | -------------------------------------------------------------------------------- /conf/PROFILES.conf: -------------------------------------------------------------------------------- 1 | param value 2 | mod.path /home/.../mod.sel.2/type.both_size.4_class.bal_mlc.500/export 3 | ## Define subgroups based on score 4 | very.unlikely.thresh -1.0 5 | very.likely.thresh +1.1 6 | ## Output 7 | out.path /home/.../mod.sel.2/type.both_size.4_class.bal_mlc.500/export/PROFILES 8 | html.min.var.imp 5 9 | yHat.hist.fname yHat.png 10 | yHat.hist.title "Model: Pr(2ndCheck | x)" 11 | -------------------------------------------------------------------------------- /conf/PARDEP.conf: -------------------------------------------------------------------------------- 1 | param value 2 | ## Variable to plot 3 | var.name carat 4 | ## Partial dependence specs 5 | num.obs 500 6 | var.num.values 200 7 | var.trim.qntl 0.025 8 | var.rug.qntl 0.1 9 | var.levels.las 1 10 | show.pdep.dist 1 11 | show.yhat.mean 1 12 | var.boxplot.range 0 13 | ## Output 14 | out.path /tmp/REgo/Diamonds_wd/export/R2HTML/singleplot_new 15 | ## Model info 16 | model.path /tmp/REgo/Diamonds_wd/export 17 | rf.working.dir /tmp/REgo/Diamonds_wd 18 | ## Other 19 | log.level kLogLevelINFO 20 | -------------------------------------------------------------------------------- /examples/predict_rdata.conf: -------------------------------------------------------------------------------- 1 | param value 2 | ## Data source 3 | data.source.type rdata 4 | ## RData spec 5 | rdata.path . 6 | rdata.fname diamonds.RData 7 | ## Name of dataframe object; can be omitted if there is only one object in the RData file 8 | # rdata.dfname X 9 | ## Column specs 10 | ## ... if y is known: 11 | col.y price 12 | # col.id ROWNUM 13 | ## Output 14 | out.path /tmp/REgo/Diamonds_wd/predict 15 | out.fname diamonds_predict.csv 16 | ## Woking directory 17 | rf.working.dir /tmp/REgo/Diamonds_wd 18 | ## Other 19 | log.level kLogLevelINFO 20 | -------------------------------------------------------------------------------- /bin/trainModel.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | REM Provides a "batch" interface to the RuleFit statistical model building 3 | REM program. To invoke it use: 4 | REM 5 | REM %REGO_HOME%/bin/trainModel.bat DATA.conf MODEL.conf [LOGGER.txt] 6 | REM 7 | REM You need to set two environment variables: 8 | REM - RS_PATH: path to R installation where R's Rscript is installed 9 | REM E.g., set RS_PATH=D:/R/R-2.15.1/bin/x64 10 | REM - REGO_HOME: path to where the REgo scripts are located 11 | REM E.g., set REGO_HOME=D:/rego 12 | 13 | IF %3.==. ( 14 | %RS_PATH%/Rscript %REGO_HOME%/src/rfTrain_main.R -d %1 -m %2 15 | ) ELSE ( 16 | %RS_PATH%/Rscript %REGO_HOME%/src/rfTrain_main.R -d %1 -m %2 -l %3 17 | ) 18 | -------------------------------------------------------------------------------- /examples/data_csv.conf: -------------------------------------------------------------------------------- 1 | param value 2 | ## Data source 3 | data.source.type csv 4 | ## ...CSV spec 5 | csv.path . 6 | csv.fname diamonds.csv 7 | ## Column specs 8 | col.types.fname diamonds_types.txt 9 | col.y price 10 | col.weights "" 11 | col.skip.fname diamonds_col2skip.txt 12 | col.winz.fname "" 13 | ## Any preprocessing 14 | na.threshold 0.95 15 | min.level.count 0 16 | do.class.balancing 0 17 | ## HTML model report 18 | html.fname "model_summary" 19 | html.title "Diamonds -- Model Summary" 20 | html.title2 "Data used: diamonds.csv" 21 | html.min.var.imp 5 22 | html.min.rule.imp 5 23 | html.singleplot.fname "model_singleplot" 24 | html.singleplot.title "Diamonds -- Dependence Plots" 25 | html.singleplot.nvars 15 26 | ## Other 27 | rand.seed 135711 28 | log.level kLogLevelINFO 29 | -------------------------------------------------------------------------------- /conf/MODEL.conf: -------------------------------------------------------------------------------- 1 | param value 2 | task "classification" 3 | ## Model spec 4 | model.type "both" 5 | model.max.rules 2000 6 | model.max.terms 500 7 | ## Tree Ensemble control 8 | te.tree.size 4 9 | te.sample.fraction 0.5 10 | te.interaction.suppress 3.0 11 | te.memory.param 0.01 12 | ## Regularization method 13 | sparsity.method "Lasso" 14 | ## Model selection 15 | score.criterion "1-AUC" 16 | crossvalidation.num.folds 10 17 | crossvalidation.fold.size 0.1 18 | misclassification.costs 1,1 19 | ## Preprocessing 20 | data.trim.quantile 0.025 21 | data.NA.value 9.0e30 22 | ## Iteration Control 23 | convergence.threshold 1.0e-3 24 | ## Memory management 25 | mem.tree.store 10000000 26 | mem.cat.store 1000000 27 | ## Working directory (model output will be in rf.working.dir/export) 28 | rf.working.dir d:/tmp/RuleFit 29 | -------------------------------------------------------------------------------- /conf/DATA.conf: -------------------------------------------------------------------------------- 1 | param value 2 | ## Data source 3 | data.source.type csv 4 | ## ...DB spec 5 | db.dsn 6 | db.name 7 | db.type 8 | db.tbl.name 9 | db.tbl.maxrows 10 | db.query.tmpl 11 | ## ...CSV spec 12 | csv.path 13 | csv.fname 14 | ## ...RData spec 15 | rdata.path 16 | rdata.fname 17 | rdata.dfname 18 | ## Column specs 19 | col.types.fname 20 | col.y y 21 | col.weights "" 22 | col.skip.fname "" 23 | col.winz.fname "" 24 | ## Any preprocessing 25 | na.threshold 0.95 26 | min.level.count 0 27 | do.class.balancing 0 28 | ## HTML model report 29 | html.fname "model_summary" 30 | html.title "" 31 | html.title2 "" 32 | html.min.var.imp 5 33 | html.min.rule.imp 5 34 | html.singleplot.fname "model_singleplot" 35 | html.singleplot.title "" 36 | html.singleplot.nvars 15 37 | ## Other 38 | rand.seed 135711 39 | log.level kLogLevelINFO 40 | -------------------------------------------------------------------------------- /examples/model.conf: -------------------------------------------------------------------------------- 1 | ## WARNING: the settings here are chosen to reduce the running time 2 | ## and differ from the RuleFit defaults 3 | param value 4 | task "regression" 5 | ## Model spec 6 | model.type "both" 7 | model.max.rules 100 8 | model.max.terms 20 9 | ## Tree Ensemble control 10 | te.tree.size 3 11 | te.sample.fraction 0.5 12 | te.interaction.suppress 3.0 13 | te.memory.param 0.01 14 | ## Regularization method 15 | sparsity.method "Lasso" 16 | ## Model selection 17 | score.criterion "AAE" 18 | crossvalidation.num.folds 10 19 | crossvalidation.fold.size 0.1 20 | misclassification.costs 1,1 21 | ## Preprocessing 22 | data.trim.quantile 0.025 23 | data.NA.value 9.0e30 24 | ## Iteration Control 25 | convergence.threshold 1.0e-3 26 | ## Memory management 27 | mem.tree.store 10000000 28 | mem.cat.store 1000000 29 | ## Working directory (model output will be in rf.working.dir/export) 30 | rf.working.dir /tmp/REgo/Diamonds_wd 31 | -------------------------------------------------------------------------------- /examples/data_rdata.conf: -------------------------------------------------------------------------------- 1 | param value 2 | ## Data source 3 | data.source.type rdata 4 | ## RData spec 5 | rdata.path . 6 | rdata.fname diamonds.RData 7 | ## Name of dataframe object; can be omitted if there is only one object in the RData file 8 | # rdata.dfname X 9 | ## Column specs, optional when data.source.type is rdata 10 | # col.types.fname diamonds_types.txt 11 | col.y price 12 | col.weights "" 13 | col.skip.fname diamonds_col2skip.txt 14 | col.winz.fname "" 15 | ## Any preprocessing 16 | na.threshold 0.95 17 | min.level.count 0 18 | do.class.balancing 0 19 | ## HTML model report 20 | html.fname "model_summary" 21 | html.title "Diamonds -- Model Summary" 22 | html.title2 "Data used: diamonds.RData" 23 | html.min.var.imp 5 24 | html.min.rule.imp 5 25 | html.singleplot.fname "model_singleplot" 26 | html.singleplot.title "Diamonds -- Dependence Plots" 27 | html.singleplot.nvars 15 28 | ## Other 29 | rand.seed 135711 30 | log.level kLogLevelINFO 31 | -------------------------------------------------------------------------------- /src/winsorize.R: -------------------------------------------------------------------------------- 1 | source(file.path(REGO_HOME, "/src/logger.R")) 2 | 3 | winsorize <- function(x, beta = 0.025) 4 | { 5 | # Return l(x) = min(delta_+, max(delta_-, x)) where delta_- and delta_+ are 6 | # the beta and (1 - beta) quantiles of the data distribution {x_i}. The 7 | # value for beta reflects ones prior suspicions concerning the fraction of 8 | # such outliers. 9 | x.n <- length(x) 10 | if ( beta < 0 || beta > 0.5 ) { 11 | error(logger, paste("winsorize: invalid argument beta =", beta)) 12 | } 13 | if ( length(x) < 1 || length(which(is.na(x) == FALSE)) == 0) { 14 | error(logger, paste("winsorize: invalid argument x -- length =", length(x), "NAs =", length(which(is.na(x))))) 15 | } 16 | quant <- quantile(x, probs = c(beta, 1.0 - beta), na.rm = T) 17 | x.min2keep <- quant[1] 18 | x.max2keep <- quant[2] 19 | x.copy <- x 20 | x.copy[which(x < x.min2keep)] <- x.min2keep 21 | x.copy[which(x > x.max2keep)] <- x.max2keep 22 | return(list(x = x.copy, min2keep = x.min2keep, max2keep = x.max2keep)) 23 | } 24 | -------------------------------------------------------------------------------- /bin/trainModel.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | #=================================================================================== 3 | # FILE: trainModel.sh 4 | # 5 | # USAGE: trainModel.sh --d= --m= 6 | # 7 | # DESCRIPTION: Builds a RuleFit model with the given data source and model spec files. 8 | #=================================================================================== 9 | 10 | USAGESTR="usage: trainModel.sh --d= --m=" 11 | 12 | # Parse arguments 13 | for i in $* 14 | do 15 | case $i in 16 | --d=*) 17 | DATA_CONF=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 18 | ;; 19 | --m=*) 20 | MODEL_CONF=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 21 | ;; 22 | *) 23 | # unknown option 24 | echo $USAGESTR 25 | exit 1 26 | ;; 27 | esac 28 | done 29 | 30 | # Validate command-line arguments 31 | if [ -z "$DATA_CONF" -o -z "$MODEL_CONF" ]; then 32 | echo $USAGESTR 33 | exit 1 34 | fi 35 | 36 | # Invoke R code 37 | $REGO_HOME/src/rfTrain_main.R -d ${DATA_CONF} -m ${MODEL_CONF} 38 | -------------------------------------------------------------------------------- /bin/runModel.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | #=================================================================================== 3 | # FILE: runModel.sh 4 | # 5 | # USAGE: runModel.sh --m= --d= 6 | # 7 | # DESCRIPTION: Computes predictions using a previously built RuleFit model on the 8 | # specified data. 9 | #=================================================================================== 10 | 11 | USAGESTR="usage: runModel.sh --m= --d=" 12 | 13 | # Parse arguments 14 | for i in $* 15 | do 16 | case $i in 17 | --m=*) 18 | MODEL_PATH=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 19 | ;; 20 | --d=*) 21 | DATA_CONF=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 22 | ;; 23 | *) 24 | # unknown option 25 | echo $USAGESTR 26 | exit 1 27 | ;; 28 | esac 29 | done 30 | 31 | # Validate command-line arguments 32 | if [ -z "$MODEL_PATH" -o -z "$DATA_CONF" ]; then 33 | echo $USAGESTR 34 | exit 1 35 | fi 36 | 37 | # Invoke R code 38 | $REGO_HOME/src/rfPredict_main.R -m ${MODEL_PATH} -d ${DATA_CONF} 39 | 40 | -------------------------------------------------------------------------------- /bin/exportModel.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | #=================================================================================== 3 | # FILE: exportModel.sh 4 | # 5 | # USAGE: exportModel.sh --m= [--c=0 + a1*b1(x) + a2*b2(x) + ...), **rulesonly** (generates rules clauses without coefficients -- i.e.,*b1(x), *b2(x), ...; useful to monitor the firing pattern of each individual rule), or **rulescoeff** (generates rules clauses with coefficients| 7 | | sql.dialect| One of **SQLServer**, **HiveQL**, **Netezza**, or **MySQL** (default is **SQLServer**)| 8 | | max.sqle.length | Max sql expression length, in number of characters (some sql engine have a limit on this)) | 9 | | log.level | one of **kLogLevelDEBUG**, **kLogLevelINFO**, **kLogLevelWARNING**, **kLogLevelERROR**, **kLogLevelFATAL**, **kLogLevelCRITICAL**. Controls the verbosity of the logging messages | 10 | | out.fname | Output file name (default is "rules_forSQL.txt")| 11 | -------------------------------------------------------------------------------- /doc/PARDEP_CONF.md: -------------------------------------------------------------------------------- 1 | # Rego Partial Dependence Plot Configuration File 2 | | parameter | Value| 3 | | -------------- |:-----:| 4 | | var.name | Name of variable to be considered. | 5 | | ## *Partial Dependence plot control* || 6 | | num.obs | Number of observations to include in averaging calculation (default is 500).| 7 | | var.num.values | Number of distinct variable evaluation points (default is 200).| 8 | | var.trim.qntl | Trim extreme values of variable (default is 0.025) | 9 | | var.rug.qntl | Rug quantile to show numeric variable data density (default is 0.1). | 10 | | var.levels.las |Text orientation of level names (for categorical variable). Default is 1. | 11 | | show.pdep.dist | Show partial dependence distribution (default is 0 -- i.e FALSE)| 12 | | show.yhat.mean | Show partial dependence mean| 13 | | var.boxplot.range | This determines how far the whiskers of a categorical variable extend out from the boxplot's box (this value times interquartile range gives whiskers range). Default is 1.0e-4.| 14 | | ## *Output* || 15 | | out.path | | 16 | | out.fname | output file name (default is var.name.PNG) | 17 | | ## *Model and installation info* || 18 | | model.path | | 19 | | rf.working.dir | path to working directory where model will be saved to. If not specified, an attempt to read environment variable RF_WORKING_DIR will be made.| 20 | |log.level | one of **kLogLevelDEBUG**, **kLogLevelINFO**, **kLogLevelWARNING**, **kLogLevelERROR**, **kLogLevelFATAL**, **kLogLevelCRITICAL**. Controls the verbosity of the logging messages| 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /doc/DATA_CONF.md: -------------------------------------------------------------------------------- 1 | # Rego Data Configuration File 2 | 3 | | parameter | Value| 4 | | -------------- |:-----:| 5 | | ## Data source | where is the data coming from | 6 | | data.source.type | can be **db** (database) or **csv** (comma-separated values file) or **rdata** (RData file save()'d by R) | 7 | | ## ...CSV spec| | 8 | |csv.path | specifies the location of the cdv file, when the data source type is **csv**| 9 | | csv.fname | csv file name| 10 | | csv.sep | the field separator character (optional) | 11 | | ## ...RData spec| | 12 | | rdata.path| specifies the location of the RData file, when the data source type is **rdata**| 13 | | rdata.fname| RData file name| 14 | | rdata.dfname| specifies the name of the data-frame object in the RData file. Can be omitted if there is only one object in the RData file. | 15 | | ## ...DB spec| | 16 | | db.dsn | | 17 | | db.name| | 18 | | db.tbl.name| | 19 | | db.tbl.maxrows| | 20 | | db.query.tmpl | a SQL query "template" file to use when fetching the data -- e.g., "SELECT * FROM _TBLNAME_ WHERE Y IN (0, 1) LIMIT _MAXROWS_ "| 21 | | ## Column specs| | 22 | | col.types.fname | a text file with **column name, column type** pairs (column type is **1** for continuous, and **2** for categorical variables). Can be omitted if the data source type is **rdata**, in which case columns inheriting from factor will be treated as categorical| 23 | | col.y | name of response variable.| 24 | | col.id | name of row-id column (optional). Often useful during prediction when tuples are generated. | 25 | | col.weights | name of weights column (optional). | 26 | |row.weights.fname | name of text file with customized weights for each row (no header, one weight per line) | 27 | | col.skip.fname| name of text file listing columns to skip (like a row-id column); one column name per line (optional)| 28 | | col.winz.fname| text file with column-specific [winsorizing](http://en.wikipedia.org/wiki/Winsorising) parameters (optional).| 29 | | ## Any preprocessing| | 30 | | na.threshold | maximum fraction of NA values to allow per column (optional).| 31 | | min.level.count | levels with fewer than this count will be merged.| 32 | | do.class.balancing| set to 1 to have classes to be equally weighted; 0 otherwise.| 33 | | ## HTML model report| | 34 | | html.fname | file where to write model summary as an HTML report.| 35 | | html.title | | 36 | | html.title2 | | 37 | | html.min.var.imp | exclude from HTML report variables with importance score lower than this.| 38 | | html.min.rule.imp | exclude from HTML report rules with importance score lower than this.| 39 | | ## Other| | 40 | | rand.seed| random number seed | 41 | |log.level | one of **kLogLevelDEBUG**, **kLogLevelINFO**, **kLogLevelWARNING**, **kLogLevelERROR**, **kLogLevelFATAL**, **kLogLevelCRITICAL**. Controls the verbosity of the logging messages | 42 | -------------------------------------------------------------------------------- /src/rfCV.R: -------------------------------------------------------------------------------- 1 | rfCV <- function(x, y, wts, x.cat.vars, rf.ctxt, nfold=5, yHat.return=FALSE, seed=135711) 2 | { 3 | # TODO: config$rand.seed 4 | # 5 | # Performs outer cross-validation of RuleFit models. 6 | # 7 | # Args: 8 | # x: predictor matrix 9 | # y: response vector 10 | # wts: observation weights 11 | # x.cat.vars: list of categorical variable indices 12 | # rf.ctxt: RuleFit-specific configuration parameters 13 | # nfold: number of cross-validation runs to do 14 | # yHat.return: whether to return out-of-sample yHat 15 | # seed: random number seed (for partitioning x's rows) 16 | # 17 | # Returns: 18 | # A list with: 19 | # stats: a nfold-by-10 matrix with the following columns: 20 | # ECV, ECV std: Estimated criterion value (e.g., AAE) reported by RuleFit 21 | # terms: model size reported by RuleFit 22 | # train.*: in-sample error 23 | # test.*: out-of-sample error 24 | # oos.yHat: out-of-sample yHat (if requested) 25 | dbg(logger, "rfCV:") 26 | 27 | # Init return structures 28 | cv.stats <- matrix(NA, nrow = nfold, ncol = 10) 29 | oos.yHat <- NULL 30 | oos.y.idx <- NULL 31 | colnames(cv.stats) <- c("ECV", "ECV_std", "terms", 32 | "train.error", "train.med.error", "train.aae", 33 | "test.error", "test.med.error", "test.aae", 34 | "cor.test") 35 | # Generate data splits 36 | set.seed(seed) 37 | group <- sample(rep(1:nfold, length = nrow(x))) 38 | 39 | # Build 'nfold' models 40 | for (i.cv in 1:nfold) { 41 | # Subset data 42 | test <- which(group == i.cv) 43 | x.train <- x[-test,] 44 | x.test <- x[test,] 45 | y.train <- y[-test] 46 | y.test <- y[test] 47 | wt.train <- wts[-test] 48 | wt.test <- wts[test] 49 | 50 | # Build model-i (which internally also uses cv for stopping param) 51 | set.seed(config$rand.seed) 52 | rfmod <- TrainRF(x, y, obs.wt, rf.ctxt, x.cat.vars) 53 | # ... Log (estimated generalization) model error and size 54 | rfmod.stats <- runstats(rfmod) 55 | info(logger, paste("Estimated criterion value:", rfmod.stats$cri, paste("(+/- ", rfmod.stats$err, "),", sep=""), 56 | "Num terms:", rfmod.stats$terms)) 57 | 58 | # Collect model stats: "Criterion", "terms" 59 | cv.stats[i.cv, 1] <- rfmod.stats$cri 60 | cv.stats[i.cv, 2] <- rfmod.stats$err 61 | cv.stats[i.cv, 3] <- rfmod.stats$terms 62 | 63 | # Collect in-sample accuracy 64 | yHat.train <- rfpred(x.train) 65 | re.train.error <- sum(wt.train*abs(yHat.train - y.train))/sum(wt.train) 66 | train.med.error <- sum(wt.train*abs(y.train - wMed(y.train, wt.train)))/sum(wt.train) 67 | train.aae <- re.train.error / train.med.error 68 | cv.stats[i.cv, 4] <- re.train.error 69 | cv.stats[i.cv, 5] <- train.med.error 70 | cv.stats[i.cv, 6] <- train.aae 71 | # Collect out-of-sample accuracy 72 | yHat.test <- rfpred(x.test) 73 | re.test.error <- sum(wt.test*abs(yHat.test - y.test))/sum(wt.test) 74 | test.med.error <- sum(wt.test*abs(y.test - wMed(y.train, wt.train)))/sum(wt.test) 75 | test.aae <- re.test.error / test.med.error 76 | cv.stats[i.cv, 7] <- re.test.error 77 | cv.stats[i.cv, 8] <- test.med.error 78 | cv.stats[i.cv, 9] <- test.aae 79 | cv.stats[i.cv, 10] <- cor(yHat.test, y.test) 80 | # Collect out-of-sample yHat 81 | if (yHat.return) { 82 | if (is.null(oos.yHat) && is.null(oos.y.idx)) { 83 | oos.yHat <- list(yHat.test) 84 | oos.y.idx <- list(test) 85 | } else { 86 | oos.yHat <- c(oos.yHat, list(yHat.test)) 87 | oos.y.idx <- c(oos.y.idx, list(test)) 88 | } 89 | } 90 | } 91 | 92 | return(list(stats = cv.stats, oos.y.idx = oos.y.idx, oos.yHat = oos.yHat)) 93 | } 94 | -------------------------------------------------------------------------------- /doc/MODEL_CONF.md: -------------------------------------------------------------------------------- 1 | # Rego Model Configuration File 2 | 3 | | parameter | Value| 4 | | -------------- |:-----:| 5 | | task | **regression** or (binary) **classification** | 6 | | ## *Model spec* | | 7 | | model.type | **rules**, **linear** or **both**| 8 | | model.max.rules | the approximate number of rules generated (prior to the regularization phase). It is approximate because the rules are not generated directly but via the tree ensemble. Since a *J*−terminal node tree generates *2 × (J − 1)* rules, *max.rules/J* defines the number of trees. | 9 | | model.max.terms | maximum number of terms selected for final model| 10 | | ## *Tree Ensemble control* | | 11 | | te.tree.size | the average number of terminal nodes in trees generated during phase-1 of the ensemble generation phase (i.e., before each tree is built, a size value is drawn from a distribution that has a mean of *J*) . This controls the dominant [interaction order](http://en.wikipedia.org/wiki/Interaction_(statistics)) of the model -- e.g., size-2 trees implies a "main-effects" only model, size-3 trees allow for first-order interactions, etc. | 12 | | te.sample.fraction | represents the size of the training data sub-sample used to build each tree. It controls "diversity" in the ensemble (correlation among the trees) and speed. Smaller values result in higher diversity and faster compute times. Default exponentially decays from ~0.5 (for small data sets) towards ~0.01, but overwrite it if you would like to have each tree fit to larger data subsets than what this default will produce.| 13 | | te.interaction.suppress | interaction suppression factor. "Boosts" split scores to favor reusing the same variable along a given (root to node) tree path. This places a preference on fewer variables defining splits along such paths, which will be later converted into rules. This makes it harder for "spurious" interactions to come into the model. These spurious interactions can occur in the presence of high [collinearity](http://en.wikipedia.org/wiki/Multicollinearity) among the input variables.| 14 | | te.memory.param | tree ensemble learning rate | 15 | | ## *Regularization method* | | 16 | | sparsity.method | one of **Lasso**, **Lasso+FSR** (lasso to select variable entry order, followed by a forward stepwise regression), **FSR** (forward stepwise to select variables and fit model), **ElasticNet** | 17 | | elastic.net.param | Required when sparsity.method = "ElasticNet"; 0 < alpha < 1. | 18 | | ## *Model selection* | | 19 | | score.criterion | one of **1-AUC**(optimizes 1.0 minus the Area under the ROC Curve),**AAE**(optimizes average absolute error; if (task == *classification*), optimizes average squared-error loss on predicted probabilities),**LS** (optimizes average squared-error), **Misclassification** (optimizes misclassification risk). | 20 | | crossvalidation.num.folds | number of test replications used for model selection| 21 | | crossvalidation.fold.size | fraction of input observations used it each test fold| 22 | | misclassification.costs | misclassification costs (when task == *classification*)| 23 | | ## *Preprocessing* | | 24 | | data.trim.quantile | [winsorising](http://en.wikipedia.org/wiki/Winsorising) quantile; applies to both ends of the variable range. If you need to winsorise different variables by different amounts, you need to use the col.winz.fname parameter and then set this parameter to 0.| 25 | | data.NA.value | predictor variable values of NA are internally set to this value| 26 | | ## *Iteration Control* | | 27 | | convergence.threshold | coefficient estimation iterations stop when the maximum standardized coefficient change from the previous iteration is less than this value.| 28 | | ## *Memory management* | | 29 | | mem.tree.store | | 30 | | mem.cat.store | | 31 | | ## *Installation info* | | 32 | | rf.working.dir | path to working directory where model will be saved to. If not specified, an attempt to read environment variable RF_WORKING_DIR will be made.| 33 | -------------------------------------------------------------------------------- /src/logger.R: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Simple logging class. 3 | # 4 | # Usage: 5 | # REGO_HOME = Sys.getenv("REGO_HOME") 6 | # source(file.path(REGO_HOME, "src/logger.R")) 7 | # logger <- new("logger") 8 | # OR logger <- new("logger", file.name="foo.log") 9 | # OR logger <- new("logger", log.level = kLogLevelWARNING, file.name="foo.log") 10 | # info(logger, "hello") 11 | # warn(logger, "world") 12 | # 13 | # Author: Giovanni Seni 14 | ############################################################################### 15 | library(methods) # not automatically loaded by Rscript 16 | 17 | # Standard logging levels that can be used to control logging output (borrowed 18 | # definitions from python) 19 | kLogLevelNOTSET <- 0 20 | kLogLevelDEBUG <- 10 21 | kLogLevelINFO <- 20 22 | kLogLevelWARNING <- 30 23 | kLogLevelERROR <- 40 24 | kLogLevelFATAL <- 50 25 | kLogLevelCRITICAL <- 50 26 | 27 | # Logging class and functions 28 | setClass("logger", 29 | representation(file.name = "character", 30 | log.level = "numeric"), 31 | prototype = list(file.name = "", 32 | log.level = kLogLevelDEBUG)) 33 | 34 | setMethod("initialize", "logger", 35 | def = function(.Object) { 36 | .Object@file.name <- "" 37 | .Object@log.level <- kLogLevelDEBUG 38 | .Object 39 | } 40 | ) 41 | 42 | setMethod("initialize", "logger", 43 | def = function(.Object, file.name) { 44 | if (file.exists(file.name)) { 45 | stop(paste("Log file '", file.name, "' already exists", sep=""), 46 | call. = FALSE) 47 | } 48 | .Object@file.name <- file.name 49 | .Object@log.level <- kLogLevelDEBUG 50 | .Object 51 | } 52 | ) 53 | 54 | setMethod("initialize", "logger", 55 | def = function(.Object, file.name, log.level) { 56 | if (file.exists(file.name)) { 57 | stop(paste("Log file '", file.name, "' already exists", sep=""), 58 | call. = FALSE) 59 | } 60 | .Object@file.name <- file.name 61 | .Object@log.level <- log.level 62 | .Object 63 | } 64 | ) 65 | 66 | setGeneric("loggerWrite", 67 | def = function(this, level, msg, levelName) { 68 | standardGeneric("loggerWrite") 69 | }, 70 | useAsDefault = function(this, level, msg, levelName) { 71 | ts = format(Sys.time(), "%Y-%m-%d %H:%M:%S") 72 | if (level >= this@log.level) { 73 | write(paste(ts, ' ', levelName, ' - ', msg, sep=''), 74 | file = this@file.name, append = T) 75 | } 76 | } 77 | ) 78 | 79 | setGeneric("dbg", 80 | def = function(this, msg) { 81 | standardGeneric("dbg") 82 | }, 83 | useAsDefault = function(this, msg) { 84 | loggerWrite(this, kLogLevelDEBUG, msg, " DEBUG") 85 | } 86 | ) 87 | 88 | setGeneric("info", 89 | def = function(this, msg) { 90 | standardGeneric("info") 91 | }, 92 | useAsDefault = function(this, msg) { 93 | loggerWrite(this, kLogLevelINFO, msg, " INFO") 94 | } 95 | ) 96 | 97 | setGeneric("warn", 98 | def = function(this, msg) { 99 | standardGeneric("warn") 100 | }, 101 | useAsDefault = function(this, msg) { 102 | loggerWrite(this, kLogLevelWARNING, msg, " WARN") 103 | } 104 | ) 105 | 106 | setGeneric("error", 107 | def = function(this, msg) { 108 | standardGeneric("error") 109 | }, 110 | useAsDefault = function(this, msg) { 111 | loggerWrite(this, kLogLevelERROR, msg, " ERROR") 112 | stop("Ending program...", call. = FALSE) 113 | } 114 | ) 115 | -------------------------------------------------------------------------------- /doc/REgo.html: -------------------------------------------------------------------------------- 1 |

Rule Ensembles go! (REgo)

2 | Predictive Learning plays an important role in many areas of science, finance and industry. Here are some examples of learning problems: 3 | 4 |
    5 |
  • Predict whether a customer would be attracted to a new service offering. Recognizing such customers can reduce the cost of a campaign by reducing the number of contacts.
  • 6 |
  • Predict whether a web site visitor is unlikely to become a customer. The prediction allows prioritization of customer support resources.
  • 7 |
  • Identify the risk factors for churn, based on the content of customer support messages.
  • 8 |
9 | 10 | REgo is a collection of R-based scripts intended to facilitate the process of building, interpreting, and deploying state-of-art predictive learning models. REgo can: 11 | 12 |
    13 |
  • Enable rapid experimentation
  • 14 |
  • Increase self-service capability
  • 15 |
  • Support easy model deployment into a production environment
  • 16 |
17 | 18 | Under the hood REgo uses RuleFit, a statistical model building program created by Prof. Jerome Friedman. RuleFit was written in Fortran but has an R interface. RuleFit implements a model building methodology known as "ensembling," where multiple simple models (base learners) are combined into one usually more accurate than the best of its components. This type of model can be described as an additive expansion of the form F(x) = a0 + a1*b1(x) + a2*b2(x) + ... + aM*bM(x) where the bj(x)'s are the base-learners. 19 | 20 |

21 | In the case of RuleFit, the bj(x) terms are conjunctive rules of the form “if x1 > 22 and x2 > 27 then 1 else 0” or linear functions of a single variable -- e.g., bj(x) = xj. Using base-learners of this type is attractive because they constitute easily interpretable statements about attributes xj. They also preserve the desirable characteristics of Decision Trees such as easy handling of categorical attributes, robustness to outliers in the distribution of x, etc. 22 |

23 | 24 |

25 | RuleFit builds model F(x) in a three-step process: 1. build a tree ensemble (one where the bj(x)'s are decision trees), 2. generate candidate rules from the tree ensemble, and 3. fit coefficients aj via regularized regression. 26 |

27 | 28 | REgo consists of additional R code that we've written to make working with RuleFit easier, including: 29 | 30 |
    31 |
  • The ability to have multiple rulefit batch jobs running simultaneously
  • 32 |
  • Easily specifying a data source
  • 33 |
  • Automatically executing common preprocessing operations
  • 34 |
  • Automatically generating a model summary report with interpretation plots and quality assessment
  • 35 |
  • Exporting a model from R to SQL for deployment in a production environment
  • 36 |
37 | 38 |

Build a Model

39 |
    40 |
  • Call: $REGO_HOME/bin/trainModel.sh --d=DATA.conf --m=MODEL.conf [--l LOGGER.txt]
  • 41 |
  • Input:
  • 42 |
      43 |
    • DATA.conf: data configuration file specifying options such as where the data is coming from, what column corresponds to the target, etc.
    • 44 |
    • MODEL.conf: model configuration file specifying options such as the type of model being fit, the criteria being optimized, etc.
    • 45 |
    • LOGGER.txt: optional file name where to write logging messages
    • 46 |
    47 |
  • Output:
  • 48 |
      49 |
    • model_summary.html: model summary and assessment
    • 50 |
    • model_singleplot.html: interpretation plots
    • 51 |
    • Model definition files: for later export or prediction
    • 52 |
    53 |
54 |

Export a Model

55 |
    56 |
  • Call: $REGO_HOME/bin/exportModel.sh --m=MODEL.dir [--c=EXPORT.conf]
  • 57 |
  • Input
  • 58 |
      59 |
    • MODEL_DIR: path to model definition files
    • 60 |
    • EXPORT.conf: the configuration file specifying export options such as desired sql dialect, type of scoring clause, etc.
    • 61 |
    62 |
  • Output:
  • 63 |
      64 |
    • SQL_FILE: output file name containing model as a SQL expression
    • 65 |
    66 |
67 |

Predict on New Data

68 |
    69 |
  • Call: $REGO_HOME/bin/runModel.sh --m=MODEL.dir --d=DATA.conf
  • 70 |
  • Input:
  • 71 |
      72 |
    • MODEL_DIR: path to model definition files
    • 73 |
    • DATA.conf: specifies test data location
    • 74 |
    75 |
  • Output:
  • 76 |
      77 |
    • Text file with <id, y, yHat> tuples
    • 78 |
    79 |
80 |

Deploy a Model

81 |
    82 |
  • Call: $REGO_HOME/bin/runModelSQL.sh --host --dbIn --tblIn=<Feature table> --pk=<Primary Key> --model=<Model Definition SQL File> --dbOut --tblOut=<Score table>
  • 83 |
  • Input
  • 84 |
      85 |
    • dbIn.tblIn: new data to be scored
    • 86 |
    • model: previously built (and exported) model
    • 87 |
    88 |
  • Output:
  • 89 |
      90 |
    • dbOut.tblOut: Computed scores
    • 91 |
    92 |
93 | 94 |

Dependencies

95 |
    96 |
  • R and R packages R2HTML, ROCR, RODBC, getopt
  • 97 |
  • Other:
  • 98 |
      99 |
    • REGO_HOME: environment variable pointing to where you have checked out the REgo code
    • 100 |
    • RF_HOME: environment variable pointing to appropriate RuleFit executable -- e.g., export RF_HOME=$REGO_HOME/lib/mac
    • 101 |
    102 |
-------------------------------------------------------------------------------- /src/rfExportSQL_main.R: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/Rscript 2 | 3 | ############################################################################### 4 | # FILE: rfExportSQL_main.R 5 | # 6 | # USAGE: rfExportSQL_main.R -m -c EXPORT.conf 7 | # 8 | # DESCRIPTION: 9 | # Exports a previously built RuleFit model to SQL -- i.e., generates a 10 | # SQL expression corresponding to the scoring function defined by the model. 11 | # 12 | # ARGUMENTS: 13 | # model.dir: path to RuleFit model exported files 14 | # EXPORT.conf: the configuration file specifying export options such as 15 | # desired sql dialect, type of scoring clause, etc. 16 | # 17 | # REQUIRES: 18 | # REGO_HOME: environment variable pointing to the directory where you 19 | # have placed this file (and its companion ones) 20 | # 21 | # AUTHOR: Giovanni Seni 22 | ############################################################################### 23 | REGO_HOME <- Sys.getenv("REGO_HOME") 24 | source(file.path(REGO_HOME, "/src/logger.R")) 25 | source(file.path(REGO_HOME, "/src/rfExportSQL.R")) 26 | library(getopt) 27 | 28 | ValidateConfigArgs <- function(conf) 29 | { 30 | # Validates and initializes configuration parameters. 31 | # 32 | # Args: 33 | # conf: A list of pairs 34 | # Returns: 35 | # A list of pairs 36 | 37 | # SQl variant? 38 | if (is.null(conf$sql.dialect)) { 39 | conf$sql.dialect <- "SQLServer" 40 | } else { 41 | stopifnot(conf$sql.dialect %in% c("SQLServer", "HiveQL", "Netezza", "MySQL")) 42 | } 43 | 44 | # Do we need to dedup rule set? 45 | if (is.null(conf$do.dedup)) { 46 | conf$do.dedup <- TRUE 47 | } else { 48 | conf$do.dedup <- (as.numeric(conf$do.dedup) == 1) 49 | } 50 | 51 | # How do we expand low-count levels? 52 | # 1: replace _LowCountLevels_ in SQL scoring expression with corresponding levels 53 | # 2: keep _LowCountLevels_ in SQL scoring expression (keeps it shorter, easier to read) 54 | # and generate an extra sql clause with logic to substitute low count levels with 55 | # _LowCountLevels_ in a data preparation step prior to scoring 56 | if (is.null(conf$expand.lcl.mode)) { 57 | conf$expand.lcl.mode <- 1 58 | } else { 59 | conf$expand.lcl.mode <- as.numeric(conf$expand.lcl.mode) 60 | } 61 | 62 | # log level? 63 | if (is.null(conf$log.level)) { 64 | conf$log.level <- kLogLevelINFO 65 | } else { 66 | conf$log.level <- get(conf$log.level) 67 | } 68 | 69 | # Output type? 70 | if (is.null(conf$out.type)) { 71 | conf$out.type <- "score" 72 | } else { 73 | stopifnot(conf$out.type %in% c("score", "rulesonly", "rulesscore", "rulescoeff")) 74 | } 75 | 76 | # Max sql expression length (in number of characters) 77 | if (is.null(conf$max.sql.length)) { 78 | conf$max.sql.length <- 500 79 | } 80 | 81 | # Levels file? 82 | if (is.null(conf$in.fname)) { 83 | conf$levels.fname <- "xtrain_levels.txt" 84 | } 85 | 86 | # Output file? 87 | if (is.null(conf$out.path)) { 88 | conf$out.path <- conf$model_path 89 | } 90 | if (is.null(conf$out.fname)) { 91 | conf$out.fname <- "rules_forSQL.txt" 92 | } 93 | 94 | return(conf) 95 | } 96 | 97 | ValidateCmdArgs <- function(opt, args.m) 98 | { 99 | # Parses and validates command line arguments. 100 | # 101 | # Args: 102 | # opt: getopt() object 103 | # args.m: valid arguments spec passed to getopt(). 104 | # 105 | # Returns: 106 | # A list of pairs 107 | kUsageString <- "/path/to/rfExportSQL_main.R -m [-c ] [-l ]" 108 | 109 | # Validate command line arguments 110 | if ( !is.null(opt$help) || is.null(opt$model_path) ) { 111 | self <- commandArgs()[1] 112 | cat("Usage: ", kUsageString, "\n") 113 | q(status=1); 114 | } 115 | 116 | # Do we have export options? If so, read them from given file (two columns assumed: 'param' and 'value') 117 | if (is.null(opt$export_conf)) { 118 | conf <- list() 119 | } else { 120 | tmp <- read.table(opt$export_conf, header=TRUE, as.is=T) 121 | conf <- as.list(tmp$value) 122 | names(conf) <- tmp$param 123 | } 124 | conf <- ValidateConfigArgs(conf) 125 | 126 | if (!(file.exists(opt$model_path))) { 127 | stop("Didn't find model directory:", opt$model_path, "\n") 128 | } else { 129 | conf$model.path <- opt$model_path 130 | } 131 | 132 | # Do we have a log file name? "" will send messages to stdout 133 | if (is.null(opt$log)) { 134 | opt$log <- "" 135 | } 136 | conf$log.fname <- opt$log 137 | 138 | return(conf) 139 | } 140 | 141 | ############## 142 | ## Main 143 | # 144 | 145 | # Grab command-line arguments 146 | args.m <- matrix(c( 147 | 'model_path' ,'m', 1, "character", 148 | 'export_conf' ,'c', 1, "character", 149 | 'log' ,'l', 1, "character", 150 | 'help' ,'h', 0, "logical" 151 | ), ncol=4,byrow=TRUE) 152 | opt <- getopt(args.m) 153 | conf <- ValidateCmdArgs(opt, args.m) 154 | 155 | # Create logging object 156 | logger <- new("logger", log.level = conf$log.level, file.name = conf$log.fname) 157 | info(logger, paste("rfExportSQL_main args:", 'model.path =', conf$model.path, ', do.dedup =', conf$do.dedup, 158 | ', expand.lcl.mode =', conf$expand.lcl.mode, ', sql =', conf$sql.dialect, ', log.level =', conf$log.level, 159 | ', out.type =', conf$out.type, ', out.fname =', conf$out.fname, ', max.sql.length =', conf$max.sql.length)) 160 | 161 | # Run export 162 | ExportModel2SQL(model.path = conf$model.path, merge.dups = conf$do.dedup, expand.lcl.mode = conf$expand.lcl.mode, 163 | db.type = conf$sql.dialect, export.type = conf$out.type, levels.fname = conf$levels.fname, 164 | out.path = conf$out.path, out.fname = conf$out.fname, 165 | max.sql.length = conf$max.sql.length) 166 | 167 | q(status=0) 168 | -------------------------------------------------------------------------------- /src/rfPrintProfiles_main.R: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/Rscript 2 | 3 | ############################################################################### 4 | # FILE: printProfiles.R 5 | # 6 | # USAGE: printProfiles.R -c printProfiles.cfg 7 | # 8 | # DESCRIPTION: 9 | # Prints a summary of the training data population (for a given RF model) 10 | # for the "extreme" values of yHat. 11 | ############################################################################### 12 | REGO_HOME <- Sys.getenv("REGO_HOME") 13 | source(file.path(REGO_HOME, "/src/rfRulesIO.R")) 14 | source(file.path(REGO_HOME, "/src/rfExport.R")) 15 | source(file.path(REGO_HOME, "/src/rfGraphics.R")) 16 | library(getopt) 17 | kPlotWidth <- 620 18 | kPlotHeight <- 480 19 | 20 | ValidateCmdArgs <- function(opt, args.m) 21 | { 22 | # Parses and validates command line arguments. 23 | # 24 | # Args: 25 | # opt: getopt() object 26 | # args.m: valid arguments spec passed to getopt(). 27 | # 28 | # Returns: 29 | # A list of pairs 30 | kUsageString <- "/path/to/printProfiles.R -c printProfiles.cfg" 31 | 32 | # Validate command line arguments 33 | if ( !is.null(opt$help) || is.null(opt$conf)) { 34 | self <- commandArgs()[1]; 35 | cat("Usage: ", kUsageString, "\n"); 36 | q(status=1); 37 | } 38 | 39 | # Read config file (two columns assumed: 'param' and 'value') 40 | tmp <- read.table(opt$conf, header=T, as.is=T) 41 | conf <- as.list(tmp$value) 42 | names(conf) <- tmp$param 43 | 44 | # Must have a valid model and output path 45 | stopifnot("mod.path" %in% names(conf)) 46 | stopifnot("out.path" %in% names(conf)) 47 | 48 | # Set defaults for the options that were not specified 49 | if (!("very.unlikely.thresh" %in% names(conf))) { 50 | conf$v.u.thresh <- -1.0 51 | } else { 52 | conf$v.u.thresh <- as.numeric(conf$very.unlikely.thresh) 53 | } 54 | if (!("very.likely.thresh" %in% names(conf))) { 55 | conf$v.l.thresh <- 1.0 56 | } else { 57 | conf$v.l.thresh <- as.numeric(conf$very.likely.thresh) 58 | } 59 | 60 | if (!("html.min.var.imp" %in% names(conf))) { 61 | conf$html.min.var.imp <- 5 62 | } else { 63 | conf$html.min.var.imp <- as.numeric(conf$html.min.var.imp) 64 | } 65 | if (!("yHat.hist.fname" %in% names(conf))) { 66 | conf$yHat.hist.fname <- "yHat.png" 67 | } 68 | if (!("yHat.hist.title" %in% names(conf))) { 69 | conf$yHat.hist.title <- "" 70 | } 71 | 72 | return(conf) 73 | } 74 | 75 | PrintSegmentProfile <- function(x, x.levels, i.segment, vars2print=NULL) 76 | { 77 | if (is.null(vars2print)) { 78 | vars2print <- colnames(x) 79 | } else { 80 | nvars <- length(vars2print) 81 | if (length(intersect(vars2print, colnames(x))) != nvars) { 82 | error(logger, "PrintSegmentProfile: variable name mismatch") 83 | } 84 | } 85 | 86 | for (var.name in vars2print) { 87 | cat(var.name, ":\n") 88 | var.x <- x[, var.name] 89 | i.NA <- which(var.x == kMissing) 90 | var.x[i.NA] <- NA 91 | var.levels <- NULL 92 | # Fetch level names (if appropriate) 93 | for (iVar in 1:length(x.levels)) { 94 | if ( x.levels[[iVar]]$var == var.name ) { 95 | var.levels <- x.levels[[iVar]]$levels 96 | break; 97 | } 98 | } 99 | # Print summary according to type 100 | if (!is.null(var.levels)) { 101 | # Categorical variable 102 | var <- as.factor(var.x) 103 | levels(var) <- var.levels 104 | tmp <- summary(var[i.segment], maxsum = 10) 105 | print(tmp) 106 | print(paste(round(100.0*summary(var[i.segment], maxsum = 10)/sum(tmp), 2), "%", sep = "")) 107 | } else { 108 | # Numeric variable 109 | print(summary(var.x[i.segment])) 110 | } 111 | } 112 | } 113 | 114 | ############## 115 | ## Main 116 | 117 | # Grab command-line arguments 118 | args.m <- matrix(c( 119 | 'conf' ,'c', 1, "character", 120 | 'help' ,'h', 0, "logical" 121 | ), ncol=4,byrow=TRUE) 122 | opt <- getopt(args.m) 123 | conf <- ValidateCmdArgs(opt, args.m) 124 | 125 | ## Load x, y, yhat, levels 126 | load(file.path(conf$mod.path, kMod.yHat.fname)) 127 | load(file.path(conf$mod.path, kMod.xyw.fname)) 128 | x.levels <- ReadLevels(file.path(conf$mod.path, kMod.x.levels.fname)) 129 | 130 | ## Plot score histogram 131 | plot.fname <- conf$yHat.hist.fname 132 | 133 | ## Use own version of png() if necessary: 134 | if (isTRUE(conf$html.graph.dev == "Bitmap")) { 135 | png <- png_via_bitmap 136 | if (!CheckWorkingPNG(png)) stop("cannot generate PNG graphics") 137 | } else { 138 | png <- GetWorkingPNG() 139 | if (is.null(png)) stop("cannot generate PNG graphics") 140 | } 141 | 142 | png(file = file.path(conf$out.path, plot.fname), width=kPlotWidth, height=kPlotHeight) 143 | 144 | hist(y.hat, main=conf$yHat.hist.title, cex.main=0.8, xlab="score", breaks=22) 145 | abline(v = conf$v.u.thresh, col = "red", lty = "dashed") 146 | abline(v = conf$v.l.thresh, col = "blue", lty = "dashed") 147 | dev.off() 148 | 149 | # Which variables are to be summarized (and in what order)? 150 | # ... read variable importance table 151 | vi.df <- read.table(file.path(conf$mod.path, kMod.varimp.fname), header = F, sep = "\t") 152 | colnames(vi.df) <- c("Importance", "Variable") 153 | # ... pick subset from varimp list 154 | # ... ... first, filter out low importance entries 155 | min.var.imp <- max(conf$html.min.var.imp, 1.0) 156 | i.zero.imp <- which(vi.df$Importance < min.var.imp) 157 | if ( length(i.zero.imp == 1) ) { 158 | vi.df <- vi.df[-i.zero.imp, ] 159 | } 160 | nvars <- min(conf$html.singleplot.nvars, nrow(vi.df)) 161 | vars2print <- vi.df$Variable[1:nvars] 162 | 163 | ## "Very Unlikely" subgroup 164 | i.yHat.low <- which(y.hat <= conf$v.u.thresh) 165 | cat("'Very Unlikely' subgroup:", length(i.yHat.low), "(", 166 | round(100*length(i.yHat.low)/nrow(x), 2), "% )\n") 167 | ## ... Accuracy 168 | tbl <- table(y[i.yHat.low], sign(y.hat[i.yHat.low])) 169 | print(tbl) 170 | cat("Accuracy:", round(100*diag(tbl)/sum(tbl), 2), "\n") 171 | ## ... Profile 172 | PrintSegmentProfile(x, x.levels, i.yHat.low, vars2print) 173 | 174 | 175 | ## "Very Likely" subgroup 176 | i.yHat.hi <- which(y.hat >= conf$v.l.thresh) 177 | cat("'Very Likely' subgroup:", length(i.yHat.hi), "(", 178 | round(100*length(i.yHat.hi)/nrow(x), 2), "% )\n") 179 | ## ... Accuracy 180 | tbl <- table(y[i.yHat.hi], sign(y.hat[i.yHat.hi])) 181 | print(tbl) 182 | cat("Accuracy:", round(100 - 100*diag(tbl)/sum(tbl), 2), "\n") 183 | ## ... Profile 184 | PrintSegmentProfile(x, x.levels, i.yHat.hi, vars2print) 185 | q(status=0) 186 | -------------------------------------------------------------------------------- /src/rfCV_main.R: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/Rscript 2 | 3 | ############################################################################### 4 | # FILE: rfCV_main.R 5 | # 6 | # USAGE: rfCV_main.R -c CV.conf 7 | # 8 | # DESCRIPTION: 9 | # Cross-validate a fitted RuleFit model. 10 | # 11 | # ARGUMENTS: 12 | # CV.conf: specifies data & model definition location . 13 | # 14 | # REQUIRES: 15 | # REGO_HOME: environment variable pointing to the directory where you 16 | # have placed this file (and its companion ones) 17 | # RF_HOME: environment variable pointing to appropriate RuleFit 18 | # executable -- e.g., export RF_HOME=$REGO_HOME/lib/RuleFit/mac 19 | # 20 | # AUTHOR: Giovanni Seni 21 | ############################################################################### 22 | REGO_HOME <- Sys.getenv("REGO_HOME") 23 | source(file.path(REGO_HOME, "/src/logger.R")) 24 | source(file.path(REGO_HOME, "/src/rfExport.R")) 25 | source(file.path(REGO_HOME, "/src/rfTrain.R")) 26 | source(file.path(REGO_HOME, "/src/rfGraphics.R")) 27 | source(file.path(REGO_HOME, "/lib/RuleFit", ifelse(nchar(Sys.getenv("RF_API")) > 0, Sys.getenv("RF_API"), "rulefit.r"))) 28 | library(ROCR, verbose = FALSE, quietly=TRUE, warn.conflicts = FALSE) 29 | library(getopt) 30 | 31 | ValidateCmdArgs <- function(opt, args.m) 32 | { 33 | # Parses and validates command line arguments. 34 | # 35 | # Args: 36 | # opt: getopt() object 37 | # args.m: valid arguments spec passed to getopt(). 38 | # 39 | # Returns: 40 | # A list of pairs 41 | kUsageString <- "/path/to/rfCV_main.R -c CV.conf -d -m [-l ]" 42 | 43 | # Validate command line arguments 44 | if ( !is.null(opt$help) || is.null(opt$conf) || is.null(opt$data_conf) || is.null(opt$model_conf) ) { 45 | self <- commandArgs()[1] 46 | cat("Usage: ", kUsageString, "\n") 47 | q(status=1); 48 | } 49 | 50 | # Do we have a log file name? "" will send messages to stdout 51 | if (is.null(opt$log)) { 52 | opt$log <- "" 53 | } 54 | 55 | # Read config file (two columns assumed: 'param' and 'value') 56 | tmp <- read.table(opt$conf, header=T, as.is=T) 57 | conf <- as.list(tmp$value) 58 | names(conf) <- tmp$param 59 | 60 | # Must have a valid data source type 61 | stopifnot("data.source.type" %in% names(conf)) 62 | stopifnot(conf$data.source.type %in% c("csv", "db")) 63 | if (conf$data.source.type == "db") { 64 | stopifnot("db.dsn" %in% names(conf) && "db.name" %in% names(conf) && 65 | "db.type" %in% names(conf) && "db.tbl.name" %in% names(conf)) 66 | } else { 67 | stopifnot("csv.path" %in% names(conf) && "csv.fname" %in% names(conf)) 68 | } 69 | 70 | # Did user specified a log level? 71 | conf$log.fname <- opt$log 72 | if (is.null(conf$log.level)) { 73 | conf$log.level <- kLogLevelINFO 74 | } else { 75 | conf$log.level <- get(conf$log.level) 76 | } 77 | 78 | # Did user specified an output file? 79 | if (is.null(conf$out.fname)) { 80 | conf$out.fname <- "rfCV_out.csv" 81 | } 82 | 83 | return(conf) 84 | } 85 | 86 | 87 | ############## 88 | ## Main 89 | # 90 | 91 | # Grab command-line arguments 92 | args.m <- matrix(c( 93 | 'conf' ,'c', 1, "character", 94 | 'data_conf' ,'d', 1, "character", 95 | 'model_conf' ,'m', 1, "character", 96 | 'log' ,'l', 1, "character", 97 | 'help' ,'h', 0, "logical" 98 | ), ncol=4,byrow=TRUE) 99 | opt <- getopt(args.m) 100 | conf <- ValidateCmdArgs(opt, args.m) 101 | 102 | # Create logging object 103 | logger <- new("logger", log.level = conf$log.level, file.name = conf$log.fname) 104 | 105 | # Read data config file (same one used to build original model) 106 | tmp <- read.table(opt$data_conf, header=T, as.is=T) 107 | train.data.conf <- as.list(tmp$value) 108 | names(train.data.conf) <- tmp$param 109 | train.data.conf <- ValidateConfigArgs(train.data.conf) 110 | 111 | # Load model specification parameters file (same one used to build original model) 112 | train.rf.ctxt <- IntiRFContext(opt$model_conf) 113 | 114 | # Set global env variables required by RuleFit 115 | platform <- conf$rf.platform 116 | RF_HOME <- Sys.getenv("RF_HOME") 117 | RF_WORKING_DIR <- conf$rf.working.dir 118 | 119 | # Create logging object 120 | logger <- new("logger", log.level = conf$log.level, file.name = "") 121 | info(logger, paste("rfCV_main args:", 'model.path =', conf$model.path, 122 | ', log.level =', conf$log.level, ', out.fname =', conf$out.fname)) 123 | 124 | ## Use own version of png() if necessary: 125 | if (isTRUE(conf$graph.dev == "Bitmap")) { 126 | png <- png_via_bitmap 127 | if (!CheckWorkingPNG(png)) error(logger, "cannot generate PNG graphics") 128 | } else { 129 | png <- GetWorkingPNG() 130 | if (is.null(png)) error(logger, "cannot generate PNG graphics") 131 | } 132 | 133 | # Load x, y, w used to build original model 134 | load(file.path(conf$mod.path, kMod.xyw.fname)) 135 | info(logger, paste("Data loaded: dim =", nrow(x), "x", ncol(mod$x), "; NAs =", 136 | length(which(x == 9.0e30)), "(", 137 | round(100*length(which(x == 9.0e30))/(nrow(x)*ncol(x)), 2), 138 | "%)")) 139 | 140 | # Get list of categorical variables 141 | x.levels <- as.data.frame(do.call("rbind", ReadLevels(file.path(conf$mod.path, kMod.x.levels.fname)))) 142 | x.cat.vars <- which(sapply(x.levels[1:10,'levels'], function(x) {ifelse(length(x) > 0,1,0)}) == 1) 143 | 144 | # Do CV 145 | cv.out <- rfCV(x, y, wt, x.cat.vars, train.rf.ctxt, conf$nfold, conf$yHat.return, conf$seed) 146 | 147 | # Compute CV error 148 | if (train.rf.ctxt$task == "class") { 149 | # "classification" model... convert from log-odds to probability estimates 150 | cv.out$oos.yHat <- 1.0/(1.0+exp(-cv.out$oos.yHat)) 151 | # Confusion matrix 152 | conf.m <- table(y, sign(y.hat - 0.5)) 153 | test.acc <- 100*sum(diag(conf.m))/sum(conf.m) 154 | info(logger, paste("CV acc:", round(test.acc, 2))) 155 | info(logger, sprintf("CV confusion matrix - 0/0: %d, 0/1: %d, 1/0: %d, 1/1: %d", 156 | conf.m[1, 1], conf.m[1, 2], conf.m[2, 1], conf.m[2, 2])) 157 | 158 | # Generate ROC plot 159 | kPlotWidth <- 620 160 | kPlotHeight <- 480 161 | plot.fname <- "ROC.png" 162 | pred <- prediction(y.hat, y) 163 | perf <- performance(pred, "tpr", "fpr") 164 | png(file = file.path(conf$out.path, plot.fname), width=kPlotWidth, height=kPlotHeight) 165 | plot(perf, colorize=T, main="") 166 | lines(x=c(0, 1), y=c(0,1)) 167 | dev.off() 168 | } else { 169 | re.test.error <- sum(abs(y.hat - y))/nrow(x) 170 | med.test.error <- sum(abs(y - median(y)))/nrow(x) 171 | aae.test <- re.test.error / med.test.error 172 | info(logger, sprintf("CV AAE: %f (RE:%f, Med:%f)", aae.test, re.test.error, med.test.error)) 173 | } 174 | } 175 | 176 | # Dump tuples, as appropriate 177 | if ("col.id" %in% names(conf)) { 178 | obs.id <- data[,conf$col.id] 179 | } else { 180 | obs.id <- rep(NA, nrow(data)) 181 | } 182 | WriteObsIdYhat(out.path = conf$out.path, obs.id = obs.id, y = y, y.hat = y.hat, file.name = conf$out.fname) 183 | 184 | q(status=0) 185 | -------------------------------------------------------------------------------- /src/rfTrain.R: -------------------------------------------------------------------------------- 1 | source(file.path(REGO_HOME, "/src/logger.R")) 2 | source(file.path(REGO_HOME, "/src/rfPreproc.R")) 3 | source(file.path(REGO_HOME, "/src/rfAPI.R")) 4 | source(file.path(REGO_HOME, "/src/rfExport.R")) 5 | 6 | ReadColumnTypes <- function(fname) 7 | { 8 | # Reads a text file of pairs. 9 | # 10 | # Args: 11 | # fname: character string naming the file 12 | # Returns: 13 | # A numeric vector 14 | stopifnot(!file.access(fname, mode = 4)) 15 | dbg(logger, "ReadColumTypes:") 16 | 17 | col.types.df <- read.table(fname, sep = ",", header = F) 18 | if (ncol(col.types.df) != 2) { 19 | error(logger, paste("Expecting 2 columns in", fname, "file!")) 20 | } 21 | col.types <- col.types.df[, 2] 22 | names(col.types) <- col.types.df[, 1] 23 | dbg(logger, paste(" ", length(col.types), "elements read")) 24 | if (!all(col.types %in% c(1, 2))) { 25 | error(logger, paste("Invalid column types found")) 26 | } 27 | return(col.types) 28 | } 29 | 30 | TrainModel <- function(data, config, rf.ctxt) 31 | { 32 | # Builds a RuleFit model on the given data set. 33 | # 34 | # Args: 35 | # data: data-frame with rows 36 | # config: configuration parameters 37 | # rf.ctxt: RuleFit-specific configuration parameters 38 | # 39 | # Returns: 40 | # A tuple . 41 | stopifnot(!is.na(config$col.y)) 42 | dbg(logger, "TrainModel:") 43 | 44 | # Pop the target column from the feature space 45 | i <- grep(paste("^", config$col.y, "$", sep=""), colnames(data), perl=T) 46 | if (length(i) == 0) { 47 | error(logger, paste("Target column", config$col.y, "not found in input data-frame!")) 48 | } 49 | y <- data[, i] 50 | data <- data[, -i] # pop 51 | 52 | # If classification task, check y values 53 | if (rf.ctxt$rfmode == "class") { 54 | if (length(unique(y)) != 2) { 55 | error(logger, paste("Target column", config$col.y, "must have 2 values only")) 56 | } 57 | y.summary <- summary(as.factor(y)) 58 | info(logger, "Y summary:") 59 | info(logger, sprintf(" %s: %d, %s: %d", names(y.summary)[1], y.summary[1], 60 | names(y.summary)[2], y.summary[2])) 61 | if (min(y) != -1 || max(y) != 1) { 62 | if (min(y) == 0 && max(y) == 1) { 63 | y <- 2 * y - 1 # switch from (0,1) boolean to (-1,1) boolean 64 | } else { 65 | error(logger, paste("Target column", config$col.y, "must be 0/1")) 66 | } 67 | } # else, leave it as is 68 | } 69 | 70 | # Set observation weights 71 | if (!is.na(config$col.weights)) { 72 | obs.wt <- data[[config$col.weights]] 73 | # Pop weight column from data 74 | i <- grep(paste("^", config$col.weights, "$", sep=""), colnames(data), perl=T); 75 | if (length(i) > 0) { 76 | dbg(logger, paste("Pop weight Column: ", i)) 77 | data <- data[, -i] # pop 78 | } else { 79 | error(logger, paste("Weight column", config$col.weights, "not found!")) 80 | } 81 | } else if (conf$do.class.balancing && rf.ctxt$rfmode == "class") { 82 | # Balance classes via observation weights - equal ratio 83 | i.pos <- which(y == 1) 84 | obs.wt <- rep(1, nrow(data)) 85 | obs.wt[-i.pos] <- nrow(data[i.pos, ]) / nrow(data[-i.pos, ]) 86 | obs.wt.summary <- summary(as.factor(obs.wt)) 87 | info(logger, "Class weights:") 88 | info(logger, paste(names(obs.wt.summary), obs.wt.summary, sep=":", collapse=" ")) 89 | } else if ("row.weights.fname" %in% names(conf)) { 90 | # Customized weights for each row 91 | obs.wt <- read.table(conf$row.weights.fname, header = F)[,1] 92 | obs.wt.summary <- summary(obs.wt) 93 | info(logger, "Row weights (from weight file):") 94 | info(logger, paste(names(obs.wt.summary), obs.wt.summary, sep=":", collapse=" ")) 95 | } else { 96 | obs.wt <- rep(1, nrow(data)) 97 | } 98 | 99 | # Pop the row-id column from the feature space, if present 100 | if (!is.na(config$col.id)) { 101 | i <- grep(paste("^", config$col.id, "$", sep=""), colnames(data), perl=T) 102 | if (length(i) == 0) { 103 | error(logger, paste("ID column", config$col.id, "not found in input data-frame!")) 104 | } 105 | obs.id <- data[, i] 106 | dbg(logger, paste("Pop ID Column: ", i)) 107 | data <- data[, -i] # pop 108 | } else { 109 | obs.id <- NULL 110 | } 111 | 112 | # Prune data to contain only useful predictor variables 113 | xDF <- PruneFeatures(data, conf$na.threshold, conf$col.skip.fname) 114 | if (ncol(xDF) == 0) { 115 | error(logger, paste("PruneFeatures pruned everything! Nothing to do!")); 116 | } 117 | 118 | # Make sure categorical variables are factors 119 | if ("col.types.fname" %in% names(conf)) { 120 | col.types <- ReadColumnTypes(config$col.types.fname) 121 | } else { 122 | col.types <- ifelse(sapply(data, inherits, what = "factor"), 2, 1) 123 | } 124 | ef.out <- EnforceFactors(xDF, col.types, config$min.level.count) 125 | xDF <- ef.out$x 126 | x.cat.vars <- ef.out$cat.vars 127 | x.recoded.cat.vars <- ef.out$recoded.cat.vars 128 | rm(ef.out) 129 | 130 | # Check column types 131 | if (!CheckColumnTypes(xDF)) { 132 | error(logger, paste("Invalid column type(s) found!")); 133 | } 134 | 135 | # Any numeric column preprocessing? 136 | x.trims <- NULL 137 | if (nchar(config$col.winz.fname) > 0) { 138 | wf.out <- WinsorizeFeatures(xDF, config$col.winz.fname) 139 | xDF <- wf.out$x 140 | x.trims <- wf.out$trims 141 | # Overwrite trim.qntl param 142 | rf.ctxt$trim.qntl <- 0.0 143 | } 144 | 145 | # Coerce to numeric matrix & mark NAs 146 | x <- data.matrix(xDF) 147 | if ("data.NA.value" %in% names(conf)) { 148 | x[x == conf$data.NA.value] <- 9.0e30 149 | } else { 150 | x[is.na(x)] <- 9.0e30 151 | } 152 | 153 | # Save workspace (before training - mostly for debugging purposes)? 154 | if (config$save.workspace) { 155 | dbg(logger, "Saving workspace") 156 | save(xDF, x, y, obs.wt, rf.ctxt, x.cat.vars, file = file.path(rf.ctxt$export.dir , "workspace.rdata")) 157 | } 158 | 159 | # Build model 160 | set.seed(config$rand.seed) 161 | rfmod <- TrainRF(x, y, obs.wt, rf.ctxt, x.cat.vars) 162 | # ... Log (estimated generalization) model error and size 163 | rfmod.stats <- runstats(rfmod) 164 | info(logger, paste("Estimated criterion value:", rfmod.stats$cri, paste("(+/- ", rfmod.stats$err, "),", sep=""), 165 | "Num terms:", rfmod.stats$terms)) 166 | 167 | # Training error... just for info 168 | y.hat.train <- rfpred(x) 169 | train.stats <- list() 170 | if (rf.ctxt$rfmode == "class") { 171 | conf.m <- table(y, sign(y.hat.train)) 172 | stopifnot("-1" %in% rownames(conf.m)) 173 | stopifnot("1" %in% rownames(conf.m)) 174 | TN <- ifelse("-1" %in% colnames(conf.m), conf.m["-1", "-1"], 0) 175 | FP <- ifelse("1" %in% colnames(conf.m), conf.m["-1","1"], 0) 176 | FN <- ifelse("-1" %in% colnames(conf.m), conf.m["1", "-1"], 0) 177 | TP <- ifelse("1" %in% colnames(conf.m), conf.m["1","1"], 0) 178 | train.acc <- 100*(TN+TP)/length(y.hat.train) 179 | info(logger, paste("Training acc:", round(train.acc, 2))) 180 | info(logger, sprintf("Training confusion matrix - 0/0: %d, 0/1: %d, 1/0: %d, 1/1: %d", 181 | TN, FP, FN, TP)) 182 | train.stats$type <- "class" 183 | train.stats$acc <- train.acc 184 | train.stats$conf.m <- conf.m 185 | } else { 186 | re.train.error <- sum(abs(y.hat.train - y))/nrow(x) 187 | med.train.error <- sum(abs(y - median(y)))/nrow(x) 188 | aae.train <- re.train.error / med.train.error 189 | info(logger, sprintf("Training AAE: %f (RE:%f, Med:%f)", aae.train, re.train.error, med.train.error)) 190 | train.stats$type <- "regress" 191 | train.stats$aae <- aae.train 192 | train.stats$re.error <- re.train.error 193 | train.stats$med.error <- med.train.error 194 | } 195 | 196 | # Export (save) model 197 | ExportModel(rfmod = rfmod, rfmod.path = rf.ctxt$working.dir, 198 | x = x, y = y, wt = obs.wt, y.hat = y.hat.train, 199 | out.path = rf.ctxt$export.dir, 200 | x.df = xDF, col.types = col.types, 201 | winz = ifelse(is.null(x.trims), rf.ctxt$trim.qntl, x.trims), 202 | x.recoded.cat.vars) 203 | 204 | # Dump tuples, if appropriate 205 | if (!is.null(obs.id)) { 206 | WriteObsIdYhat(out.path = rf.ctxt$export.dir, obs.id = obs.id, y = y, y.hat = y.hat.train) 207 | } 208 | 209 | return(list(rfmod = rfmod, train.stats = train.stats)) 210 | } 211 | -------------------------------------------------------------------------------- /src/rfPreproc.R: -------------------------------------------------------------------------------- 1 | source(file.path(REGO_HOME, "/src/logger.R")) 2 | source(file.path(REGO_HOME, "/src/winsorize.R")) 3 | 4 | # Constants 5 | kLowCountLevelsName <- "_LowCountLevels_" 6 | 7 | PruneFeatures <- function(x, na.thresh = 0.95, col.skip.fname = "") 8 | { 9 | # Prune df to contain only useful predictor variables -- e.g. remove 10 | # singletons, all NAs. 11 | # 12 | # Args: 13 | # x : data frame 14 | # na.thresh : max percentage of NAs allowed in a column of x 15 | # col.skip.fname : character string naming the file with user-specifed 16 | # column to remoce from x 17 | # 18 | # Returns: 19 | # A new data frame without pruned columns 20 | dbg(logger, "PruneFeatures:") 21 | stopifnot(class(x) == "data.frame") 22 | 23 | colNum2Remove<- vector() 24 | N <- nrow(x) 25 | ## 0) all identical values? 26 | colNum2Remove.0 <- vector() 27 | for (i in 1:ncol(x)) { 28 | if (length(unique(x[,i])) == 1) { 29 | dbg(logger, paste(" i = ", i, " name = ", colnames(x)[i], "is constant")) 30 | colNum2Remove.0 <- c(colNum2Remove.0, i) 31 | } 32 | } 33 | info(logger, paste("Singletons:", length(colNum2Remove.0))) 34 | colNum2Remove <- c(colNum2Remove, colNum2Remove.0) 35 | 36 | ## 1) only two values, one of which is NA? 37 | colNum2Remove.1 <- vector() 38 | for (i in 1:ncol(x)) { 39 | if (length(unique(x[,i])) == 2 && length(which(is.na(x[,i]) == T)) > 0) { 40 | dbg(logger, paste(" i = ", i, " name = ", colnames(x)[i], "is NA+constant")) 41 | colNum2Remove.1 <- c(colNum2Remove.1, i) 42 | } 43 | } 44 | info(logger, paste("Quasi-Singletons:", length(colNum2Remove.1))) 45 | colNum2Remove <- c(colNum2Remove, colNum2Remove.1) 46 | 47 | ## 2) columns w "many" NA's? 48 | max.num.NAs <- na.thresh*N 49 | dbg(logger, paste("Max Num NAs:", max.num.NAs)) 50 | colNum2Remove.2 <- vector() 51 | for (i in 1:ncol(x)) { 52 | if ( length(which(is.na(x[,i]) == T)) > max.num.NAs ) { 53 | dbg(logger, paste(" i =", i, " name =", colnames(x)[i], "is mostly NA")) 54 | colNum2Remove.2 <- c(colNum2Remove.2, i) 55 | } 56 | } 57 | info(logger, paste("Mostly NAs:", length(colNum2Remove.2))) 58 | colNum2Remove <- c(colNum2Remove, colNum2Remove.2) 59 | 60 | ## 3) user-specified columns? 61 | if (nchar(col.skip.fname) > 0) { 62 | if (file.exists(col.skip.fname)) { 63 | colNum2Remove.3 <- vector() 64 | feat2skip <- read.table(col.skip.fname, sep = ",", header = F, col.names = c("var.name"), as.is = T)[,1] 65 | for (cname in feat2skip) { 66 | iCol <- grep(paste("^", cname, "$", sep=""), colnames(x), perl=T, ignore.case=T) 67 | if ( length(iCol == 1) ) { 68 | colNum2Remove.3 <- c(colNum2Remove.3, iCol) 69 | } else { 70 | dbg(logger, paste("Didn't find: ", cname)) 71 | } 72 | } 73 | info(logger, paste("User specified:", length(colNum2Remove.3))) 74 | dbg(logger, paste(colnames(x)[colNum2Remove.3])) 75 | colNum2Remove <- c(colNum2Remove, colNum2Remove.3) 76 | } else { 77 | warn(logger, paste("File ", col.skip.fname, "doesn't exist")) 78 | } 79 | } 80 | 81 | # Columns w "some" NA's? Just for info... 82 | num.some.NA <- 0 83 | for (i in 1:ncol(x)) { 84 | if (!(i %in% colNum2Remove)) { 85 | num.NA <- length(which(is.na(x[,i]) == T)) 86 | if (num.NA > 0) { 87 | dbg(logger, paste(" i =", i, " NA-rate =", round(100.0*num.NA/N), 88 | "\tname =", colnames(x)[i])) 89 | num.some.NA <- num.some.NA + 1 90 | } 91 | } 92 | } 93 | info(logger, paste("Some NAs:", num.some.NA)) 94 | 95 | if (length(colNum2Remove) > 0) { 96 | if (ncol(x) - length(colNum2Remove) == 1) { 97 | # Only one column left... avoid coercion to vector 98 | tmp.df <- data.frame(x[,-colNum2Remove]) 99 | colnames(tmp.df) <- setdiff(colnames(x), colnames(x)[colNum2Remove]) 100 | return(tmp.df) 101 | } else { 102 | return(x[, -colNum2Remove]) 103 | } 104 | } else { 105 | return(x) 106 | } 107 | } 108 | 109 | WinsorizeFeatures <- function(x, feat2winz.fname) 110 | { 111 | # Applies Winsorization transformation to the specified variables. 112 | # 113 | # Args: 114 | # x : data frame 115 | # feat2winz.fname : text file with names of columns to Winsorize 116 | # (pairs expected) 117 | # Returns: 118 | # A list with 119 | # x : copy of input data frame with transformed columns 120 | # trims : data frame with used trim values 121 | stopifnot(class(x) == data.frame) 122 | stopifnot(file.access(feat2winz.fname, mode = 4) == 0) 123 | 124 | dbg(logger, "WinsorizeFeatures:") 125 | feat2winz <- read.table(feat2winz.fname, sep =",", as.is=T) 126 | dbg(logger, paste(nrow(feat2winz), "columns to Winsorize")) 127 | 128 | # Augment data-frame to remember computed trims 129 | feat2winz <- cbind(feat2winz, rep(NA, nrow(feat2winz)), 130 | rep(NA, nrow(feat2winz)), rep(NA, nrow(feat2winz))) 131 | names(feat2winz) <- c("vname", "beta", "min2keep", "max2keep", "mean") 132 | 133 | # Loop over columns to winsorize 134 | for (i in 1:nrow(feat2winz)) { 135 | if (is.finite(feat2winz$beta[i]) && feat2winz$beta[i] >= 0 && feat2winz$beta[i] < 1) { 136 | iCol <- grep(paste("^",feat2winz$vname[i],"$", sep=""), colnames(x), perl=T) 137 | if ( length(iCol) == 1 ) { 138 | if (class(x[, iCol]) %in% c("numeric", "integer")) { 139 | l.x <- winsorize(x[, iCol], feat2winz$beta[i]) 140 | x[, iCol] <- l.x$x 141 | feat2winz$min2keep[i] <- l.x$min2keep 142 | feat2winz$max2keep[i] <- l.x$max2keep 143 | feat2winz$mean[i] <- mean(l.x$x, na.rm = T) 144 | } else { 145 | warn(logger, paste("Can't trim non-continuous var: ", feat2winz$vname[i])) 146 | } 147 | } else { 148 | warn(logger, paste("Didn't find: ", feat2winz$vname[i])) 149 | } 150 | } 151 | } 152 | 153 | return(list(x=x, trims=feat2winz)) 154 | } 155 | 156 | CheckColumnTypes <- function(x) 157 | { 158 | # Checks if there are columns which are not numeric, factor or logical 159 | # 160 | # Args: x : data frame 161 | # Returns: boolean 162 | dbg(logger, "CheckColumnTypes:") 163 | good_class <- sapply(x, inherits, what = c("numeric", "integer", "factor", "logical")) 164 | all.good <- all(good_class) 165 | sapply(which(!good_class), function(i) { 166 | warn(logger, paste(" i = ", i, " name = ", colnames(x)[i], "type: ", paste(class(x[,i]), collapse = ", "))) 167 | }) 168 | return(all.good) 169 | } 170 | 171 | EnforceFactors <- function(x, col.types, min.level.count = 0) 172 | { 173 | # Make sure categorical variables are factors in given data-frame. 174 | # 175 | # Args: 176 | # x : data frame 177 | # col.types : vector indicating column types -- 1:continuous 178 | # 2:categorical 179 | # min.level.count : merge levels with fewer than this count 180 | # Returns: 181 | # A copy of input data frame with transformed columns, and a 182 | # list of categorical variable indices 183 | dbg(logger, "EnforceFactors:") 184 | stopifnot(ncol(x) <= length(col.types)) 185 | 186 | cat.vars <- c() 187 | recoded.cat.vars <- list() 188 | i.recoded <- 1 189 | for (i in 1:ncol(x)) { 190 | cname <- colnames(x)[i] 191 | i.col.type <- grep(paste("^", cname, "$", sep=""), names(col.types), perl=T, ignore.case=T) 192 | if ( length(i.col.type == 1) ) { 193 | if (col.types[i.col.type] == 2) { 194 | cat.vars <- c(cat.vars, i) 195 | if ( ! inherits(x[,i], "factor") ) { 196 | dbg(logger, paste(" Converting i =", i, " name =", cname, 197 | "from: '", class(x[, i]), "' to 'factor'")) 198 | x[, i] <- as.factor(x[, i]) 199 | } 200 | # Check for low count levels? 201 | if (min.level.count > 0) { 202 | low.count.levels <- c() 203 | level.hist <- summary(x[, i], maxsum = nlevels(x[, i])) 204 | for (i.level in names(level.hist)) { 205 | if (level.hist[i.level] < min.level.count) { 206 | low.count.levels <- c(low.count.levels, i.level) 207 | } 208 | } 209 | # Recode factor if necessary 210 | if (length(low.count.levels) > 0) { 211 | x[, i] <- factor(ifelse(x[, i] %in% low.count.levels, kLowCountLevelsName, as.character(x[, i]))) 212 | recoded.cat.vars[[i.recoded]] <- list(var=cname, low.count.levels = low.count.levels) 213 | i.recoded <- i.recoded + 1 214 | dbg(logger, paste(" Collapsing low-count levels for '", cname, "' ; ", paste(low.count.levels, collapse = ', '))) 215 | } 216 | } 217 | } 218 | } else { 219 | error(logger, paste("Didn't find type for:", cname)) 220 | } 221 | } 222 | 223 | return(list(x = x, cat.vars = cat.vars, recoded.cat.vars = recoded.cat.vars)) 224 | } 225 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Rego (Rule Ensembles Go!) 2 | ========================= 3 | 4 | Rego provides a command-line batch interface to the [RuleFit](http://statweb.stanford.edu/~jhf/R_RuleFit.html) statistical model building program, making it possible to: 5 | 6 | * Declaratively train and run a model using ensemble methods 7 | * Export a model as production-ready SQL (for MySQL, SQLServer, Hive, and Netezza) 8 | * Incorporate best practices for data pre-processing and model interpretation 9 | * Provide a detailed, human-readable model explanation in HTML 10 | 11 | Rego uses RuleFit, Stanford Professor Jerome Friedman's implementation of *Rule Ensembles*, an interpretable type of ensemble model where the base-learners consists of conjunctive rules derived from decision trees. 12 | 13 | Rego is developed and maintained by Dr. Giovanni Seni, and was initially sponsored by the Data Engineering and Analytics group (IDEA) of [Intuit, Inc](http://intuit.com). 14 | 15 | 16 | Rego is released as open source under the Eclipse Public License - v 1.0. 17 | 18 | ## What is Rego? 19 | 20 | Predictive learning plays an important role in many areas of science, finance and industry. Here are some examples of learning problems: 21 | 22 | * Predict whether a customer would be attracted to a new service offering. Recognizing such customers can reduce the cost of a campaign by reducing the number of contacts. 23 | * Predict whether a web site visitor is unlikely to become a customer. The prediction allows prioritization of customer support resources. 24 | * Identify the risk factors for churn, based on the content of customer support messages. 25 | 26 | Rego is a collection of R-based scripts intended to facilitate the process of building, interpreting, and deploying state-of-art predictive learning models. Rego can: 27 | 28 | * Enable rapid experimentation 29 | * Increase self-service capability 30 | * Support easy model deployment into a production environment 31 | 32 | Under the hood Rego uses [RuleFit](http://statweb.stanford.edu/~jhf/R_RuleFit.html), a statistical model building program created by Prof. Jerome Friedman. RuleFit was written in Fortran but has an R interface. RuleFit implements a model building methodology known as ["ensembling"](http://www.amazon.com/Ensemble-Methods-Data-Mining-Predictions/dp/1608452840), where multiple simple models (base learners) are combined into one usually more accurate than the best of its components. This type of model can be described as an additive expansion of the form: 33 | 34 | F(x) = a0 + a1*b1(x) + a2*b2(x) + ... + aM*bM(x) where the bj(x)'s are the base-learners. 35 | 36 | In the case of RuleFit, the bj(x) terms are conjunctive rules of the form 37 | 38 | if x1 > 22 and x2 > 27 then 1 else 0 39 | 40 | or linear functions of a single variable -- e.g., bj(x) = xj. 41 | 42 | Using base-learners of this type is attractive because they constitute easily interpretable statements about attributes xj. They also preserve the desirable characteristics of Decision Trees such as easy handling of categorical attributes, robustness to outliers in the distribution of x, etc. 43 | 44 | RuleFit builds model F(x) in a three-step process: 45 | 46 | 1. build a tree ensemble (one where the bj(x)'s are decision trees), 47 | 2. generate candidate rules from the tree ensemble, and 48 | 3. fit coefficients aj via regularized regression. 49 | 50 | Rego consists of additional R code that we've written to make working with RuleFit easier, including: 51 | 52 | * The ability to have multiple rulefit batch jobs running simultaneously 53 | * Easily specifying a data source 54 | * Automatically executing common preprocessing operations 55 | * Automatically generating a model summary report with interpretation plots and quality assessment 56 | * Exporting a model from R to SQL for deployment in a production environment 57 | 58 | 59 | ## Getting Started 60 | 61 | ### 1. Dependencies 62 | 63 | Install [R](http://cran.us.r-project.org/) and the following R packages: R2HTML, ROCR, RODBC, getopt 64 | 65 | ### 2. Environment variables 66 | 67 | * ```REGO_HOME```: environment variable pointing to where you have checked out the Rego code 68 | * ```RF_HOME```: environment variable pointing to appropriate RuleFit executable -- e.g., export ```RF_HOME=$REGO_HOME/lib/mac``` 69 | 70 | ### 3. RuleFit binaries 71 | 72 | Before using Rego, you must download the [RuleFit](http://statweb.stanford.edu/~jhf/R_RuleFit.html) binary appropriate to your platform: 73 | 74 | #### Windows (64-bit) 75 | 76 | Place the following files in ```$REGO_HOME/lib/RuleFit/windows/``` 77 | 78 | * [rf_go.exe for Windows](http://statweb.stanford.edu/~jhf/r-rulefit/rulefit3/windows/windows64/rf_go.exe) 79 | * [move.bat](http://statweb.stanford.edu/~jhf/r-rulefit/rulefit3/windows/move.bat) 80 | 81 | #### Linux (64-bit) 82 | 83 | Place the following file in ```$REGO_HOME/lib/RuleFit/linux/``` 84 | 85 | * [rf_go.exe for Linux](http://statweb.stanford.edu/~jhf/r-rulefit/rulefit3/linux/linux64/rf_go.exe) 86 | 87 | Run ```chmod u+x ${REGO_HOME}/lib/RuleFit/linux/rf_go.exe``` to make the file executable. 88 | 89 | #### Mac (64-bit) 90 | 91 | Place the following file in ```$REGO_HOME/lib/RuleFit/mac/``` 92 | 93 | * [rf_go.exe for Mac OS X](http://statweb.stanford.edu/~jhf/r-rulefit/rulefit3/mac/mac64/rf_go.exe) 94 | 95 | Run ```chmod u+x $REGO_HOME/lib/RuleFit/mac/rf_go.exe``` to make the file executable. 96 | 97 | ## Commands 98 | 99 | ### Build a model 100 | 101 | * ```$REGO_HOME/bin/trainModel.sh --d=DATA.conf --m=MODEL.conf [--l LOG.txt]``` 102 | * Input: 103 | * ```DATA.conf```: [data configuration file](http://github.com/intuit/rego/blob/master/doc/DATA_CONF.md) specifying options such as where the data is coming from, what column corresponds to the target, etc. 104 | * ```MODEL.conf```: [model configuration file](http://github.com/intuit/rego/blob/master/doc/MODEL_CONF.md) specifying options such as the type of model being fit, the criteria being optimized, etc. 105 | * ```LOG.txt```: optional file name where to write logging messages 106 | * Output: 107 | * ```model_summary.html```: model summary and assessment 108 | * ```model_singleplot.html```: interpretation plots 109 | * ``````: for later export or prediction 110 | 111 | ### Export a Model 112 | 113 | * ```$REGO_HOME/bin/exportModel.sh --m=MODEL.dir [--c=EXPORT.conf]``` 114 | * Input 115 | * ```MODEL_DIR```: path to model definition files 116 | * ```EXPORT.conf```: the [export configuration file](http://github.com/intuit/rego/blob/master/doc/EXPORT_CONF.md) specifying options such as desired sql dialect, type of scoring clause, etc. 117 | * Output: 118 | * ```SQL_FILE```: output file name containing model as a SQL expression 119 | 120 | ### Predict on New Data 121 | 122 | * ```$REGO_HOME/bin/runModel.sh --m=MODEL.dir --d=DATA.conf``` 123 | * Input: 124 | * ```MODEL_DIR```: path to model definition files 125 | * ```DATA.conf```: [data configuration file](http://github.com/intuit/rego/blob/master/doc/DATA_CONF.md) specifying test data location 126 | * Output: 127 | * Text file with `````` tuples 128 | 129 | ### Deploy a Model 130 | 131 | * ```$REGO_HOME/bin/runModelSQL.sh --host --dbIn --tblIn= --pk= --model= --dbOut --tblOut=``` 132 | * Input 133 | * ```dbIn.tblIn```: new data to be scored 134 | * ```model```: previously built (and exported) model 135 | * Output: 136 | * ```dbOut.tblOut```: Computed scores 137 | 138 | ### Create Partial Dependence Plot (under construction) 139 | 140 | * ```$REGO_HOME/rfPardep_main.R -c PARDEP.conf ``` 141 | * Input 142 | * ```PARDEP.conf```: [data configuration file](https://github.com/intuit/rego/blob/master/doc/PARDEP_CONF.md) specifying variable to be plotted and partial dependence options 143 | * Output: 144 | * PNG file with partial dependence graph 145 | 146 | ## Examples 147 | 148 | These examples show how to use a CSV or RData file as a data source, using the R diamonds dataset. 149 | 150 | The CSV and RData files were created in R as follows: 151 | 152 | X <- ggplot2::diamonds 153 | write.csv(X, file = "diamonds.csv", na = "", row.names = FALSE) 154 | save(X, file = "diamonds.RData") 155 | 156 | CSV: 157 | 158 | # training 159 | $REGO_HOME/bin/trainModel.sh --d=data_csv.conf --m=model.conf 160 | 161 | # prediction on csv data file 162 | $REGO_HOME/bin/runModel.sh --m=/tmp/REgo/Diamonds_wd/export --d=predict_csv.conf 163 | 164 | # export model to SQL $REGO_HOME/bin/exportModel.sh --m=/tmp/REgo/Diamonds_wd/export --c=$REGO_HOME/conf/EXPORT.conf 165 | # prediction on db table 166 | $REGO_HOME/bin/runModelSQL.sh --host= --dbIn= --tblIn=diamond_test --pk=id --model=rules_forSQL.txt --sql=HiveQL --typeOut=1 167 | 168 | RData: 169 | 170 | # training 171 | $REGO_HOME/bin/trainModel.sh --d=data_rdata.conf --m=model.conf 172 | 173 | # prediction or RData data file 174 | $REGO_HOME/bin/runModel.sh --m=/tmp/REgo/Diamonds_wd/export --d=predict_rdata.conf 175 | 176 | 177 | 178 | #### Warning! 179 | 180 | If using RData as the data source type for either training or prediction, then it must be used for both. This is because the order of factor levels may be different for the two source types. 181 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Eclipse Public License - v 1.0 2 | 3 | THE ACCOMPANYING PROGRAM IS PROVIDED UNDER THE TERMS OF THIS ECLIPSE PUBLIC LICENSE ("AGREEMENT"). ANY USE, REPRODUCTION OR DISTRIBUTION OF THE PROGRAM CONSTITUTES RECIPIENT'S ACCEPTANCE OF THIS AGREEMENT. 4 | 5 | 1. DEFINITIONS 6 | 7 | "Contribution" means: 8 | 9 | a) in the case of the initial Contributor, the initial code and documentation distributed under this Agreement, and 10 | b) in the case of each subsequent Contributor: 11 | i) changes to the Program, and 12 | ii) additions to the Program; 13 | where such changes and/or additions to the Program originate from and are distributed by that particular Contributor. A Contribution 'originates' from a Contributor if it was added to the Program by such Contributor itself or anyone acting on such Contributor's behalf. Contributions do not include additions to the Program which: (i) are separate modules of software distributed in conjunction with the Program under their own license agreement, and (ii) are not derivative works of the Program. 14 | "Contributor" means any person or entity that distributes the Program. 15 | 16 | "Licensed Patents" mean patent claims licensable by a Contributor which are necessarily infringed by the use or sale of its Contribution alone or when combined with the Program. 17 | 18 | "Program" means the Contributions distributed in accordance with this Agreement. 19 | 20 | "Recipient" means anyone who receives the Program under this Agreement, including all Contributors. 21 | 22 | 2. GRANT OF RIGHTS 23 | 24 | a) Subject to the terms of this Agreement, each Contributor hereby grants Recipient a non-exclusive, worldwide, royalty-free copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, distribute and sublicense the Contribution of such Contributor, if any, and such derivative works, in source code and object code form. 25 | b) Subject to the terms of this Agreement, each Contributor hereby grants Recipient a non-exclusive, worldwide, royalty-free patent license under Licensed Patents to make, use, sell, offer to sell, import and otherwise transfer the Contribution of such Contributor, if any, in source code and object code form. This patent license shall apply to the combination of the Contribution and the Program if, at the time the Contribution is added by the Contributor, such addition of the Contribution causes such combination to be covered by the Licensed Patents. The patent license shall not apply to any other combinations which include the Contribution. No hardware per se is licensed hereunder. 26 | c) Recipient understands that although each Contributor grants the licenses to its Contributions set forth herein, no assurances are provided by any Contributor that the Program does not infringe the patent or other intellectual property rights of any other entity. Each Contributor disclaims any liability to Recipient for claims brought by any other entity based on infringement of intellectual property rights or otherwise. As a condition to exercising the rights and licenses granted hereunder, each Recipient hereby assumes sole responsibility to secure any other intellectual property rights needed, if any. For example, if a third party patent license is required to allow Recipient to distribute the Program, it is Recipient's responsibility to acquire that license before distributing the Program. 27 | d) Each Contributor represents that to its knowledge it has sufficient copyright rights in its Contribution, if any, to grant the copyright license set forth in this Agreement. 28 | 3. REQUIREMENTS 29 | 30 | A Contributor may choose to distribute the Program in object code form under its own license agreement, provided that: 31 | 32 | a) it complies with the terms and conditions of this Agreement; and 33 | b) its license agreement: 34 | i) effectively disclaims on behalf of all Contributors all warranties and conditions, express and implied, including warranties or conditions of title and non-infringement, and implied warranties or conditions of merchantability and fitness for a particular purpose; 35 | ii) effectively excludes on behalf of all Contributors all liability for damages, including direct, indirect, special, incidental and consequential damages, such as lost profits; 36 | iii) states that any provisions which differ from this Agreement are offered by that Contributor alone and not by any other party; and 37 | iv) states that source code for the Program is available from such Contributor, and informs licensees how to obtain it in a reasonable manner on or through a medium customarily used for software exchange. 38 | When the Program is made available in source code form: 39 | 40 | a) it must be made available under this Agreement; and 41 | b) a copy of this Agreement must be included with each copy of the Program. 42 | Contributors may not remove or alter any copyright notices contained within the Program. 43 | 44 | Each Contributor must identify itself as the originator of its Contribution, if any, in a manner that reasonably allows subsequent Recipients to identify the originator of the Contribution. 45 | 46 | 4. COMMERCIAL DISTRIBUTION 47 | 48 | Commercial distributors of software may accept certain responsibilities with respect to end users, business partners and the like. While this license is intended to facilitate the commercial use of the Program, the Contributor who includes the Program in a commercial product offering should do so in a manner which does not create potential liability for other Contributors. Therefore, if a Contributor includes the Program in a commercial product offering, such Contributor ("Commercial Contributor") hereby agrees to defend and indemnify every other Contributor ("Indemnified Contributor") against any losses, damages and costs (collectively "Losses") arising from claims, lawsuits and other legal actions brought by a third party against the Indemnified Contributor to the extent caused by the acts or omissions of such Commercial Contributor in connection with its distribution of the Program in a commercial product offering. The obligations in this section do not apply to any claims or Losses relating to any actual or alleged intellectual property infringement. In order to qualify, an Indemnified Contributor must: a) promptly notify the Commercial Contributor in writing of such claim, and b) allow the Commercial Contributor to control, and cooperate with the Commercial Contributor in, the defense and any related settlement negotiations. The Indemnified Contributor may participate in any such claim at its own expense. 49 | 50 | For example, a Contributor might include the Program in a commercial product offering, Product X. That Contributor is then a Commercial Contributor. If that Commercial Contributor then makes performance claims, or offers warranties related to Product X, those performance claims and warranties are such Commercial Contributor's responsibility alone. Under this section, the Commercial Contributor would have to defend claims against the other Contributors related to those performance claims and warranties, and if a court requires any other Contributor to pay any damages as a result, the Commercial Contributor must pay those damages. 51 | 52 | 5. NO WARRANTY 53 | 54 | EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, THE PROGRAM IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. Each Recipient is solely responsible for determining the appropriateness of using and distributing the Program and assumes all risks associated with its exercise of rights under this Agreement , including but not limited to the risks and costs of program errors, compliance with applicable laws, damage to or loss of data, programs or equipment, and unavailability or interruption of operations. 55 | 56 | 6. DISCLAIMER OF LIABILITY 57 | 58 | EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, NEITHER RECIPIENT NOR ANY CONTRIBUTORS SHALL HAVE ANY LIABILITY FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING WITHOUT LIMITATION LOST PROFITS), HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OR DISTRIBUTION OF THE PROGRAM OR THE EXERCISE OF ANY RIGHTS GRANTED HEREUNDER, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 59 | 60 | 7. GENERAL 61 | 62 | If any provision of this Agreement is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this Agreement, and without further action by the parties hereto, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable. 63 | 64 | If Recipient institutes patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Program itself (excluding combinations of the Program with other software or hardware) infringes such Recipient's patent(s), then such Recipient's rights granted under Section 2(b) shall terminate as of the date such litigation is filed. 65 | 66 | All Recipient's rights under this Agreement shall terminate if it fails to comply with any of the material terms or conditions of this Agreement and does not cure such failure in a reasonable period of time after becoming aware of such noncompliance. If all Recipient's rights under this Agreement terminate, Recipient agrees to cease use and distribution of the Program as soon as reasonably practicable. However, Recipient's obligations under this Agreement and any licenses granted by Recipient relating to the Program shall continue and survive. 67 | 68 | Everyone is permitted to copy and distribute copies of this Agreement, but in order to avoid inconsistency the Agreement is copyrighted and may only be modified in the following manner. The Agreement Steward reserves the right to publish new versions (including revisions) of this Agreement from time to time. No one other than the Agreement Steward has the right to modify this Agreement. The Eclipse Foundation is the initial Agreement Steward. The Eclipse Foundation may assign the responsibility to serve as the Agreement Steward to a suitable separate entity. Each new version of the Agreement will be given a distinguishing version number. The Program (including Contributions) may always be distributed subject to the version of the Agreement under which it was received. In addition, after a new version of the Agreement is published, Contributor may elect to distribute the Program (including its Contributions) under the new version. Except as expressly stated in Sections 2(a) and 2(b) above, Recipient receives no rights or licenses to the intellectual property of any Contributor under this Agreement, whether expressly, by implication, estoppel or otherwise. All rights in the Program not expressly granted under this Agreement are reserved. 69 | 70 | This Agreement is governed by the laws of the State of New York and the intellectual property laws of the United States of America. No party to this Agreement will bring a legal action under this Agreement more than one year after the cause of action arose. Each party waives its rights to a jury trial in any resulting litigation. 71 | -------------------------------------------------------------------------------- /src/rfTrain_main.R: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/Rscript 2 | 3 | ############################################################################### 4 | # FILE: rfTrain_main.R 5 | # 6 | # USAGE: rfTrain_main.R -d DATA.conf -m MODEL.conf 7 | # 8 | # DESCRIPTION: 9 | # Provides a "batch" interface to the RuleFit statistical model building 10 | # program. RuleFit refers to Professor Jerome Friedman's implementation of Rule 11 | # Ensembles, an interpretable type of ensemble model where the base-learners 12 | # consist of conjunctive rules derived from decision trees. 13 | # 14 | # ARGUMENTS: 15 | # DATA.conf: the data configuration file specifying options such as 16 | # where the data is coming from, what column corresponds to 17 | # the target, etc. 18 | # MODEL.conf: the model configuration file specifying options such as the 19 | # type of model being fit, the criteria being optimized, etc. 20 | # 21 | # REQUIRES: 22 | # REGO_HOME: environment variable pointing to the directory where you 23 | # have placed this file (and its companion ones) 24 | # RF_HOME: environment variable pointing to appropriate RuleFit 25 | # executable -- e.g., export RF_HOME=$REGO_HOME/lib/RuleFit/mac 26 | # 27 | # AUTHOR: Giovanni Seni 28 | ############################################################################### 29 | REGO_HOME <- Sys.getenv("REGO_HOME") 30 | source(file.path(REGO_HOME, "/src/logger.R")) 31 | source(file.path(REGO_HOME, "/src/rfAPI.R")) 32 | source(file.path(REGO_HOME, "/src/rfTrain.R")) 33 | source(file.path(REGO_HOME, "/src/rfR2HTML.R")) 34 | source(file.path(REGO_HOME, "/src/rfGraphics.R")) 35 | library(getopt) 36 | library(RODBC) 37 | 38 | ValidateConfigArgs <- function(conf) 39 | { 40 | # Validates and initializes configuration parameters. 41 | # 42 | # Args: 43 | # conf: A list of pairs 44 | # Returns: 45 | # A list of pairs 46 | 47 | # Must have a valid data source type 48 | stopifnot("data.source.type" %in% names(conf)) 49 | stopifnot(conf$data.source.type %in% c("csv", "db", "rdata")) 50 | if (conf$data.source.type == "db") { 51 | stopifnot("db.dsn" %in% names(conf) && "db.name" %in% names(conf) && 52 | "db.type" %in% names(conf) && "db.tbl.name" %in% names(conf)) 53 | } else if (conf$data.source.type == "csv") { 54 | stopifnot("csv.path" %in% names(conf) && "csv.fname" %in% names(conf)) 55 | if ("csv.sep" %in% names(conf)) { 56 | conf$csv.sep <- as.character(conf$csv.sep) 57 | } else { 58 | conf$csv.sep <- "," 59 | } 60 | } else { # rdata 61 | stopifnot(c("rdata.path", "rdata.fname") %in% names(conf)) 62 | } 63 | 64 | # Must have column type specification, unless data source type is 65 | # "rdata" 66 | stopifnot(conf$data.source.type == "rdata" || "col.types.fname" %in% names(conf)) 67 | 68 | ## Set defaults for the options that were not specified 69 | if (!("col.y" %in% names(conf))) { 70 | conf$col.y <- "y" 71 | } 72 | 73 | if (!("db.tbl.maxrows" %in% names(conf))) { 74 | conf$db.tbl.maxrows <- "ALL" 75 | } else { 76 | conf$db.tbl.maxrows <- as.numeric(conf$db.tbl.maxrows) 77 | } 78 | 79 | if (!("col.weights" %in% names(conf)) || 80 | nchar(conf$col.weights) == 0) { 81 | conf$col.weights <- NA 82 | } 83 | 84 | if (!("col.id" %in% names(conf)) || 85 | nchar(conf$col.id) == 0) { 86 | conf$col.id <- NA 87 | } 88 | 89 | if (!("col.skip.fname" %in% names(conf))) { 90 | conf$col.skip.fname <- "" 91 | } 92 | 93 | if (!("col.winz.fname" %in% names(conf))) { 94 | conf$col.winz.fname <- "" 95 | } 96 | 97 | if (!("na.threshold" %in% names(conf))) { 98 | conf$na.threshold <- 0.95 99 | } else { 100 | conf$na.threshold <- as.numeric(conf$na.threshold) 101 | } 102 | 103 | if (!("min.level.count" %in% names(conf))) { 104 | conf$min.level.count <- 0 105 | } else { 106 | conf$min.level.count <- as.numeric(conf$min.level.count) 107 | } 108 | 109 | if (!("do.class.balancing" %in% names(conf))) { 110 | conf$do.class.balancing <- FALSE 111 | } else { 112 | conf$do.class.balancing <- (as.numeric(conf$do.class.balancing) == 1) 113 | } 114 | 115 | if (!("html.min.var.imp" %in% names(conf))) { 116 | conf$html.min.var.imp <- 5 117 | } else { 118 | conf$html.min.var.imp <- as.numeric(conf$html.min.var.imp) 119 | } 120 | 121 | if (!("html.min.rule.imp" %in% names(conf))) { 122 | conf$html.min.rule.imp <- 5 123 | } else { 124 | conf$html.min.rule.imp <- as.numeric(conf$html.min.rule.imp) 125 | } 126 | 127 | if ("html.singleplot.fname" %in% names(conf)) { 128 | if (!("html.singleplot.title" %in% names(conf))) { 129 | conf$html.singleplot.title <- "Dependence Plots:" 130 | } 131 | if (!("html.singleplot.nvars" %in% names(conf))) { 132 | conf$html.singleplot.nvars <- 10 133 | } else { 134 | conf$html.singleplot.nvars <- as.integer(conf$html.singleplot.nvars) 135 | } 136 | } 137 | 138 | if (!("rand.seed" %in% names(conf))) { 139 | conf$rand.seed <- 135711 140 | } else { 141 | conf$rand.seed <- as.numeric(conf$rand.seed) 142 | } 143 | 144 | if (!("log.level" %in% names(conf))) { 145 | conf$log.level <- kLogLevelDEBUG 146 | } else { 147 | conf$log.level <- get(conf$log.level) 148 | } 149 | 150 | # Save workspace before training? (for debugging purposes) 151 | if (!("save.workspace" %in% names(conf))) { 152 | conf$save.workspace <- FALSE 153 | } else { 154 | conf$save.workspace <- as.logical(as.numeric(conf$save.workspace)) 155 | } 156 | 157 | return(conf) 158 | } 159 | 160 | ValidateCmdArgs <- function(opt, args.m) 161 | { 162 | # Parses and validates command line arguments. 163 | # 164 | # Args: 165 | # opt: getopt() object 166 | # args.m: valid arguments spec passed to getopt(). 167 | # 168 | # Returns: 169 | # A list of pairs 170 | kUsageString <- "/path/to/rfTrain_main.R -d -m [-l ]" 171 | 172 | # Validate command line arguments 173 | if ( !is.null(opt$help) || is.null(opt$data_conf) || is.null(opt$model_conf) ) { 174 | self <- commandArgs()[1] 175 | cat("Usage: ", kUsageString, "\n") 176 | q(status=1); 177 | } 178 | 179 | # Do we have a log file name? "" will send messages to stdout 180 | if (is.null(opt$log)) { 181 | opt$log <- "" 182 | } 183 | 184 | # Read config file (two columns assumed: 'param' and 'value') 185 | tmp <- read.table(opt$data_conf, header=T, as.is=T) 186 | conf <- as.list(tmp$value) 187 | names(conf) <- tmp$param 188 | 189 | conf <- ValidateConfigArgs(conf) 190 | conf$log.fname <- opt$log 191 | 192 | return(conf) 193 | } 194 | 195 | GetSQLQueryTemplate <- function(conf) 196 | { 197 | # Returns a SQL query template string for fetching training data 198 | if ("db.query.tmpl" %in% names(conf)) { 199 | # User-supplied template 200 | sql.query.tmpl <- scan(conf$db.query.tmpl, "character", quiet = T) 201 | } else { 202 | if (conf$db.type == "SQLServer") { 203 | sql.query.tmpl <- " 204 | SELECT _MAXROWS_ * 205 | FROM _TBLNAME_ 206 | " 207 | } else { 208 | stopifnot(conf$db.type == "Netezza") 209 | sql.query.tmpl <- " 210 | SELECT * 211 | FROM _TBLNAME_ 212 | LIMIT _MAXROWS_ 213 | " 214 | } 215 | } 216 | return(sql.query.tmpl) 217 | } 218 | 219 | CopyConfigFiles <- function(conf, rf.ctxt) 220 | { 221 | # Save configuration files with model export directory 222 | ok <- 1 223 | if (!file.exists(file.path(rf.ctxt$working.dir,"configuration"))) { 224 | ok <- dir.create(file.path(rf.ctxt$working.dir,"configuration")) 225 | } 226 | if (ok) { 227 | ok <- file.copy(from = opt$data_conf, to = file.path(rf.ctxt$working.dir,"configuration")) 228 | } 229 | if (ok) { 230 | ok <- file.copy(from = opt$model_conf, to = file.path(rf.ctxt$working.dir,"configuration")) 231 | } 232 | if ("col.types.fname" %in% names(conf) && nchar(conf$col.types.fname) > 0 && ok) { 233 | ok <- file.copy(from = conf$col.types.fname, to = file.path(rf.ctxt$working.dir,"configuration")) 234 | } 235 | if ("col.skip.fname" %in% names(conf) && nchar(conf$col.skip.fname) > 0 && ok) { 236 | ok <- file.copy(from = conf$col.skip.fname, to = file.path(rf.ctxt$working.dir,"configuration")) 237 | } 238 | if (ok == 0) { 239 | dbg(logger, "CopyConfigFiles: couldn't copy files") 240 | } 241 | } 242 | 243 | ############## 244 | ## Main 245 | # 246 | 247 | # Grab command-line arguments 248 | args.m <- matrix(c( 249 | 'data_conf' ,'d', 1, "character", 250 | 'model_conf' ,'m', 1, "character", 251 | 'log' ,'l', 1, "character", 252 | 'help' ,'h', 0, "logical" 253 | ), ncol=4,byrow=TRUE) 254 | opt <- getopt(args.m) 255 | conf <- ValidateCmdArgs(opt, args.m) 256 | 257 | # Create logging object 258 | logger <- new("logger", log.level = conf$log.level, file.name = conf$log.fname) 259 | 260 | ## Use own version of png() if necessary: 261 | if (isTRUE(conf$html.graph.dev == "Bitmap")) { 262 | png <- png_via_bitmap 263 | if (!CheckWorkingPNG(png)) error(logger, "cannot generate PNG graphics") 264 | } else { 265 | png <- GetWorkingPNG() 266 | if (is.null(png)) error(logger, "cannot generate PNG graphics") 267 | } 268 | 269 | # Load data 270 | if (conf$data.source.type == "db") { 271 | SQL.QUERY.TMPL <- GetSQLQueryTemplate(conf) 272 | # Get db connection 273 | ch <- odbcConnect(conf$db.dsn, believeNRows = FALSE) 274 | if (class(ch) != "RODBC") { 275 | error(logger, paste("rfTrain_main.R: Failed to connect to ", conf$db.dsn)) 276 | } 277 | 278 | # Fetch data 279 | SQL.QUERY <- gsub("_TBLNAME_", conf$db.tbl.name, SQL.QUERY.TMPL) 280 | if (conf$db.type == "SQLServer" & conf$db.tbl.maxrows == 'ALL') { 281 | SQL.QUERY <- gsub("_MAXROWS_", "", SQL.QUERY) 282 | } else { 283 | SQL.QUERY <- gsub("_MAXROWS_", conf$db.tbl.maxrows, SQL.QUERY) 284 | } 285 | 286 | data <- sqlQuery(ch, SQL.QUERY, stringsAsFactors = FALSE) 287 | if (class(data) != "data.frame") { 288 | error(logger, paste("rfTrain_main.R: Failed to retrieve data from ", conf$db.tbl.name)) 289 | } 290 | 291 | # Close connection 292 | close(ch) 293 | } else if (conf$data.source.type == "csv") { 294 | data <- read.csv(file.path(conf$csv.path, conf$csv.fname), 295 | na.strings = "", check.names = FALSE, sep=conf$csv.sep) 296 | } else if (conf$data.source.type == "rdata") { 297 | envir <- new.env() 298 | load(file.path(conf$rdata.path, conf$rdata.fname), envir = envir) 299 | if (is.null(conf$rdata.dfname)) { 300 | dfname <- ls(envir) 301 | stopifnot(length(dfname) == 1) 302 | } else { 303 | dfname <- conf$rdata.dfname 304 | } 305 | data <- get(dfname, envir, inherits = FALSE) 306 | stopifnot(is.data.frame(data)) 307 | rm(envir) 308 | } else { 309 | error(logger, paste("rfTrain_main.R: unknown data source type ", conf$data.source.type)) 310 | } 311 | info(logger, paste("Data loaded: dim =", nrow(data), "x", ncol(data), "; NAs =", 312 | length(which(is.na(data) == T)), "(", 313 | round(100*length(which(is.na(data) == T))/(nrow(data)*ncol(data)), 2), 314 | "%)")) 315 | 316 | # Load model specification parameters 317 | rf.ctxt <- InitRFContext(opt$model_conf, conf) 318 | 319 | # Set global env variables required by RuleFit 320 | platform <- rf.ctxt$platform 321 | RF_HOME <- Sys.getenv("RF_HOME") 322 | RF_WORKING_DIR <- rf.ctxt$working.dir 323 | 324 | # Fit (and export) model 325 | train.out <- TrainModel(data, conf, rf.ctxt) 326 | 327 | # Save configuration files with model 328 | CopyConfigFiles(conf, rf.ctxt) 329 | 330 | # Generate HTML report 331 | if ("html.fname" %in% names(conf)) { 332 | rfmod.stats <- runstats(train.out$rfmod) 333 | WriteHTML(conf, model.path = rf.ctxt$export.dir, rfmod.stats = rfmod.stats) 334 | if ("html.singleplot.fname" %in% names(conf)) { 335 | WriteHTMLSinglePlot(conf, model.path = rf.ctxt$export.dir) 336 | } 337 | } 338 | 339 | q(status=0) 340 | -------------------------------------------------------------------------------- /src/rfR2HTML.R: -------------------------------------------------------------------------------- 1 | source(file.path(REGO_HOME, "/src/rfExport.R")) 2 | source(file.path(REGO_HOME, "/src/rfRulesIO.R")) 3 | library(R2HTML) 4 | library(ROCR, verbose = FALSE, quietly=TRUE, warn.conflicts = FALSE) 5 | 6 | WriteHTML <- function(conf, model.path, out.path = model.path, rfmod.stats = NULL) 7 | { 8 | # Writes an HTML page with information about a RuleFit model. Page includes 9 | # variable importance table, rules, training confusion matrix, training ROC 10 | # curve, etc. 11 | # 12 | # Args: 13 | # model.path : rulefit model location (content generated by 14 | # ExportModel() function) 15 | # conf : html configuration parameters 16 | # out.path : String naming the location to write to (default 17 | # is model.path/R2HTML) 18 | # rfmod.stats : fit statistics of a rulefit model 19 | # 20 | # Returns: 21 | # None. 22 | kPlotWidth <- 620 23 | kPlotHeight <- 480 24 | 25 | if (out.path == model.path) { 26 | out.path <- file.path(out.path, "R2HTML") 27 | } 28 | # Create output directory (if appropriate) 29 | if (!file.exists(out.path)) { 30 | dir.create(out.path) 31 | } 32 | 33 | # Initialize HTML report 34 | html.file <- HTMLInitFile(out.path, conf$html.fname, Title = conf$html.title, 35 | BackGroundColor = "#BBBBEE") 36 | HTML.title(conf$html.title) 37 | 38 | # Read variable importance table 39 | vi.df <- read.table(file.path(model.path, kMod.varimp.fname), header = F, sep = "\t") 40 | colnames(vi.df) <- c("Importance", "Variable") 41 | # ... filter out low importance entries 42 | i.zero.imp <- which(vi.df$Importance < conf$html.min.var.imp) 43 | if ( length(i.zero.imp == 1) ) { 44 | vi.df <- vi.df[-i.zero.imp, ] 45 | } 46 | # ... Format float values for easier reading 47 | vi.df$Importance <- sprintf("%.1f", vi.df$Importance) 48 | 49 | # Write out var imp 50 | HTML(vi.df, caption = "Global Variable Importance", row.names = F) 51 | 52 | # Read model rules 53 | rules <- ReadRules(file.path(model.path, kMod.rules.fname)) 54 | rules.df <- PrintRules(rules, x.levels.fname = file.path(model.path, kMod.x.levels.fname), file = NULL) 55 | if (!setequal(colnames(rules.df), c("type", "supp.std", "coeff", "importance", "def"))) { 56 | error(logger, "WriteHTML: Rules header mismatch") 57 | } 58 | # ... filter out low importance entries 59 | i.zero.imp <- which(rules.df$importance < conf$html.min.rule.imp) 60 | if ( length(i.zero.imp == 1) ) { 61 | rules.df <- rules.df[-i.zero.imp, ] 62 | } 63 | # ... Rename column to more verbose values 64 | colnames(rules.df) <- c("Rule", "Support (or std)", "Coefficient", "Importance", "Definition") 65 | rules.df$Rule <- 1:nrow(rules.df) 66 | # ... Replace 'AND' with '&' in rule strings for easier reading 67 | rules.df$Definition <- gsub(" AND ", " & ", rules.df$Definition) 68 | # ... Format float values for easier reading 69 | rules.df$Importance <- sprintf("%.1f", rules.df$Importance) 70 | rules.df$Coefficient <- sprintf("%.3f", rules.df$Coefficient) 71 | 72 | # Write out rules 73 | HTML(rules.df, caption = "Rule Ensemble Model", row.names = F, innerBorder = 1) 74 | 75 | # Read (train) y and yHat 76 | y <- LoadModel(model.path)$y 77 | saved.y.hat <- load(file = file.path(model.path, kMod.yHat.fname)) 78 | if ( length(which((saved.y.hat == c("y.hat")) == T)) != 1 ) { 79 | stop("Failed to find required objects in: ", file.path(model.path, kMod.yHat.fname)) 80 | } 81 | 82 | # Process y vs yHat according to model type 83 | if (length(unique(y)) == 2) { 84 | # "classification" mode... build confusion table 85 | conf.m <- table(y, sign(y.hat)) 86 | stopifnot("-1" %in% rownames(conf.m)) 87 | stopifnot("1" %in% rownames(conf.m)) 88 | TN <- ifelse("-1" %in% colnames(conf.m), conf.m["-1", "-1"], 0) 89 | TP <- ifelse("1" %in% colnames(conf.m), conf.m["1","1"], 0) 90 | train.acc <- 100*(TN+TP)/sum(conf.m) 91 | # Write out table 92 | HTML(conf.m, caption = sprintf("Training Confusion Matrix (accuracy: %.2f%%)", train.acc), 93 | innerBorder = 1) 94 | # Generate ROC plot 95 | pred <- prediction(y.hat, y) 96 | perf <- performance(pred, "tpr", "fpr") 97 | plot.fname <- "ROC.png" 98 | png(file = file.path(out.path, plot.fname), width=kPlotWidth, height=kPlotHeight) 99 | plot(perf, colorize=T, main="") 100 | lines(x=c(0, 1), y=c(0,1)) 101 | dev.off() 102 | auc.value <- unlist(slot(performance(pred,"auc"), "y.values")) 103 | HTMLInsertGraph(plot.fname, Caption=sprintf("Training ROC curve. AUC:%.4f\n", auc.value), WidthHTML=kPlotWidth, HeightHTML=kPlotHeight) 104 | info(logger, sprintf("AUC:%.4f", auc.value)) 105 | } else { 106 | # "regression" mode... create data-frame with simple AAE meassure 107 | re.train.error <- sum(abs(y.hat - y))/length(y) 108 | med.train.error <- sum(abs(y - median(y)))/length(y) 109 | aae.train <- re.train.error / med.train.error 110 | error.df <- data.frame(cbind(c("RE", "Median ", "ratio"), 111 | c(round(re.train.error, 4), round(med.train.error, 4), round(aae.train, 2)))) 112 | colnames(error.df) <- c("Model", "Error") 113 | HTML(error.df, caption = "Training Error (Unweighted)", row.names = F, col.names = F) 114 | } 115 | 116 | # Write out (estimated) test error 117 | error.df <- data.frame(cbind(c("RE"), c(round(rfmod.stats$cri, 4)), c(round(rfmod.stats$err)))) 118 | colnames(error.df) <- c("Model", "Error", "+/-") 119 | HTML(error.df, caption = "Test Error (estimated)", row.names = F, col.names = F) 120 | 121 | # End report 122 | HTMLEndFile() 123 | } 124 | 125 | WriteHTMLSinglePlot <- function(conf, model.path, vars = NULL, out.path = model.path, 126 | do.restore = FALSE, xmiss = 9.0e30) 127 | { 128 | # Writes an HTML page with single variable partial dependence plots for a given 129 | # RuleFit model. 130 | # 131 | # Args: 132 | # conf : html configuration parameters 133 | # model.path : rulefit model location (content generated by 134 | # ExportModel() function) 135 | # vars : vector of variable identifiers (column names) 136 | # specifying selected variables to be plotted. If NULL, 137 | # variable names are taken from varimp table. 138 | # out.path : String naming the location to write to (default 139 | # is model.path/R2HTML/singleplot 140 | # do.restore : whether or not a model restore is needed -- e.g, not 141 | # necessary if called inmediatedly after a model build. 142 | # Returns: 143 | # None. 144 | AddQuantileJitter <- function(q) 145 | { 146 | # Add a small amount of noise to given numeric vector, assumed to represent 147 | # quantiles of a variable, when adjacent values are equal. 148 | set.seed(123) 149 | q.rle <- rle(q) 150 | out.q <- c() 151 | for (i in seq(along=q.rle$lengths)) { 152 | num.eq <- q.rle$lengths[i] 153 | if (num.eq > 1) { 154 | tmp.q <- rep(q.rle$values[i], num.eq) 155 | if ( q.rle$values[i] > 1) { 156 | tmp.q <- sort(jitter(tmp.q, factor = 0.75)) 157 | } else { 158 | # small value... do jittering by hand 159 | i.mid <- ceiling(num.eq/2) 160 | jitter.amount <- 0.005 * (1:num.eq - i.mid) 161 | tmp.q <- tmp.q + jitter.amount 162 | } 163 | } else { 164 | tmp.q <- c(q.rle$values[i]) 165 | } 166 | out.q <- c(out.q, tmp.q) 167 | } 168 | return(out.q) 169 | } 170 | 171 | GetRugQuantiles <- function(x, var, xmiss, do.jitter = TRUE) 172 | { 173 | # Computes quantiles for generating a "rug" (skips NA in calculation). 174 | # Args: 175 | # x : "matrix" -- e.g., from model restore 176 | # var : column name in x for whcih quantiles are to be computed 177 | # xmiss : NA value 178 | # do.jitter : whether or not "jitter" should be added to differentiate 179 | # equal quantile values 180 | # Returns: 181 | # list(quant:quantile values, na.rate: number of NAs as percentage 182 | quant <- NULL 183 | na.rate <- NULL 184 | i.col <- grep(paste("^", var, "$", sep = ""), colnames(x), perl=T) 185 | if ( length(i.col) > 0 ) { 186 | which.NA <- which(x[,i.col] == xmiss) 187 | na.rate <- round(100*length(which.NA)/nrow(x), 1) 188 | if (length(which.NA) < nrow(x)) { 189 | quant <- quantile(x[, i.col], na.rm = T, probs = seq(0.1, 0.9, 0.1)) 190 | if (do.jitter) { 191 | quant <- AddQuantileJitter(quant) 192 | } 193 | } 194 | } 195 | return(list(quant = quant, na.rate = na.rate)) 196 | } 197 | 198 | if (out.path == model.path) { 199 | out.path <- file.path(out.path, "R2HTML", "singleplot") 200 | } 201 | # Create output directory (if appropriate) 202 | if (!file.exists(out.path)) { 203 | dir.create(out.path) 204 | } 205 | 206 | # Load & restore model 207 | mod <- LoadModel(model.path) 208 | ok <- 1 209 | if (do.restore) { 210 | tryCatch(rfrestore(mod$rfmod, mod$x, mod$y, mod$wt), error = function(err){ok <<- 0}) 211 | if (ok == 0) { 212 | error(logger, "WriteHTMLSinglePlot: got stuck in rfrestore") 213 | } 214 | } 215 | 216 | # Load levels 217 | x.levels <- as.data.frame(do.call("rbind", ReadLevels(file.path(model.path, kMod.x.levels.fname)))) 218 | if (ncol(x.levels) != 2 || any(colnames(x.levels) != c("var", "levels"))) { 219 | error(logger, "WriteHTMLSinglePlot: problem reading level info") 220 | } else { 221 | var.names <- x.levels[, 1] 222 | x.levels <- x.levels[, 2] 223 | names(x.levels) <- var.names 224 | } 225 | 226 | # Initialize HTML report 227 | html.file <- HTMLInitFile(out.path, conf$html.singleplot.fname, Title = conf$html.title, 228 | BackGroundColor = "#BBBBEE") 229 | HTML.title(conf$html.singleplot.title) 230 | 231 | # Which variables are to be plotted? 232 | # ... read variable importance table 233 | vi.df <- read.table(file.path(model.path, kMod.varimp.fname), header = F, sep = "\t") 234 | colnames(vi.df) <- c("Importance", "Variable") 235 | if (is.null(vars)) { 236 | # ... pick subset from varimp list 237 | # ... ... first, filter out low importance entries 238 | min.var.imp <- max(conf$html.min.var.imp, 1.0) 239 | i.zero.imp <- which(vi.df$Importance < min.var.imp) 240 | if ( length(i.zero.imp == 1) ) { 241 | vi.df <- vi.df[-i.zero.imp, ] 242 | } 243 | nvars <- min(conf$html.singleplot.nvars, nrow(vi.df)) 244 | vars <- vi.df$Variable[1:nvars] 245 | } else { 246 | nvars <- length(vars) 247 | if (length(intersect(vars, vi.df$Variable)) != nvars) { 248 | error(logger, "WriteHTMLSinglePlot: variable name mismatch") 249 | } 250 | } 251 | 252 | # Generate plots 253 | for (var in vars) { 254 | plot.fname <- paste(var, "png", sep = ".") 255 | plot.width <- 620 256 | plot.height <- 480 257 | cat.vals <- NULL 258 | rug.vals <- NULL 259 | rug.cols <- NULL 260 | # Get levels, if this is a categorical variables. Get deciles, if continuous. 261 | if (!is.null(x.levels[var][[1]])) { 262 | cat.vals <- sapply(x.levels[var][[1]], substr, 1, 10, USE.NAMES = F) 263 | if (length(cat.vals) > 40 ) { 264 | plot.width <- 920 265 | } 266 | plot.caption <- var 267 | } else { 268 | rug.out <- GetRugQuantiles(mod$x, var, xmiss) 269 | rug.vals <- rug.out$quant 270 | if (!is.null(rug.vals)) { 271 | rug.cols <- rainbow(length(rug.vals)) 272 | } 273 | # Insert NA rate in caption 274 | plot.caption <- paste(var, " (NAs: ", rug.out$na.rate, "%)", sep="") 275 | } 276 | png(file = file.path(out.path, plot.fname), width=plot.width, height=plot.height) 277 | # tryCatch(singleplot(var, catvals=cat.vals, rugvals=rug.vals, rugcols=rug.cols), 278 | # error = function(err){ok <<- 0}) 279 | tryCatch(singleplot(var, catvals=cat.vals), error = function(err){ok <<- 0}) 280 | if (ok == 0) { 281 | error(logger, "WriteHTMLSinglePlot: got stuck in singleplot") 282 | } 283 | dev.off() 284 | HTMLInsertGraph(plot.fname, Caption=plot.caption, WidthHTML=plot.width, HeightHTML=plot.height) 285 | } 286 | 287 | # End report 288 | HTMLEndFile() 289 | } 290 | -------------------------------------------------------------------------------- /bin/runModelSQL.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | #=================================================================================== 3 | # FILE: runModelSQL.sh 4 | # 5 | # USAGE: runModelSQL.sh --host --dbIn --tblIn= --pk= 6 | # --model= 7 | # [--sql=] 8 | # [--dbOut] [--tblOut=] 9 | # [--typeOut=] 10 | # [--uname=] [--upw=] 11 | # 12 | # DESCRIPTION: Computes predictions using a previously built (and exported) RuleFit 13 | # model for a feature table in a database. 14 | #=================================================================================== 15 | 16 | USAGESTR="usage: runModelSQL.sh --host= --dbIn= --tblIn= 17 | --pk= --model= 18 | [--sql=] 19 | [--dbOut=] [--tblOut=] 20 | [--typeOut=] 21 | [--uname=] [--upwd=]" 22 | RUNMODEL_TEMP_FILE="runModel_temp.sql" 23 | MODELSQLFNAME_SUFFIX1="2_part1" 24 | MODELSQLFNAME_SUFFIX2="2_part2" 25 | 26 | # Parse arguments 27 | for i in $* 28 | do 29 | case $i in 30 | --host=*) 31 | HOST=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 32 | ;; 33 | --dbIn=*) 34 | INDB=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 35 | ;; 36 | --tblIn=*) 37 | FEATTBL=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 38 | ;; 39 | --pk=*) 40 | PK=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 41 | ;; 42 | --model=*) 43 | MODELSQLFNAME=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 44 | ;; 45 | --sql=*) 46 | SQLTYPE=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 47 | ;; 48 | --dbOut=*) 49 | OUTDB=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 50 | ;; 51 | --tblOut=*) 52 | OUTTBL=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 53 | ;; 54 | --typeOut=*) 55 | OUTTYPE=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 56 | ;; 57 | --uname=*) 58 | UNAME=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 59 | ;; 60 | --upwd=*) 61 | UPWD=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 62 | ;; 63 | --verbose=*) 64 | VERBOSE=`echo $i | sed 's/[-a-zA-Z0-9]*=//'` 65 | ;; 66 | *) 67 | # unknown option 68 | echo $USAGESTR 69 | exit 1 70 | ;; 71 | esac 72 | done 73 | 74 | # Validate command-line arguments 75 | if [ -z "$HOST" -o -z "$INDB" -o -z "$FEATTBL" -o -z "$PK" -o -z "$MODELSQLFNAME" ]; then 76 | echo $USAGESTR 77 | exit 1 78 | fi 79 | 80 | # Set defaults for unspecified arguments 81 | if [ -z "$OUTDB" ]; then 82 | OUTDB=${INDB} 83 | fi 84 | if [ -z "$OUTTYPE" ]; then 85 | OUTTYPE=1 86 | fi 87 | if [ -z "$OUTTBL" ]; then 88 | if [ $OUTTYPE -eq 1 ]; then 89 | OUTTBL=${FEATTBL}_scores 90 | elif [ $OUTTYPE -eq 2 ]; then 91 | OUTTBL=${FEATTBL}_rules 92 | elif [ $OUTTYPE -eq 3 ]; then 93 | OUTTBL=${FEATTBL}_scores 94 | else 95 | echo "Invalid typeOut value: "$OUTTYPE 96 | exit 1 97 | fi 98 | fi 99 | if [ -z "$VERBOSE" ]; then 100 | VERBOSE=1 101 | fi 102 | 103 | set_delete_query() 104 | { 105 | if [ "$1" = "SQLServer" ]; then 106 | DELETE_QUERY=" 107 | IF OBJECT_ID('${OUTTBL}') IS NOT NULL 108 | DROP TABLE ${OUTTBL} 109 | ; 110 | IF OBJECT_ID('${OUTTBL}_TEMP') IS NOT NULL 111 | DROP TABLE ${OUTTBL}_TEMP 112 | " 113 | elif [ "$1" = "MySQL" -o "$1" = "HiveQL" ]; then 114 | DELETE_QUERY=" 115 | DROP TABLE IF EXISTS ${OUTDB}.${OUTTBL}; 116 | DROP TABLE IF EXISTS ${OUTDB}.${OUTTBL}_TEMP; 117 | " 118 | elif [ "$1" = "Netezza" ]; then 119 | DELETE_QUERY1=" 120 | DROP TABLE ${OUTTBL}; 121 | " 122 | DELETE_QUERY2=" 123 | DROP TABLE ${OUTTBL}_TEMP; 124 | " 125 | else 126 | echo "Invalid SQLtype value: "$1 127 | exit 1 128 | fi 129 | return 0 130 | } 131 | 132 | set_score_query() 133 | { 134 | if [ "$1" = "SQLServer" ]; then 135 | SCORE_QUERY=" 136 | SET NOCOUNT ON; 137 | SELECT 138 | ${PK} = LIST.${PK} 139 | ,score = ${SQLCLAUSE} 140 | INTO ${OUTDB}..${OUTTBL} 141 | FROM ${INDB}..${FEATTBL} LIST 142 | ; 143 | -- Add index 144 | CREATE UNIQUE INDEX IX_${OUTTBL}_${PK} ON ${OUTTBL} (${PK}) 145 | " 146 | elif [ "$1" = "MySQL" ]; then 147 | SCORE_QUERY=" 148 | CREATE TABLE ${OUTDB}.${OUTTBL} ( 149 | id INT 150 | ,score DOUBLE 151 | ); 152 | INSERT INTO ${OUTDB}.${OUTTBL} 153 | SELECT 154 | LIST.${PK} 155 | ,${SQLCLAUSE} 156 | FROM ${INDB}.${FEATTBL} LIST 157 | ; 158 | " 159 | elif [ "$1" = "HiveQL" ]; then 160 | SCORE_QUERY=" 161 | CREATE TABLE ${OUTDB}.${OUTTBL} ( 162 | id INT 163 | ,score DOUBLE 164 | ); 165 | INSERT INTO TABLE ${OUTDB}.${OUTTBL} 166 | SELECT 167 | LIST.${PK} 168 | ,${SQLCLAUSE} 169 | FROM ${INDB}.${FEATTBL} LIST 170 | ; 171 | " 172 | elif [ "$1" = "Netezza" ]; then 173 | SCORE_QUERY="" 174 | else 175 | echo "Invalid SQLtype value: "$1 176 | exit 1 177 | fi 178 | return 0 179 | } 180 | 181 | set_rules_only_query() 182 | { 183 | if [ "$1" = "SQLServer" ]; then 184 | RULES_ONLY_QUERY=" 185 | SET NOCOUNT ON; 186 | SELECT 187 | ${PK} = LIST.${PK} 188 | ,${SQLCLAUSE} 189 | INTO ${OUTDB}..${OUTTBL} 190 | FROM ${INDB}..${FEATTBL} LIST 191 | ; 192 | -- Add index 193 | CREATE UNIQUE INDEX IX_${OUTTBL}_${PK} ON ${OUTTBL} (${PK}) 194 | " 195 | elif [ "$1" = "MySQL" ]; then 196 | RULES_ONLY_QUERY=" 197 | CREATE TABLE ${OUTDB}.${OUTTBL} AS 198 | SELECT 199 | LIST.${PK} 200 | ,${SQLCLAUSE} 201 | FROM ${INDB}.${FEATTBL} LIST 202 | ; 203 | " 204 | elif [ "$1" = "Netezza" ]; then 205 | RULES_ONLY_QUERY="" 206 | else 207 | echo "Invalid SQLtype value: "$1 208 | exit 1 209 | fi 210 | return 0 211 | } 212 | 213 | set_rules_score_query() 214 | { 215 | if [ "$1" = "SQLServer" ]; then 216 | RULES_SCORE_QUERY=" 217 | SET NOCOUNT ON; 218 | SELECT 219 | ${PK} = LIST.${PK} 220 | ,${SQLCLAUSE} 221 | INTO ${OUTDB}..${OUTTBL}_TEMP 222 | FROM ${INDB}..${FEATTBL} LIST 223 | ; 224 | SELECT 225 | ${PK} 226 | ,score = ${SQLSUMCLAUSE_PART1} 227 | INTO ${OUTDB}..${OUTTBL} 228 | FROM ( 229 | SELECT 230 | ${PK} 231 | ,${SQLSUMCLAUSE_PART2} 232 | FROM ${OUTDB}..${OUTTBL}_TEMP)T 233 | ; 234 | " 235 | elif [ "$1" = "MySQL" ]; then 236 | RULES_SCORE_QUERY="" 237 | elif [ "$1" = "Netezza" ]; then 238 | RULES_SCORE_QUERY=" 239 | SELECT 240 | LIST.${PK} AS ${PK} 241 | ,${SQLCLAUSE} 242 | INTO ${OUTDB}..${OUTTBL}_TEMP 243 | FROM ${INDB}..${FEATTBL} LIST 244 | ; 245 | SELECT 246 | ${PK} 247 | ,${SQLSUMCLAUSE_PART1} AS score 248 | INTO ${OUTDB}..${OUTTBL} 249 | FROM ( 250 | SELECT 251 | ${PK} 252 | ,${SQLSUMCLAUSE_PART2} 253 | FROM ${OUTDB}..${OUTTBL}_TEMP 254 | LIMIT ALL)T 255 | ; 256 | " 257 | else 258 | echo "Invalid SQLtype value: "$1 259 | exit 1 260 | fi 261 | return 0 262 | } 263 | 264 | set_diag_query() 265 | { 266 | if [ "$1" = "SQLServer" ]; then 267 | DIAG_QUERY=" 268 | SET ANSI_WARNINGS OFF; 269 | SET NOCOUNT ON; 270 | SELECT 271 | MIN = MIN(score) 272 | ,AVG = AVG(CAST(score AS BIGINT)) 273 | ,MAX = MAX(score) 274 | ,nNULLs = SUM(CASE 275 | WHEN score IS NULL THEN 1 276 | ELSE 0 277 | END) 278 | FROM ${OUTDB}..${OUTTBL} 279 | " 280 | elif [ "$1" = "MySQL" ]; then 281 | DIAG_QUERY=" 282 | SELECT 283 | MIN(score) AS MIN 284 | ,AVG(score) AS AVG 285 | ,MAX(score) AS MAX 286 | ,SUM(CASE 287 | WHEN score IS NULL THEN 1 288 | ELSE 0 289 | END) AS nNULLs 290 | FROM ${OUTDB}.${OUTTBL} 291 | " 292 | elif [ "$1" = "HiveQL" ]; then 293 | DIAG_QUERY=" 294 | SELECT 295 | MIN(score) AS MIN 296 | ,AVG(CAST(score AS BIGINT)) AS AVG 297 | ,MAX(score) AS MAX 298 | ,SUM(CASE 299 | WHEN score IS NULL THEN 1 300 | ELSE 0 301 | END) AS nNULLs 302 | FROM ${OUTDB}.${OUTTBL} 303 | " 304 | elif [ "$1" = "Netezza" ]; then 305 | DIAG_QUERY="" 306 | else 307 | echo "Invalid SQLtype value: "$1 308 | exit 1 309 | fi 310 | return 0 311 | } 312 | 313 | # Load SQL clause expression 314 | SQLCLAUSE=`cat ${MODELSQLFNAME}` 315 | if [ ${#SQLCLAUSE} -lt 1 ]; then 316 | echo "Empty/missing scoring SQL..." 317 | exit 1 318 | fi 319 | 320 | # Drop output table if it already exists 321 | set_delete_query $SQLTYPE 322 | if [ $VERBOSE -ge 2 ]; then 323 | echo ${DELETE_QUERY}; echo 324 | fi 325 | if [ $VERBOSE -ne 100 ]; then 326 | if [ $SQLTYPE = "SQLServer" ]; then 327 | sqlcmd -E -S $HOST -d $OUTDB -Q "${DELETE_QUERY}" 328 | elif [ $SQLTYPE = "MySQL" ]; then 329 | mysql -h $HOST -u $UNAME -p$UPWD -e "${DELETE_QUERY}" 330 | elif [ $SQLTYPE = "HiveQL" ]; then 331 | hive -S -e "${DELETE_QUERY}" 332 | elif [ $SQLTYPE = "Netezza" ]; then 333 | nzsql -h $HOST -db $OUTDB -u $UNAME -pw $UPWD -c "select * from _V_RELATION_COLUMN where NAME=UPPER('${OUTTBL}')" -A -o temp.txt 334 | if [ `grep -c "(0 rows)" temp.txt` -eq 0 ]; then 335 | nzsql -h $HOST -db $OUTDB -u $UNAME -pw $UPWD -c "${DELETE_QUERY1}" 336 | fi 337 | rm temp.txt 338 | nzsql -h $HOST -db $OUTDB -u $UNAME -pw $UPWD -c "select * from _V_RELATION_COLUMN where NAME=UPPER('${OUTTBL}_TEMP')" -A -o temp1.txt 339 | if [ `grep -c "(0 rows)" temp1.txt` -eq 0 ]; then 340 | nzsql -h $HOST -db $OUTDB -u $UNAME -pw $UPWD -c "${DELETE_QUERY2}" 341 | fi 342 | rm temp1.txt 343 | else 344 | echo "Invalid SQLtype value: "$SQLTYPE 345 | exit 1 346 | fi 347 | fi 348 | 349 | # Load "score = t0+t1+..." clauses (if needed) 350 | if [ $OUTTYPE -eq 3 ]; then 351 | MODELSQLFNAME2=${MODELSQLFNAME}${MODELSQLFNAME_SUFFIX1} 352 | SQLSUMCLAUSE_PART1=`cat ${MODELSQLFNAME2}` 353 | MODELSQLFNAME2=${MODELSQLFNAME}${MODELSQLFNAME_SUFFIX2} 354 | SQLSUMCLAUSE_PART2=`cat ${MODELSQLFNAME2}` 355 | fi 356 | 357 | # Instantiate query 358 | if [ $OUTTYPE -eq 1 ]; then 359 | set_score_query $SQLTYPE 360 | QUERY=${SCORE_QUERY} 361 | elif [ $OUTTYPE -eq 2 ]; then 362 | set_rules_only_query $SQLTYPE 363 | QUERY=${RULES_ONLY_QUERY} 364 | else 365 | set_rules_score_query $SQLTYPE 366 | QUERY=${RULES_SCORE_QUERY} 367 | fi 368 | if [ $VERBOSE -ge 2 ]; then 369 | echo ${QUERY}; echo 370 | fi 371 | # ... write SQL to a temp file 372 | echo ${QUERY} > ${RUNMODEL_TEMP_FILE} 373 | 374 | # Execute query 375 | if [ $VERBOSE -ne 100 ]; then 376 | if [ $SQLTYPE = "SQLServer" ]; then 377 | # sqlcmd -E -S $HOST -d $OUTDB -Q "${QUERY}" 378 | sqlcmd -E -S $HOST -d $OUTDB -i ${RUNMODEL_TEMP_FILE} 379 | elif [ $SQLTYPE = "MySQL" ]; then 380 | mysql -h $HOST -u $UNAME -p$UPWD < ${RUNMODEL_TEMP_FILE} 381 | elif [ $SQLTYPE = "HiveQL" ]; then 382 | hive -S -f "${RUNMODEL_TEMP_FILE}" 383 | exit 1 384 | elif [ $SQLTYPE = "Netezza" ]; then 385 | echo "Create output table" 386 | nzsql -host $HOST -db $OUTDB -u $UNAME -pw $UPWD -f ${RUNMODEL_TEMP_FILE} 387 | else 388 | echo "Invalid SQLtype value: "$SQLTYPE 389 | exit 1 390 | fi 391 | fi 392 | 393 | # Diagnostic queries 394 | if [ $VERBOSE -ge 1 ] && [ $OUTTYPE -eq 1 ]; then 395 | set_diag_query $SQLTYPE 396 | echo "" 397 | if [ $VERBOSE -ne 100 ]; then 398 | if [ $SQLTYPE = "SQLServer" ]; then 399 | sqlcmd -E -S $HOST -d $OUTDB -Q "${DIAG_QUERY}" 400 | elif [ $SQLTYPE = "MySQL" ]; then 401 | mysql -h $HOST -u $UNAME -p$UPWD -e "${DIAG_QUERY}" 402 | elif [ $SQLTYPE = "HiveQL" ]; then 403 | hive -S -e "${DIAG_QUERY}" 404 | exit 1 405 | elif [ $SQLTYPE = "Netezza" ]; then 406 | echo "Don't know how to run: "$SQLTYPE 407 | exit 1 408 | else 409 | echo "Invalid SQLtype value: "$SQLTYPE 410 | exit 1 411 | fi 412 | else 413 | echo ${DIAG_QUERY}; echo 414 | fi 415 | fi 416 | 417 | # Print out number of rows processed 418 | if [ $VERBOSE -ne 100 ]; then 419 | if [ $SQLTYPE = "SQLServer" ]; then 420 | nco=`sqlcmd -E -S $HOST -d $OUTDB -Q "SELECT COUNT(*) NumRows FROM ${OUTTBL}" | awk 'NR==3{print $1;}'` 421 | elif [ $SQLTYPE = "MySQL" ]; then 422 | nco=`mysql -h $HOST -u $UNAME -p$UPWD -e "SELECT COUNT(*) NumRows FROM ${OUTDB}.${OUTTBL}" | perl -ane '$. > 1 && print'` 423 | elif [ $SQLTYPE = "HiveQL" ]; then 424 | nco=`hive -S -e "SELECT COUNT(*) NumRows FROM ${OUTDB}.${OUTTBL}" | perl -ane '$. > 1 && print'` 425 | elif [ $SQLTYPE = "Netezza" ]; then 426 | nco=`nzsql -host $HOST -db $OUTDB -u $UNAME -pw $UPWD -c "SELECT COUNT(*) NumRows FROM ${OUTTBL}" | awk 'NR==3{print $1;}'` 427 | fi 428 | echo "Processed "$nco" rows..." 429 | fi 430 | 431 | exit 0 432 | -------------------------------------------------------------------------------- /src/rfAPI.R: -------------------------------------------------------------------------------- 1 | source(file.path(REGO_HOME, "/src/logger.R")) 2 | source(file.path(REGO_HOME, "/lib/RuleFit", ifelse(nchar(Sys.getenv("RF_API")) > 0, Sys.getenv("RF_API"), "rulefit.r"))) 3 | 4 | DefaultRFContext <- function() 5 | { 6 | # Populates a list of configuration parameters, with names and values as used 7 | # by RuleFit3. 8 | rf.ctxt <- vector(mode='list') 9 | rf.ctxt$xmiss <- 9.0e30 10 | rf.ctxt$rfmode <- "regress" 11 | rf.ctxt$sparse <- 1 12 | rf.ctxt$test.reps <- 0 13 | rf.ctxt$test.fract <- 0.2 14 | rf.ctxt$mod.sel <- 2 15 | rf.ctxt$model.type <- "both" 16 | rf.ctxt$tree.size <- 4 17 | rf.ctxt$max.rules <- 2000 18 | rf.ctxt$max.trms <- 500 19 | rf.ctxt$costs <- c(1,1) 20 | rf.ctxt$trim.qntl <- 0.025 21 | rf.ctxt$inter.supp <- 3.0 22 | rf.ctxt$memory.par <- 0.01 23 | rf.ctxt$conv.thr <- 1.0e-3 24 | rf.ctxt$mem.tree.store <- 10000000 25 | rf.ctxt$mem.cat.store <- 1000000 26 | rf.ctxt$quiet <- TRUE 27 | return(rf.ctxt) 28 | } 29 | 30 | InitRFContext <- function(model.conf.fname, data.conf) 31 | { 32 | # Parses a given configuration file, and uses it to initialize an RF "context." 33 | # An "RF Context" is just a verbose version of RuleFit's config param set needed 34 | # to run it. 35 | # 36 | # Args: 37 | # model.conf.fname: model configuration file name 38 | # 39 | # Returns: 40 | # A list of pairs 41 | kKnownParamNames <- c("rf.platform", "rf.working.dir", "rf.export.dir", "task", "model.type", 42 | "model.max.rules", "model.max.terms", "te.tree.size", "te.sample.fraction", 43 | "te.interaction.suppress", "te.memory.param", "sparsity.method", 44 | "score.criterion", "crossvalidation.num.folds", "crossvalidation.fold.size", 45 | "misclassification.costs", "data.trim.quantile", "data.NA.value", 46 | "convergence.threshold", "mem.tree.store", "mem.cat.store", "elastic.net.param") 47 | 48 | rf.ctxt <- DefaultRFContext() 49 | 50 | # Read config file (two columns assumed: 'param' and 'value') 51 | tmp <- read.table(model.conf.fname, header=T, as.is=T) 52 | conf <- as.list(tmp$value) 53 | names(conf) <- tmp$param 54 | 55 | # Do we recognize all given param names? 56 | for (param in names(conf)) { 57 | if (!(param %in% kKnownParamNames)) { 58 | warn(logger, paste("Unrecognized parameter name:", param)) 59 | } 60 | } 61 | 62 | # Installation info 63 | # ...Auto-detect the platform: 64 | rf.ctxt$platform <- switch(.Platform$OS.type 65 | , windows = "windows" 66 | , unix = switch(Sys.info()["sysname"] 67 | , Linux = "linux" 68 | , Darwin = "mac")) 69 | 70 | if (is.null(rf.ctxt$platform)) error(logger, "Unable to detect platform") 71 | 72 | # ...path to RF working directory 73 | if (!("rf.working.dir" %in% names(conf))) { 74 | # Try to get parameter from the env 75 | conf$rf.working.dir <- Sys.getenv("RF_WORKING_DIR") 76 | } 77 | if (substr(conf$rf.working.dir, 1, 2) == "./") { 78 | # Avoid relative paths 79 | conf$rf.working.dir <- file.path(getwd(), conf$rf.working.dir) 80 | } 81 | if (file.access(conf$rf.working.dir) != 0) { 82 | error(logger, paste("You need to specify a valid RF working dir...", conf$rf.working.dir, "isn't good")) 83 | } else { 84 | rf.ctxt$working.dir <- conf$rf.working.dir 85 | } 86 | 87 | # ...path to RF export directory 88 | if (!("rf.export.dir" %in% names(conf))) { 89 | conf$rf.export.dir <- file.path(conf$rf.working.dir, "export") 90 | } 91 | if (!(file.exists(conf$rf.export.dir))) { 92 | dir.create(conf$rf.export.dir) 93 | } 94 | if (file.access(conf$rf.export.dir) != 0) { 95 | error(logger, paste("You need to specify a valid RF export dir...", conf$rf.export.dir, "isn't good")) 96 | } else { 97 | rf.ctxt$export.dir <- conf$rf.export.dir 98 | } 99 | 100 | # Regression or classification task? 101 | if (!("task" %in% names(conf))) { 102 | error(logger,"You need to specify a 'task' type: 'regression' or 'classification'") 103 | } else if (conf$task == "regression") { 104 | rf.ctxt$rfmode <- "regress" 105 | } else if (conf$task == "classification") { 106 | rf.ctxt$rfmode <- "class" 107 | } else { 108 | error(logger, paste("Unrecognized 'task' type:", conf$task, 109 | " ... expecting: 'regression' or 'classification'")) 110 | } 111 | 112 | # Model specification 113 | # ... type 114 | if (!("model.type" %in% names(conf)) | 115 | !(conf$model.type %in% c("linear", "rules", "both"))) { 116 | error(logger, "You need to specify a model type: 'linear', 'rules' or 'both'") 117 | } else { 118 | rf.ctxt$model.type <- conf$model.type 119 | } 120 | # ...number of rules generated for regression posprocessing 121 | if ("model.max.rules" %in% names(conf)) { 122 | rf.ctxt$max.rules <- as.numeric(conf$model.max.rules) 123 | } 124 | # ...maximum number of terms selected for final model 125 | if ("model.max.terms" %in% names(conf)) { 126 | rf.ctxt$max.trms <- as.numeric(conf$model.max.terms) 127 | } 128 | 129 | # Tree Ensemble control 130 | # ...average number of terminal nodes in generated trees 131 | if ("te.tree.size" %in% names(conf)) { 132 | rf.ctxt$tree.size <- as.numeric(conf$te.tree.size) 133 | } 134 | # ...fraction of randomly chosen training observations used to produce each tree 135 | if ("te.sample.fraction" %in% names(conf)) { 136 | rf.ctxt$samp.fract <- as.numeric(conf$te.sample.fraction) 137 | } 138 | # ...incentive factor for using fewer variables in tree based rules 139 | if ("te.interaction.suppress" %in% names(conf)) { 140 | rf.ctxt$inter.supp <- as.numeric(conf$te.interaction.suppress) 141 | } 142 | # ... learning rate applied to each new tree when sequentially induced 143 | if ("te.memory.param" %in% names(conf)) { 144 | rf.ctxt$memory.par <- as.numeric(conf$te.memory.param) 145 | } 146 | 147 | # Regularization (postptocessing) control 148 | if (!("sparsity.method" %in% names(conf)) | 149 | !(conf$sparsity.method %in% c("Lasso", "Lasso+FSR", "FSR", "ElasticNet"))) { 150 | error(logger, "You need to specify a sparsity method: 'Lasso', 'Lasso+FSR', 'FSR' or 'ElasticNet'") 151 | } else if (conf$sparsity.method == "Lasso") { 152 | rf.ctxt$sparse <- 1 153 | } else if (conf$sparsity.method == "Lasso+FSR") { 154 | rf.ctxt$sparse <- 2 155 | } else if (conf$sparsity.method == "FSR") { 156 | rf.ctxt$sparse <- 3 157 | } else if (conf$sparsity.method == "ElasticNet") { 158 | if ("elastic.net.param" %in% names(conf)) { 159 | rf.ctxt$sparse <- as.numeric(conf$elastic.net.param) 160 | } else { 161 | error(logger, "For 'ElasticNet' sparsity method, you need to specify 'elastic.net.param'") 162 | } 163 | } 164 | 165 | # Model selection 166 | # ...loss/score criterion 167 | if (!("score.criterion" %in% names(conf)) | 168 | !(conf$score.criterion %in% c("1-AUC", "AAE", "LS", "Misclassification"))) { 169 | error(logger, "You need to specify a score criterion: '1-AUC', 'AAE', 'LS' or 'Misclassification'") 170 | } else if (conf$score.criterion == "1-AUC" & conf$task == "classification") { 171 | rf.ctxt$mod.sel <- 1 172 | } else if (conf$score.criterion == "AAE" & conf$task == "regression") { 173 | rf.ctxt$mod.sel <- 1 174 | } else if (conf$score.criterion == "LS" & conf$task == "regression") { 175 | rf.ctxt$mod.sel <- 2 176 | } else if (conf$score.criterion == "LS" & conf$task == "classification") { 177 | # average squared-error loss on predicted probabilities 178 | rf.ctxt$mod.sel <- 2 179 | } else if (conf$score.criterion == "Misclassification" & conf$task == "classification") { 180 | rf.ctxt$mod.sel <- 3 181 | } else { 182 | error(logger, paste("Invalid score criterion specification -- task: '", 183 | conf$task, "', score.criterion: '", conf$score.criterion, "'", sep = "")) 184 | } 185 | # ...number of cross-validation replications 186 | if ("crossvalidation.num.folds" %in% names(conf)) { 187 | rf.ctxt$test.reps <- as.numeric(conf$crossvalidation.num.folds) 188 | } 189 | # ...fraction of observations used it test group 190 | if ("crossvalidation.fold.size" %in% names(conf)) { 191 | rf.ctxt$test.fract <- as.numeric(conf$crossvalidation.fold.size) 192 | } 193 | # ...misclassificarion costs 194 | if ("misclassification.costs" %in% names(conf)) { 195 | rf.ctxt$costs <- c() 196 | rf.ctxt$costs[1] <- as.numeric(strsplit(conf$misclassification.costs, ",")[[1]][1]) 197 | rf.ctxt$costs[2] <- as.numeric(strsplit(conf$misclassification.costs, ",")[[1]][2]) 198 | } 199 | 200 | # Data preprocessing 201 | # ...linear variable winsorizing factor 202 | if ("data.trim.quantile" %in% names(conf)) { 203 | rf.ctxt$trim.qntl <- as.numeric(conf$data.trim.quantile) 204 | } 205 | # ...numeric value indicating missingness in predictors 206 | if ("data.NA.value" %in% names(conf)) { 207 | rf.ctxt$xmiss <- as.numeric(conf$data.NA.value) 208 | } 209 | 210 | # Iteration Control 211 | # ...convergence threshold for regression postprocessing 212 | if ("convergence.threshold" %in% names(conf)) { 213 | rf.ctxt$conv.thr <- as.numeric(conf$convergence.threshold) 214 | } 215 | 216 | # Memory management 217 | # ...size of internal tree storage (decrease in response to allocation error; 218 | # increase value for very large values of max.rules and/or tree.size) 219 | if ("mem.tree.store" %in% names(conf)) { 220 | rf.ctxt$mem.tree.store <- as.numeric(conf$mem.tree.store) 221 | } 222 | # ... size of internal categorical value storage (decrease in response to 223 | # allocation error; increase for very large values of max.rules and/or 224 | # tree.size in the presence of many categorical variables with many levels) 225 | if ("mem.cat.store" %in% names(conf)) { 226 | rf.ctxt$mem.cat.store <- as.numeric(conf$mem.cat.store) 227 | } 228 | 229 | # Print RF's progress info 230 | rf.ctxt$quiet <- ifelse(data.conf$log.level <= kLogLevelDEBUG, FALSE, TRUE) 231 | 232 | return(rf.ctxt) 233 | } 234 | 235 | TrainRF <- function(x, y, wt, rf.context, cat.vars=NULL, not.used=NULL) 236 | { 237 | # Invokes RuleFit model building procedure. 238 | # 239 | # Args: 240 | # x: input data frame 241 | # y: response vector 242 | # wt: observation weights 243 | # cat.vars: categorical variables (column numbers or names) 244 | # rf.context: configuration parameters 245 | # 246 | # Returns: 247 | # RuleFit model object 248 | dbg(logger, "TrainRF:") 249 | 250 | ok <- 1 251 | if ("samp.fract" %in% names(rf.context)) { 252 | # User-specified "samp.fract" 253 | tryCatch(rfmod <- rulefit(x, y, wt, cat.vars, not.used 254 | ,xmiss = rf.context$xmiss 255 | ,rfmode = rf.context$rfmode 256 | ,sparse = rf.context$sparse 257 | ,test.reps = rf.context$test.reps 258 | ,test.fract = rf.context$test.fract 259 | ,mod.sel = rf.context$mod.sel 260 | ,model.type = rf.context$model.type 261 | ,tree.size = rf.context$tree.size 262 | ,max.rules = rf.context$max.rules 263 | ,max.trms = rf.context$max.trms 264 | ,costs = rf.context$costs 265 | ,trim.qntl = rf.context$trim.qntl 266 | ,samp.fract = rf.context$samp.fract 267 | ,inter.supp = rf.context$inter.supp 268 | ,memory.par = rf.context$memory.par 269 | ,conv.thr = rf.context$conv.thr 270 | ,quiet = rf.ctxt$quiet 271 | ,tree.store = rf.context$mem.tree.store 272 | ,cat.store = rf.context$mem.cat.store), 273 | error = function(err){ok <<- 0; dbg(logger, paste("Error Message from RuleFit:", err))}) 274 | if (ok == 0) { 275 | error(logger, "TrainRF: got stuck in rulefit") 276 | } 277 | } else { 278 | # No mention of "samp.fract"... let rulefit set it based on data size 279 | tryCatch(rfmod <- rulefit(x, y, wt, cat.vars, not.used 280 | ,xmiss = rf.context$xmiss 281 | ,rfmode = rf.context$rfmode 282 | ,sparse = rf.context$sparse 283 | ,test.reps = rf.context$test.reps 284 | ,test.fract = rf.context$test.fract 285 | ,mod.sel = rf.context$mod.sel 286 | ,model.type = rf.context$model.type 287 | ,tree.size = rf.context$tree.size 288 | ,max.rules = rf.context$max.rules 289 | ,max.trms = rf.context$max.trms 290 | ,costs = rf.context$costs 291 | ,trim.qntl = rf.context$trim.qntl 292 | ,inter.supp = rf.context$inter.supp 293 | ,memory.par = rf.context$memory.par 294 | ,conv.thr = rf.context$conv.thr 295 | ,quiet = rf.ctxt$quiet 296 | ,tree.store = rf.context$mem.tree.store 297 | ,cat.store = rf.context$mem.cat.store), 298 | error = function(err){ok <<- 0; dbg(logger, paste("Error Message from RuleFit:", err))}) 299 | if (ok == 0) { 300 | error(logger, "TrainRF: got stuck in rulefit") 301 | } 302 | } 303 | 304 | return(rfmod) 305 | } 306 | -------------------------------------------------------------------------------- /src/rfExport.R: -------------------------------------------------------------------------------- 1 | source(file.path(REGO_HOME, "/src/winsorize.R")) 2 | 3 | # Constants 4 | kMod.fname <- "rfmod.Rdata" 5 | kMod.xyw.fname <- "xywtrain.Rdata" 6 | kMod.x.levels.fname <- "xtrain_levels.txt" 7 | kMod.x.levels.lowcount.fname <- "xtrain_levels_lowcount.txt" 8 | kMod.x.trim.fname <- "xtrain_trim.txt" 9 | kMod.yHat.fname <- "xtrain_yHat.Rdata" 10 | kMod.varimp.fname <- "varimp.txt" 11 | kMod.rules.fname <- "rules.txt" 12 | kMod.intercept.fname <- "intercept.txt" 13 | 14 | GetColType <- function(cname, col.types) 15 | { 16 | # Returns an integer from the vector of column types matching the given column name. 17 | i.col.type <- grep(paste("^", cname, "$", sep=""), names(col.types), perl=T, ignore.case=T) 18 | if ( length(i.col.type) == 1 ) { 19 | return(col.types[i.col.type]) 20 | } else { 21 | return(NA) 22 | } 23 | } 24 | 25 | WriteLevels <- function (x.df, col.types, out.fname) 26 | { 27 | # Writes out "levels" for each column of type categorical in the given data. 28 | # 29 | # Args: 30 | # x.df : data frame 31 | # col.types : vector indicating column types -- 1:continuous 32 | # 2:categorical 33 | # out.fname : ouput file name 34 | # 35 | # Returns: 36 | # None. 37 | stopifnot(is.character(out.fname)) 38 | stopifnot(class(x.df) == "data.frame") 39 | stopifnot(length(col.types) >= ncol(x.df)) 40 | 41 | # Make sure we have the appropriate quotation mark (non-directional single 42 | # quotation mark) 43 | op <- options("useFancyQuotes") 44 | options(useFancyQuotes = FALSE) 45 | 46 | outF <- file(out.fname, "w") 47 | for (i in 1:ncol(x.df)) { 48 | cname <- colnames(x.df)[i] 49 | ctype <- GetColType(cname, col.types) 50 | if (is.na(ctype)) { 51 | error(logger, paste("WriteLevels: don't know type for this var: ", cname)) 52 | } else { 53 | if ( ctype == 2 ) { 54 | clevels <- levels(as.factor(x.df[, i])) 55 | cat(sQuote(cname), sQuote(clevels), file = outF, sep=",", append = T) 56 | cat("\n", file = outF, append = T) 57 | } else if ( ctype == 1 ) { 58 | cat(sQuote(cname), file = outF, append = T) 59 | cat("\n", file = outF, append = T) 60 | } else { 61 | error(logger, paste("WriteLevels: don't know about this var type: ", ctype)) 62 | } 63 | } 64 | } 65 | close(outF) 66 | 67 | # Restore quotation mark 68 | options(useFancyQuotes = op) 69 | } 70 | 71 | WriteLevelsLowCount <- function (x.recoded.cat.vars, out.fname) 72 | { 73 | # Writes out low-count "levels" for columns of type categorical in the given data. 74 | # 75 | # Args: 76 | # x.recoded.cat.vars : list of pairs 77 | # out.fname : ouput file name 78 | # 79 | # Returns: 80 | # None. 81 | stopifnot(is.character(out.fname)) 82 | stopifnot(class(x.recoded.cat.vars) == "list") 83 | 84 | # Make sure we have the appropriate quotation mark (non-directional single 85 | # quotation mark) 86 | op <- options("useFancyQuotes") 87 | options(useFancyQuotes = FALSE) 88 | 89 | outF <- file(out.fname, "w") 90 | for (i.var in 1:length(x.recoded.cat.vars)) { 91 | var.name <- (x.recoded.cat.vars[[i.var]])$var 92 | var.lcount.levels <- (x.recoded.cat.vars[[i.var]])$low.count.levels 93 | cat(sQuote(var.name), sQuote(var.lcount.levels), file = outF, sep=",", append = T) 94 | cat("\n", file = outF, append = T) 95 | } 96 | close(outF) 97 | 98 | # Restore quotation mark 99 | options(useFancyQuotes = op) 100 | } 101 | 102 | WriteTrimQuantiles <- function(x.df, col.types, beta, out.fname, feat2winz) 103 | { 104 | # Writes out "trim" quantiles for each column of type numeric in the given 105 | # data. 106 | # 107 | # Args: 108 | # x.df : data frame 109 | # col.types : vector indicating column types -- 1:continuous 110 | # 2:categorical 111 | # out.fname : ouput file name 112 | # beta : the beta and (1-beta) quantiles of the data distribution 113 | # {x_ij} for each continuous variable x_j will be written out. 114 | # Returns: 115 | # None. 116 | stopifnot(is.character(out.fname)) 117 | stopifnot(length(col.types) >= ncol(x.df)) 118 | 119 | if (!is.na(beta)) { 120 | if (beta < 0 || beta > 0.5) { 121 | error(logger, paste("WriteTrimQuantiles: Invalid 'beta' value:", beta)) 122 | } 123 | 124 | outF <- file(out.fname, "w") 125 | for (i in 1:ncol(x.df)) { 126 | cname <- colnames(x.df)[i] 127 | ctype <- GetColType(cname, col.types) 128 | if (is.na(ctype)) { 129 | error(logger, paste("WriteTrimQuantiles: don't know type for this var:", cname)) 130 | } else { 131 | if (ctype == 1) { 132 | if (is.numeric(x.df[, i])) { 133 | l.x <- winsorize(x.df[, i], beta) 134 | if (length(l.x) != 3) { 135 | error(logger, paste("WriteTrimQuantiles: expecting 3 values from call to winsorize with this var:", cname)) 136 | } 137 | x.min2keep <- l.x$min2keep 138 | x.max2keep <- l.x$max2keep 139 | x.mean <- mean(l.x$x, na.rm = T) 140 | cat(sQuote(cname), x.min2keep, x.max2keep, x.mean, file = outF, sep=",", append = T) 141 | cat("\n", file = outF, append = T) 142 | } else { 143 | dbg(logger, paste("WriteTrimQuantiles: can't winsorize this var:", cname)) 144 | cat(sQuote(cname), NA, NA, NA, file = outF, sep=",", append = T) 145 | cat("\n", file = outF, append = T) 146 | } 147 | } else if (ctype == 2) { 148 | cat(sQuote(cname), NA, NA, NA, file = outF, sep=",", append = T) 149 | cat("\n", file = outF, append = T) 150 | } else { 151 | error(logger, paste("WriteTrimQuantiles: don't know about this var type: ", ctype)) 152 | } 153 | } 154 | } 155 | close(outF) 156 | } else { 157 | # trims have already been computed 158 | if ( missing(feat2winz)) { 159 | error(logger, "WriteTrimQuantiles: 'feat2winz' must not be missing when 'beta' isn't specified") 160 | } 161 | if (class(feat2winz) != "data.frame" || 162 | length(which((names(feat2winz) == c("vname", "beta", "min2keep", "max2keep", "mean"))==T)) != 5) { 163 | error(logger, "WriteTrimQuantiles: 'feat2winz' must not be missing when 'beta' isn't specified") 164 | } 165 | outF <- file(file, "w") 166 | for (i in 1:ncol(x.df)) { 167 | cname <- colnames(x.df)[i] 168 | ctype <- GetColType(cname, col.types) 169 | if (is.na(ctype)) { 170 | error(logger, paste("WriteTrimQuantiles: don't know type for this var: ", cname)) 171 | } else { 172 | if (ctype == 1) { 173 | iRow <- grep(paste("^", cname, "$", sep=""), feat2winz$vname, perl=T) 174 | if ( length(iRow) == 1 ) { 175 | cat(sQuote(cname), feat2winz$min2keep[iRow], feat2winz$max2keep[iRow], feat2winz$mean[iRow], 176 | file = outF, sep=",", append = T) 177 | cat("\n", file = outF, append = T) 178 | } else { 179 | error(logger, paste("WriteTrimQuantiles: Didn't find: ", cname, "in trim list")) 180 | } 181 | } else if (ctype == 2) { 182 | cat(sQuote(cname), NA, NA, NA, file = outF, sep=",", append = T) 183 | cat("\n", file = outF, append = T) 184 | } else { 185 | error(logger, paste("WriteTrimQuantiles: don't know about this var type: ", ctype)) 186 | } 187 | } 188 | } 189 | close(outF) 190 | } 191 | } 192 | 193 | ReadTrimQuantiles <- function(in.fname) 194 | { 195 | # Reads in "trim" quantiles, tuples , from the specified file. 196 | trims.df <- read.csv(in.fname, header = FALSE, stringsAsFactors = FALSE, quote="'") 197 | colnames(trims.df) <- c("vname", "min2keep", "max2keep", "mean") 198 | return(trims.df) 199 | } 200 | 201 | WriteObsIdYhat <- function(out.path, obs.id, y, y.hat, field.sep = ",", file.name = "id_y_yHat.csv") 202 | { 203 | # Writes out tuples to a text file. Useful for loading score into a db. 204 | # 205 | # Args: 206 | # out.path : where output file will be written 207 | # obs.id : row-id values from training data-frame 208 | # y : training target vector 209 | # y.hat : predicted response on training data 210 | # Returns: 211 | # None. 212 | stopifnot(is.character(out.path)) 213 | stopifnot(length(y) == length(y.hat)) 214 | stopifnot(length(y) == length(obs.id)) 215 | 216 | out.df <- data.frame(cbind(obs.id, y, y.hat)) 217 | write.table(out.df, file = file.path(out.path, file.name), row.names = F, quote = F, na ="", sep = field.sep) 218 | } 219 | 220 | ExportModel <- function(rfmod, rfmod.path, x, y, wt, y.hat, out.path, x.df, col.types, winz=0.025, x.recoded.cat.vars=NULL) 221 | { 222 | # Writes out all components of a given RuleFit model required for a "restore" 223 | # within R, or an export as ASCII text rules. 224 | # 225 | # Args: 226 | # rfmod : rulefit model object 227 | # rfmod.path : path to rfmod related files 228 | # x : training data matrix 229 | # y : training target vector 230 | # wt : training weights 231 | # y.hat : predicted response on training data 232 | # out.path : where output (exported) files will be written 233 | # x.df : training data frame 234 | # col.types : vector indicating column types -- 1:continuous 235 | # 2:categorical 236 | # winz : a data-frame, if column-specific beta was used for winsorizing; 237 | # otherwise the 'global' beta used 238 | # x.recoded.cat.vars : list of 239 | # Requires: 240 | # - REGO__HOME to be already set 241 | # 242 | # Returns: 243 | # None. 244 | stopifnot(is.character(out.path)) 245 | stopifnot(class(x) == "matrix") 246 | stopifnot(length(y) == nrow(x)) 247 | stopifnot(length(y) == length(wt)) 248 | stopifnot(length(y) == length(y.hat)) 249 | stopifnot(class(x.df) == "data.frame") 250 | stopifnot(nrow(x.df) == nrow(x) && ncol(x.df) == ncol(x)) 251 | stopifnot(length(col.types) >= ncol(x.df)) 252 | 253 | GetRules <- function(rfmod.path) { 254 | # Returns the RuleFit model rules in the given directory as text strings. 255 | if (GetRF_WORKING_DIR() == rfmod.path) { 256 | ruleList <- getrules() 257 | } else { 258 | warn(logger, paste("GetRules: Failed to retrieve rules from: ", rfmod.path)) 259 | ruleList <- NA 260 | } 261 | return(ruleList) 262 | } 263 | 264 | GetIntercept <- function(rfmod.path) { 265 | # Returns the intercept of the RuleFit model in the given directory. 266 | if (GetRF_WORKING_DIR() == rfmod.path) { 267 | c0 <- getintercept() 268 | } else { 269 | warn(logger, paste("GetIntercept: Failed to retrieve intercept from: ", rfmod.path)) 270 | c0 <- NA 271 | } 272 | return(c0) 273 | } 274 | 275 | # Save model 276 | save(rfmod, file = file.path(out.path, kMod.fname)) 277 | 278 | # Save train data; required to restore model 279 | save(x, y, wt, file = file.path(out.path, kMod.xyw.fname)) 280 | 281 | # Save yHat (on x train) -- optional 282 | save(y.hat, file = file.path(out.path, kMod.yHat.fname)) 283 | 284 | # Write out var imp statistic 285 | vi <- varimp(plot = FALSE) 286 | sink(file.path(out.path, kMod.varimp.fname)) 287 | for (i in 1:length(vi$ord)) { 288 | cat(vi$imp[i], "\t", colnames(x)[vi$ord[i]], "\n", sep = "") 289 | } 290 | sink() 291 | 292 | # Write out rules as text -- includes coefficients. 293 | rulesStr <- GetRules(rfmod.path) 294 | sink(file.path(out.path, kMod.rules.fname)) 295 | for (i in 4:length(rulesStr)) { # skip header 296 | cat(rulesStr[i], "\n") 297 | } 298 | sink() 299 | 300 | # Write out intercept 301 | c0 <- GetIntercept(rfmod.path) 302 | sink(file.path(out.path, kMod.intercept.fname)) 303 | cat(c0, "\n") 304 | sink() 305 | 306 | # Write out levels for categ vars -- needed to decode rules later on 307 | WriteLevels(x.df, col.types, file.path(out.path, kMod.x.levels.fname)) 308 | 309 | # Write out recoded categorical variables (if any) 310 | if (!is.null(x.recoded.cat.vars) && length(x.recoded.cat.vars) > 0) { 311 | WriteLevelsLowCount(x.recoded.cat.vars, file.path(out.path, kMod.x.levels.lowcount.fname)) 312 | } 313 | 314 | # Write out trim-quantiles -- needed to run model outside R 315 | if ( class(winz) == "data.frame") { 316 | WriteTrimQuantiles(x.df, col.types, beta=NA, file.path(out.path, kMod.x.trim.fname), winz) 317 | } else { 318 | WriteTrimQuantiles(x.df, col.types, beta=winz, file.path(out.path, kMod.x.trim.fname)) 319 | } 320 | } 321 | 322 | LoadModel <- function(model.path) 323 | { 324 | # Loads previously exported RuleFit model components required for a "restore" 325 | # operation. 326 | # 327 | # Args: 328 | # model.path : path to RuleFit model exported files 329 | # 330 | # Returns: 331 | # Tuple 332 | # 333 | # Notes: 334 | # - Should be followed by rfrestore(rfmod, x=x, y=y, wt=wt). 335 | stopifnot(is.character(model.path)) 336 | 337 | if (file.access(model.path, mode=4) != 0) { 338 | error(logger, paste("LoadModel: Path: '", model.path, "' is not accessible")) 339 | } 340 | if (!file.exists(file.path(model.path, kMod.xyw.fname))) { 341 | error(logger, paste("LoadModel: Can't find file: ", kMod.xyw.fname)) 342 | } 343 | if (!file.exists(file.path(model.path, kMod.fname))) { 344 | error(logger, paste("LoadModel: Can't find file: ", kMod.fname)) 345 | } 346 | 347 | # Load data 348 | saved.objs <- load(file = file.path(model.path, kMod.xyw.fname)) 349 | if (any(saved.objs != c("x", "y", "wt"))) { 350 | error(logger, paste("LoadModel: Failed to find required objects in: ", file.path(model.path, kMod.xyw.fname))) 351 | } 352 | 353 | # Load RuleFit model 354 | saved.objs <- load(file = file.path(model.path, kMod.fname)) 355 | if ( length(which((saved.objs == c("rfmod")) == T)) != 1 ) { 356 | error(logger, paste("LoadModel: Failed to find required objects in: ", file.path(model.path, kMod.fname))) 357 | } 358 | 359 | return(list(rfmod = rfmod, x = x, y = y, wt = wt)) 360 | } 361 | -------------------------------------------------------------------------------- /src/rfPredict_main.R: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/Rscript 2 | 3 | ############################################################################### 4 | # FILE: rfPredict_main.R 5 | # 6 | # USAGE: rfPredict_main.R -m -d DATA.conf 7 | # 8 | # DESCRIPTION: 9 | # Computes predictions from a fitted RuleFit model. 10 | # 11 | # ARGUMENTS: 12 | # model.dir: path to RuleFit model exported files 13 | # DATA.conf: specifies test data location. 14 | # 15 | # REQUIRES: 16 | # REGO_HOME: environment variable pointing to the directory where you 17 | # have placed this file (and its companion ones) 18 | # RF_HOME: environment variable pointing to appropriate RuleFit 19 | # executable -- e.g., export RF_HOME=$REGO_HOME/lib/RuleFit/mac 20 | # 21 | # AUTHOR: Giovanni Seni 22 | ############################################################################### 23 | REGO_HOME <- Sys.getenv("REGO_HOME") 24 | source(file.path(REGO_HOME, "/src/logger.R")) 25 | source(file.path(REGO_HOME, "/src/rfExport.R")) 26 | source(file.path(REGO_HOME, "/src/rfTrain.R")) 27 | source(file.path(REGO_HOME, "/lib/RuleFit", ifelse(nchar(Sys.getenv("RF_API")) > 0, Sys.getenv("RF_API"), "rulefit.r"))) 28 | source(file.path(REGO_HOME, "/src/rfGraphics.R")) 29 | source(file.path(REGO_HOME, "/src/rfRulesIO.R")) 30 | library(ROCR, verbose = FALSE, quietly=TRUE, warn.conflicts = FALSE) 31 | library(getopt) 32 | 33 | ValidateConfigArgs <- function(conf) 34 | { 35 | # Validates and initializes configuration parameters. 36 | # 37 | # Args: 38 | # conf: A list of pairs 39 | # Returns: 40 | # A list of pairs 41 | 42 | ## Must have a valid data source type 43 | stopifnot("data.source.type" %in% names(conf)) 44 | stopifnot(conf$data.source.type %in% c("csv", "db", "rdata")) 45 | if (conf$data.source.type == "db") { 46 | stopifnot("db.dsn" %in% names(conf) && "db.name" %in% names(conf) && 47 | "db.type" %in% names(conf) && "db.tbl.name" %in% names(conf)) 48 | } else if (conf$data.source.type == "csv") { 49 | stopifnot("csv.path" %in% names(conf) && "csv.fname" %in% names(conf)) 50 | if ("csv.sep" %in% names(conf)) { 51 | conf$csv.sep <- as.character(conf$csv.sep) 52 | } else { 53 | conf$csv.sep <- "," 54 | } 55 | } else { # rdata 56 | stopifnot(c("rdata.path", "rdata.fname") %in% names(conf)) 57 | } 58 | 59 | ## Auto-detect the platform: 60 | conf$rf.platform <- switch(.Platform$OS.type 61 | , windows = "windows" 62 | , unix = switch(Sys.info()["sysname"] 63 | , Linux = "linux" 64 | , Darwin = "mac")) 65 | 66 | if (is.null(conf$rf.platform)) error(logger, "Unable to detect platform") 67 | 68 | ## Did user specified a log level? 69 | if (is.null(conf$log.level)) { 70 | conf$log.level <- kLogLevelINFO 71 | } else { 72 | conf$log.level <- get(conf$log.level) 73 | } 74 | 75 | ## Did user specified an output file name? 76 | if (is.null(conf$out.fname)) { 77 | conf$out.fname <- "rfPredict_out.csv" 78 | } 79 | ## ... field separator? 80 | if (is.null(conf$out.sep)) { 81 | conf$out.sep <- "," 82 | } else if (nchar(conf$out.sep) == 2 && (conf$out.sep == paste("\\", "t", sep=""))) { 83 | conf$out.sep <- "\t" 84 | } 85 | 86 | ## Generate plots? 87 | if (is.null(conf$graph.plots.ROC)) { 88 | conf$graph.plots.ROC <- TRUE 89 | } else { 90 | conf$graph.plots.ROC <- (as.numeric(conf$graph.plots.ROC) == 1) 91 | } 92 | ## ... LIFT, Gain, etc. plots? 93 | if (is.null(conf$graph.plots.extra)) { 94 | conf$graph.plots.extra <- FALSE 95 | } else { 96 | conf$graph.plots.extra <- (as.numeric(conf$graph.plots.extra) == 1) 97 | } 98 | 99 | return(conf) 100 | } 101 | 102 | ValidateCmdArgs <- function(opt, args.m) 103 | { 104 | # Parses and validates command line arguments. 105 | # 106 | # Args: 107 | # opt: getopt() object 108 | # args.m: valid arguments spec passed to getopt(). 109 | # 110 | # Returns: 111 | # A list of pairs 112 | kUsageString <- "/path/to/rfPredict_main.R -m -d " 113 | 114 | # Validate command line arguments 115 | if ( !is.null(opt$help) || is.null(opt$model_path) || is.null(opt$data_conf) ) { 116 | self <- commandArgs()[1] 117 | cat("Usage: ", kUsageString, "\n") 118 | q(status=1); 119 | } 120 | 121 | # Read config file (two columns assumed: 'param' and 'value') 122 | tmp <- read.table(opt$data_conf, header=T, as.is=T) 123 | conf <- as.list(tmp$value) 124 | names(conf) <- tmp$param 125 | conf <- ValidateConfigArgs(conf) 126 | 127 | # Check Model path 128 | if (!(file.exists(opt$model_path))) { 129 | stop("Didn't find model directory:", opt$model_path, "\n") 130 | } else { 131 | conf$model.path <- opt$model_path 132 | } 133 | 134 | # Do we have a log file name? "" will send messages to stdout 135 | if (is.null(opt$log)) { 136 | opt$log <- "" 137 | } 138 | conf$log.fname <- opt$log 139 | 140 | return(conf) 141 | } 142 | 143 | CheckFactorsEncoding <- function(x.test, x.train.levels, x.train.levels.lowcount=NULL) 144 | { 145 | # Check that the integer codes given to factors in x.test are the same as the ordering 146 | # used when the model was built. Otherwise, when the RuleFit::rfpred() function casts 147 | # the data frame to matrix, a different ordering would lead to incorrect predictions. 148 | # 149 | # Args: 150 | # x.test : data frame 151 | # x.train.levels : {} list 152 | # x.train.levels.lowcount : {} df (if exists) 153 | # Returns: 154 | # A copy of the given test data frame with transformed columns 155 | for (iVar in 1:length(x.train.levels)) { 156 | var.levels.train <- x.train.levels[[iVar]]$levels 157 | # Was iVar a factor at train time? 158 | if (!(is.null(var.levels.train))) { 159 | factor.name <- x.train.levels[[iVar]]$var 160 | factor.vals <- as.character(x.test[, factor.name]) 161 | 162 | # Were there low-count levels at train time? If so, replace them in x.test too 163 | if (!is.null(x.train.levels.lowcount)) { 164 | i.recoded.var <- grep(paste("^", factor.name, "$", sep=""), x.train.levels.lowcount$var, perl=T) 165 | if (length(i.recoded.var) > 0) { 166 | warn(logger, paste("CheckFactorsEncoding: replacing low-count levels for '", factor.name)) 167 | low.count.levels <- unlist(x.train.levels.lowcount$levels[i.recoded.var]) 168 | factor.vals <- ifelse(factor.vals %in% low.count.levels, kLowCountLevelsName, factor.vals) 169 | } 170 | } 171 | # Check for presence of new levels and replace them with NA (if any) 172 | levels.diff <- setdiff(unique(factor.vals), var.levels.train) 173 | if (length(levels.diff) > 0) { 174 | warn(logger, paste("CheckFactorsEncoding: new levels found for '", factor.name, "' : ", 175 | paste(lapply(levels.diff, sQuote), collapse=","), 176 | "; replacing with NA")) 177 | factor.vals <- ifelse(factor.vals %in% levels.diff, NA, factor.vals) 178 | } 179 | 180 | # Lastly, make sure we have the same level ordering 181 | x.test[, factor.name] <- factor(factor.vals, levels = var.levels.train) 182 | } 183 | } 184 | return(x.test) 185 | } 186 | 187 | ############## 188 | ## Main 189 | # 190 | 191 | # Grab command-line arguments 192 | args.m <- matrix(c( 193 | 'model_path' ,'m', 1, "character", 194 | 'data_conf' ,'d', 1, "character", 195 | 'log' ,'l', 1, "character", 196 | 'help' ,'h', 0, "logical" 197 | ), ncol=4,byrow=TRUE) 198 | opt <- getopt(args.m) 199 | conf <- ValidateCmdArgs(opt, args.m) 200 | 201 | # Set global env variables required by RuleFit 202 | platform <- conf$rf.platform 203 | RF_HOME <- Sys.getenv("RF_HOME") 204 | RF_WORKING_DIR <- conf$rf.working.dir 205 | 206 | # Create logging object 207 | logger <- new("logger", log.level = conf$log.level, file.name = conf$log.fname) 208 | info(logger, paste("rfPredict_main args:", 'model.path =', conf$model.path, 209 | ', log.level =', conf$log.level, ', out.fname =', conf$out.fname)) 210 | 211 | ## Use own version of png() if necessary: 212 | if (isTRUE(conf$graph.dev == "Bitmap")) { 213 | png <- png_via_bitmap 214 | if (!CheckWorkingPNG(png)) error(logger, "cannot generate PNG graphics") 215 | } else { 216 | png <- GetWorkingPNG() 217 | if (is.null(png)) error(logger, "cannot generate PNG graphics") 218 | } 219 | 220 | # Load data 221 | if (conf$data.source.type == "db") { 222 | error(logger, paste("rfPredict_main.R: not yet implemented data source type ", conf$data.source.type)) 223 | } else if (conf$data.source.type == "csv") { 224 | data <- read.csv(file.path(conf$csv.path, conf$csv.fname), na.strings = "", sep=conf$csv.sep, check.names = FALSE) 225 | } else if (conf$data.source.type == "rdata") { 226 | envir <- new.env() 227 | load(file.path(conf$rdata.path, conf$rdata.fname), envir = envir) 228 | if (is.null(conf$rdata.dfname)) { 229 | dfname <- ls(envir) 230 | stopifnot(length(dfname) == 1) 231 | } else { 232 | dfname <- conf$rdata.dfname 233 | } 234 | data <- get(dfname, envir, inherits = FALSE) 235 | stopifnot(is.data.frame(data)) 236 | rm(envir) 237 | } else { 238 | error(logger, paste("rfPredict_main.R: unknown data source type ", conf$data.source.type)) 239 | } 240 | info(logger, paste("Data loaded: dim =", nrow(data), "x", ncol(data), "; NAs =", 241 | length(which(is.na(data) == T)), "(", 242 | round(100*length(which(is.na(data) == T))/(nrow(data)*ncol(data)), 2), 243 | "%)")) 244 | 245 | # Load & restore model 246 | mod <- LoadModel(conf$model.path) 247 | ok <- 1 248 | tryCatch(rfrestore(mod$rfmod, mod$x, mod$y, mod$wt), error = function(err){ok <<- 0}) 249 | if (ok == 0) { 250 | error(logger, "rfPredict_main.R: got stuck in rfrestore") 251 | } 252 | rf.mode <- ifelse(length(unique(mod$y)) == 2, "class", "regress") 253 | 254 | # Extract columns used to build model 255 | ok <- 1 256 | tryCatch(x.test <- data[,colnames(mod$x)], error = function(err){ok <<- 0}) 257 | if (ok == 0) { 258 | error(logger, "rfPredict_main.R: train/test column mismatch") 259 | } 260 | 261 | # Any preprocessing needed? 262 | # ... Ensure factor levels are encoded in the same order used at model building time 263 | # ... and substitute low-count levels (if appropriate) 264 | x.levels.fname <- file.path(conf$model.path, kMod.x.levels.fname) 265 | x.levels <- ReadLevels(x.levels.fname) 266 | x.levels.lowcount.fname <- file.path(conf$model.path, kMod.x.levels.lowcount.fname) 267 | if (file.exists(x.levels.lowcount.fname)) { 268 | x.levels.lowcount <- as.data.frame(do.call("rbind", ReadLevels(x.levels.lowcount.fname))) 269 | x.test <- CheckFactorsEncoding(x.test, x.levels, x.levels.lowcount) 270 | } else { 271 | x.test <- CheckFactorsEncoding(x.test, x.levels) 272 | } 273 | 274 | # Predict 275 | y.hat <- rfpred(x.test) 276 | if (rf.mode == "class") { 277 | # "classification" model... convert from log-odds to probability estimates 278 | y.hat <- 1.0/(1.0+exp(-y.hat)) 279 | } 280 | 281 | # Compute test error (if y is known) 282 | if ("col.y" %in% names(conf)) { 283 | y <- data[,conf$col.y] 284 | if (rf.mode == "class") { 285 | conf.m <- table(y, sign(y.hat - 0.5)) 286 | stopifnot("0" %in% rownames(conf.m)) 287 | stopifnot("1" %in% rownames(conf.m)) 288 | TN <- ifelse("-1" %in% colnames(conf.m), conf.m["0", "-1"], 0) 289 | FP <- ifelse("1" %in% colnames(conf.m), conf.m["0","1"], 0) 290 | FN <- ifelse("-1" %in% colnames(conf.m), conf.m["1", "-1"], 0) 291 | TP <- ifelse("1" %in% colnames(conf.m), conf.m["1","1"], 0) 292 | test.acc <- 100*(TN+TP)/length(y.hat) 293 | info(logger, paste("Test acc:", round(test.acc, 2))) 294 | info(logger, sprintf("Test confusion matrix - 0/0: %d, 0/1: %d, 1/0: %d, 1/1: %d", 295 | TN, FP, FN, TP)) 296 | # AUC 297 | pred <- prediction(y.hat, y) 298 | perf <- performance(pred, "auc") 299 | info(logger, paste("Area under the ROC curve:", as.numeric(perf@y.values))) 300 | 301 | # Generate ROC plot 302 | if (conf$graph.plots.ROC) { 303 | kPlotWidth <- 620 304 | kPlotHeight <- 480 305 | plot.fname <- "ROC.png" 306 | pred <- prediction(y.hat, y) 307 | perf <- performance(pred, "tpr", "fpr") 308 | png(file = file.path(conf$out.path, plot.fname), width=kPlotWidth, height=kPlotHeight) 309 | plot(perf, colorize=T, main="") 310 | lines(x=c(0, 1), y=c(0,1)) 311 | dev.off() 312 | } 313 | # Generate extra plots 314 | if (conf$graph.plots.extra) { 315 | # Generate LIFT plot 316 | perf <- performance(pred,"lift","rpp") 317 | plot.fname <- "LIFT.png" 318 | png(file = file.path(conf$out.path, plot.fname), width=kPlotWidth, height=kPlotHeight) 319 | plot(perf, main="Lift curve", colorize=T) 320 | lines(x=c(0, 1), y=c(1,1)) 321 | dev.off() 322 | 323 | # Generate Cumulative Gains plot 324 | perf <- performance(pred,"tpr","rpp") 325 | plot.fname <- "Gains.png" 326 | png(file = file.path(conf$out.path, plot.fname), width=kPlotWidth, height=kPlotHeight) 327 | plot(perf, main="Cumulative Gains curve", colorize=T) 328 | lines(x=c(0, 1), y=c(0,1)) 329 | dev.off() 330 | 331 | # Score histogram colored by y 332 | library(ggplot2) 333 | plot.fname = "score_hist1.png" 334 | hist.yresult <- data.frame(x = y.hat, category = factor(y)) 335 | p <- ggplot(hist.yresult, aes(x, colour = category)) 336 | p <- p + geom_freqpoly() 337 | p <- p + xlab('predicted score colored by true category') 338 | ggsave(file.path(conf$out.path, plot.fname), width = kPlotWidth/100, height = kPlotHeight/100, units = "in") 339 | 340 | # Score histogram (stacked) 341 | library(ggplot2) 342 | plot.fname = "score_hist2.png" 343 | p <- ggplot(hist.yresult, aes(x, fill = category)) 344 | p <- p + geom_bar(position = "fill") + stat_bin(binwidth=0.1) 345 | p <- p + xlab('predicted score colored by true category') + ylab('percentage') 346 | ggsave(file.path(conf$out.path, plot.fname), width = kPlotWidth/100, height = kPlotHeight/100, units = "in") 347 | } 348 | } else { 349 | re.test.error <- sum(abs(y.hat - y))/nrow(x.test) 350 | med.test.error <- sum(abs(y - median(y)))/nrow(x.test) 351 | aae.test <- re.test.error / med.test.error 352 | info(logger, sprintf("Test AAE: %f (RE:%f, Med:%f)", aae.test, re.test.error, med.test.error)) 353 | } 354 | } else { 355 | y <- rep(NA, nrow(data)) 356 | } 357 | 358 | # Plot histogram of yHat 359 | kPlotWidth <- 620 360 | kPlotHeight <- 480 361 | plot.fname <- "yHat_hist.png" 362 | png(file = file.path(conf$out.path, plot.fname), width=kPlotWidth, height=kPlotHeight) 363 | hist(y.hat, main="") 364 | dev.off() 365 | 366 | # Dump tuples, as appropriate 367 | if ("col.id" %in% names(conf)) { 368 | obs.id <- data[,conf$col.id] 369 | } else { 370 | obs.id <- rep(NA, nrow(data)) 371 | } 372 | WriteObsIdYhat(out.path = conf$out.path, obs.id = obs.id, y = y, y.hat = y.hat, field.sep = conf$out.sep, file.name = conf$out.fname) 373 | info(logger, "Done!") 374 | 375 | q(status=0) 376 | -------------------------------------------------------------------------------- /src/rfRulesIO.R: -------------------------------------------------------------------------------- 1 | source(file.path(REGO_HOME, "/src/rfPreproc.R")) 2 | 3 | # Constants 4 | kRuleTypeLinear <- "linear" 5 | kRuleTypeSplit <- "split" 6 | kSplitTypeContinuous <- "continuous" 7 | kSplitTypeCategorical <- "categorical" 8 | kMinusInf <- -9.9e+35 9 | kPlusInf <- 9.9e+35 10 | kMissing <- 9e+30 11 | kCondIN <- "in" 12 | kCondNotIN <- "not" 13 | 14 | ReadRules <- function(file) 15 | { 16 | # Reads the given text file, assumed to contain RuleFit-generated rules, 17 | # and returns a corresponding list representation. 18 | # 19 | # Args: 20 | # file: a character string or connection (this file is typically 21 | # created by the ExportModel() function) 22 | # Returns: 23 | # A list of tuples, 24 | # where: 25 | # type: "split" or "linear" term 26 | # supp: rule support, if "split" term 27 | # std: standard deviation of predictor, if "linear" term 28 | # coeff: term's coefficient 29 | # imp: term's importance 30 | # splits: list of one or more splits defining "split" term; can be 31 | # or 32 | # var: predictor name, if "linear" term 33 | stopifnot(is.character(file) || inherits(file, "connection")) 34 | 35 | NextRuleFitToken <- function(file) { 36 | # Fetches the next "token" from the given RuleFit rules text file. 37 | # 38 | # Args: 39 | # file: an open file or connection 40 | # Returns: 41 | # A string. 42 | kSep <- "=" # value separator 43 | 44 | # Skip "space" until beginning of token 45 | begunToken <- FALSE 46 | aChar <- readChar(file, 1) 47 | while (length(aChar) != 0 && begunToken == FALSE && aChar != "") { 48 | # skip whitespace 49 | while (regexpr("[[:space:]]", aChar) != -1 || aChar == kSep) { 50 | aChar <- readChar(file, 1) 51 | if ( length(aChar) == 0) break 52 | } 53 | aToken <- aChar 54 | begunToken <- TRUE 55 | } 56 | 57 | if (begunToken) { 58 | aChar <- readChar(file, 1) 59 | # Until a brace, whitespace, comma, or EOF 60 | while (length(aChar) != 0 61 | && aChar != "" 62 | && regexpr("[[:space:]]", aChar) == -1 63 | && !aChar %in% c(":", kSep) ) 64 | { 65 | aToken <- paste(aToken, aChar, sep="") 66 | aChar <- readChar(file, 1) 67 | } 68 | return(list(token=aToken)) 69 | } else { 70 | return(list(token=NULL)) 71 | } 72 | } 73 | 74 | ParseLinearRule <- function(file) { 75 | # Parses and returns a "linear" rule from the given RuleFit rules text file. 76 | # 77 | # Args: 78 | # file: an open file or connection 79 | # Returns: 80 | # A tuple . 81 | kStdStr <- "std" 82 | kCoefficientStr <- "coeff" 83 | kImportanceStr <- "impotance" # misspelling is correct 84 | 85 | # Get variable name 86 | res <- NextRuleFitToken(file) 87 | ruleVarName <- res$token 88 | 89 | # Get rule 'std' 90 | res <- NextRuleFitToken(file) 91 | if (res$token != kStdStr) { 92 | error(logger, paste("ParseLinearRule: '", support.str, "' token expected, got: ", res$token)) 93 | } 94 | res <- NextRuleFitToken(file) 95 | ruleStd <- as.numeric(res$token) 96 | 97 | # Get rule 'coefficient' 98 | res <- NextRuleFitToken(file) 99 | if (res$token != kCoefficientStr) { 100 | error(logger, paste("ParseLinearRule: '", kCoefficientStr, "' token expected, got: ", res$token)) 101 | } 102 | res <- NextRuleFitToken(file) 103 | ruleCoeff <- as.numeric(res$token) 104 | 105 | # Get rule 'importance' 106 | res <- NextRuleFitToken(file) 107 | if (res$token != kImportanceStr) { 108 | error(logger, paste("ParseLinearRule: '", kImportanceStr, "' token expected, got: ", res$token)) 109 | } 110 | res <- NextRuleFitToken(file) 111 | ruleImp <- as.numeric(res$token) 112 | 113 | return(list(type="linear", std = ruleStd, coeff = ruleCoeff, imp = ruleImp, var = ruleVarName)) 114 | } 115 | 116 | ParseSplitRule <- function(file, ruleLgth) { 117 | # Parse and return a "split" rule from the given RuleFit rules text file. 118 | # 119 | # Args: 120 | # file: an open file or connection 121 | # ruleLght: how many vars are in this rule? 122 | # 123 | # Returns: 124 | # A tuple where 'splits' is a 125 | # list with tuples of the form or 126 | # 127 | support.str <- "support" 128 | kCoefficientStr <- "coeff" 129 | kImportanceStr <- "importance" 130 | kCont.split.id.str <- "range" 131 | kCateg.split.id1.str <- "in" 132 | kCateg.split.id2.str <- "not" 133 | kCateg.missing <- "0.9000E+31" 134 | 135 | stopifnot(ruleLgth > 0) 136 | 137 | # Get rule 'support' 138 | res <- NextRuleFitToken(file) 139 | if (res$token != support.str) { 140 | error(logger, paste("ParseSplitRule: '", support.str, "' token expected, got: ", res$token)) 141 | } 142 | res <- NextRuleFitToken(file) 143 | ruleSupp <- as.numeric(res$token) 144 | 145 | # Get rule 'coefficient' 146 | res <- NextRuleFitToken(file) 147 | if (res$token != kCoefficientStr) { 148 | error(logger, paste("ParseSplitRule: '", kCoefficientStr, "' token expected, got: ", res$token)) 149 | } 150 | res <- NextRuleFitToken(file) 151 | ruleCoeff <- as.numeric(res$token) 152 | 153 | # Get rule 'importance' 154 | res <- NextRuleFitToken(file) 155 | if (res$token != kImportanceStr) { 156 | error(logger, paste("ParseSplitRule: '", kImportanceStr, "' token expected, got: ", res$token)) 157 | } 158 | res <- NextRuleFitToken(file) 159 | ruleImp <- as.numeric(res$token) 160 | 161 | # Get splits 162 | splits = vector(mode = "list") 163 | splitVarsSeen <- c() 164 | iSplit <- 1 165 | res <- NextRuleFitToken(file) 166 | while (length(res$token) > 0 && res$token != "Rule") { 167 | # Get variable name 168 | splitVarName <- res$token 169 | # Parse split according to 'type' 170 | res <- NextRuleFitToken(file) 171 | if (res$token == kCont.split.id.str) { 172 | # "Continuous" split... Got range... need min & max 173 | res <- NextRuleFitToken(file) 174 | splitRangeMin <- as.numeric(res$token) 175 | res <- NextRuleFitToken(file) 176 | splitRangeMax <- as.numeric(res$token) 177 | split <- list(type=kSplitTypeContinuous, var = splitVarName, min = splitRangeMin, max = splitRangeMax) 178 | } else if (res$token == kCateg.split.id1.str || res$token == kCateg.split.id2.str) { 179 | # "Categorical" split... need levels 180 | if (res$token == kCateg.split.id2.str) { 181 | categ.cond <- kCateg.split.id2.str 182 | } else { 183 | categ.cond <- kCateg.split.id1.str 184 | } 185 | # ...skip until the end of the line 186 | readLines(file, n=1, ok=TRUE) 187 | # ... soak all levels, assumed to be in one single line 188 | level.line <- sub("^[ ]+", "", readLines(file, n = 1), perl=T) 189 | level.list.raw <- strsplit(level.line, "[ ]+", perl=T)[[1]] 190 | level.list <- c() 191 | for (iLevel in level.list.raw) { 192 | if (iLevel == kCateg.missing) { 193 | level.list <- c(level.list, NA) 194 | } else { 195 | level.list <- c(level.list, as.integer(iLevel)) 196 | } 197 | } 198 | split <- list(type=kSplitTypeCategorical, var = splitVarName, cond = categ.cond, levels = level.list) 199 | } else { 200 | error(logger, paste("ParseSplitRule: One of '", kCont.split.id.str, "', '", kCateg.split.id1.str, "', '", 201 | kCateg.split.id2.str, "' token expected, got: ", res$token)) 202 | } 203 | # Save split 204 | splits[[iSplit]] <- split 205 | iSplit <- iSplit + 1 206 | 207 | # Did we get a "new" variable, or one we had seen before? 208 | if (length(grep(paste("^", splitVarName, "$", sep = ""), splitVarsSeen, perl=T)) == 0) { 209 | splitVarsSeen <- c(splitVarsSeen, splitVarName) 210 | } 211 | 212 | # Get next token 213 | res <- NextRuleFitToken(file) 214 | } 215 | 216 | # Check distinct vars found matched input param 217 | if (length(splitVarsSeen) != ruleLgth) { 218 | error(logger, paste("ParseSplitRule: ruleLgth = ", ruleLgth, " given... found ", length(splitVarsSeen), " variables!")) 219 | } 220 | 221 | rule <- list(type="split", supp = ruleSupp, coeff = ruleCoeff, imp = ruleImp, splits = splits) 222 | return(list(rule=rule, lastToken=res)) 223 | } 224 | 225 | # ----------------------------------------------------------------------- 226 | # Open input file 227 | if (is.character(file)) { 228 | file <- file(file, "r") 229 | on.exit(close(file)) 230 | } 231 | if (!isOpen(file)) { 232 | open(file, "r") 233 | on.exit(close(file)) 234 | } 235 | 236 | # Read in the header information 237 | res <- NextRuleFitToken(file) 238 | 239 | # Read in rules 240 | rules = vector(mode = "list") 241 | while (length(res$token) > 0 && res$token == "Rule") { 242 | # Get rule number 243 | ruleNum <- as.integer(NextRuleFitToken(file)$token) 244 | 245 | # Parse rule according to 'type' 246 | res <- NextRuleFitToken(file) 247 | if (res$token == kRuleTypeLinear) { 248 | # "Linear" rule 249 | rule <- ParseLinearRule(file) 250 | # Get next token 251 | res <- NextRuleFitToken(file) 252 | } else { 253 | # "Split" rule 254 | ruleLgth <- res$token 255 | # ...skip until the end of the line 256 | readLines(file, n=1, ok=TRUE) 257 | # ... get rule info 258 | parseRes <- ParseSplitRule(file, as.integer(ruleLgth)) 259 | rule <- parseRes$rule 260 | # Get next token... 'ParseSplitRule' already advanced it, so just grab it 261 | res <- parseRes$lastToken 262 | } 263 | rules[[ruleNum]] <- rule 264 | } 265 | 266 | return(rules) 267 | } 268 | 269 | SplitRule2Char <- function(rule, x.levels, x.levels.lowcount) 270 | { 271 | # Turns a 'split' rule into a human 'readable' string. 272 | # 273 | # Args: 274 | # rule : list of > 275 | # tuples, where the split sublist has elements of the form 276 | # or 277 | # 278 | # x.levels : level info used when rules were built so we can translate 279 | # level codes in categorical splits (optional). 280 | # x.levels.lowcount : low-count levels collapsed into a single one when 281 | # rules were built (optional). 282 | # Returns: 283 | # A character vector with representation of the rule. 284 | kSplit.and.str <- "AND" 285 | stopifnot(length(rule) > 0) 286 | stopifnot(rule$type == "split") 287 | old.o <- options("useFancyQuotes" = FALSE) 288 | 289 | ContSplit2Char <- function(split) { 290 | # Turns a 'continuous' split into a human 'readable' string 291 | # 292 | # Args: 293 | # split : list of the form 294 | # 295 | # Returns: 296 | # A character vector with one of these strings: "var == NA", 297 | # "var != NA", "var <= split-value", "var > split-value", or 298 | # "var between split-value-low and split-value-high" 299 | if (split$max == kPlusInf) { 300 | if (split$min == kMissing) { 301 | # range = kMissing plus_inf" ==> "== NA" 302 | str <- paste(split$var, "== NA") 303 | } else { 304 | # range = xxx plus_inf" ==> "> xxx" 305 | str <- paste(split$var, ">", split$min) 306 | } 307 | } else if (split$min == kMinusInf) { 308 | if (split$max == kMissing) { 309 | # range = kMinusInf kMissing" ==> "!= NA" 310 | str <- paste(split$var, "!= NA") 311 | } else { 312 | # "range = kMinusInf xxx" ==> "<= xxx" 313 | str <- paste(split$var, "<=", split$max) 314 | } 315 | } else if (split$max == kMissing) { 316 | # range = xxx kMissing" ==> "> xxx" 317 | str <- paste(split$var, ">", split$min) 318 | } else if (split$max != kPlusInf && split$min != kMinusInf && 319 | split$max != kMissing && split$min != kMissing) { 320 | # "range = xxx yyy" ==> ">= xxx and < yyy" 321 | str <- paste(split$var, ">=", split$min, "and", split$var, "<", split$max) 322 | } else { 323 | error(logger, paste("ContSplit2Char: don't know how to print split: ", split)) 324 | } 325 | 326 | return(str) 327 | } 328 | 329 | CategSplit2Char <- function (split, x.levels, x.levels.lowcount) { 330 | # Turns a 'categorical' split into a human 'readable' string. 331 | # 332 | # Args: 333 | # split : list of the form 334 | # x.levels : (optional) - {} list 335 | # x.levels.lowcount : (optional) - {} df 336 | # 337 | # Returns: 338 | # A character vector with one of these strings: "var IN (level set)", 339 | # or "var NOT IN (level set)" 340 | if (is.null(x.levels)) { 341 | split.levels.str <- paste(split$levels, collapse=", ") 342 | } else { 343 | # Substitute level-code... locate var's possible values 344 | var.levels <- NULL 345 | for (iVar in 1:length(x.levels)) { 346 | if (x.levels[[iVar]]$var == split$var) { 347 | var.levels <- x.levels[[iVar]]$levels 348 | } 349 | } 350 | if (is.null(var.levels)) { 351 | error(logger, paste("CategSplit2Char: Failed to find level data for: '", split$var, "'")) 352 | } 353 | 354 | # Replace each level-code by corresponding level-string 355 | for (iLevel in 1:length(split$levels)) { 356 | if (is.na(split$levels[iLevel])) { 357 | level.str <- NA 358 | } else { 359 | level.str <- var.levels[split$levels[iLevel]] 360 | # Is this a factor with recoded levels? 361 | if (!is.null(x.levels.lowcount)) { 362 | i.recoded <- grep(paste("^", split$var, "$", sep=""), x.levels.lowcount$var, perl=T) 363 | if (length(i.recoded) == 1 && level.str == kLowCountLevelsName) { 364 | low.count.levels <- unlist(x.levels.lowcount$levels[i.recoded]) 365 | level.str <- paste(lapply(low.count.levels, sQuote), collapse=",") 366 | } else { 367 | level.str <- sQuote(level.str) 368 | } 369 | } else { 370 | level.str <- sQuote(level.str) 371 | } 372 | } 373 | if (iLevel == 1) { 374 | split.levels.str <- level.str 375 | } else { 376 | split.levels.str <- paste(split.levels.str, level.str, sep = ",") 377 | } 378 | } 379 | } 380 | 381 | if (split$cond == kCondIN) { 382 | str <- paste(split$var, "IN", "(", split.levels.str, ")") 383 | } else if (split$cond == kCondNotIN) { 384 | str <- paste(split$var, "NOT IN", "(", split.levels.str, ")") 385 | } else { 386 | error(logger, paste("CategSplit2Char: don't know how to print split: ", split)) 387 | } 388 | 389 | return(str) 390 | } 391 | 392 | # ----------------------------------------------------------------------- 393 | splits <- rule$splits 394 | nSplits <- length(splits) 395 | stopifnot(nSplits > 0) 396 | 397 | # Build string representation of the rule: conjunction of splits 398 | splitStr <- "" 399 | for (iSplit in 1:nSplits) { 400 | split <- splits[[iSplit]] 401 | if (split$type == kSplitTypeContinuous) { 402 | if (iSplit == 1) { 403 | splitStr <- ContSplit2Char(split) 404 | } else { 405 | splitStr <- paste(splitStr, kSplit.and.str, ContSplit2Char(split)) 406 | } 407 | } else if (split$type == kSplitTypeCategorical) { 408 | if (iSplit == 1) { 409 | splitStr <- CategSplit2Char(split, x.levels, x.levels.lowcount) 410 | } else { 411 | splitStr <- paste(splitStr, kSplit.and.str, CategSplit2Char(split, x.levels, x.levels.lowcount)) 412 | } 413 | } else { 414 | error(logger, paste("SplitRule2Char: unknown split type: ", split$type)) 415 | } 416 | } 417 | 418 | options("useFancyQuotes" = old.o) 419 | return(splitStr) 420 | } 421 | 422 | LinearRule2Char <- function(rule) 423 | { 424 | # Turns a 'linear' rule into a human 'readable' string. 425 | # 426 | # Args: 427 | # rule : list of the form 428 | # 429 | # Returns: 430 | # A character vector with just the variable name 431 | if (length(rule) == 0) { 432 | error(logger, "LinearRule2Char: 'rule' must not be empty") 433 | } 434 | if (rule$type != "linear") { 435 | error(logger, paste("LinearRule2Char: unexpected rule type: ", rule$type)) 436 | } 437 | # Simply return var name 438 | return(rule$var) 439 | } 440 | 441 | ReadLevels <- function(file) 442 | { 443 | # Parse and return a list of "levels" for each categorical variable. 444 | # 445 | # Args: 446 | # file - a file name or an open file or connection 447 | # Returns: 448 | # A list of pairs 449 | if (is.character(file)) { 450 | file <- file(file, "r") 451 | on.exit(close(file)) 452 | } 453 | if (!inherits(file, "connection")) 454 | error(logger, "ReadLevels: argument `file' must be a character string or connection") 455 | 456 | if (!isOpen(file)) { 457 | open(file, "r") 458 | on.exit(close(file)) 459 | } 460 | 461 | # Read in level info 462 | levels <- vector(mode = "list") 463 | iVar <- 1 464 | 465 | while (TRUE) { 466 | ## Read one line, which has one of these two forms: 467 | ## varname 468 | ## varname, level1, level2, ..., leveln 469 | ## (the varname and levels are optionally quoted) 470 | v <- scan(file, what = character(), sep = ",", nlines = 1, quiet = TRUE) 471 | if (length(v) == 0) break 472 | if (length(v) == 1) { 473 | levels[[iVar]] <- list(var = v, levels = NULL) 474 | } else { 475 | levels[[iVar]] <- list(var = v[1], levels = v[-1]) 476 | } 477 | iVar <- iVar + 1 478 | } 479 | return(levels) 480 | } 481 | 482 | PrintRules <- function(rules, x.levels.fname = "", x.levels.lowcount.fname = "", file = "") 483 | { 484 | # Outputs a 'readable' version of the given RuleFit rules. 485 | # 486 | # Args: 487 | # rules : list of tuples, 488 | # as generated by the ReadRules() function 489 | # x.levels.fname: (optional) - text file with pairs 490 | # x.levels.lowcount.fname: (optional) - text file with pairs 491 | # file: connection, or a character string naming the file to print 492 | # to; if "" (the default), prints to the standard output; if 493 | # NULL, prints to a data.frame 494 | # Returns: 495 | # None, or a data.frame with cols 496 | stopifnot(length(rules) > 0) 497 | nRules <- length(rules) 498 | 499 | # Were we given data to translate categorical split levels? 500 | x.levels <- NULL 501 | x.levels.lowcount <- NULL 502 | if (nchar(x.levels.fname) > 0) { 503 | x.levels <- ReadLevels(x.levels.fname) 504 | if (nchar(x.levels.lowcount.fname) > 0) { 505 | x.levels.lowcount <- as.data.frame(do.call("rbind", ReadLevels(x.levels.lowcount.fname))) 506 | } 507 | } 508 | 509 | if (is.null(file)) { 510 | # Print to a data-frame instead 511 | out.df <- data.frame(type = rep(NA, nRules), supp.std = rep(NA, nRules), 512 | coeff = rep(NA, nRules), importance = rep(NA, nRules), 513 | def = rep(NA, nRules)) 514 | } 515 | 516 | # Print one rule at a time according to type 517 | for (iRule in 1:nRules) { 518 | rule <- rules[[iRule]] 519 | if (is.null(file)) { 520 | out.df$type[iRule] <- rule$type 521 | out.df$coeff[iRule] <- rule$coeff 522 | out.df$importance[iRule] <- rule$imp 523 | } 524 | if (rule$type == kRuleTypeLinear) { 525 | ruleStr <- LinearRule2Char(rule) 526 | if (is.null(file)) { 527 | out.df$supp.std[iRule] <- rule$std 528 | out.df$def[iRule] <- ruleStr 529 | } else { 530 | cat(ruleStr, "\n", file = file, append = T) 531 | } 532 | } else if (rule$type == kRuleTypeSplit) { 533 | ruleStr <- SplitRule2Char(rule, x.levels, x.levels.lowcount) 534 | if (is.null(file)) { 535 | out.df$supp.std[iRule] <- rule$supp 536 | out.df$def[iRule] <- ruleStr 537 | } else { 538 | cat(ruleStr, "\n", file = file, append = T) 539 | } 540 | } else { 541 | error(logger, paste("PrintRules: unknown rule type: ", rule$type)) 542 | } 543 | } 544 | 545 | if (is.null(file)) { 546 | return(out.df) 547 | } 548 | } 549 | -------------------------------------------------------------------------------- /src/rfPardep_main.R: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/Rscript 2 | 3 | ############################################################################### 4 | # FILE: rfPardep_main.R 5 | # 6 | # USAGE: rfPardep_main.R -c PARDEP.conf 7 | # 8 | # DESCRIPTION: 9 | # Computes single-variable and two-variable partial dependence plots, and 10 | # two-variable interaction statistics from a fitted RuleFit model. 11 | # 12 | # ARGUMENTS: 13 | # PARDEP.conf: specifies path to previously built RuleFit model and plot 14 | # control parameters 15 | # 16 | # REQUIRES: 17 | # REGO_HOME: environment variable pointing to the directory where you 18 | # have placed this file (and its companion ones) 19 | # RF_HOME: environment variable pointing to appropriate RuleFit 20 | # executable -- e.g., export RF_HOME=$REGO_HOME/lib/RuleFit/mac 21 | # 22 | # AUTHOR: Giovanni Seni 23 | ############################################################################### 24 | REGO_HOME <- Sys.getenv("REGO_HOME") 25 | source(file.path(REGO_HOME, "/src/logger.R")) 26 | source(file.path(REGO_HOME, "/src/rfExport.R")) 27 | source(file.path(REGO_HOME, "/src/rfRulesIO.R")) 28 | source(file.path(REGO_HOME, "/lib/RuleFit", ifelse(nchar(Sys.getenv("RF_API")) > 0, Sys.getenv("RF_API"), "rulefit.r"))) 29 | source(file.path(REGO_HOME, "/src/rfGraphics.R")) 30 | library(ROCR, verbose = FALSE, quietly=TRUE, warn.conflicts = FALSE) 31 | library(getopt) 32 | 33 | ValidateConfigArgs <- function(conf) 34 | { 35 | # Validates and initializes configuration parameters. 36 | # 37 | # Args: 38 | # conf: A list of pairs 39 | # Returns: 40 | # A list of pairs 41 | stopifnot("var.name" %in% names(conf)) 42 | stopifnot("out.path" %in% names(conf)) 43 | 44 | ## Check Model path 45 | stopifnot("model.path" %in% names(conf)) 46 | if (!(file.exists(conf$model.path))) { 47 | stop("Didn't find model directory:", conf$model.path, "\n") 48 | } 49 | 50 | ## Number of observations to include in averaging calculation 51 | if (!("num.obs" %in% names(conf))) { 52 | conf$num.obs <- 500 53 | } else { 54 | conf$num.obs <- as.numeric(conf$num.obs) 55 | } 56 | 57 | ## Number of distinct variable evaluation points 58 | if (!("var.num.values" %in% names(conf))) { 59 | conf$var.num.values <- 200 60 | } else { 61 | conf$var.num.values <- as.numeric(conf$var.num.values) 62 | } 63 | 64 | ## Trim extreme values of variable 65 | if (!("var.trim.qntl" %in% names(conf))) { 66 | conf$var.trim.qntl <- 0.025 67 | } else { 68 | conf$var.trim.qntl <- as.numeric(conf$var.trim.qntl) 69 | } 70 | 71 | ## Rug quantile to show numeric variable data density 72 | if (!("var.rug.qntl" %in% names(conf))) { 73 | conf$var.rug.qntl <- 0.1 74 | } else { 75 | conf$var.rug.qntl <- as.numeric(conf$var.rug.qntl) 76 | } 77 | 78 | ## Text orientation of level names (for categorical variable) 79 | if (!("var.levels.las" %in% names(conf))) { 80 | conf$var.levels.las <- 1 81 | } else { 82 | conf$var.levels.las <- as.numeric(conf$var.levels.las) 83 | } 84 | 85 | ## Show partial dependence distribution 86 | if (!("show.pdep.dist" %in% names(conf))) { 87 | conf$show.pdep.dist <- FALSE 88 | } else { 89 | conf$show.pdep.dist <- (as.numeric(conf$show.pdep.dist) == 1) 90 | } 91 | 92 | ## Show partial dependence mean 93 | if (!("show.yhat.mean" %in% names(conf))) { 94 | conf$show.yhat.mean <- FALSE 95 | } else { 96 | conf$show.yhat.mean <- (as.numeric(conf$show.yhat.mean) == 1) 97 | } 98 | 99 | ## This determines how far the whiskers of a categorical variable extend out from 100 | ## the boxplot's box (this value times interquartile range gives whiskers range) 101 | if (!("var.boxplot.range" %in% names(conf$var.boxplot.range))) { 102 | conf$var.boxplot.range <- 1.0e-4 103 | } else { 104 | conf$var.boxplot.range <- as.numeric(conf$var.boxplot.range) 105 | } 106 | 107 | ## Auto-detect the platform: 108 | conf$rf.platform <- switch(.Platform$OS.type 109 | , windows = "windows" 110 | , unix = switch(Sys.info()["sysname"] 111 | , Linux = "linux" 112 | , Darwin = "mac")) 113 | 114 | if (is.null(conf$rf.platform)) error(logger, "Unable to detect platform") 115 | 116 | ## Did user specified a log level? 117 | if (is.null(conf$log.level)) { 118 | conf$log.level <- kLogLevelINFO 119 | } else { 120 | conf$log.level <- get(conf$log.level) 121 | } 122 | 123 | ## Did user specified an output file? 124 | if (is.null(conf$out.fname)) { 125 | conf$out.fname <- paste(conf$var.name, ".PNG", sep="") 126 | } 127 | 128 | return(conf) 129 | } 130 | 131 | ValidateCmdArgs <- function(opt, args.m) 132 | { 133 | # Parses and validates command line arguments. 134 | # 135 | # Args: 136 | # opt: getopt() object 137 | # args.m: valid arguments spec passed to getopt(). 138 | # 139 | # Returns: 140 | # A list of pairs 141 | kUsageString <- "/path/to/rfPardep_main.R -c " 142 | 143 | # Validate command line arguments 144 | if ( !is.null(opt$help) || is.null(opt$conf) ) { 145 | self <- commandArgs()[1] 146 | cat("Usage: ", kUsageString, "\n") 147 | q(status=1); 148 | } 149 | 150 | # Read config file (two columns assumed: 'param' and 'value') 151 | tmp <- read.table(opt$conf, header=T, as.is=T) 152 | conf <- as.list(tmp$value) 153 | names(conf) <- tmp$param 154 | conf <- ValidateConfigArgs(conf) 155 | 156 | # Do we have a log file name? "" will send messages to stdout 157 | if (is.null(opt$log)) { 158 | opt$log <- "" 159 | } 160 | conf$log.fname <- opt$log 161 | 162 | return(conf) 163 | } 164 | 165 | ComputeSinglePDep <- function(var, x, max.var.vals=40, sample.size=500, qntl=0.025, var.levels=NULL) 166 | { 167 | # Compute single variable partial dependence plot. 168 | # 169 | # Args: 170 | # var: variable identifier 171 | # x: training data 172 | # max.var.vals: maximum number of abscissa evaluation points for numeric variables 173 | # sample.size: number of observations used for averaging calculations 174 | # qntl: trimming factor for plotting numeric variables 175 | # var.levels: var level names (if var is categorical) 176 | 177 | ## Get evaluation points 178 | if (is.factor(x[,var]) || length(var.levels) > 0) { 179 | ## Use all levels for categorical variables 180 | is.fact <- TRUE 181 | var.vals <- sort(unique(x[,var])) 182 | } else { 183 | ## Use 'max.var.vals' percentiles for numeric variables 184 | is.fact <- FALSE 185 | var.vals <- quantile(x[,var], na.rm=T, probs=seq(qntl, 1-qntl, 1.0/max.var.vals)) 186 | } 187 | num.var.vals <- length(var.vals) 188 | par.dep <- rep(0, num.var.vals) 189 | 190 | ## Get random sample of observations (to speed things up) 191 | x.sample <- x[sample(1:nrow(x), size=sample.size),] 192 | 193 | ## Compute partial dependence over selected random sample 194 | y.hat.m <- matrix(nrow=sample.size, ncol=num.var.vals) 195 | for (i.var.val in 1:num.var.vals) { 196 | ## Hold x[,var] constant 197 | x.sample[,var] <- var.vals[i.var.val] 198 | ## Compute y.hat 199 | y.hat.m[,i.var.val] <- rfpred(x.sample) 200 | ## Compute avg(y.hat) 201 | par.dep[i.var.val] <- mean(y.hat.m[,i.var.val]) 202 | } 203 | 204 | return(list(var.vals = var.vals, par.dep = par.dep, y.hat.m = y.hat.m)) 205 | } 206 | 207 | PlotSinglePDep <- function(var, x, max.var.vals=40, sample.size=500, qntl=0.025, rugqnt=0.1, main="", 208 | var.levels=NULL, var.boxplot.range=1.0e-4, var.levels.las=1, shift.pdep=FALSE, 209 | show.pdep.dist=FALSE, show.yhat.mean=FALSE) 210 | { 211 | # Generates single variable partial dependence plot. 212 | # 213 | # Args: 214 | # var: variable identifier 215 | # x: training data 216 | # max.var.vals: maximum number of abscissa evaluation points for numeric variables 217 | # sample.size: number of observations used for averaging calculations 218 | # qntl: trimming factor for plotting numeric variables 219 | # rugqnt: quantile for data density tick marks on numeric variables 220 | # var.levels: var level names (if var is categorical) 221 | # var.boxplot.range: 'range' parameter to boxplot 222 | # var.levels.las: orientation of var level names (if var is categorical) 223 | # shift.pdep: shift pdep so that min is zero 224 | # show.pdep.dist: show partial dependence distribution 225 | # show.yhat.mean: show line indicating avg(yHat) 226 | 227 | ## Compute partial dependence 228 | pdep <- ComputeSinglePDep(var, x, max.var.vals, sample.size, qntl, var.levels) 229 | num.var.vals <- length(pdep$var.vals) 230 | 231 | ## Do we want to "shift" the partial dependence? 232 | if (shift.pdep && show.pdep.dist == FALSE) { 233 | pdep$par.dep <- pdep$par.dep - min(pdep$par.dep) 234 | } 235 | 236 | ## Do we want to show global mean(yhat)? 237 | if (show.yhat.mean) { 238 | mean.y.hat <- mean(pdep$y.hat.m) 239 | } 240 | 241 | ## Is var a factor? 242 | if (is.factor(x[,var]) || length(var.levels) > 0) { 243 | is.fact <- TRUE 244 | } else { 245 | is.fact <- FALSE 246 | } 247 | 248 | ## Plot partial dependence according to variable type 249 | if (is.fact) { 250 | if (length(var.levels) > 0) { 251 | var.names <- var.levels 252 | } else { 253 | var.names <- pdep$var.vals 254 | } 255 | ## Adjust bar widths according to var's density 256 | bar.widths <- table(x[,var])/length(x[,var]) 257 | if (show.pdep.dist) { 258 | boxplot(pdep$y.hat.m, names=var.names, width=bar.widths, 259 | xlab=var, ylab='Partial dependence', 260 | outline=FALSE, range=var.boxplot.range) 261 | points(1:num.var.vals, pdep$par.dep, col='blue') ## indicate averages 262 | if (show.yhat.mean) { 263 | lines(x=c(0, ncol(pdep$y.hat.m)+1), y=c(mean.y.hat, mean.y.hat), col='blue', lty=3) 264 | } 265 | } else { 266 | barplot(pdep$par.dep, names=var.names, width=bar.widths, 267 | xlab=var, ylab='Partial dependence', cex.names=0.75, las=var.levels.las) 268 | } 269 | } else { 270 | if (show.pdep.dist) { 271 | ## Add boxplot stats (lower whisker, lower hinge, median, upper hinge, upper whisker) 272 | ## where 273 | ## lower whisker = max(min(x), Q_1 - 1.5 * IQR) 274 | ## upper whisker = min(max(x), Q_3 + 1.5 * IQR) 275 | var.vals.bp <- boxplot(pdep$y.hat.m, at = pdep$var.vals, plot = F) 276 | ymin <- min(var.vals.bp$stats) 277 | ymax <- max(var.vals.bp$stats) 278 | plot(c(min(pdep$var.vals), max(pdep$var.vals)), c(0,0), ylim=c(ymin, ymax), xlab=var, ylab='Partial dependence', 279 | type='l', lty=3, col='white') 280 | matlines(pdep$var.vals, t(var.vals.bp$stats), lty=c(3,2,1,2,3), col=c('black','black','white','black','black')) 281 | lines(pdep$var.vals, pdep$par.dep, type='l', col='red') 282 | if (show.yhat.mean) { 283 | lines(c(min(pdep$var.vals), max(pdep$var.vals)), c(mean.y.hat, mean.y.hat), col='blue', lty=3) 284 | } 285 | } else { 286 | ymin <- min(pdep$par.dep) 287 | ymax <- max(pdep$par.dep) 288 | plot(c(min(pdep$var.vals), max(pdep$var.vals)), c(0,0), ylim=c(ymin, ymax), xlab=var, ylab='Partial dependence', type='l', lty=3) 289 | lines(pdep$var.vals, pdep$par.dep, type='l', col='red') 290 | } 291 | ## Add var's density rug at top 292 | axis(3, quantile(x[,var], probs=seq(qntl, 1-qntl, rugqnt)), labels=F) 293 | } 294 | title(main) 295 | } 296 | 297 | ComputePairPDep <- function(var1, var2, x, max.var.vals=40, sample.size=500, qntl=0.025, var1.levels=NULL, 298 | var2.levels=NULL) 299 | { 300 | # Compute two variable partial dependence plot. 301 | # 302 | # Args: 303 | # var1: variable identifier for the first variable to be plotted 304 | # var2: variable identifier for the second variable to be plotted 305 | # x: training data 306 | # max.var.vals: maximum number of abscissa evaluation points for numeric variables 307 | # sample.size: number of observations used for averaging calculations 308 | # qntl: trimming factor for plotting numeric variables 309 | # var1.levels: var1 level names (if var1 is categorical) 310 | # var2.levels: var2 level names (if var2 is categorical) 311 | 312 | ## Get evaluation points 313 | if (is.factor(x[,var1]) || length(var1.levels) > 0) { 314 | ## Use all levels for categorical variables 315 | is.var1.fact <- TRUE 316 | var1.vals <- sort(unique(x[,var1])) 317 | } else { 318 | ## Use 'max.var.vals' percentiles for numeric variables 319 | is.var1.fact <- FALSE 320 | var1.vals <- quantile(x[,var1], na.rm=T, probs=seq(qntl, 1-qntl, 1.0/max.var.vals)) 321 | } 322 | if (is.factor(x[,var2]) || length(var2.levels) > 0) { 323 | ## Use all levels for categorical variables 324 | is.var2.fact <- TRUE 325 | var2.vals <- sort(unique(x[,var2])) 326 | } else { 327 | ## Use 'max.var.vals' percentiles for numeric variables 328 | is.var2.fact <- FALSE 329 | var2.vals <- quantile(x[,var2], na.rm=T, probs=seq(qntl, 1-qntl, 1.0/max.var.vals)) 330 | } 331 | num.var1.vals <- length(var1.vals) 332 | num.var2.vals <- length(var2.vals) 333 | par.dep <- matrix(0, nrow = num.var1.vals, ncol = num.var2.vals) 334 | 335 | ## Get random sample of observations (to speed things up) 336 | x.sample <- x[sample(1:nrow(x), size=sample.size),] 337 | 338 | ## Compute partial dependence over selected random sample 339 | y.hat.a <- array(dim=c(sample.size, num.var1.vals, num.var2.vals)) 340 | for (i.var1.val in 1:num.var1.vals) { 341 | for (i.var2.val in 1:num.var2.vals) { 342 | ## Hold x[,var] constant 343 | x.sample[,var1] <- var1.vals[i.var1.val] 344 | x.sample[,var2] <- var2.vals[i.var2.val] 345 | ## Compute y.hat 346 | y.hat.a[, i.var1.val, i.var2.val] <- rfpred(x.sample) 347 | ## Compute avg(y.hat) 348 | par.dep[i.var1.val, i.var2.val] <- mean(y.hat.a[, i.var1.val, i.var2.val]) 349 | } 350 | } 351 | 352 | return(list(var1.vals = var1.vals, var2.vals = var2.vals, par.dep = par.dep, y.hat.m = y.hat.a)) 353 | } 354 | 355 | PlotPairPDep <- function(var1, var2, x, max.var.vals=40, sample.size=500, qntl=0.025, rugqnt=0.1, main="", 356 | var1.levels=NULL, var2.levels=NULL, var.boxplot.range=1.0e-4, var.levels.las=1, 357 | shift.pdep=FALSE, show.pdep.dist=FALSE, show.yhat.mean=FALSE) 358 | { 359 | # Generates two variable partial dependence plot. 360 | # 361 | # Args: 362 | # var1: variable identifier for the first variable to be plotted 363 | # var2: variable identifier for the second variable to be plotted 364 | # x: training data 365 | # max.var.vals: maximum number of abscissa evaluation points for numeric variables 366 | # sample.size: number of observations used for averaging calculations 367 | # qntl: trimming factor for plotting numeric variables 368 | # rugqnt: quantile for data density tick marks on numeric variables 369 | # var1.levels: var1 level names (if var1 is categorical) 370 | # var2.levels: var2 level names (if var2 is categorical) 371 | # var.boxplot.range: 'range' parameter to boxplot 372 | # var.levels.las: orientation of var level names (if var is categorical) 373 | # shift.pdep: shift pdep so that min is zero 374 | # show.pdep.dist: show partial dependence distribution 375 | # show.yhat.mean: show line indicating avg(yHat) 376 | 377 | ## Compute partial dependence 378 | pdep <- ComputePairPDep(var1, var2, x, max.var.vals, sample.size, qntl, var1.levels, var2.levels) 379 | num.var1.vals <- length(pdep$var1.vals) 380 | num.var2.vals <- length(pdep$var2.vals) 381 | 382 | ## Do we want to "shift" the partial dependence? 383 | pdep.lim <- NULL 384 | if (shift.pdep && show.pdep.dist == FALSE) { 385 | for (i.var1.val in 1:num.var1.vals) { 386 | pdep$par.dep[i.var1.val,] <- pdep$par.dep[i.var1.val,] - min(pdep$par.dep[i.var1.val,]) 387 | } 388 | pdep.lim <- c(0, max(pdep$par.dep)) 389 | } 390 | 391 | ## Is either var a factor? 392 | if (is.factor(x[,var1]) || length(var1.levels) > 0) { 393 | is.var1.fact <- TRUE 394 | } else { 395 | is.var1.fact <- FALSE 396 | } 397 | if (is.factor(x[,var2]) || length(var2.levels) > 0) { 398 | is.var2.fact <- TRUE 399 | } else { 400 | is.var2.fact <- FALSE 401 | } 402 | 403 | ## Plot partial dependence according to variable type 404 | ## Figures arranged in nrow rows and 2 columns 405 | if (is.var1.fact) { 406 | nrow <- ceiling(num.var1.vals/2) 407 | oldparams <- par(mfrow=c(nrow, 2)) 408 | if (is.var2.fact) { 409 | ## Both variables are factors 410 | if (length(var2.levels) > 0) { 411 | var2.names <- var2.levels 412 | } else { 413 | var2.names <- pdep$var2.vals 414 | } 415 | for (i.var1.val in 1:num.var1.vals) { 416 | barplot(pdep$par.dep[i.var1.val,], names = var2.names, 417 | xlab = var2, ylab = 'Partial dependence', 418 | main = paste0(var1, " = ", ifelse(length(var1.levels)>0, var1.levels[i.var1.val], as.character(pdep$var1.vals[i.var1.val]))), 419 | ylim = pdep.lim, cex.names = 0.75, las = var.levels.las) 420 | } 421 | } else { 422 | ## Var1 is a factor; var2 is continuous 423 | for (i.var1.val in 1:num.var1.vals) { 424 | ymin <- min(pdep$par.dep[i.var1.val,]) 425 | ymax <- max(pdep$par.dep[i.var1.val,]) 426 | plot(c(min(pdep$var2.vals), max(pdep$var2.vals)), c(0,0), ylim = c(ymin, ymax), 427 | main = paste0(var1, " = ", ifelse(length(var1.levels)>0, var1.levels[i.var1.val], as.character(pdep$var1.vals[i.var1.val]))), 428 | xlab = var2, ylab = 'Partial dependence', type = 'l', lty = 3) 429 | lines(pdep$var2.vals, pdep$par.dep[i.var1.val,], type='l', col='red') 430 | } 431 | } 432 | par(oldparams) 433 | } else { 434 | if (is.var2.fact) { 435 | ## Var1 is continuous; var2 is a factor 436 | nrow <- ceiling(num.var2.vals/2) 437 | oldparams <- par(mfrow=c(nrow, 2)) 438 | if (length(var2.levels) > 0) { 439 | var2.names <- var2.levels 440 | } else { 441 | var2.names <- pdep$var2.vals 442 | } 443 | for (i.var2.val in 1:num.var2.vals) { 444 | ymin <- min(pdep$par.dep[,i.var2.val]) 445 | ymax <- max(pdep$par.dep[,i.var2.val]) 446 | plot(c(min(pdep$var1.vals), max(pdep$var1.vals)), c(0,0), ylim = c(ymin, ymax), 447 | main = paste0(var2, " = ", ifelse(length(var2.levels)>0, var2.levels[i.var2.val], as.character(pdep$var2.vals[i.var2.val]))), 448 | xlab = var1, ylab = 'Partial dependence', type = 'l', lty = 3) 449 | lines(pdep$var1.vals, pdep$par.dep[,i.var2.val], type = 'l', col = 'red') 450 | } 451 | par(oldparams) 452 | } else { 453 | persp(pdep$par.dep, xlab = var1, ylab = var2, zlab = 'Partial dependence', 454 | col = 'blue', ticktype = 'detailed', shade = 0.5, ltheta = 30, lphi = 15) 455 | } 456 | } 457 | title(main) 458 | } 459 | 460 | PairInteract <- function(var1, var2, x, var1.levels=NULL, var2.levels=NULL, seed=135711) 461 | { 462 | # Computes two-variable interaction strength. 463 | # 464 | # Args: 465 | # var1: variable identifier for the first variable to be plotted 466 | # var2: variable identifier for the second variable to be plotted 467 | # x: training data 468 | # seed: random number seed (for sampling from x) 469 | 470 | ## Compute the single and pair (centered) partial dependences 471 | ## X1 472 | set.seed(seed) 473 | pdep.var1 <- ComputeSinglePDep(var1, x, var.levels = var1.levels) 474 | pdep.var1$par.dep <- pdep.var1$par.dep - mean(pdep.var1$par.dep) 475 | num.var1.vals <- length(pdep.var1$var.vals) 476 | ## X2 477 | set.seed(seed) 478 | pdep.var2 <- ComputeSinglePDep(var2, x, var.levels = var2.levels) 479 | pdep.var2$par.dep <- pdep.var2$par.dep - mean(pdep.var2$par.dep) 480 | num.var2.vals <- length(pdep.var2$var.vals) 481 | ## X1, X2 482 | set.seed(seed) 483 | pdep.var1.var2 <- ComputePairPDep(var1, var2, x, var1.levels = var1.levels, var2.levels = var2.levels) 484 | pdep.var1.var2$par.dep <- pdep.var1.var2$par.dep - mean(pdep.var1.var2$par.dep) 485 | 486 | ## Compute interaction statistic 487 | accum.num = 0 488 | accum.den = 0 489 | for (i.var1.val in 1:num.var1.vals) { 490 | for (i.var2.val in 1:num.var2.vals) { 491 | accum.num = accum.num + (pdep.var1.var2$par.dep[i.var1.val, i.var2.val] - 492 | pdep.var1$par.dep[i.var1.val] - pdep.var2$par.dep[i.var2.val])^2 493 | accum.den = accum.den + (pdep.var1.var2$par.dep[i.var1.val, i.var2.val])^2 494 | } 495 | } 496 | return(accum.num / accum.den) 497 | } 498 | 499 | PairInteract2 <- function(var1, var2, x, var1.levels=NULL, var2.levels=NULL) 500 | { 501 | # Computes two-variable interaction strength. 502 | # 503 | # Args: 504 | # var1: variable identifier for the first variable to be plotted 505 | # var2: variable identifier for the second variable to be plotted 506 | # x: training data 507 | 508 | ## Compute the centered partial dependences 509 | ## X1 510 | set.seed(1357) 511 | pdep.var1 <- ComputeSinglePDep(var1, x, var.levels = var1.levels) 512 | pdep.var1$par.dep <- pdep.var1$par.dep - mean(pdep.var1$par.dep) 513 | num.var1.vals <- length(pdep.var1$var.vals) 514 | ## X2 515 | set.seed(1357) 516 | pdep.var2 <- ComputeSinglePDep(var2, x, var.levels = var2.levels) 517 | pdep.var2$par.dep <- pdep.var2$par.dep - mean(pdep.var2$par.dep) 518 | num.var2.vals <- length(pdep.var2$var.vals) 519 | ## X1, X2 520 | set.seed(1357) 521 | pdep.var1.var2 <- ComputePairPDep(var1, var2, x, var1.levels = var1.levels, var2.levels = var2.levels) 522 | pdep.var1.var2$par.dep <- pdep.var1.var2$par.dep - mean(pdep.var1.var2$par.dep) 523 | 524 | ## Compute interaction statistic 525 | accum.num = 0 526 | accum.den = 0 527 | for (i in 1:nrow(x)) { 528 | i.var1.val = which(pdep.var1$var.vals == x[i, var1]) 529 | i.var2.val = which(pdep.var2$var.vals == x[i, var2]) 530 | accum.num = accum.num + (pdep.var1.var2$par.dep[i.var1.val, i.var2.val] - 531 | pdep.var1$par.dep[i.var1.val] - pdep.var2$par.dep[i.var2.val])^2 532 | accum.den = accum.den + (pdep.var1.var2$par.dep[i.var1.val, i.var2.val])^2 533 | } 534 | return(accum.num / accum.den) 535 | } 536 | 537 | PlotPairInteract <- function(var1, vars2, x, var1.levels=NULL, vars2.levels=NULL, seed=135711) 538 | { 539 | # Generates plot of two-variable interaction strengths of given variable with selected 540 | # other variables. 541 | # 542 | # Args: 543 | # var1: variable identifier for the target variable to be plotted 544 | # vars2: list of variable identifiers for the other variables to be plotted 545 | # x: training data 546 | # var1.levels: var1 level names (if var1 is categorical) 547 | # vars2.levels: list of var2 level names (if var2 is categorical) 548 | two.var.int <- c() 549 | for (iVar2 in 1:length(vars2)) { 550 | if (length(vars2.levels) > 0) { 551 | var2.levels <- vars2.levels[[iVar2]] 552 | } else { 553 | var2.levels <- NULL 554 | } 555 | two.var.int[iVar2] <- PairInteract(var1, vars2[iVar2], x, var1.levels, var2.levels, seed) 556 | } 557 | names(two.var.int) <- vars2 558 | barplot(two.var.int, ylab = paste0("Interaction strength with ", var1), names = vars2, cex.names = 0.75) 559 | return(two.var.int) 560 | } 561 | 562 | ############## 563 | ## Main 564 | # 565 | 566 | # Grab command-line arguments 567 | args.m <- matrix(c( 568 | 'conf' ,'c', 1, "character", 569 | 'log' ,'l', 1, "character", 570 | 'help' ,'h', 0, "logical" 571 | ), ncol=4,byrow=TRUE) 572 | opt <- getopt(args.m) 573 | conf <- ValidateCmdArgs(opt, args.m) 574 | 575 | # Set global env variables required by RuleFit 576 | platform <- conf$rf.platform 577 | RF_HOME <- Sys.getenv("RF_HOME") 578 | RF_WORKING_DIR <- conf$rf.working.dir 579 | 580 | # Create logging object 581 | logger <- new("logger", log.level = conf$log.level, file.name = conf$log.fname) 582 | info(logger, paste("rfPardep_main args:", 'model.path =', conf$model.path, 583 | ', log.level =', conf$log.level, ', out.fname =', conf$out.fname)) 584 | 585 | # Use own version of png() if necessary: 586 | if (isTRUE(conf$graph.dev == "Bitmap")) { 587 | png <- png_via_bitmap 588 | if (!CheckWorkingPNG(png)) error(logger, "cannot generate PNG graphics") 589 | } else { 590 | png <- GetWorkingPNG() 591 | if (is.null(png)) error(logger, "cannot generate PNG graphics") 592 | } 593 | 594 | # Load & restore model 595 | mod <- LoadModel(conf$model.path) 596 | ok <- 1 597 | tryCatch(rfrestore(mod$rfmod, mod$x, mod$y, mod$wt), error = function(err){ok <<- 0}) 598 | if (ok == 0) { 599 | error(logger, "rfPardep_main.R: got stuck in rfrestore") 600 | } 601 | rf.mode <- ifelse(length(unique(mod$y)) == 2, "class", "regress") 602 | x.levels <- ReadLevels(file.path(conf$model.path, kMod.x.levels.fname)) 603 | var.levels <- Filter(function(x) {x$var == conf$var.name}, x.levels)[[1]]$levels 604 | 605 | # Check desired var is in columns used to build model 606 | if (!(conf$var.name %in% colnames(mod$x))) { 607 | error(logger, paste("rfPardep_main.R: don't know about", conf$var.name)) 608 | } 609 | 610 | # Generate partial dependence plot 611 | kPlotWidth <- 620 612 | kPlotHeight <- 480 613 | png(file = file.path(conf$out.path, conf$out.fname), width = kPlotWidth, height = kPlotHeight) 614 | SinglePlot(conf$var.name, mod$x, var.levels = var.levels, var.levels.las = conf$var.levels.las, 615 | max.var.vals = conf$var.num.values, sample.size = conf$num.obs, 616 | qntl = conf$var.trim.qntl, rugqnt = conf$var.rug.qntl, 617 | show.pdep.dist = conf$show.pdep.dist, var.boxplot.range = conf$var.boxplot.range, 618 | show.yhat.mean = conf$show.yhat.mean) 619 | dev.off() 620 | 621 | q(status=0) 622 | --------------------------------------------------------------------------------