22 | 23 | 24 | -------------------------------------------------------------------------------- /java_code_documentation/package-list: -------------------------------------------------------------------------------- 1 | AlgorithmTesting 2 | CustomLogging 3 | OpenSourceExtensions 4 | bartMachine 5 | -------------------------------------------------------------------------------- /java_code_documentation/resources/background.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/java_code_documentation/resources/background.gif -------------------------------------------------------------------------------- /java_code_documentation/resources/tab.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/java_code_documentation/resources/tab.gif -------------------------------------------------------------------------------- /java_code_documentation/resources/titlebar.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/java_code_documentation/resources/titlebar.gif -------------------------------------------------------------------------------- /java_code_documentation/resources/titlebar_end.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/java_code_documentation/resources/titlebar_end.gif -------------------------------------------------------------------------------- /junit-4.10.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/junit-4.10.jar -------------------------------------------------------------------------------- /simple_example.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## try using built-in jcache 4 | options(java.parameters = "-Xmx1500m") 5 | library(bartMachine) 6 | n = 500 7 | x = 1 : n; y = x + rnorm(n) 8 | bart_machine = build_bart_machine(as.data.frame(x), y, serialize = TRUE) 9 | save.image("test_bart_machine.RData") 10 | q("no") 11 | #close R, open R 12 | R 13 | options(java.parameters = "-Xmx1500m") 14 | library(bartMachine) 15 | load("test_bart_machine.RData") 16 | x = 1 : n 17 | predict(bart_machine, as.data.frame(x)) 18 | 19 | 20 | ##how big is this stuff? 21 | data_names = names(bart_machine) 22 | sizes = matrix(NA, ncol = 1, nrow = length(data_names)) 23 | rownames(sizes) = data_names 24 | for (i in 1 : length(data_names)){ 25 | sizes[i, ] = object.size(bart_machine[[data_names[i]]]) / 1e6 26 | } 27 | t(t(sizes[order(-sizes), ])) 28 | q("no") 29 | 30 | ####ensure no more memory leak 31 | R 32 | options(java.parameters = "-Xmx1000m") 33 | library(bartMachine) 34 | x = 1 : 100; y = x + rnorm(100) 35 | for (i in 1 : 10000){ 36 | bart_machine = build_bart_machine(as.data.frame(x), y) 37 | } 38 | q("no") 39 | 40 | ## If it helps, this may 41 | 42 | #get some data 43 | R 44 | options(java.parameters = "-Xmx1000m") 45 | library(bartMachine) 46 | 47 | library(MASS) 48 | data(Boston) 49 | X = Boston 50 | y = X$medv 51 | X$medv = NULL 52 | 53 | Xtrain = X[1 : (nrow(X) / 2), ] 54 | ytrain = y[1 : (nrow(X) / 2)] 55 | Xtest = X[(nrow(X) / 2 + 1) : nrow(X), ] 56 | ytest = y[(nrow(X) / 2 + 1) : nrow(X)] 57 | 58 | set_bart_machine_num_cores(4) 59 | bart_machine = build_bart_machine(Xtrain, ytrain, 60 | num_trees = 200, 61 | num_burn_in = 300, 62 | num_iterations_after_burn_in = 1000, 63 | use_missing_data = TRUE, 64 | debug_log = TRUE, 65 | verbose = TRUE) 66 | bart_machine 67 | 68 | plot_y_vs_yhat(bart_machine) 69 | 70 | yhat = predict(bart_machine, Xtest) 71 | q("no") 72 | 73 | 74 | options(java.parameters = "-Xmx1500m") 75 | library(bartMachine) 76 | data("Pima.te", package = "MASS") 77 | X <- data.frame(Pima.te[, -8]) 78 | y <- Pima.te[, 8] 79 | bart_machine = bartMachine(X, y) 80 | bart_machine 81 | table(y, predict(bart_machine, X, type = "class")) 82 | 83 | raw_node_data = extract_raw_node_data(bart_machine, g = 37) 84 | raw_node_data[[17]] 85 | 86 | 87 | 88 | 89 | 90 | options(java.parameters = "-Xmx20000m") 91 | library(bartMachine) 92 | set_bart_machine_num_cores(10) 93 | set.seed(1) 94 | data(iris) 95 | iris2 = iris[51 : 150, ] #do not include the third type of flower for this example 96 | iris2$Species = factor(iris2$Species) 97 | X = iris2[ ,1:4] 98 | y = iris2$Species 99 | 100 | 101 | 102 | bart_machine = bartMachine(X, y, num_trees = 50, seed = 1) 103 | bart_machine 104 | ##make probability predictions on the training data 105 | p_hat = predict(bart_machine, iris2[ ,1:4]) 106 | p_hat -------------------------------------------------------------------------------- /src/AlgorithmTesting/DataSetupForCSVFile.java: -------------------------------------------------------------------------------- 1 | /* 2 | BART - Bayesian Additive Regressive Trees 3 | Software for Supervised Statistical Learning 4 | 5 | Copyright (C) 2012 Professor Ed George & Adam Kapelner, 6 | Dept of Statistics, The Wharton School of the University of Pennsylvania 7 | 8 | This program is free software; you can redistribute it and/or modify 9 | it under the terms of the GNU General Public License as published by 10 | the Free Software Foundation; either version 2 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU General Public License for more details: 17 | 18 | http://www.gnu.org/licenses/gpl-2.0.txt 19 | 20 | You should have received a copy of the GNU General Public License along 21 | with this program; if not, write to the Free Software Foundation, Inc., 22 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 23 | */ 24 | 25 | package AlgorithmTesting; 26 | 27 | import java.io.BufferedReader; 28 | import java.io.File; 29 | import java.io.FileReader; 30 | import java.io.IOException; 31 | import java.util.ArrayList; 32 | import java.util.HashSet; 33 | 34 | import bartMachine.Classifier; 35 | 36 | 37 | public class DataSetupForCSVFile { 38 | 39 | private final ArrayList
Level
value in this
36 | \* classloader
37 | \* @throws ObjectStreamException If unable to deserialize
38 | \*/
39 | protected Object readResolve()
40 | throws ObjectStreamException {
41 | if (this.intValue() == STDOUT.intValue())
42 | return STDOUT;
43 | if (this.intValue() == STDERR.intValue())
44 | return STDERR;
45 | throw new InvalidObjectException("Unknown instance :" + this);
46 | }
47 |
48 | }
--------------------------------------------------------------------------------
/src/CustomLogging/SuperSimpleFormatter.java:
--------------------------------------------------------------------------------
1 | package CustomLogging;
2 |
3 | import java.io.*;
4 | import java.util.logging.*;
5 |
6 | /**
7 | * Print a brief summary of the LogRecord in a human readable
8 | * format. The summary will typically be 1 or 2 lines.
9 | *
10 | * @version %I%, %G%
11 | * @since 1.4
12 | */
13 |
14 | public class SuperSimpleFormatter extends Formatter {
15 |
16 | // Date dat = new Date();
17 | // private final static String format = "{0,date} {0,time}";
18 | // private MessageFormat formatter;
19 | //
20 | // private Object args[] = new Object[1];
21 |
22 | // Line separator string. This is the value of the line.separator
23 | // property at the moment that the SimpleFormatter was created.
24 | private String lineSeparator = "\n";
25 |
26 | /**
27 | * Format the given LogRecord.
28 | * @param record the log record to be formatted.
29 | * @return a formatted log record
30 | */
31 | public synchronized String format(LogRecord record) {
32 | StringBuffer sb = new StringBuffer();
33 | // Minimize memory allocations here.
34 | // dat.setTime(record.getMillis());
35 | // args[0] = dat;
36 | // StringBuffer text = new StringBuffer();
37 | // if (formatter == null) {
38 | // formatter = new MessageFormat(format);
39 | // }
40 | // formatter.format(args, text, null);
41 | // sb.append(text);
42 | // sb.append(" ");
43 | if (record.getSourceClassName() != null) {
44 | sb.append(record.getSourceClassName());
45 | } else {
46 | sb.append(record.getLoggerName());
47 | }
48 | if (record.getSourceMethodName() != null) {
49 | sb.append(" ");
50 | sb.append(record.getSourceMethodName());
51 | }
52 | sb.append(lineSeparator);
53 | String message = formatMessage(record);
54 | if (record.getLevel().getLocalizedName().equals("STDERR")){
55 | sb.append("ERROR: ");
56 | }
57 | sb.append(message);
58 | // sb.append(lineSeparator);
59 | if (record.getThrown() != null) {
60 | try {
61 | StringWriter sw = new StringWriter();
62 | PrintWriter pw = new PrintWriter(sw);
63 | record.getThrown().printStackTrace(pw);
64 | pw.close();
65 | sb.append(sw.toString());
66 | } catch (Exception ex) {
67 | }
68 | }
69 | return sb.toString();
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/src/CustomLogging/package-info.java:
--------------------------------------------------------------------------------
1 | /**
2 | * This is for debugging only, please ignore. Thus,
3 | * these methods are not documented.
4 | */
5 | package CustomLogging;
--------------------------------------------------------------------------------
/src/OpenSourceExtensions/UnorderedPair.java:
--------------------------------------------------------------------------------
1 | package OpenSourceExtensions;
2 |
3 | /**
4 | * An unordered immutable pair of elements. Unordered means the pair (a,b) equals the
5 | * pair (b,a). An element of a pair cannot be null.
6 | *
7 | * The class implements {@link Comparable}. Pairs are compared by their smallest elements, and if
8 | * those are equal, they are compared by their maximum elements.
9 | *
10 | * To work correctly, the underlying element type must implement {@link Object#equals} and
11 | * {@link Comparable#compareTo} in a consistent fashion.
12 | */
13 | public final class UnorderedPairclassification_rule
*/
17 | private static double DEFAULT_CLASSIFICATION_RULE = 0.5;
18 | /** The value of the classification rule which if the probability estimate of Y = 1 is greater than, we predict 1 */
19 | private double classification_rule;
20 |
21 | /** Set up an array of binary classification BARTs with length equal to num_cores
, the number of CPU cores requested */
22 | protected void SetupBARTModels() {
23 | bart_gibbs_chain_threads = new bartMachineClassification[num_cores];
24 | for (int t = 0; t < num_cores; t++){
25 | SetupBartModel(new bartMachineClassification(), t);
26 | }
27 | classification_rule = DEFAULT_CLASSIFICATION_RULE;
28 | }
29 |
30 | /**
31 | * Predicts the best guess of the class for an observation
32 | *
33 | * @param record The record who's class we wish to predict
34 | * @param num_cores_evaluate The number of CPU cores to use during this operation
35 | * @return The best guess of the class based on the probability estimate evaluated against the {@link classification_rule}
36 | */
37 | public double Evaluate(double[] record, int num_cores_evaluate) {
38 | return EvaluateViaSampAvg(record, num_cores_evaluate) > classification_rule ? 1 : 0;
39 | }
40 |
41 | /**
42 | * This returns the Gibbs sample predictions for all trees and all posterior samples.
43 | * This differs from the parent implementation because we convert the response value to
44 | * a probability estimate using the normal CDF.
45 | *
46 | * @param data The data for which to generate predictions
47 | * @param num_cores_evaluate The number of CPU cores to use during this operation
48 | * @return The predictions as a vector of size number of posterior samples of vectors of size number of trees
49 | */
50 | protected double[][] getGibbsSamplesForPrediction(double[][] data, int num_cores_evaluate){
51 | double[][] y_gibbs_samples = super.getGibbsSamplesForPrediction(data, num_cores_evaluate);
52 | double[][] y_gibbs_samples_probs = new double[y_gibbs_samples.length][y_gibbs_samples[0].length];
53 | for (int g = 0; g < y_gibbs_samples.length; g++){
54 | for (int i = 0; i < y_gibbs_samples[0].length; i++){
55 | y_gibbs_samples_probs[g][i] = StatUtil.normal_cdf(y_gibbs_samples[g][i]);
56 | }
57 | }
58 | return y_gibbs_samples_probs;
59 | }
60 |
61 | public void setClassificationRule(double classification_rule) {
62 | this.classification_rule = classification_rule;
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/src/bartMachine/bartMachineRegression.java:
--------------------------------------------------------------------------------
1 | package bartMachine;
2 |
3 | import java.io.Serializable;
4 |
5 | /**
6 | * The class that is instantiated to build a Regression BART model
7 | *
8 | * @author Adam Kapelner and Justin Bleich
9 | *
10 | */
11 | @SuppressWarnings("serial")
12 | public class bartMachineRegression extends bartMachine_i_prior_cov_spec implements Serializable{
13 |
14 | /**
15 | * Constructs the BART classifier for regression.
16 | */
17 | public bartMachineRegression() {
18 | super();
19 | }
20 |
21 |
22 | }
23 |
--------------------------------------------------------------------------------
/src/bartMachine/bartMachine_a_base.java:
--------------------------------------------------------------------------------
1 | package bartMachine;
2 |
3 | import java.io.Serializable;
4 |
5 | /**
6 | * The base class for any BART implementation. Contains
7 | * mostly instance variables and settings for the algorithm
8 | *
9 | * @author Adam Kapelner and Justin Bleich
10 | */
11 | @SuppressWarnings("serial")
12 | public abstract class bartMachine_a_base extends Classifier implements Serializable {
13 |
14 | /** all Gibbs samples for burn-in and post burn-in where each entry is a vector of pointers to the num_trees
trees in the sum-of-trees model */
15 | protected bartMachineTreeNode[][] gibbs_samples_of_bart_trees;
16 | /** Gibbs samples post burn-in where each entry is a vector of pointers to the num_trees
trees in the sum-of-trees model */
17 | protected bartMachineTreeNode[][] gibbs_samples_of_bart_trees_after_burn_in;
18 | /** Gibbs samples for burn-in and post burn-in of the variances */
19 | protected double[] gibbs_samples_of_sigsq;
20 | /** Gibbs samples for post burn-in of the variances */
21 | protected double[] gibbs_samples_of_sigsq_after_burn_in;
22 | /** a record of whether each Gibbs sample accepted or rejected the MH step within each of the num_trees
trees */
23 | protected boolean[][] accept_reject_mh;
24 | /** a record of the proposal of each Gibbs sample within each of the m
trees: G, P or C for "grow", "prune", "change" */
25 | protected char[][] accept_reject_mh_steps;
26 |
27 | /** the number of trees in our sum-of-trees model */
28 | protected int num_trees;
29 | /** how many Gibbs samples we burn-in for */
30 | protected int num_gibbs_burn_in;
31 | /** how many total Gibbs samples in a BART model creation */
32 | protected int num_gibbs_total_iterations;
33 |
34 | /** the current thread being used to run this Gibbs sampler */
35 | protected int threadNum;
36 | /** how many CPU cores to use during model creation */
37 | protected int num_cores;
38 | /**
39 | * whether or not we use the memory cache feature
40 | *
41 | * @see Section 3.1 of Kapelner, A and Bleich, J. bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013
42 | */
43 | protected boolean mem_cache_for_speed;
44 | /** saves indices in nodes (useful for computing weights) */
45 | protected boolean flush_indices_to_save_ram;
46 | /** should we print stuff out to screen? */
47 | protected boolean verbose = true;
48 |
49 |
50 |
51 | /** Remove unnecessary data from the Gibbs chain to conserve RAM */
52 | protected void FlushData() {
53 | for (bartMachineTreeNode[] bart_trees : gibbs_samples_of_bart_trees){
54 | FlushDataForSample(bart_trees);
55 | }
56 | }
57 |
58 | /** Remove unnecessary data from an individual Gibbs sample */
59 | protected void FlushDataForSample(bartMachineTreeNode[] bart_trees) {
60 | for (bartMachineTreeNode tree : bart_trees){
61 | tree.flushNodeData();
62 | }
63 | }
64 |
65 | /** Must be implemented, but does nothing */
66 | public void StopBuilding(){}
67 |
68 | public void setThreadNum(int threadNum) {
69 | this.threadNum = threadNum;
70 | }
71 |
72 | public void setVerbose(boolean verbose){
73 | this.verbose = verbose;
74 | }
75 |
76 | public void setTotalNumThreads(int num_cores) {
77 | this.num_cores = num_cores;
78 | }
79 |
80 | public void setMemCacheForSpeed(boolean mem_cache_for_speed){
81 | this.mem_cache_for_speed = mem_cache_for_speed;
82 | }
83 |
84 | public void setFlushIndicesToSaveRAM(boolean flush_indices_to_save_ram) {
85 | this.flush_indices_to_save_ram = flush_indices_to_save_ram;
86 | }
87 |
88 | public void setNumTrees(int m){
89 | this.num_trees = m;
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/src/bartMachine/bartMachine_c_debug.java:
--------------------------------------------------------------------------------
1 | package bartMachine;
2 |
3 | import java.io.Serializable;
4 |
5 | /**
6 | * This portion of the code used to have many debug functions. These have
7 | * been removed during the tidy up for release.
8 | *
9 | * @author Adam Kapelner and Justin Bleich
10 | */
11 | @SuppressWarnings("serial")
12 | public abstract class bartMachine_c_debug extends bartMachine_b_hyperparams implements Serializable{
13 |
14 | /** should we create illustrations of the trees and save the images to the debug directory? */
15 | protected boolean tree_illust = false;
16 |
17 | /** the hook that gets called to save the tree illustrations when the Gibbs sampler begins */
18 | protected void InitTreeIllustrations() {
19 | bartMachineTreeNode[] initial_trees = gibbs_samples_of_bart_trees[0];
20 | TreeArrayIllustration tree_array_illustration = new TreeArrayIllustration(0, unique_name);
21 |
22 | for (bartMachineTreeNode tree : initial_trees){
23 | tree_array_illustration.AddTree(tree);
24 | tree_array_illustration.addLikelihood(0);
25 | }
26 | tree_array_illustration.CreateIllustrationAndSaveImage();
27 | }
28 |
29 | /** the hook that gets called to save the tree illustrations for each Gibbs sample */
30 | protected void illustrate(TreeArrayIllustration tree_array_illustration) {
31 | if (tree_illust){ //
32 | tree_array_illustration.CreateIllustrationAndSaveImage();
33 | }
34 | }
35 |
36 | /**
37 | * Get the untransformed samples of the sigsqs from the Gibbs chaing
38 | *
39 | * @return The vector of untransformed variances over all the Gibbs samples
40 | */
41 | public double[] getGibbsSamplesSigsqs(){
42 | double[] sigsqs_to_export = new double[gibbs_samples_of_sigsq.length];
43 | for (int n_g = 0; n_g < gibbs_samples_of_sigsq.length; n_g++){
44 | sigsqs_to_export[n_g] = un_transform_sigsq(gibbs_samples_of_sigsq[n_g]);
45 | }
46 | return sigsqs_to_export;
47 | }
48 |
49 | /**
50 | * Queries the depths of the num_trees
trees between a range of Gibbs samples
51 | *
52 | * @param n_g_i The Gibbs sample number to start querying
53 | * @param n_g_f The Gibbs sample number (+1) to stop querying
54 | * @return The depths of all num_trees
trees for each Gibbs sample specified
55 | */
56 | public int[][] getDepthsForTrees(int n_g_i, int n_g_f){
57 | int[][] all_depths = new int[n_g_f - n_g_i][num_trees];
58 | for (int g = n_g_i; g < n_g_f; g++){
59 | for (int t = 0; t < num_trees; t++){
60 | all_depths[g - n_g_i][t] = gibbs_samples_of_bart_trees[g][t].deepestNode();
61 | }
62 | }
63 | return all_depths;
64 | }
65 |
66 | /**
67 | * Queries the number of nodes (terminal and non-terminal) in the num_trees
trees between a range of Gibbs samples
68 | *
69 | * @param n_g_i The Gibbs sample number to start querying
70 | * @param n_g_f The Gibbs sample number (+1) to stop querying
71 | * @return The number of nodes of all num_trees
trees for each Gibbs sample specified
72 | */
73 | public int[][] getNumNodesAndLeavesForTrees(int n_g_i, int n_g_f){
74 | int[][] all_new_nodes = new int[n_g_f - n_g_i][num_trees];
75 | for (int g = n_g_i; g < n_g_f; g++){
76 | for (int t = 0; t < num_trees; t++){
77 | all_new_nodes[g - n_g_i][t] = gibbs_samples_of_bart_trees[g][t].numNodesAndLeaves();
78 | }
79 | }
80 | return all_new_nodes;
81 | }
82 |
83 |
84 | }
85 |
--------------------------------------------------------------------------------
/src/bartMachine/bartMachine_d_init.java:
--------------------------------------------------------------------------------
1 | package bartMachine;
2 |
3 | import java.io.Serializable;
4 |
5 | /**
6 | * This portion of the code initializes the Gibbs sampler
7 | *
8 | * @author Adam Kapelner and Justin Bleich
9 | */
10 | @SuppressWarnings("serial")
11 | public abstract class bartMachine_d_init extends bartMachine_c_debug implements Serializable{
12 |
13 | /** during debugging, we may want to fix sigsq */
14 | protected transient double fixed_sigsq;
15 | /** the number of the current Gibbs sample */
16 | protected int gibbs_sample_num;
17 | /** cached current sum of residuals vector */
18 | protected transient double[] sum_resids_vec;
19 |
20 | /** Initializes the Gibbs sampler setting all zero entries and moves the counter to the first sample */
21 | protected void SetupGibbsSampling(){
22 | InitGibbsSamplingData();
23 | InitizializeSigsq();
24 | InitializeTrees();
25 | InitializeMus();
26 | if (tree_illust){
27 | InitTreeIllustrations();
28 | }
29 | //the zeroth gibbs sample is the initialization we just did; now we're onto the first in the chain
30 | gibbs_sample_num = 1;
31 |
32 | sum_resids_vec = new double[n];
33 | }
34 |
35 | /** Initializes the vectors that hold information across the Gibbs sampler */
36 | protected void InitGibbsSamplingData(){
37 | //now initialize the gibbs sampler array for trees and error variances
38 | gibbs_samples_of_bart_trees = new bartMachineTreeNode[num_gibbs_total_iterations + 1][num_trees];
39 | gibbs_samples_of_bart_trees_after_burn_in = new bartMachineTreeNode[num_gibbs_total_iterations - num_gibbs_burn_in + 1][num_trees];
40 | gibbs_samples_of_sigsq = new double[num_gibbs_total_iterations + 1];
41 | gibbs_samples_of_sigsq_after_burn_in = new double[num_gibbs_total_iterations - num_gibbs_burn_in];
42 |
43 | accept_reject_mh = new boolean[num_gibbs_total_iterations + 1][num_trees];
44 | accept_reject_mh_steps = new char[num_gibbs_total_iterations + 1][num_trees];
45 | }
46 |
47 | /** Initializes the tree structures in the zeroth Gibbs sample to be merely stumps */
48 | protected void InitializeTrees() {
49 | //create the array of trees for the zeroth gibbs sample
50 | bartMachineTreeNode[] bart_trees = new bartMachineTreeNode[num_trees];
51 | for (int i = 0; i < num_trees; i++){
52 | bartMachineTreeNode stump = new bartMachineTreeNode(this);
53 | stump.setStumpData(X_y, y_trans, p);
54 | bart_trees[i] = stump;
55 | }
56 | gibbs_samples_of_bart_trees[0] = bart_trees;
57 | }
58 |
59 |
60 | /** Initializes the leaf structure (the mean predictions) by setting them to zero (in the transformed scale, this is the center of the range) */
61 | protected void InitializeMus() {
62 | for (bartMachineTreeNode stump : gibbs_samples_of_bart_trees[0]){
63 | stump.y_pred = 0;
64 | }
65 | }
66 |
67 | /** Initializes the first variance value by drawing from the prior */
68 | protected void InitizializeSigsq() {
69 | gibbs_samples_of_sigsq[0] = StatToolbox.sample_from_inv_gamma(hyper_nu / 2, 2 / (hyper_nu * hyper_lambda));
70 | }
71 |
72 | /** this is the number of posterior Gibbs samples afte burn-in (thinning was never implemented) */
73 | public int numSamplesAfterBurningAndThinning(){
74 | return num_gibbs_total_iterations - num_gibbs_burn_in;
75 | }
76 |
77 | public void setNumGibbsBurnIn(int num_gibbs_burn_in){
78 | this.num_gibbs_burn_in = num_gibbs_burn_in;
79 | }
80 |
81 | public void setNumGibbsTotalIterations(int num_gibbs_total_iterations){
82 | this.num_gibbs_total_iterations = num_gibbs_total_iterations;
83 | }
84 |
85 | public void setSigsq(double fixed_sigsq){
86 | this.fixed_sigsq = fixed_sigsq;
87 | }
88 |
89 | public boolean[][] getAcceptRejectMH(){
90 | return accept_reject_mh;
91 | }
92 | }
93 |
--------------------------------------------------------------------------------
/src/bartMachine/bartMachine_i_prior_cov_spec.java:
--------------------------------------------------------------------------------
1 | package bartMachine;
2 |
3 | import java.io.Serializable;
4 |
5 | import gnu.trove.list.array.TIntArrayList;
6 |
7 | /**
8 | * This portion of the code implements the informed prior information on covariates feature.
9 | *
10 | * @author Adam Kapelner and Justin Bleich
11 | * @see Section 4.10 of Kapelner, A and Bleich, J. bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013
12 | */
13 | @SuppressWarnings("serial")
14 | public class bartMachine_i_prior_cov_spec extends bartMachine_h_eval implements Serializable{
15 |
16 | /** Do we use this feature in this BART model? */
17 | protected boolean use_prior_cov_spec;
18 | /** This is a probability vector which is the prior on which covariates to split instead of the uniform discrete distribution by default */
19 | protected double[] cov_split_prior;
20 |
21 |
22 | /**
23 | * Pick one predictor from a set of valid predictors that can be part of a split rule at a node
24 | * while accounting for the covariate prior.
25 | *
26 | * @param node The node of interest
27 | * @return The index of the column to split on
28 | */
29 | private int pickRandomPredictorThatCanBeAssignedF1(bartMachineTreeNode node){
30 | TIntArrayList predictors = node.predictorsThatCouldBeUsedToSplitAtNode();
31 | //get probs of split prior based on predictors that can be used and weight it accordingly
32 | double[] weighted_cov_split_prior_subset = getWeightedCovSplitPriorSubset(predictors);
33 | //choose predictor based on random prior value
34 | return StatToolbox.multinomial_sample(predictors, weighted_cov_split_prior_subset);
35 | }
36 |
37 | /**
38 | * The prior-adjusted number of covariates available to be split at this node
39 | *
40 | * @param node The node of interest
41 | * @return The prior-adjusted number of covariates that can be split
42 | */
43 | private double pAdjF1(bartMachineTreeNode node) {
44 | if (node.padj == null){
45 | node.padj = node.predictorsThatCouldBeUsedToSplitAtNode().size();
46 | }
47 | if (node.padj == 0){
48 | return 0;
49 | }
50 | if (node.isLeaf){
51 | return node.padj;
52 | }
53 | //pull out weighted cov split prior subset vector
54 | TIntArrayList predictors = node.predictorsThatCouldBeUsedToSplitAtNode();
55 | //get probs of split prior based on predictors that can be used and weight it accordingly
56 | double[] weighted_cov_split_prior_subset = getWeightedCovSplitPriorSubset(predictors);
57 |
58 | //find index inside predictor vector
59 | int index = bartMachineTreeNode.BAD_FLAG_int;
60 | for (int i = 0; i < predictors.size(); i++){
61 | if (predictors.get(i) == node.splitAttributeM){
62 | index = i;
63 | break;
64 | }
65 | }
66 |
67 | //return inverse probability
68 | return 1 / weighted_cov_split_prior_subset[index];
69 | }
70 |
71 | /**
72 | * Given a set of valid predictors return the probability vector that corresponds to the
73 | * elements of cov_split_prior
re-normalized because some entries may be deleted
74 | *
75 | * @param predictors The indices of the valid covariates
76 | * @return The updated and renormalized prior probability vector on the covariates to split
77 | */
78 | private double[] getWeightedCovSplitPriorSubset(TIntArrayList predictors) {
79 | double[] weighted_cov_split_prior_subset = new double[predictors.size()];
80 | for (int i = 0; i < predictors.size(); i++){
81 | weighted_cov_split_prior_subset[i] = cov_split_prior[predictors.get(i)];
82 | }
83 | Tools.normalize_array(weighted_cov_split_prior_subset);
84 | return weighted_cov_split_prior_subset;
85 | }
86 |
87 | public void setCovSplitPrior(double[] cov_split_prior) {
88 | this.cov_split_prior = cov_split_prior;
89 | //if we're setting the vector, we're using this feature
90 | use_prior_cov_spec = true;
91 | }
92 |
93 | /////////////nothing but scaffold code below, do not alter!
94 |
95 | public int pickRandomPredictorThatCanBeAssigned(bartMachineTreeNode node){
96 | if (use_prior_cov_spec){
97 | return pickRandomPredictorThatCanBeAssignedF1(node);
98 | }
99 | return super.pickRandomPredictorThatCanBeAssigned(node);
100 | }
101 |
102 | public double pAdj(bartMachineTreeNode node){
103 | if (use_prior_cov_spec){
104 | return pAdjF1(node);
105 | }
106 | return super.pAdj(node);
107 | }
108 |
109 | }
110 |
--------------------------------------------------------------------------------
/src/bartMachine/package-info.java:
--------------------------------------------------------------------------------
1 | /**
2 | * The code for the implementation of the BART algorithm as described in detail in:
3 | *
4 | * @see Kapelner, A and Bleich, J. bartMachine: A Powerful Tool for Machine Learning in R, ArXiv e-prints, 2013
5 | * @since 1.0
6 | * @author Adam Kapelner and Justin Bleich
7 | */
8 | package bartMachine;
--------------------------------------------------------------------------------
/trove-3.0.3.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/trove-3.0.3.jar
--------------------------------------------------------------------------------