├── .classpath ├── .gitignore ├── .project ├── LICENSE ├── README.md ├── bartMachine ├── CHANGELOG ├── DESCRIPTION ├── NAMESPACE ├── R │ ├── bartMachine.R │ ├── bart_arrays.R │ ├── bart_node_related_methods.R │ ├── bart_package_builders.R │ ├── bart_package_cross_validation.R │ ├── bart_package_data_preprocessing.R │ ├── bart_package_f_tests.R │ ├── bart_package_inits.R │ ├── bart_package_plots.R │ ├── bart_package_predicts.R │ ├── bart_package_summaries.R │ ├── bart_package_variable_selection.R │ └── zzz.R ├── data │ ├── automobile.RData │ └── benchmark_datasets.RData ├── inst │ ├── CITATION │ ├── COPYRIGHTS │ └── java │ │ └── bart_java.jar ├── man │ ├── automobile.Rd │ ├── bartMachine.Rd │ ├── bartMachineArr.Rd │ ├── bartMachineCV.Rd │ ├── bart_machine_get_posterior.Rd │ ├── bart_machine_num_cores.Rd │ ├── bart_predict_for_test_data.Rd │ ├── benchmark_datasets.Rd │ ├── calc_credible_intervals.Rd │ ├── calc_prediction_intervals.Rd │ ├── check_bart_error_assumptions.Rd │ ├── cov_importance_test.Rd │ ├── destroy_bart_machine.Rd │ ├── dummify_data.Rd │ ├── extract_raw_node_data.Rd │ ├── get_projection_weights.Rd │ ├── get_sigsqs.Rd │ ├── get_var_counts_over_chain.Rd │ ├── get_var_props_over_chain.Rd │ ├── interaction_investigator.Rd │ ├── investigate_var_importance.Rd │ ├── k_fold_cv.Rd │ ├── linearity_test.Rd │ ├── node_prediction_training_data_indices.Rd │ ├── pd_plot.Rd │ ├── plot_convergence_diagnostics.Rd │ ├── plot_y_vs_yhat.Rd │ ├── predict.bartMachine.Rd │ ├── predict_bartMachineArr.Rd │ ├── print.bartMachine.Rd │ ├── rmse_by_num_trees.Rd │ ├── set_bart_machine_num_cores.Rd │ ├── summary.bartMachine.Rd │ ├── var_selection_by_permute.Rd │ └── var_selection_by_permute_cv.Rd └── vignettes │ ├── bartMachine.Rnw │ ├── bart_normality_heteroskedasticity_2.pdf │ ├── convergence_diagnostics4.pdf │ ├── cov_test_body_style2.pdf │ ├── cov_test_omnibus2.pdf │ ├── cov_test_top_10_2.pdf │ ├── cov_test_width2.pdf │ ├── covariate_test_age3.pdf │ ├── friedman_function_interactions2.pdf │ ├── glucose_partial_dependence2.pdf │ ├── pdp_horsepower2.pdf │ ├── pdp_stroke2.pdf │ ├── plot_y_vs_y_hat_cred_ints2.pdf │ ├── plot_y_vs_y_hat_pred_ints2.pdf │ ├── rmse_num_trees_3.pdf │ ├── speed_full4.pdf │ ├── speed_zoomed4.pdf │ ├── v70i04.bib │ ├── var_imp_automobile_cc2.pdf │ └── var_selection_plot2.pdf ├── bartMachineJARs ├── DESCRIPTION ├── NAMESPACE ├── R │ └── onLoad.R ├── inst │ ├── COPYRIGHTS │ └── java │ │ ├── commons-math-2.1.jar │ │ ├── fastutil-core-8.5.8.jar │ │ └── trove-3.0.3.jar └── java │ └── README ├── bart_package_paper ├── appB.R ├── interaction_constraint_demo.R ├── sec3.1.R ├── sec4.1-4.9.R ├── sec4.10-4.12.R ├── sec5.R └── time_mat.RData ├── build.xml ├── commons-math-2.1.jar ├── datasets ├── c_breastcancer.csv ├── c_ionosphere.csv ├── c_letterrecognition.csv ├── r_CGMtoydata.csv ├── r_ankara.csv ├── r_automobile.csv ├── r_bands.csv ├── r_baseballsalary.csv ├── r_bbb.csv ├── r_bivariatelinear.csv ├── r_boston.csv ├── r_boston_half.csv ├── r_boston_half_missing.csv ├── r_boston_tiny_with_missing.csv ├── r_cars.csv ├── r_compactiv.csv ├── r_concretedata.csv ├── r_cpu.csv ├── r_digit_rec.csv ├── r_forestfires.csv ├── r_friedman.csv ├── r_friedman_hd.csv ├── r_german_credit.csv ├── r_glass.csv ├── r_ionosphere.csv ├── r_iris.csv ├── r_just_noise.csv ├── r_ozone.csv ├── r_parkinsonsmotor.csv ├── r_parkinsonstotal.csv ├── r_pole.csv ├── r_psoriasireduced.csv ├── r_simple.csv ├── r_simpledata.csv ├── r_stupiddata.csv ├── r_treemodel_high_n.csv ├── r_treemodel_high_p.csv ├── r_treemodel_high_p_low_n.csv ├── r_treemodel_low_n.csv ├── r_triazine.csv ├── r_univariatelinear.csv ├── r_waveform.csv ├── r_wine_red.csv └── r_wine_white.csv ├── fastutil-core-8.5.8.jar ├── java_code_documentation ├── AlgorithmTesting │ ├── DataAnalysis.html │ ├── DataSetupForCSVFile.html │ ├── class-use │ │ ├── DataAnalysis.html │ │ └── DataSetupForCSVFile.html │ ├── package-frame.html │ ├── package-summary.html │ ├── package-tree.html │ └── package-use.html ├── CustomLogging │ ├── StdOutErrLevel.html │ ├── SuperSimpleFormatter.html │ ├── class-use │ │ ├── StdOutErrLevel.html │ │ └── SuperSimpleFormatter.html │ ├── package-frame.html │ ├── package-summary.html │ ├── package-tree.html │ └── package-use.html ├── OpenSourceExtensions │ ├── StatUtil.html │ ├── TDoubleHashSetAndArray.html │ ├── UnorderedPair.html │ ├── class-use │ │ ├── StatUtil.html │ │ ├── TDoubleHashSetAndArray.html │ │ └── UnorderedPair.html │ ├── package-frame.html │ ├── package-summary.html │ ├── package-tree.html │ └── package-use.html ├── allclasses-frame.html ├── allclasses-noframe.html ├── bartMachine │ ├── Classifier.ErrorTypes.html │ ├── Classifier.html │ ├── StatToolbox.html │ ├── Tools.html │ ├── TreeArrayIllustration.html │ ├── TreeIllustration.ImageFileFilter.html │ ├── TreeIllustration.html │ ├── bartMachineClassification.html │ ├── bartMachineClassificationMultThread.html │ ├── bartMachineRegression.html │ ├── bartMachineRegressionMultThread.html │ ├── bartMachineTreeNode.html │ ├── bartMachine_a_base.html │ ├── bartMachine_b_hyperparams.html │ ├── bartMachine_c_debug.html │ ├── bartMachine_d_init.html │ ├── bartMachine_e_gibbs_base.html │ ├── bartMachine_f_gibbs_internal.html │ ├── bartMachine_g_mh.Steps.html │ ├── bartMachine_g_mh.html │ ├── bartMachine_h_eval.html │ ├── bartMachine_i_prior_cov_spec.html │ ├── class-use │ │ ├── Classifier.ErrorTypes.html │ │ ├── Classifier.html │ │ ├── StatToolbox.html │ │ ├── Tools.html │ │ ├── TreeArrayIllustration.html │ │ ├── TreeIllustration.ImageFileFilter.html │ │ ├── TreeIllustration.html │ │ ├── bartMachineClassification.html │ │ ├── bartMachineClassificationMultThread.html │ │ ├── bartMachineRegression.html │ │ ├── bartMachineRegressionMultThread.html │ │ ├── bartMachineTreeNode.html │ │ ├── bartMachine_a_base.html │ │ ├── bartMachine_b_hyperparams.html │ │ ├── bartMachine_c_debug.html │ │ ├── bartMachine_d_init.html │ │ ├── bartMachine_e_gibbs_base.html │ │ ├── bartMachine_f_gibbs_internal.html │ │ ├── bartMachine_g_mh.Steps.html │ │ ├── bartMachine_g_mh.html │ │ ├── bartMachine_h_eval.html │ │ └── bartMachine_i_prior_cov_spec.html │ ├── package-frame.html │ ├── package-summary.html │ ├── package-tree.html │ └── package-use.html ├── constant-values.html ├── deprecated-list.html ├── help-doc.html ├── index-files │ ├── index-1.html │ ├── index-10.html │ ├── index-11.html │ ├── index-12.html │ ├── index-13.html │ ├── index-14.html │ ├── index-15.html │ ├── index-16.html │ ├── index-17.html │ ├── index-18.html │ ├── index-19.html │ ├── index-2.html │ ├── index-20.html │ ├── index-21.html │ ├── index-3.html │ ├── index-4.html │ ├── index-5.html │ ├── index-6.html │ ├── index-7.html │ ├── index-8.html │ └── index-9.html ├── index.html ├── overview-frame.html ├── overview-summary.html ├── overview-tree.html ├── package-list ├── resources │ ├── background.gif │ ├── tab.gif │ ├── titlebar.gif │ └── titlebar_end.gif ├── serialized-form.html └── stylesheet.css ├── junit-4.10.jar ├── missing_data_paper ├── sec_4_mar.R ├── sec_4_mcar.R ├── sec_4_nmar.R ├── sec_4_pm_nmar.R ├── sec_5_mar.R ├── sec_5_mar_with_bumpup.R ├── sec_5_mcar.R └── sec_5_nmar.R ├── simple_example.R ├── src ├── AlgorithmTesting │ ├── DataAnalysis.java │ ├── DataSetupForCSVFile.java │ └── package-info.java ├── CustomLogging │ ├── LoggingOutputStream.java │ ├── StdOutErrLevel.java │ ├── SuperSimpleFormatter.java │ └── package-info.java ├── OpenSourceExtensions │ ├── MersenneTwisterFast.java │ ├── StatUtil.java │ ├── TDoubleHashSetAndArray.java │ ├── UnorderedPair.java │ └── package-info.java └── bartMachine │ ├── Classifier.java │ ├── StatToolbox.java │ ├── Tools.java │ ├── TreeArrayIllustration.java │ ├── TreeIllustration.java │ ├── bartMachineClassification.java │ ├── bartMachineClassificationMultThread.java │ ├── bartMachineRegression.java │ ├── bartMachineRegressionMultThread.java │ ├── bartMachineTreeNode.java │ ├── bartMachine_a_base.java │ ├── bartMachine_b_hyperparams.java │ ├── bartMachine_c_debug.java │ ├── bartMachine_d_init.java │ ├── bartMachine_e_gibbs_base.java │ ├── bartMachine_f_gibbs_internal.java │ ├── bartMachine_g_mh.java │ ├── bartMachine_h_eval.java │ ├── bartMachine_i_prior_cov_spec.java │ └── package-info.java └── trove-3.0.3.jar /.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled source # 2 | ################### 3 | *.com 4 | *.class 5 | *.dll 6 | *.exe 7 | *.o 8 | *.so 9 | 10 | # Packages # 11 | ############ 12 | # it's better to unpack these files and commit the raw source 13 | # git has its own built in compression methods 14 | *.7z 15 | *.dmg 16 | *.gz 17 | *.iso 18 | *.rar 19 | *.tar 20 | 21 | # Logs and databases # 22 | ###################### 23 | *.log 24 | *.sql 25 | *.sqlite 26 | 27 | # OS generated files # 28 | ###################### 29 | .DS_Store* 30 | ehthumbs.db 31 | Icon? 32 | Thumbs.db 33 | 34 | # other stuff project specific 35 | Rob_BayesTree.Rcheck/* 36 | *.swp 37 | *.log 38 | log 39 | log.lck 40 | log* 41 | log*.lck 42 | *.log.lck 43 | log~ 44 | log*~ 45 | bin/* 46 | .gitignore.swo 47 | .gitignore~ 48 | .settings/* 49 | debug_output/desktop.ini 50 | simulation_results.csv 51 | r_scripts/debug_only/* 52 | r_scripts/bart_bakeoff_params.R 53 | mathematica_scripts/* 54 | datasets/r_treemodel.csv 55 | debug_output/*.csv 56 | debug_output/*.png 57 | output_plots/*.pdf 58 | output_plots/*.csv 59 | datasets/bart_data.csv 60 | .externalToolBuilders/* 61 | sweave_reports/*.aux 62 | sweave_reports/*.log 63 | sweave_reports/*.out 64 | sweave_reports/*.pdf 65 | sweave_reports/*.tex 66 | sweave_reports/*.toc 67 | sweave_reports/auxilary_latex_files/* 68 | r_log/*.txt 69 | *.aux 70 | *.log 71 | *.out 72 | *.dvi 73 | *.Rhistory 74 | bin/* 75 | randsamps/chisq* 76 | RWeka/ 77 | rJAVA_examples/ 78 | bartMachine.Rcheck 79 | bartMachine.Rcheck/**/* 80 | bartMachineJARs.Rcheck 81 | bartMachineJARs.Rcheck/**/* 82 | old_mar/* 83 | datasets/r_zach.csv 84 | datasets/c_crime.csv 85 | datasets/c_crime_big.csv 86 | r_scripts/examples/sample_code.R 87 | r_scripts/examples/working1.rdata 88 | bartMachine/install_instructions.txt 89 | /bin/ 90 | bartMachine/vignettes/bartMachine.pdf 91 | bartMachine/vignettes/bartMachine.tex 92 | -------------------------------------------------------------------------------- /.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | bartMachine 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.jdt.core.javabuilder 10 | 11 | 12 | 13 | 14 | de.walware.docmlet.tex.builders.Tex 15 | 16 | 17 | 18 | 19 | de.walware.statet.r.builders.RSupport 20 | 21 | 22 | 23 | 24 | 25 | org.eclipse.jdt.core.javanature 26 | org.eclipse.statet.ide.resourceProjects.Statet 27 | org.eclipse.statet.r.resourceProjects.R 28 | org.eclipse.statet.docmlet.resourceProjects.Tex 29 | 30 | 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Adam Kapelner and Justin Bleich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /bartMachine/DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: bartMachine 2 | Type: Package 3 | Title: Bayesian Additive Regression Trees 4 | Version: 1.3.4 5 | Date: 2023-6-25 6 | Author: Adam Kapelner and Justin Bleich (R package) 7 | Maintainer: Adam Kapelner 8 | Description: An advanced implementation of Bayesian Additive Regression Trees with expanded features for data analysis and visualization. 9 | License: GPL-3 10 | Depends: R (>= 2.14.0), rJava (>= 0.9-8), bartMachineJARs (>= 1.2.1), randomForest, missForest 11 | Imports: graphics, grDevices, stats 12 | SystemRequirements: Java (>= 8.0) 13 | BugReports: https://github.com/kapelner/bartMachine/issues 14 | -------------------------------------------------------------------------------- /bartMachine/NAMESPACE: -------------------------------------------------------------------------------- 1 | import(rJava) 2 | import(bartMachineJARs) 3 | import(randomForest) 4 | import(missForest) 5 | 6 | import(graphics) 7 | import(grDevices) 8 | import(stats) 9 | importFrom("utils", "packageVersion") 10 | 11 | export(bartMachine) 12 | export(build_bart_machine) 13 | export(bartMachineArr) 14 | export(predict_bartMachineArr) 15 | 16 | export(set_bart_machine_num_cores) 17 | export(bart_machine_num_cores) 18 | 19 | export(dummify_data) 20 | export(get_var_counts_over_chain) 21 | export(get_var_props_over_chain) 22 | export(node_prediction_training_data_indices) 23 | export(get_projection_weights) 24 | 25 | export(cov_importance_test) 26 | export(linearity_test) 27 | 28 | export(plot_y_vs_yhat) 29 | export(check_bart_error_assumptions) 30 | export(get_sigsqs) 31 | 32 | export(investigate_var_importance) 33 | export(plot_convergence_diagnostics) 34 | export(interaction_investigator) 35 | export(pd_plot) 36 | export(rmse_by_num_trees) 37 | 38 | export(bart_predict_for_test_data) 39 | export(bart_machine_get_posterior) 40 | export(calc_credible_intervals) 41 | export(calc_prediction_intervals) 42 | 43 | export(build_bart_machine_cv) 44 | export(bartMachineCV) 45 | export(k_fold_cv) 46 | 47 | export(var_selection_by_permute) 48 | export(var_selection_by_permute_cv) 49 | 50 | export(extract_raw_node_data) 51 | 52 | S3method(predict, bartMachine) 53 | S3method(print, bartMachine) 54 | S3method(summary, bartMachine) #alias for print 55 | -------------------------------------------------------------------------------- /bartMachine/R/bartMachine.R: -------------------------------------------------------------------------------- 1 | bartMachine = function( 2 | X = NULL, 3 | y = NULL, 4 | Xy = NULL, 5 | num_trees = 50, #found many times to not get better after this value... so let it be the default, it's faster too 6 | num_burn_in = 250, 7 | num_iterations_after_burn_in = 1000, 8 | alpha = 0.95, 9 | beta = 2, 10 | k = 2, 11 | q = 0.9, 12 | nu = 3.0, 13 | prob_rule_class = 0.5, 14 | mh_prob_steps = c(2.5, 2.5, 4) / 9, #only the first two matter 15 | debug_log = FALSE, 16 | run_in_sample = TRUE, 17 | s_sq_y = "mse", # "mse" or "var" 18 | sig_sq_est = NULL, 19 | print_tree_illustrations = FALSE, #POWER USERS ONLY 20 | cov_prior_vec = NULL, 21 | interaction_constraints = NULL, 22 | use_missing_data = FALSE, 23 | covariates_to_permute = NULL, #PRIVATE 24 | num_rand_samps_in_library = 10000, #give the user the option to make a bigger library of random samples of normals and inv-gammas 25 | use_missing_data_dummies_as_covars = FALSE, 26 | replace_missing_data_with_x_j_bar = FALSE, 27 | impute_missingness_with_rf_impute = FALSE, 28 | impute_missingness_with_x_j_bar_for_lm = TRUE, 29 | mem_cache_for_speed = TRUE, 30 | flush_indices_to_save_RAM = TRUE, 31 | serialize = FALSE, 32 | seed = NULL, 33 | verbose = TRUE 34 | ){ 35 | build_bart_machine( 36 | X = X, 37 | y = y, 38 | Xy = Xy, 39 | num_trees = num_trees, 40 | num_burn_in = num_burn_in, 41 | num_iterations_after_burn_in = num_iterations_after_burn_in, 42 | alpha = alpha, 43 | beta = beta, 44 | k = k, 45 | q = q, 46 | nu = nu, 47 | prob_rule_class = prob_rule_class, 48 | mh_prob_steps = mh_prob_steps, 49 | debug_log = debug_log, 50 | run_in_sample = run_in_sample, 51 | s_sq_y = s_sq_y, 52 | sig_sq_est = sig_sq_est, 53 | print_tree_illustrations = print_tree_illustrations, 54 | cov_prior_vec = cov_prior_vec, 55 | interaction_constraints = interaction_constraints, 56 | use_missing_data = use_missing_data, 57 | covariates_to_permute = covariates_to_permute, 58 | num_rand_samps_in_library = num_rand_samps_in_library, #give the user the option to make a bigger library of random samples of normals and inv-gammas 59 | use_missing_data_dummies_as_covars = use_missing_data_dummies_as_covars, 60 | replace_missing_data_with_x_j_bar = replace_missing_data_with_x_j_bar, 61 | impute_missingness_with_rf_impute = impute_missingness_with_rf_impute, 62 | impute_missingness_with_x_j_bar_for_lm = impute_missingness_with_x_j_bar_for_lm, 63 | mem_cache_for_speed = mem_cache_for_speed, 64 | flush_indices_to_save_RAM = flush_indices_to_save_RAM, 65 | serialize = serialize, 66 | seed = seed, 67 | verbose = verbose 68 | ) 69 | } 70 | 71 | 72 | bartMachineCV = function(X = NULL, y = NULL, Xy = NULL, 73 | num_tree_cvs = c(50, 200), 74 | k_cvs = c(2, 3, 5), 75 | nu_q_cvs = NULL, 76 | k_folds = 5, 77 | folds_vec = NULL, 78 | verbose = FALSE, ...){ 79 | 80 | build_bart_machine_cv(X, y, Xy, 81 | num_tree_cvs, 82 | k_cvs, 83 | nu_q_cvs, 84 | k_folds, 85 | folds_vec, ...) 86 | } 87 | -------------------------------------------------------------------------------- /bartMachine/R/bart_arrays.R: -------------------------------------------------------------------------------- 1 | 2 | bartMachineArr = function(bart_machine, R = 10){ 3 | arr = list() 4 | arr[[1]] = bart_machine 5 | for (i in 2 : R){ 6 | arr[[i]] = bart_machine_duplicate(bart_machine) 7 | } 8 | class(arr) = "bartMarchineArr" 9 | arr 10 | } 11 | 12 | predict_bartMachineArr = function(object, new_data, ...){ 13 | R = length(object) 14 | n_star = nrow(new_data) 15 | predicts = matrix(NA, nrow = n_star, ncol = R) 16 | for (r in 1 : R){ 17 | predicts[, r] = predict(object[[r]], new_data, ...) 18 | } 19 | rowMeans(predicts) 20 | } -------------------------------------------------------------------------------- /bartMachine/R/bart_package_f_tests.R: -------------------------------------------------------------------------------- 1 | ##function to permute columns of X and check BART's performance 2 | cov_importance_test = function(bart_machine, covariates = NULL, num_permutation_samples = 100, plot = TRUE){ 3 | check_serialization(bart_machine) #ensure the Java object exists and fire an error if not 4 | #be able to handle regular expressions to find the covariates 5 | 6 | all_covariates = bart_machine$training_data_features_with_missing_features 7 | 8 | if (is.null(covariates)){ 9 | title = "bartMachine omnibus test for covariate importance\n" 10 | } else if (length(covariates) <= 3){ 11 | if (inherits(covariates[1], "numeric")){ 12 | cov_names = paste(all_covariates[covariates], collapse = ", ") 13 | } else { 14 | cov_names = paste(covariates, collapse = ", ") 15 | } 16 | title = paste("bartMachine test for importance of covariate(s):", cov_names, "\n") 17 | } else { 18 | title = paste("bartMachine test for importance of", length(covariates), "covariates", "\n") 19 | } 20 | cat(title) 21 | observed_error_estimate = ifelse(bart_machine$pred_type == "regression", bart_machine$PseudoRsq, bart_machine$misclassification_error) 22 | 23 | permutation_samples_of_error = array(NA, num_permutation_samples) 24 | for (nsim in 1 : num_permutation_samples){ 25 | cat(".") 26 | if (nsim %% 50 == 0){ 27 | cat("\n") 28 | } 29 | #omnibus F-like test - just permute y (same as permuting ALL the columns of X and it's faster) 30 | if (is.null(covariates)){ 31 | bart_machine_samp = bart_machine_duplicate(bart_machine, y = sample(bart_machine$y), run_in_sample = TRUE, verbose = FALSE) #we have to turn verbose off otherwise there would be too many outputs 32 | #partial F-like test - permute the columns that we're interested in seeing if they matter 33 | } else { 34 | X_samp = bart_machine$X #copy original design matrix 35 | 36 | covariates_left_to_permute = c() 37 | for (cov in covariates){ 38 | if (cov %in% colnames(X_samp)){ 39 | X_samp[, cov] = sample(X_samp[, cov]) 40 | } else { 41 | covariates_left_to_permute = c(covariates_left_to_permute, cov) 42 | } 43 | } 44 | 45 | bart_machine_samp = bart_machine_duplicate(bart_machine, X = X_samp, covariates_to_permute = covariates_left_to_permute, run_in_sample = TRUE, verbose = FALSE) #we have to turn verbose off otherwise there would be too many outputs 46 | } 47 | #record permutation result 48 | permutation_samples_of_error[nsim] = ifelse(bart_machine$pred_type == "regression", bart_machine_samp$PseudoRsq, bart_machine_samp$misclassification_error) 49 | } 50 | cat("\n") 51 | 52 | ##compute p-value 53 | pval = ifelse(bart_machine$pred_type == "regression", sum(observed_error_estimate < permutation_samples_of_error), sum(observed_error_estimate > permutation_samples_of_error)) / (num_permutation_samples + 1) 54 | 55 | if (plot){ 56 | hist(permutation_samples_of_error, 57 | xlim = c(min(permutation_samples_of_error, 0.99 * observed_error_estimate), max(permutation_samples_of_error, 1.01 * observed_error_estimate)), 58 | xlab = paste("permutation samples\n pval = ", round(pval, 3)), 59 | br = num_permutation_samples / 10, 60 | main = paste(title, "Null Samples of", ifelse(bart_machine$pred_type == "regression", "Pseudo-R^2's", "Misclassification Errors"))) 61 | abline(v = observed_error_estimate, col = "blue", lwd = 3) 62 | } 63 | cat("p_val = ", pval, "\n") 64 | invisible(list(permutation_samples_of_error = permutation_samples_of_error, observed_error_estimate = observed_error_estimate, pval = pval)) 65 | } 66 | 67 | linearity_test = function(lin_mod = NULL, X = NULL, y = NULL, num_permutation_samples = 100, plot = TRUE, ...){ 68 | if (is.null(lin_mod)){ 69 | lin_mod = lm(y ~ as.matrix(X)) 70 | } 71 | y_hat = predict(lin_mod, X) 72 | bart_mod = bartMachine(X, y - y_hat, ...) 73 | cov_importance_test(bart_mod, num_permutation_samples = num_permutation_samples, plot = plot) 74 | } 75 | 76 | -------------------------------------------------------------------------------- /bartMachine/R/bart_package_inits.R: -------------------------------------------------------------------------------- 1 | ##color array 2 | COLORS = array(NA, 500) 3 | for (i in 1 : 500){ 4 | COLORS[i] = rgb(runif(1, 0, 0.7), runif(1, 0, 0.7), runif(1, 0, 0.7)) 5 | } 6 | 7 | ##set number of cores in use 8 | set_bart_machine_num_cores = function(num_cores){ 9 | assign("BART_NUM_CORES", num_cores, bartMachine_globals) 10 | cat("bartMachine now using", num_cores, "cores.\n") 11 | } 12 | 13 | ##get number of cores in use 14 | DEFAULT_BART_NUM_CORES = 1 15 | bart_machine_num_cores = function(){ 16 | if (exists("BART_NUM_CORES", envir = bartMachine_globals)){ 17 | get("BART_NUM_CORES", bartMachine_globals) 18 | } else { 19 | DEFAULT_BART_NUM_CORES 20 | } 21 | } 22 | 23 | set_bart_machine_memory = function(bart_max_mem){ 24 | cat("This method has been deprecated. Please use 'options(java.parameters = \"-Xmx", bart_max_mem, "m\")' instead.\n", sep = "") 25 | } 26 | 27 | ##get variable counts 28 | get_var_counts_over_chain = function(bart_machine, type = "splits"){ 29 | check_serialization(bart_machine) #ensure the Java object exists and fire an error if not 30 | 31 | if (!(type %in% c("trees", "splits"))){ 32 | stop("type must be \"trees\" or \"splits\"") 33 | } 34 | C = .jcall(bart_machine$java_bart_machine, "[[I", "getCountsForAllAttribute", type, simplify = TRUE) 35 | colnames(C) = colnames(bart_machine$model_matrix_training_data)[1 : bart_machine$p] 36 | C 37 | } 38 | 39 | #get variable inclusion proportions 40 | get_var_props_over_chain = function(bart_machine, type = "splits"){ 41 | check_serialization(bart_machine) #ensure the Java object exists and fire an error if not 42 | 43 | if (!(type %in% c("trees", "splits"))){ 44 | stop("type must be \"trees\" or \"splits\"") 45 | } 46 | attribute_props = .jcall(bart_machine$java_bart_machine, "[D", "getAttributeProps", type) 47 | names(attribute_props) = colnames(bart_machine$model_matrix_training_data)[1 : bart_machine$p] 48 | attribute_props 49 | } 50 | 51 | ##private function called in summary() 52 | sigsq_est = function(bart_machine){ 53 | sigsqs = .jcall(bart_machine$java_bart_machine, "[D", "getGibbsSamplesSigsqs") 54 | sigsqs_after_burnin = sigsqs[(length(sigsqs) - bart_machine$num_iterations_after_burn_in) : length(sigsqs)] 55 | mean(sigsqs_after_burnin) 56 | } 57 | 58 | #There's no standard R function for this. 59 | sample_mode = function(data){ 60 | as.numeric(names(sort(-table(data)))[1]) 61 | } 62 | 63 | check_serialization = function(bart_machine){ 64 | if (is.jnull(bart_machine$java_bart_machine)){ 65 | stop("This bartMachine object was loaded from an R image but was not serialized.\n Please build bartMachine using the option \"serialize = TRUE\" next time.\n") 66 | } 67 | } -------------------------------------------------------------------------------- /bartMachine/R/bart_package_summaries.R: -------------------------------------------------------------------------------- 1 | ##give summary info about bart 2 | summary.bartMachine = function(object, ...){ 3 | cat(paste("bartMachine v", packageVersion("bartMachine"), ifelse(object$pred_type == "regression", " for regression", " for classification"), "\n\n", sep = "")) 4 | if (object$use_missing_data){ 5 | cat("Missing data feature ON\n") 6 | } 7 | #first print out characteristics of the training data 8 | if (!is.null(object$interaction_constraints)){ 9 | cat(paste0("number of specified covariate interactivity constraints = ", length(object$interaction_constraints), ".\ntraining data size: n = ", object$n, " and p = ", object$p, "\n")) 10 | } else { 11 | cat(paste("training data size: n =", object$n, "and p =", object$p, "\n")) 12 | } 13 | 14 | ##build time 15 | ttb = as.numeric(object$time_to_build, units = "secs") 16 | if (ttb > 60){ 17 | ttb = as.numeric(object$time_to_build, units = "mins") 18 | cat(paste("built in", round(ttb, 2), "mins on", object$num_cores, ifelse(object$num_cores == 1, "core,", "cores,"), object$num_trees, "trees,", object$num_burn_in, "burn-in and", object$num_iterations_after_burn_in, "post. samples\n")) 19 | } else { 20 | cat(paste("built in", round(ttb, 1), "secs on", object$num_cores, ifelse(object$num_cores == 1, "core,", "cores,"), object$num_trees, "trees,", object$num_burn_in, "burn-in and", object$num_iterations_after_burn_in, "post. samples\n")) 21 | } 22 | 23 | if (object$pred_type == "regression"){ 24 | sigsq_est = sigsq_est(object) ##call private function 25 | cat(paste("\nsigsq est for y beforehand:", round(object$sig_sq_est, 3), "\n")) 26 | cat(paste("avg sigsq estimate after burn-in:", round(sigsq_est, 5), "\n")) 27 | 28 | if (object$run_in_sample){ 29 | cat("\nin-sample statistics:\n") 30 | cat(paste(" L1 =", round(object$L1_err_train, 2), "\n", 31 | "L2 =", round(object$L2_err_train, 2), "\n", 32 | "rmse =", round(object$rmse_train, 2), "\n"), 33 | "Pseudo-Rsq =", round(object$PseudoRsq, 4)) 34 | 35 | es = object$residuals 36 | if (length(es) > 5000){ 37 | normal_p_val = shapiro.test(sample(es, 5000))$p.value 38 | } else { 39 | normal_p_val = shapiro.test(es)$p.value 40 | } 41 | cat("\np-val for shapiro-wilk test of normality of residuals:", round(normal_p_val, 5), "\n") 42 | 43 | centered_p_val = t.test(es)$p.value 44 | cat("p-val for zero-mean noise:", round(centered_p_val, 5), "\n") 45 | } else { 46 | cat("\nno in-sample information available (use option run_in_sample = TRUE next time)\n") 47 | } 48 | } else if (object$pred_type == "classification"){ 49 | if (object$run_in_sample){ 50 | cat("\nconfusion matrix:\n\n") 51 | print(object$confusion_matrix) 52 | } else { 53 | cat("\nno in-sample information available (use option run_in_sample = TRUE next time)\n") 54 | } 55 | } 56 | cat("\n") 57 | } 58 | 59 | #alias for summary 60 | print.bartMachine = function(x, ...){summary(x)} 61 | -------------------------------------------------------------------------------- /bartMachine/R/zzz.R: -------------------------------------------------------------------------------- 1 | .onLoad = function(libname, pkgname) { 2 | .jpackage(pkgname, lib.loc = libname) 3 | 4 | #need to check if proper Java is installed by special request of Prof Brian Ripley 5 | jv = .jcall("java/lang/System", "S", "getProperty", "java.runtime.version") 6 | if (substr(jv, 1L, 2L) == "1.") { 7 | jvn = as.numeric(paste0(strsplit(jv, "[.]")[[1L]][1:2], collapse = ".")) 8 | if (jvn < 1.7){ 9 | warning("Java 7 (at minimum) is needed for this package but is does not seem to be available. This message may be in error; apologies if it is.") 10 | } 11 | } 12 | 13 | assign("bartMachine_globals", new.env(), envir = parent.env(environment())) 14 | } 15 | 16 | .onAttach = function(libname, pkgname){ 17 | num_gigs_ram_available = .jcall(.jnew("java/lang/Runtime"), "J", "maxMemory") / 1e9 18 | packageStartupMessage( 19 | paste("Welcome to bartMachine v", packageVersion("bartMachine"), 20 | "! You have ", round(num_gigs_ram_available, 2), 21 | "GB memory available.\n\n", 22 | "If you run out of memory, restart R, and use e.g.\n'options(java.parameters = \"-Xmx5g\")' for 5GB of RAM before you call\n'library(bartMachine)'.\n", 23 | sep = "") 24 | ) 25 | } -------------------------------------------------------------------------------- /bartMachine/data/automobile.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/data/automobile.RData -------------------------------------------------------------------------------- /bartMachine/data/benchmark_datasets.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/data/benchmark_datasets.RData -------------------------------------------------------------------------------- /bartMachine/inst/CITATION: -------------------------------------------------------------------------------- 1 | bibentry(bibtype = "Article", 2 | title = "{bartMachine}: Machine Learning with {B}ayesian Additive Regression Trees", 3 | author = c(person(given = "Adam", 4 | family = "Kapelner", 5 | email = "kapelner@qc.cuny.edu"), 6 | person(given = "Justin", 7 | family = "Bleich")), 8 | journal = "Journal of Statistical Software", 9 | year = "2016", 10 | volume = "70", 11 | number = "4", 12 | pages = "1--40", 13 | doi = "10.18637/jss.v070.i04", 14 | 15 | header = "To cite bartMachine/bartMachineJARs in publications use:", 16 | textVersion = 17 | paste("Adam Kapelner, Justin Bleich (2016).", 18 | "bartMachine: Machine Learning with Bayesian Additive Regression Trees.", 19 | "Journal of Statistical Software, 70(4), 1-40.", 20 | "doi:10.18637/jss.v070.i04") 21 | ) 22 | 23 | -------------------------------------------------------------------------------- /bartMachine/inst/COPYRIGHTS: -------------------------------------------------------------------------------- 1 | The following open source Java libraries are included in this package through the package bartMachineJARs: 2 | 3 | - JUnit under the Eclipse Public License 1.0, 4 | Version 4.10, January 2004, 5 | http://junit.org/, 6 | Copyright 2002-2014 JUnit 7 | 8 | - Commons Math: The Apache Commons Mathematics Library under the Apache License 2.0, 9 | Version 2.1, March 2010, 10 | http://commons.apache.org/proper/commons-math/, 11 | Copyright 2003-2014 The Apache Software Foundation 12 | 13 | - Trove under Lesser GNU Public License (LGPL) 2.1, 14 | Version 3.0.3, February 2012, 15 | http://trove.starlight-systems.com/ 16 | 17 | License 18 | ======= 19 | 20 | You may obtain a copy of the Eclipse Public License 1.0 at 21 | https://www.eclipse.org/legal/epl-v10.html 22 | 23 | You may obtain a copy of the Apache License, Version 2.0 at 24 | http://www.apache.org/licenses/LICENSE-2.0 25 | 26 | You may obtain a copy of the Lesser GNU Public License (LGPL) 2.1 at 27 | http://www.gnu.org/licenses/lgpl-2.1.html -------------------------------------------------------------------------------- /bartMachine/inst/java/bart_java.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/inst/java/bart_java.jar -------------------------------------------------------------------------------- /bartMachine/man/automobile.Rd: -------------------------------------------------------------------------------- 1 | \name{automobile} 2 | \alias{automobile} 3 | \title{Data concerning automobile prices.} 4 | \description{ 5 | The \code{automobile} data frame has 201 rows and 25 columns and 6 | concerns automobiles in the 1985 Auto Imports Database. The response 7 | variable, \code{price}, is the log selling price of the automobile. There 8 | are 7 categorical predictors and 17 continuous / integer predictors which 9 | are features of the automobiles. 41 automobiles have missing data in one 10 | or more of the feature entries. This dataset is true to the original except 11 | with a few of the predictors dropped. 12 | } 13 | \usage{ 14 | data(automobile) 15 | } 16 | \source{ 17 | K Bache and M Lichman. UCI machine learning repository, 2013. 18 | http://archive.ics.uci.edu/ml/datasets/Automobile 19 | } 20 | \keyword{datasets} 21 | -------------------------------------------------------------------------------- /bartMachine/man/bartMachineArr.Rd: -------------------------------------------------------------------------------- 1 | \name{bartMachineArr} 2 | \alias{bartMachineArr} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Create an array of BART models for the same data. 6 | } 7 | \description{ 8 | If BART creates models that are variable, 9 | running many on the same dataset and averaging is a good strategy. 10 | This function is a convenience method for this procedure. 11 | } 12 | \usage{ 13 | bartMachineArr(bart_machine, R = 10) 14 | } 15 | %- maybe also 'usage' for other objects documented here. 16 | \arguments{ 17 | \item{bart_machine}{ 18 | An object of class ``bartMachine''. 19 | } 20 | \item{R}{ 21 | The number of replicated BART models in the array. 22 | } 23 | } 24 | 25 | \value{ 26 | A \code{bartMachineArr} object which is just a list of the \code{R} bartMachine models. 27 | } 28 | 29 | \author{ 30 | Adam Kapelner 31 | } 32 | 33 | 34 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 35 | 36 | \examples{ 37 | #Regression example 38 | \dontrun{ 39 | #generate Friedman data 40 | set.seed(11) 41 | n = 200 42 | p = 5 43 | X = data.frame(matrix(runif(n * p), ncol = p)) 44 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 45 | 46 | ##build BART regression model 47 | bart_machine = bartMachine(X, y) 48 | bart_machine_arr = bartMachineArr(bart_machine) 49 | 50 | #Classification example 51 | data(iris) 52 | iris2 = iris[51 : 150, ] #do not include the third type of flower for this example 53 | iris2$Species = factor(iris2$Species) 54 | bart_machine = bartMachine(iris2[ ,1:4], iris2$Species) 55 | bart_machine_arr = bartMachineArr(bart_machine) 56 | } 57 | 58 | 59 | } 60 | -------------------------------------------------------------------------------- /bartMachine/man/bartMachineCV.Rd: -------------------------------------------------------------------------------- 1 | \name{bartMachineCV} 2 | \alias{bartMachineCV} 3 | \alias{build_bart_machine_cv} 4 | 5 | \title{ 6 | Build BART-CV 7 | } 8 | \description{ 9 | Builds a BART-CV model by cross-validating over a grid of hyperparameter choices. 10 | } 11 | \usage{ 12 | bartMachineCV(X = NULL, y = NULL, Xy = NULL, 13 | num_tree_cvs = c(50, 200), k_cvs = c(2, 3, 5), 14 | nu_q_cvs = NULL, k_folds = 5, folds_vec = NULL, verbose = FALSE, ...) 15 | 16 | build_bart_machine_cv(X = NULL, y = NULL, Xy = NULL, 17 | num_tree_cvs = c(50, 200), k_cvs = c(2, 3, 5), 18 | nu_q_cvs = NULL, k_folds = 5, folds_vec = NULL, verbose = FALSE, ...) 19 | } 20 | %- maybe also 'usage' for other objects documented here. 21 | \arguments{ 22 | \item{X}{ 23 | Data frame of predictors. Factors are automatically converted to dummies interally. 24 | } 25 | \item{y}{ 26 | Vector of response variable. If \code{y} is \code{numeric} or \code{integer}, a BART model for regression is built. If \code{y} is a factor with two levels, a BART model for classification is built. 27 | } 28 | \item{Xy}{ 29 | A data frame of predictors and the response. The response column must be named ``y''. 30 | } 31 | \item{num_tree_cvs}{ 32 | Vector of sizes for the sum-of-trees models to cross-validate over. 33 | } 34 | \item{k_cvs}{ 35 | Vector of choices for the hyperparameter \code{k} to cross-validate over. 36 | } 37 | \item{nu_q_cvs}{ 38 | Only for regression. List of vectors containing (\code{nu}, \code{q}) ordered pair choices to cross-validate over. If \code{NULL}, then it defaults to the three values \code{list(c(3, 0.9), c(3, 0.99), c(10, 0.75))}. 39 | } 40 | \item{k_folds}{ 41 | Number of folds for cross-validation 42 | } 43 | \item{folds_vec}{ 44 | An integer vector of indices specifying which fold each observation belongs to. 45 | } 46 | \item{verbose}{ 47 | Prints information about progress of the algorithm to the screen. 48 | } 49 | \item{\dots}{ 50 | Additional arguments to be passed to \code{bartMachine}. 51 | } 52 | } 53 | 54 | \value{ 55 | Returns an object of class ``bartMachine'' with the set of hyperparameters chosen via cross-validation. We also return a matrix ``cv_stats'' 56 | which contains the out-of-sample RMSE for each hyperparameter set tried and ``folds'' which gives the fold in which each observation fell across the k-folds. 57 | } 58 | \references{ 59 | Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning 60 | with Bayesian Additive Regression Trees. Journal of Statistical 61 | Software, 70(4), 1-40. doi:10.18637/jss.v070.i04 62 | } 63 | \author{ 64 | Adam Kapelner and Justin Bleich 65 | } 66 | \note{ 67 | This function may require significant run-time. 68 | This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}} via calling \code{\link{bartMachine}}. 69 | } 70 | 71 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 72 | 73 | \seealso{ 74 | \code{\link{bartMachine}} 75 | } 76 | \examples{ 77 | \dontrun{ 78 | #generate Friedman data 79 | set.seed(11) 80 | n = 200 81 | p = 5 82 | X = data.frame(matrix(runif(n * p), ncol = p)) 83 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 84 | 85 | ##build BART regression model 86 | bart_machine_cv = bartMachineCV(X, y) 87 | 88 | #information about cross-validated model 89 | summary(bart_machine_cv) 90 | } 91 | 92 | } 93 | -------------------------------------------------------------------------------- /bartMachine/man/bart_machine_get_posterior.Rd: -------------------------------------------------------------------------------- 1 | \name{bart_machine_get_posterior} 2 | \alias{bart_machine_get_posterior} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Get Full Posterior Distribution 6 | } 7 | \description{ 8 | Generates draws from posterior distribution of \eqn{\hat{f}(x)} for a specified set of observations. 9 | } 10 | \usage{ 11 | bart_machine_get_posterior(bart_machine, new_data) 12 | } 13 | %- maybe also 'usage' for other objects documented here. 14 | \arguments{ 15 | \item{bart_machine}{ 16 | An object of class ``bartMachine''. 17 | } 18 | \item{new_data}{ 19 | A data frame containing observations at which draws from posterior distribution of \eqn{\hat{f}(x)} are to be obtained. 20 | } 21 | } 22 | 23 | \value{ 24 | Returns a list with the following components: 25 | %% If it is a LIST, use 26 | \item{y_hat}{Posterior mean estimates. For regression, the estimates have the same units as the response. For classification, the estimates are probabilities.} 27 | \item{new_data}{The data frame with rows at which the posterior draws are to be generated. Column names should match that of the training data.} 28 | \item{y_hat_posterior_samples}{The full set of posterior samples of size \code{num_iterations_after_burn_in} for each observation. For regression, the estimates have the same units as the response. For classification, the estimates are probabilities.} 29 | %% ... 30 | } 31 | \author{ 32 | Adam Kapelner and Justin Bleich 33 | } 34 | \note{ 35 | This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}. 36 | } 37 | 38 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 39 | 40 | \seealso{ 41 | \code{\link{calc_credible_intervals}}, \code{\link{calc_prediction_intervals}} 42 | } 43 | \examples{ 44 | \dontrun{ 45 | #Regression example 46 | 47 | #generate Friedman data 48 | set.seed(11) 49 | n = 200 50 | p = 5 51 | X = data.frame(matrix(runif(n * p), ncol = p)) 52 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 53 | 54 | ##build BART regression model 55 | bart_machine = bartMachine(X, y) 56 | 57 | #get posterior distribution 58 | posterior = bart_machine_get_posterior(bart_machine, X) 59 | print(posterior$y_hat) 60 | 61 | 62 | #Classification example 63 | 64 | #get data and only use 2 factors 65 | data(iris) 66 | iris2 = iris[51:150,] 67 | iris2$Species = factor(iris2$Species) 68 | 69 | #build BART classification model 70 | bart_machine = bartMachine(iris2[ ,1 : 4], iris2$Species) 71 | 72 | #get posterior distribution 73 | posterior = bart_machine_get_posterior(bart_machine, iris2[ ,1 : 4]) 74 | print(posterior$y_hat) 75 | } 76 | 77 | 78 | } 79 | -------------------------------------------------------------------------------- /bartMachine/man/bart_machine_num_cores.Rd: -------------------------------------------------------------------------------- 1 | \name{bart_machine_num_cores} 2 | \alias{bart_machine_num_cores} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Get Number of Cores Used by BART 6 | } 7 | \description{ 8 | Returns number of cores used by BART 9 | } 10 | \usage{ 11 | bart_machine_num_cores() 12 | } 13 | %- maybe also 'usage' for other objects documented here. 14 | \details{ 15 | Returns the number of cores currently being used by parallelized BART functions 16 | } 17 | \value{ 18 | %% ~Describe the value returned 19 | %% If it is a LIST, use 20 | Number of cores currently being used by parallelized BART functions. 21 | %% \item{comp2 }{Description of 'comp2'} 22 | %% ... 23 | } 24 | \author{ 25 | Adam Kapelner and Justin Bleich 26 | } 27 | 28 | 29 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 30 | 31 | \seealso{ 32 | \code{\link{set_bart_machine_num_cores}} 33 | } 34 | \examples{ 35 | \dontrun{ 36 | bart_machine_num_cores() 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /bartMachine/man/bart_predict_for_test_data.Rd: -------------------------------------------------------------------------------- 1 | \name{bart_predict_for_test_data} 2 | \alias{bart_predict_for_test_data} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Predict for Test Data with Known Outcomes 6 | } 7 | \description{ 8 | Utility wrapper function for computing out-of-sample metrics for a BART model when the test set outcomes are known. 9 | } 10 | \usage{ 11 | bart_predict_for_test_data(bart_machine, Xtest, ytest, prob_rule_class = NULL) 12 | } 13 | %- maybe also 'usage' for other objects documented here. 14 | \arguments{ 15 | \item{bart_machine}{ 16 | An object of class ``bartMachine''. 17 | } 18 | \item{Xtest}{ 19 | Data frame for test data containing rows at which predictions are to be made. Colnames should match that of the training data. 20 | } 21 | \item{ytest}{ 22 | Actual outcomes for test data. 23 | } 24 | \item{prob_rule_class}{ 25 | Threshold for classification. 26 | } 27 | } 28 | 29 | \value{ 30 | For regression models, a list with the following components is returned: 31 | 32 | \item{y_hat}{Predictions (as posterior means) for the test observations.} 33 | \item{L1_err}{L1 error for predictions.} 34 | \item{L2_err}{L2 error for predictions.} 35 | \item{rmse}{RMSE for predictions.} 36 | 37 | For classification models, a list with the following components is returned: 38 | 39 | \item{y_hat}{Class predictions for the test observations.} 40 | \item{p_hat}{Probability estimates for the test observations.} 41 | \item{confusion_matrix}{A confusion matrix for the test observations.} 42 | 43 | %% ... 44 | } 45 | \author{ 46 | Adam Kapelner and Justin Bleich 47 | } 48 | 49 | 50 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 51 | 52 | \seealso{ 53 | \code{\link{predict}} 54 | } 55 | \examples{ 56 | \dontrun{ 57 | #generate Friedman data 58 | set.seed(11) 59 | n = 250 60 | p = 5 61 | X = data.frame(matrix(runif(n * p), ncol = p)) 62 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 63 | 64 | ##split into train and test 65 | train_X = X[1 : 200, ] 66 | test_X = X[201 : 250, ] 67 | train_y = y[1 : 200] 68 | test_y = y[201 : 250] 69 | 70 | ##build BART regression model 71 | bart_machine = bartMachine(train_X, train_y) 72 | 73 | #explore performance on test data 74 | oos_perf = bart_predict_for_test_data(bart_machine, test_X, test_y) 75 | print(oos_perf$rmse) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /bartMachine/man/benchmark_datasets.Rd: -------------------------------------------------------------------------------- 1 | \name{benchmark_datasets} 2 | \alias{ankara} 3 | \alias{baseball} 4 | \alias{boston} 5 | \alias{compactiv} 6 | \alias{ozone} 7 | \alias{pole} 8 | \alias{triazine} 9 | \alias{wine.red} 10 | \alias{wine.white} 11 | \title{benchmark_datasets} 12 | \description{ 13 | Nine diverse datasets which were used for benchmarking bartMachine's out of sample performance in 14 | the vignette for this package. 15 | } 16 | \usage{ 17 | data(benchmark_datasets) 18 | } 19 | \source{ 20 | See vignette for details. 21 | } 22 | \keyword{datasets} 23 | -------------------------------------------------------------------------------- /bartMachine/man/calc_credible_intervals.Rd: -------------------------------------------------------------------------------- 1 | \name{calc_credible_intervals} 2 | \alias{calc_credible_intervals} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Calculate Credible Intervals 6 | } 7 | \description{ 8 | Generates credible intervals for \eqn{\hat{f}(x)} for a specified set of observations. 9 | } 10 | \usage{ 11 | calc_credible_intervals(bart_machine, new_data, 12 | ci_conf = 0.95) 13 | } 14 | %- maybe also 'usage' for other objects documented here. 15 | \arguments{ 16 | \item{bart_machine}{ 17 | An object of class ``bartMachine''. 18 | } 19 | \item{new_data}{ 20 | A data frame containing observations at which credible intervals for \eqn{\hat{f}(x)} are to be computed. 21 | } 22 | \item{ci_conf}{ 23 | Confidence level for the credible intervals. The default is 95\%. 24 | } 25 | } 26 | \details{ 27 | This interval is the appropriate quantiles based on the confidence level, \code{ci_conf}, of the predictions 28 | for each of the Gibbs samples post-burn in. 29 | } 30 | \value{ 31 | Returns a matrix of the lower and upper bounds of the credible intervals for each observation in \code{new_data}. 32 | } 33 | 34 | \author{ 35 | Adam Kapelner and Justin Bleich 36 | } 37 | \note{ 38 | This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}. 39 | } 40 | 41 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 42 | 43 | \seealso{ 44 | \code{\link{calc_prediction_intervals}}, \code{\link{bart_machine_get_posterior}} 45 | } 46 | \examples{ 47 | 48 | \dontrun{ 49 | #generate Friedman data 50 | set.seed(11) 51 | n = 200 52 | p = 5 53 | X = data.frame(matrix(runif(n * p), ncol = p)) 54 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 55 | 56 | ##build BART regression model 57 | bart_machine = bartMachine(X, y) 58 | 59 | #get credible interval 60 | cred_int = calc_credible_intervals(bart_machine, X) 61 | print(head(cred_int)) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /bartMachine/man/calc_prediction_intervals.Rd: -------------------------------------------------------------------------------- 1 | \name{calc_prediction_intervals} 2 | \alias{calc_prediction_intervals} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Calculate Prediction Intervals 6 | } 7 | \description{ 8 | Generates prediction intervals for \eqn{\hat{y}} for a specified set of observations. 9 | } 10 | \usage{ 11 | calc_prediction_intervals(bart_machine, new_data, 12 | pi_conf = 0.95, num_samples_per_data_point = 1000) 13 | } 14 | %- maybe also 'usage' for other objects documented here. 15 | \arguments{ 16 | \item{bart_machine}{ 17 | An object of class ``bartMachine''. 18 | } 19 | \item{new_data}{ 20 | A data frame containing observations at which prediction intervals for \eqn{\hat{y}} are to be computed. 21 | } 22 | \item{pi_conf}{ 23 | Confidence level for the prediction intervals. The default is 95\%. 24 | } 25 | \item{num_samples_per_data_point}{ 26 | The number of samples taken from the predictive distribution. The default is 1000. 27 | } 28 | } 29 | \details{ 30 | Credible intervals (see \code{\link{calc_credible_intervals}}) are the appropriate quantiles of the prediction 31 | for each of the Gibbs samples post-burn in. Prediction intervals also make use of the noise estimate at each Gibbs 32 | sample and hence are wider. For each Gibbs sample, we record the \eqn{\hat{y}} estimate of the response and the 33 | \eqn{\hat{\sigma^2}} estimate of the noise variance. We then sample \code{normal_samples_per_gibbs_sample} times 34 | from a \eqn{N(\hat{y}, \hat{\sigma^2})} random variable to simulate many possible disturbances for that Gibbs sample. 35 | Then, all \code{normal_samples_per_gibbs_sample} times the number of Gibbs sample post burn-in are collected and the 36 | appropriate quantiles are taken based on the confidence level, \code{pi_conf}. 37 | } 38 | \value{ 39 | Returns a matrix of the lower and upper bounds of the prediction intervals for each observation in \code{new_data}. 40 | } 41 | \references{ 42 | Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning 43 | with Bayesian Additive Regression Trees. Journal of Statistical 44 | Software, 70(4), 1-40. doi:10.18637/jss.v070.i04 45 | } 46 | \author{ 47 | Adam Kapelner and Justin Bleich 48 | } 49 | \note{ 50 | This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}. 51 | } 52 | 53 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 54 | 55 | \seealso{ 56 | \code{\link{calc_credible_intervals}}, \code{\link{bart_machine_get_posterior}} 57 | } 58 | \examples{ 59 | \dontrun{ 60 | #generate Friedman data 61 | set.seed(11) 62 | n = 200 63 | p = 5 64 | X = data.frame(matrix(runif(n * p), ncol = p)) 65 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 66 | 67 | ##build BART regression model 68 | bart_machine = bartMachine(X, y) 69 | 70 | #get prediction interval 71 | pred_int = calc_prediction_intervals(bart_machine, X) 72 | print(head(pred_int)) 73 | } 74 | 75 | } 76 | -------------------------------------------------------------------------------- /bartMachine/man/check_bart_error_assumptions.Rd: -------------------------------------------------------------------------------- 1 | \name{check_bart_error_assumptions} 2 | \alias{check_bart_error_assumptions} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Check BART Error Assumptions 6 | } 7 | \description{ 8 | Diagnostic tools to assess whether the errors of the BART model for regression are normally distributed and homoskedastic, as assumed by the model. This function generates a normal quantile plot of the residuals with a Shapiro-Wilks p-value as well as a residual plot. 9 | } 10 | \usage{ 11 | check_bart_error_assumptions(bart_machine, hetero_plot = "yhats") 12 | } 13 | %- maybe also 'usage' for other objects documented here. 14 | \arguments{ 15 | \item{bart_machine}{ 16 | An object of class ``bartMachine''. 17 | } 18 | 19 | \item{hetero_plot}{ 20 | If ``yhats'', the residuals are plotted against the fitted values of the response. If ``ys'', the residuals are plotted against the actual values of the response. 21 | } 22 | } 23 | 24 | \value{ 25 | None. 26 | } 27 | 28 | \author{ 29 | Adam Kapelner and Justin Bleich 30 | } 31 | 32 | 33 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 34 | 35 | \seealso{ 36 | \code{\link{plot_convergence_diagnostics}} 37 | } 38 | \examples{ 39 | \dontrun{ 40 | #generate Friedman data 41 | set.seed(11) 42 | n = 300 43 | p = 5 44 | X = data.frame(matrix(runif(n * p), ncol = p)) 45 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 46 | 47 | ##build BART regression model 48 | bart_machine = bartMachine(X, y) 49 | 50 | #check error diagnostics 51 | check_bart_error_assumptions(bart_machine) 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /bartMachine/man/cov_importance_test.Rd: -------------------------------------------------------------------------------- 1 | \name{cov_importance_test} 2 | \alias{cov_importance_test} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Importance Test for Covariate(s) of Interest 6 | } 7 | \description{ 8 | This function tests the null hypothesis \eqn{H_0}: These covariates of interest 9 | do not affect the response under the assumptions of the BART 10 | model. 11 | } 12 | \usage{ 13 | cov_importance_test(bart_machine, covariates = NULL, 14 | num_permutation_samples = 100, plot = TRUE) 15 | } 16 | %- maybe also 'usage' for other objects documented here. 17 | \arguments{ 18 | \item{bart_machine}{ 19 | An object of class ``bart_machine''. 20 | } 21 | \item{covariates}{ 22 | A vector of names of covariates of interest to be tested for having an effect on the response. A value of NULL 23 | indicates an omnibus test for all covariates having an effect on the response. If the name of a covariate is a factor, 24 | the entire factor will be permuted. We do not recommend entering the names of factor covariate dummies. 25 | } 26 | \item{num_permutation_samples}{ 27 | The number of times to permute the covariates of interest and create a corresponding new BART model (see details). 28 | } 29 | \item{plot}{ 30 | If \code{TRUE}, this produces a histogram of the Pseudo-Rsq's / total misclassifcation error rates from 31 | the \code{num_permutations} BART models created with the \code{covariates} permuted. The plot also illustrates 32 | the observed Pseudo-Rsq's / total misclassifcation error rate from the original training data and indicates 33 | the test's p-value. 34 | } 35 | } 36 | \details{ 37 | To test the importance of a covariate or a set of covariates of interest on the response, this function generates 38 | \code{num_permutations} BART models with the covariate(s) of interest permuted (differently each time). 39 | On each run, a measure of fit is recorded. For regression, the metric is Pseudo-Rsq; for classification, it is 40 | total misclassification error.\cr A 41 | p-value can then be generated as follows. For regression, the p-value is the number of 42 | permutation-sampled Pseudo-Rsq's greater than the observed Pseudo-Rsq divided by 43 | \code{num_permutations + 1}. For classification, the p-value is the number of permutation-sampled 44 | total misclassification errors less than the observed total misclassification error divided by \code{num_permutations + 1}. 45 | } 46 | \value{ 47 | \item{permutation_samples_of_error}{A vector which records the error metric of the BART models with the covariates permuted (see details).} 48 | \item{observed_error_estimate}{For regression, this is the Pseudo-Rsq on the original 49 | training data set. For classification, this is the observed total misclassification error 50 | on the original training data set.} 51 | \item{pval}{The approximate p-value for this test (see details). 52 | } 53 | } 54 | \references{ 55 | Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning 56 | with Bayesian Additive Regression Trees. Journal of Statistical 57 | Software, 70(4), 1-40. doi:10.18637/jss.v070.i04 58 | } 59 | \author{ 60 | Adam Kapelner and Justin Bleich 61 | } 62 | \note{ 63 | This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}. 64 | } 65 | 66 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 67 | 68 | 69 | \examples{ 70 | \dontrun{ 71 | ##regression example 72 | 73 | ##generate Friedman data 74 | set.seed(11) 75 | n = 200 76 | p = 5 77 | X = data.frame(matrix(runif(n * p), ncol = p)) 78 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 79 | 80 | ##build BART regression model 81 | bart_machine = bartMachine(X, y) 82 | 83 | ##now test if X[, 1] affects Y nonparametrically under the BART model assumptions 84 | cov_importance_test(bart_machine, covariates = c(1)) 85 | ## note the plot and the printed p-value 86 | 87 | } 88 | 89 | } 90 | -------------------------------------------------------------------------------- /bartMachine/man/destroy_bart_machine.Rd: -------------------------------------------------------------------------------- 1 | \name{destroy_bart_machine} 2 | \alias{destroy_bart_machine} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Destroy BART Model (deprecated --- do not use!) 6 | } 7 | \description{ 8 | A deprecated function that previously was responsible for cleaning up the RAM 9 | associated with a BART model. This is now handled natively by R's garbage collection. 10 | } 11 | \usage{ 12 | destroy_bart_machine(bart_machine) 13 | } 14 | %- maybe also 'usage' for other objects documented here. 15 | \arguments{ 16 | \item{bart_machine}{ 17 | deprecated --- do not use! 18 | } 19 | } 20 | \details{ 21 | Removing a ``bart_machine'' object from \code{R} previously did not free heap space from Java. 22 | Since BART objects can consume a large amount of RAM, it is important to remove 23 | these objects by calling this function if they are no longer needed or many BART 24 | objects are being created. This operation is now taken care of by R's garbage collection. 25 | This function is deprecated and should not be used. However, running it is harmless. 26 | } 27 | \value{ 28 | None. 29 | } 30 | 31 | \author{ 32 | Adam Kapelner and Justin Bleich 33 | } 34 | 35 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 36 | 37 | 38 | \examples{ 39 | ##None 40 | } -------------------------------------------------------------------------------- /bartMachine/man/dummify_data.Rd: -------------------------------------------------------------------------------- 1 | \name{dummify_data} 2 | \alias{dummify_data} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Dummify Design Matrix 6 | } 7 | \description{ 8 | Create a data frame with factors converted to dummies. 9 | } 10 | \usage{ 11 | dummify_data(data) 12 | } 13 | %- maybe also 'usage' for other objects documented here. 14 | \arguments{ 15 | \item{data}{ 16 | Data frame to be dummified. 17 | } 18 | } 19 | \details{ 20 | The column names of the dummy variables are given by the ``FactorName_LevelName'' and are augmented to the end of the design matrix. See the example below. 21 | } 22 | \value{ 23 | Returns a data frame with factors converted to dummy indicator variables. 24 | } 25 | 26 | \author{ 27 | Adam Kapelner and Justin Bleich 28 | } 29 | \note{ 30 | BART handles dummification internally. This function is provided as a utility function. 31 | } 32 | 33 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 34 | 35 | 36 | \examples{ 37 | \dontrun{ 38 | #generate data 39 | set.seed(11) 40 | x1 = rnorm(20) 41 | x2 = as.factor(ifelse(x1 > 0, "A", "B")) 42 | x3 = runif(20) 43 | X = data.frame(x1,x2,x3) 44 | #dummify data 45 | X_dummified = dummify_data(X) 46 | print(X_dummified) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /bartMachine/man/extract_raw_node_data.Rd: -------------------------------------------------------------------------------- 1 | \name{extract_raw_node_data} 2 | \alias{extract_raw_node_data} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Gets Raw Node data 6 | } 7 | \description{ 8 | Returns a list object that contains all the information for all trees in a given Gibbs sample. Daughter nodes are nested 9 | in the list structure recursively. 10 | } 11 | \usage{ 12 | extract_raw_node_data(bart_machine, g = 1) 13 | } 14 | %- maybe also 'usage' for other objects documented here. 15 | \arguments{ 16 | \item{bart_machine}{ 17 | An object of class ``bartMachine''. 18 | } 19 | \item{g}{ 20 | The gibbs sample number. It must be a natural number between 1 and the number of iterations after burn in. Default is 1. 21 | } 22 | } 23 | 24 | \value{ 25 | Returns a list object that contains all the information for all trees in a given Gibbs sample. 26 | } 27 | 28 | \examples{ 29 | \dontrun{ 30 | options(java.parameters = "-Xmx10g") 31 | pacman::p_load(bartMachine) 32 | 33 | seed = 1984 34 | set.seed(seed) 35 | n = 100 36 | x = rnorm(n, 0, 1) 37 | sigma = 0.1 38 | y = x + rnorm(n, 0, sigma) 39 | 40 | num_trees = 200 41 | num_iterations_after_burn_in = 1000 42 | bart_mod = bartMachine(data.frame(x = x), y, 43 | flush_indices_to_save_RAM = FALSE, 44 | num_trees = num_trees, 45 | num_iterations_after_burn_in = num_iterations_after_burn_in, 46 | seed = seed) 47 | 48 | raw_node_data = extract_raw_node_data(bart_mod) 49 | 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /bartMachine/man/get_projection_weights.Rd: -------------------------------------------------------------------------------- 1 | \name{get_projection_weights} 2 | \alias{get_projection_weights} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Gets Training Sample Projection / Weights 6 | } 7 | \description{ 8 | Returns the matrix H where yhat is approximately equal to H y where yhat is the predicted values for \code{new_data}. If \code{new_data} is unspecified, yhat will be the in-sample fits. 9 | If BART was the same as OLS, H would be an orthogonal projection matrix. Here it is a projection matrix, but clearly non-orthogonal. Unfortunately, I cannot get 10 | this function to work correctly because of three possible reasons (1) BART does not work by averaging tree predictions: it is a sum of trees model where each tree sees the residuals 11 | via backfitting (2) the prediction in each node is a bayesian posterior draw which is close to ybar of the observations contained in the node if noise is gauged to be small and 12 | (3) there are transformations of the original y variable. I believe I got close and I think I'm off by a constant multiple which is a function of the number of trees. I can 13 | use regression to estimate the constant multiple and correct for it. Turn \code{regression_kludge} to \code{TRUE} for this. Note that the weights do not add up to one here. 14 | The intuition is because due to the backfitting there is multiple counting. But I'm not entirely sure. 15 | } 16 | \usage{ 17 | get_projection_weights(bart_machine, new_data = NULL, regression_kludge = FALSE) 18 | } 19 | %- maybe also 'usage' for other objects documented here. 20 | \arguments{ 21 | \item{bart_machine}{ 22 | An object of class ``bartMachine''. 23 | } 24 | \item{new_data}{ 25 | Data that you wish to investigate the training sample projection / weights. If \code{NULL}, the original training data is used. 26 | } 27 | \item{regression_kludge}{ 28 | See explanation in the description. Default is \code{FALSE}. 29 | } 30 | } 31 | 32 | \value{ 33 | Returns a matrix of proportions with number of rows equal to the number of rows of \code{new_data} and number of columns equal to the number of rows of the original training data, n. 34 | } 35 | 36 | \examples{ 37 | \dontrun{ 38 | options(java.parameters = "-Xmx10g") 39 | pacman::p_load(bartMachine, tidyverse) 40 | 41 | seed = 1984 42 | set.seed(seed) 43 | n = 100 44 | x = rnorm(n, 0, 1) 45 | sigma = 0.1 46 | y = x + rnorm(n, 0, sigma) 47 | 48 | num_trees = 200 49 | num_iterations_after_burn_in = 1000 50 | bart_mod = bartMachine(data.frame(x = x), y, 51 | flush_indices_to_save_RAM = FALSE, 52 | num_trees = num_trees, 53 | num_iterations_after_burn_in = num_iterations_after_burn_in, 54 | seed = seed) 55 | bart_mod 56 | 57 | n_star = 100 58 | x_star = rnorm(n_star) 59 | y_star = as.numeric(x_star + rnorm(n_star, 0, sigma)) 60 | yhat_star_bart = predict(bart_mod, data.frame(x = x_star)) 61 | 62 | Hstar = get_projection_weights(bart_mod, data.frame(x = x_star)) 63 | rowSums(Hstar) 64 | yhat_star_projection = as.numeric(Hstar %*% y) 65 | 66 | ggplot(data.frame( 67 | yhat_star = yhat_star_bart, 68 | yhat_star_projection = yhat_star_projection, 69 | y_star = y_star)) + 70 | geom_point(aes(x = yhat_star_bart, y = yhat_star_projection), col = "green") + 71 | geom_abline(slope = 1, intercept = 0) 72 | 73 | Hstar = get_projection_weights(bart_mod, data.frame(x = x_star), regression_kludge = TRUE) 74 | rowSums(Hstar) 75 | yhat_star_projection = as.numeric(Hstar %*% y) 76 | 77 | ggplot(data.frame( 78 | yhat_star = yhat_star_bart, 79 | yhat_star_projection = yhat_star_projection, 80 | y_star = y_star)) + 81 | geom_point(aes(x = yhat_star_bart, y = yhat_star_projection), col = "green") + 82 | geom_abline(slope = 1, intercept = 0) 83 | 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /bartMachine/man/get_sigsqs.Rd: -------------------------------------------------------------------------------- 1 | \name{get_sigsqs} 2 | \alias{get_sigsqs} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Get Posterior Error Variance Estimates 6 | } 7 | \description{ 8 | Returns the posterior estimates of the error variance from the Gibbs samples with an option to create a histogram of the posterior estimates of the error variance with a credible interval overlaid. 9 | } 10 | \usage{ 11 | get_sigsqs(bart_machine, after_burn_in = T, 12 | plot_hist = F, plot_CI = .95, plot_sigma = F) 13 | } 14 | %- maybe also 'usage' for other objects documented here. 15 | \arguments{ 16 | \item{bart_machine}{ 17 | An object of class ``bartMachine''. 18 | } 19 | \item{after_burn_in}{ 20 | If TRUE, only the \eqn{\sigma^2} draws after the burn-in period are returned. 21 | } 22 | \item{plot_hist}{ 23 | If TRUE, a histogram of the posterior \eqn{\sigma^2} draws is generated. 24 | } 25 | \item{plot_CI}{ 26 | Confidence level for credible interval on histogram. 27 | } 28 | \item{plot_sigma}{ 29 | If TRUE, plots \eqn{\sigma} instead of \eqn{\sigma^2}. 30 | } 31 | } 32 | 33 | \value{ 34 | Returns a vector of posterior \eqn{\sigma^2} draws (with or without the burn-in samples). 35 | } 36 | 37 | \author{ 38 | Adam Kapelner and Justin Bleich 39 | } 40 | 41 | 42 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 43 | 44 | \seealso{ 45 | \code{\link{get_sigsqs}} 46 | } 47 | \examples{ 48 | \dontrun{ 49 | #generate Friedman data 50 | set.seed(11) 51 | n = 300 52 | p = 5 53 | X = data.frame(matrix(runif(n * p), ncol = p)) 54 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 55 | 56 | ##build BART regression model 57 | bart_machine = bartMachine(X, y) 58 | 59 | #get posterior sigma^2's after burn-in and plot 60 | sigsqs = get_sigsqs(bart_machine, plot_hist = TRUE) 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /bartMachine/man/get_var_counts_over_chain.Rd: -------------------------------------------------------------------------------- 1 | \name{get_var_counts_over_chain} 2 | \alias{get_var_counts_over_chain} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Get the Variable Inclusion Counts 6 | } 7 | \description{ 8 | Computes the variable inclusion counts for a BART model. 9 | } 10 | \usage{ 11 | get_var_counts_over_chain(bart_machine, type = "splits") 12 | } 13 | %- maybe also 'usage' for other objects documented here. 14 | \arguments{ 15 | \item{bart_machine}{ 16 | An object of class ``bartMachine''. 17 | } 18 | \item{type}{ 19 | If ``splits'', then the number of times each variable is chosen for a splitting rule is computed. If ``trees'', then the number of times each variable appears in a tree is computed. 20 | } 21 | } 22 | 23 | \value{ 24 | Returns a matrix of counts of each predictor across all trees by Gibbs sample. Thus, the dimension is \code{num_interations_after_burn_in} 25 | by \code{p} (where \code{p} is the number of predictors after dummifying factors and adding missingness dummies if specified by \code{use_missing_data_dummies_as_covars}). 26 | } 27 | 28 | \author{ 29 | Adam Kapelner and Justin Bleich 30 | } 31 | 32 | 33 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 34 | 35 | \seealso{ 36 | \code{\link{get_var_props_over_chain}} 37 | } 38 | \examples{ 39 | \dontrun{ 40 | 41 | #generate Friedman data 42 | set.seed(11) 43 | n = 200 44 | p = 10 45 | X = data.frame(matrix(runif(n * p), ncol = p)) 46 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 47 | 48 | ##build BART regression model 49 | bart_machine = bartMachine(X, y, num_trees = 20) 50 | 51 | #get variable inclusion counts 52 | var_counts = get_var_counts_over_chain(bart_machine) 53 | print(var_counts) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /bartMachine/man/get_var_props_over_chain.Rd: -------------------------------------------------------------------------------- 1 | \name{get_var_props_over_chain} 2 | \alias{get_var_props_over_chain} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Get the Variable Inclusion Proportions 6 | } 7 | \description{ 8 | Computes the variable inclusion proportions for a BART model. 9 | } 10 | \usage{ 11 | get_var_props_over_chain(bart_machine, type = "splits") 12 | } 13 | %- maybe also 'usage' for other objects documented here. 14 | \arguments{ 15 | \item{bart_machine}{ 16 | An object of class ``bartMachine''. 17 | } 18 | \item{type}{ 19 | If ``splits'', then the proportion of times each variable is chosen for a splitting rule versus all splitting rules is computed. If ``trees'', then the proportion of times each variable appears in a tree versus all appearances of variables in trees is computed. 20 | } 21 | } 22 | 23 | \value{ 24 | Returns a vector of the variable inclusion proportions. 25 | } 26 | 27 | \author{ 28 | Adam Kapelner and Justin Bleich 29 | } 30 | 31 | 32 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 33 | 34 | \seealso{ 35 | \code{\link{get_var_counts_over_chain}} 36 | } 37 | \examples{ 38 | \dontrun{ 39 | #generate Friedman data 40 | set.seed(11) 41 | n = 200 42 | p = 10 43 | X = data.frame(matrix(runif(n * p), ncol = p)) 44 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 45 | 46 | ##build BART regression model 47 | bart_machine = bartMachine(X, y, num_trees = 20) 48 | 49 | #Get variable inclusion proportions 50 | var_props = get_var_props_over_chain(bart_machine) 51 | print(var_props) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /bartMachine/man/interaction_investigator.Rd: -------------------------------------------------------------------------------- 1 | \name{interaction_investigator} 2 | \alias{interaction_investigator} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Explore Pairwise Interactions in BART Model 6 | } 7 | \description{ 8 | Explore the pairwise interaction counts for a BART model to learn about interactions fit by the model. This function includes an option to generate a plot of the pairwise interaction counts. 9 | } 10 | \usage{ 11 | interaction_investigator(bart_machine, plot = TRUE, 12 | num_replicates_for_avg = 5, num_trees_bottleneck = 20, 13 | num_var_plot = 50, cut_bottom = NULL, bottom_margin = 10) 14 | } 15 | %- maybe also 'usage' for other objects documented here. 16 | \arguments{ 17 | \item{bart_machine}{ 18 | An object of class ``bartMachine''. 19 | } 20 | \item{plot}{ 21 | If TRUE, a plot of the pairwise interaction counts is generated. 22 | } 23 | \item{num_replicates_for_avg}{ 24 | The number of replicates of BART to be used to generate pairwise interaction inclusion counts. 25 | Averaging across multiple BART models improves stability of the estimates. 26 | } 27 | \item{num_trees_bottleneck}{ 28 | Number of trees to be used in the sum-of-trees model for computing pairwise interactions counts. 29 | A small number of trees should be used to force the variables to compete for entry into the model. 30 | } 31 | \item{num_var_plot}{ 32 | Number of variables to be shown on the plot. If ``Inf,'' all variables are plotted (not recommended if 33 | the number of predictors is large). Default is 50. 34 | } 35 | \item{cut_bottom}{ 36 | A display parameter between 0 and 1 that controls where the y-axis is plotted. A value of 0 would begin the y-axis at 0; a value of 1 begins 37 | the y-axis at the minimum of the average pairwise interaction inclusion count (the smallest bar in the bar plot). Values between 0 and 1 begin the 38 | y-axis as a percentage of that minimum. 39 | } 40 | \item{bottom_margin}{ 41 | A display parameter that adjusts the bottom margin of the graph if labels are clipped. The scale of this parameter is the same as set with \code{par(mar = c(....))} in R. 42 | Higher values allow for more space if the crossed covariate names are long. Note that making this parameter too large will prevent plotting and the plot function in R will throw an error. 43 | } 44 | } 45 | \details{ 46 | An interaction between two variables is considered to occur whenever a path from any node of a tree to 47 | any of its terminal node contains splits using those two variables. See Kapelner and Bleich, 2013, Section 4.11. 48 | } 49 | \value{ 50 | \item{interaction_counts}{For each of the \eqn{p \times p}{p times p} interactions, what is the count across all \code{num_replicates_for_avg} 51 | BART model replicates' post burn-in Gibbs samples in all trees.} 52 | \item{interaction_counts_avg}{For each of the \eqn{p \times p}{p times p} interactions, what is the average count across all \code{num_replicates_for_avg} 53 | BART model replicates' post burn-in Gibbs samples in all trees.} 54 | \item{interaction_counts_sd}{For each of the \eqn{p \times p}{p times p} interactions, what is the sd of the interaction counts across the \code{num_replicates_for_avg} 55 | BART models replicates.} 56 | \item{interaction_counts_avg_and_sd_long}{For each of the \eqn{p \times p}{p times p} interactions, what is the average and sd of the interaction counts across the \code{num_replicates_for_avg} 57 | BART models replicates. The output is organized as a convenient long table of class \code{data.frame}.} 58 | } 59 | \references{ 60 | Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning 61 | with Bayesian Additive Regression Trees. Journal of Statistical 62 | Software, 70(4), 1-40. doi:10.18637/jss.v070.i04 63 | } 64 | \author{ 65 | Adam Kapelner and Justin Bleich 66 | } 67 | \note{ 68 | In the plot, the red bars correspond to the standard error of the variable inclusion proportion estimates (since multiple replicates were used). 69 | } 70 | 71 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 72 | 73 | \seealso{ 74 | \code{\link{investigate_var_importance}} 75 | } 76 | \examples{ 77 | \dontrun{ 78 | #generate Friedman data 79 | set.seed(11) 80 | n = 200 81 | p = 10 82 | X = data.frame(matrix(runif(n * p), ncol = p)) 83 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 84 | 85 | ##build BART regression model 86 | bart_machine = bartMachine(X, y, num_trees = 20) 87 | 88 | #investigate interactions 89 | interaction_investigator(bart_machine) 90 | } 91 | 92 | } 93 | -------------------------------------------------------------------------------- /bartMachine/man/investigate_var_importance.Rd: -------------------------------------------------------------------------------- 1 | \name{investigate_var_importance} 2 | \alias{investigate_var_importance} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Explore Variable Inclusion Proportions in BART Model 6 | } 7 | \description{ 8 | Explore the variable inclusion proportions for a BART model to learn about the relative influence of the different covariates. This function includes an option to generate a plot of the variable inclusion proportions. 9 | } 10 | \usage{ 11 | investigate_var_importance(bart_machine, type = "splits", 12 | plot = TRUE, num_replicates_for_avg = 5, num_trees_bottleneck = 20, 13 | num_var_plot = Inf, bottom_margin = 10) 14 | } 15 | %- maybe also 'usage' for other objects documented here. 16 | \arguments{ 17 | \item{bart_machine}{ 18 | An object of class ``bartMachine''. 19 | } 20 | \item{type}{ 21 | If ``splits'', then the proportion of times each variable is chosen for a splitting rule is computed. If ``trees'', then the proportion of times each variable appears in a tree is computed. 22 | } 23 | \item{plot}{ 24 | If TRUE, a plot of the variable inclusion proportions is generated. 25 | } 26 | \item{num_replicates_for_avg}{ 27 | The number of replicates of BART to be used to generate variable inclusion proportions. Averaging across multiple BART models improves stability of the estimates. See Bleich et al. (2013) for more details. 28 | } 29 | \item{num_trees_bottleneck}{ 30 | Number of trees to be used in the sum-of-trees for computing the variable inclusion proportions. A small number of trees should be used to force the variables to compete for entry into the model. Chipman et al. (2010) recommend 20. See this reference for more details. 31 | } 32 | \item{num_var_plot}{ 33 | Number of variables to be shown on the plot. If ``Inf'', all variables are plotted. 34 | } 35 | \item{bottom_margin}{ 36 | A display parameter that adjusts the bottom margin of the graph if labels are clipped. The scale of this parameter is the same as set with \code{par(mar = c(....))} in R. 37 | Higher values allow for more space if the covariate names are long. Note that making this parameter too large will prevent plotting and the plot function in R will throw an error. 38 | } 39 | } 40 | \details{ 41 | In the plot, the red bars correspond to the standard error of the variable inclusion proportion estimates. 42 | } 43 | \value{ 44 | Invisibly, returns a list with the following components: 45 | \item{avg_var_props}{The average variable inclusion proportions for each variable\cr (across \code{num_replicates_for_avg})} 46 | \item{sd_var_props}{The standard deviation of the variable inclusion proportions for each variable (across \code{num_replicates_for_avg})} 47 | 48 | } 49 | \references{ 50 | Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning 51 | with Bayesian Additive Regression Trees. Journal of Statistical 52 | Software, 70(4), 1-40. doi:10.18637/jss.v070.i04 53 | 54 | J Bleich, A Kapelner, ST Jensen, and EI George. Variable Selection Inference for Bayesian 55 | Additive Regression Trees. ArXiv e-prints, 2013. 56 | 57 | HA Chipman, EI George, and RE McCulloch. BART: Bayesian Additive Regressive Trees. 58 | The Annals of Applied Statistics, 4(1): 266--298, 2010. 59 | } 60 | \author{ 61 | Adam Kapelner and Justin Bleich 62 | } 63 | \note{ 64 | This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}. 65 | } 66 | 67 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 68 | 69 | \seealso{ 70 | \code{\link{interaction_investigator}} 71 | } 72 | \examples{ 73 | \dontrun{ 74 | #generate Friedman data 75 | set.seed(11) 76 | n = 200 77 | p = 10 78 | X = data.frame(matrix(runif(n * p), ncol = p)) 79 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 80 | 81 | ##build BART regression model 82 | bart_machine = bartMachine(X, y, num_trees = 20) 83 | 84 | #investigate variable inclusion proportions 85 | investigate_var_importance(bart_machine) 86 | } 87 | 88 | } 89 | -------------------------------------------------------------------------------- /bartMachine/man/k_fold_cv.Rd: -------------------------------------------------------------------------------- 1 | \name{k_fold_cv} 2 | \alias{k_fold_cv} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Estimate Out-of-sample Error with K-fold Cross validation 6 | } 7 | \description{ 8 | Builds a BART model using a specified set of arguments to \code{build_bart_machine} and estimates the out-of-sample performance by using k-fold cross validation. 9 | } 10 | \usage{ 11 | k_fold_cv(X, y, k_folds = 5, folds_vec = NULL, verbose = FALSE, ...) 12 | } 13 | %- maybe also 'usage' for other objects documented here. 14 | \arguments{ 15 | \item{X}{ 16 | Data frame of predictors. Factors are automatically converted to dummies interally. 17 | } 18 | \item{y}{ 19 | Vector of response variable. If \code{y} is \code{numeric} or \code{integer}, a BART model for regression is built. If \code{y} is a factor with two levels, a BART model for classification is built. 20 | } 21 | \item{k_folds}{ 22 | Number of folds to cross-validate over. This argument is ignored if \code{folds_vec} is non-null. 23 | } 24 | \item{folds_vec}{ 25 | An integer vector of indices specifying which fold each observation belongs to. 26 | } 27 | \item{verbose}{ 28 | Prints information about progress of the algorithm to the screen. 29 | } 30 | \item{\dots}{ 31 | Additional arguments to be passed to \code{build_bart_machine}. 32 | } 33 | } 34 | \details{ 35 | For each fold, a new BART model is trained (using the same set of arguments) and its performance is evaluated on the holdout piece of that fold. 36 | } 37 | \value{ 38 | For regression models, a list with the following components is returned: 39 | \item{y_hat}{Predictions for the observations computed on the fold for which the observation was omitted from the training set.} 40 | \item{L1_err}{Aggregate L1 error across the folds.} 41 | \item{L2_err}{Aggregate L1 error across the folds.} 42 | \item{rmse}{Aggregate RMSE across the folds.} 43 | \item{folds}{Vector of indices specifying which fold each observation belonged to.} 44 | 45 | For classification models, a list with the following components is returned: 46 | 47 | \item{y_hat}{Class predictions for the observations computed on the fold for which the observation was omitted from the training set.} 48 | \item{p_hat}{Probability estimates for the observations computed on the fold for which the observation was omitted from the training set.} 49 | \item{confusion_matrix}{Aggregate confusion matrix across the folds.} 50 | \item{misclassification_error}{Total misclassification error across the folds.} 51 | \item{folds}{Vector of indices specifying which fold each observation belonged to.} 52 | } 53 | 54 | \author{ 55 | Adam Kapelner and Justin Bleich 56 | } 57 | \note{ 58 | This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}. 59 | } 60 | 61 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 62 | 63 | \seealso{ 64 | \code{\link{bartMachine}} 65 | } 66 | \examples{ 67 | \dontrun{ 68 | #generate Friedman data 69 | set.seed(11) 70 | n = 200 71 | p = 5 72 | X = data.frame(matrix(runif(n * p), ncol = p)) 73 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 74 | 75 | #evaluate default BART on 5 folds 76 | k_fold_val = k_fold_cv(X, y) 77 | print(k_fold_val$rmse) 78 | } 79 | 80 | } 81 | -------------------------------------------------------------------------------- /bartMachine/man/linearity_test.Rd: -------------------------------------------------------------------------------- 1 | \name{linearity_test} 2 | \alias{linearity_test} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Test of Linearity 6 | } 7 | \description{ 8 | Test to investigate \eqn{H_0:} the functional relationship between the response and the 9 | regressors is linear. We fit a linear model and then test if the residuals are a function 10 | of the regressors using the 11 | } 12 | \usage{ 13 | linearity_test(lin_mod = NULL, X = NULL, y = NULL, 14 | num_permutation_samples = 100, plot = TRUE, ...) 15 | } 16 | %- maybe also 'usage' for other objects documented here. 17 | \arguments{ 18 | \item{lin_mod}{ 19 | A linear model you can pass in if you do not want to use the default which is \code{lm(y ~ X)}. Default is \code{NULL} which should be used if you pass in \code{X} and \code{y}. 20 | } 21 | \item{X}{ 22 | Data frame of predictors. Factors are automatically converted to dummies internally. Default is \code{NULL} which should be used if you pass in \code{lin_mode}. 23 | } 24 | \item{y}{ 25 | Vector of response variable. If \code{y} is \code{numeric} or \code{integer}, a BART model for regression is built. If \code{y} is a factor with two levels, a BART model for classification is built. 26 | Default is \code{NULL} which should be used if you pass in \code{lin_mode}. 27 | } 28 | \item{num_permutation_samples}{ 29 | This function relies on \code{\link{cov_importance_test}} (see documentation there for details). 30 | } 31 | \item{plot}{ 32 | This function relies on \code{\link{cov_importance_test}} (see documentation there for details). 33 | } 34 | \item{...}{ 35 | Additional parameters to be passed to \code{bartMachine}, the model constructed on the residuals of the linear model. 36 | } 37 | } 38 | \value{ 39 | \item{permutation_samples_of_error}{ This function relies on \code{\link{cov_importance_test}} (see documentation there for details). 40 | } 41 | \item{observed_error_estimate}{ This function relies on \code{\link{cov_importance_test}} (see documentation there for details). 42 | } 43 | \item{pval}{The approximate p-value for this test. See the documentation at \code{\link{cov_importance_test}}. 44 | } 45 | } 46 | \author{ 47 | Adam Kapelner 48 | } 49 | 50 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 51 | \seealso{ 52 | \code{\link{cov_importance_test}} 53 | } 54 | 55 | \examples{ 56 | \dontrun{ 57 | ##regression example 58 | 59 | ##generate Friedman data i.e. a nonlinear response model 60 | set.seed(11) 61 | n = 200 62 | p = 5 63 | X = data.frame(matrix(runif(n * p), ncol = p)) 64 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 65 | 66 | ##now test if there is a nonlinear relationship between X1, ..., X5 and y. 67 | linearity_test(X = X, y = y) 68 | ## note the plot and the printed p-value.. should be approx 0 69 | 70 | #generate a linear response model 71 | y = 1 * X[ ,1] + 3 * X[,2] + 5 * X[,3] + 7 * X[ ,4] + 9 * X[,5] + rnorm(n) 72 | linearity_test(X = X, y = y) 73 | ## note the plot and the printed p-value.. should be > 0.05 74 | 75 | } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /bartMachine/man/node_prediction_training_data_indices.Rd: -------------------------------------------------------------------------------- 1 | \name{node_prediction_training_data_indices} 2 | \alias{node_prediction_training_data_indices} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Gets node predictions indices of the training data for new data. 6 | } 7 | \description{ 8 | This returns a binary tensor for all gibbs samples after burn-in for all trees and for all training observations. 9 | } 10 | \usage{ 11 | node_prediction_training_data_indices(bart_machine, new_data = NULL) 12 | } 13 | %- maybe also 'usage' for other objects documented here. 14 | \arguments{ 15 | \item{bart_machine}{ 16 | An object of class ``bartMachine''. 17 | } 18 | \item{new_data}{ 19 | Data that you wish to investigate the training sample weights. If \code{NULL}, the original training data is used. 20 | } 21 | } 22 | 23 | \value{ 24 | Returns a binary tensor indicating whether the prediction node contained a training datum or not. For each observation in new data, the size of this tensor is number of gibbs sample after burn-in 25 | times the number of trees times the number of training data observations. This the size of the full tensor is the number of observations in the new data times the three dimensional object just explained. 26 | } -------------------------------------------------------------------------------- /bartMachine/man/pd_plot.Rd: -------------------------------------------------------------------------------- 1 | \name{pd_plot} 2 | \alias{pd_plot} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Partial Dependence Plot 6 | } 7 | \description{ 8 | Creates a partial dependence plot for a BART model for regression or classification. 9 | } 10 | \usage{ 11 | pd_plot(bart_machine, j, 12 | levs = c(0.05, seq(from = 0.1, to = 0.9, by = 0.1), 0.95), 13 | lower_ci = 0.025, upper_ci = 0.975, prop_data = 1) 14 | } 15 | %- maybe also 'usage' for other objects documented here. 16 | \arguments{ 17 | \item{bart_machine}{ 18 | An object of class ``bartMachine''. 19 | } 20 | \item{j}{ 21 | The number or name of the column in the design matrix for which the partial dependence plot is to be created. 22 | } 23 | \item{levs}{ 24 | Quantiles at which the partial dependence function should be evaluated. Linear extrapolation is performed between these points. 25 | } 26 | \item{lower_ci}{ 27 | Lower limit for credible interval 28 | } 29 | \item{upper_ci}{ 30 | Upper limit for credible interval 31 | } 32 | \item{prop_data}{ 33 | The proportion of the training data to use. Default is 1. Use a lower proportion for speedier pd_plots. The closer to 1, the more resolution 34 | the PD plot will have; the closer to 0, the lower but faster. 35 | } 36 | } 37 | \details{ 38 | For regression models, the units on the y-axis are the same as the units of the response. For classification models, the units on the y-axis are probits. 39 | } 40 | \value{ 41 | Invisibly, returns a list with the following components: 42 | 43 | \item{x_j_quants}{Quantiles at which the partial dependence function is evaluated} 44 | \item{bart_avg_predictions_by_quantile_by_gibbs}{All samples of \eqn{\hat{f}(x)}} 45 | \item{bart_avg_predictions_by_quantile}{Posterior means for \eqn{\hat{f}(x)} at \code{x_j_quants}} 46 | \item{bart_avg_predictions_lower}{Lower bound of the desired confidence of the credible interval of \eqn{\hat{f}(x)}} 47 | \item{bart_avg_predictions_upper}{Upper bound of the desired confidence of the credible interval of \eqn{\hat{f}(x)}} 48 | \item{prop_data}{The proportion of the training data to use as specified when this function was executed} 49 | %% ... 50 | } 51 | \references{ 52 | Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning 53 | with Bayesian Additive Regression Trees. Journal of Statistical 54 | Software, 70(4), 1-40. doi:10.18637/jss.v070.i04 55 | 56 | HA Chipman, EI George, and RE McCulloch. BART: Bayesian Additive Regressive Trees. 57 | The Annals of Applied Statistics, 4(1): 266--298, 2010. 58 | } 59 | \author{ 60 | Adam Kapelner and Justin Bleich 61 | } 62 | \note{ 63 | This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}. 64 | } 65 | 66 | 67 | \examples{ 68 | \dontrun{ 69 | #Regression example 70 | 71 | #generate Friedman data 72 | set.seed(11) 73 | n = 200 74 | p = 5 75 | X = data.frame(matrix(runif(n * p), ncol = p)) 76 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 77 | 78 | ##build BART regression model 79 | bart_machine = bartMachine(X, y) 80 | 81 | #partial dependence plot for quadratic term 82 | pd_plot(bart_machine, "X3") 83 | 84 | 85 | #Classification example 86 | 87 | #get data and only use 2 factors 88 | data(iris) 89 | iris2 = iris[51:150,] 90 | iris2$Species = factor(iris2$Species) 91 | 92 | #build BART classification model 93 | bart_machine = bartMachine(iris2[ ,1:4], iris2$Species) 94 | 95 | #partial dependence plot 96 | pd_plot(bart_machine, "Petal.Width") 97 | } 98 | 99 | 100 | } 101 | -------------------------------------------------------------------------------- /bartMachine/man/plot_convergence_diagnostics.Rd: -------------------------------------------------------------------------------- 1 | \name{plot_convergence_diagnostics} 2 | \alias{plot_convergence_diagnostics} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Plot Convergence Diagnostics 6 | } 7 | \description{ 8 | A suite of plots to assess convergence diagonstics and features of the BART model. 9 | } 10 | \usage{ 11 | plot_convergence_diagnostics(bart_machine, 12 | plots = c("sigsqs", "mh_acceptance", "num_nodes", "tree_depths")) 13 | } 14 | %- maybe also 'usage' for other objects documented here. 15 | \arguments{ 16 | \item{bart_machine}{ 17 | An object of class ``bartMachine''. 18 | } 19 | \item{plots}{ 20 | The list of plots to be displayed. The four options are: "sigsqs", "mh_acceptance", "num_nodes", "tree_depths". 21 | } 22 | } 23 | \details{ 24 | The ``sigsqs'' option plots the posterior error variance estimates by the Gibbs sample number. This is a standard tool to assess convergence of MCMC algorithms. This option is not applicable to classification BART models.\cr 25 | The ``mh_acceptance'' option plots the proportion of Metropolis-Hastings steps accepted for each Gibbs sample (number accepted divided by number of trees).\cr 26 | The ``num_nodes'' option plots the average number of nodes across each tree in the sum-of-trees model by the Gibbs sample number (for post burn-in only). The blue line 27 | is the average number of nodes over all trees.\cr 28 | The ``tree_depths'' option plots the average tree depth across each tree in the sum-of-trees model by the Gibbs sample number (for post burn-in only). The blue line 29 | is the average number of nodes over all trees. 30 | } 31 | \value{ 32 | None. 33 | } 34 | 35 | \author{ 36 | Adam Kapelner and Justin Bleich 37 | } 38 | \note{ 39 | The ``sigsqs'' plot separates the burn-in \eqn{\sigma^2}'s for the first core by post burn-in \eqn{\sigma^2}'s estimates for all cores by grey vertical lines. 40 | The ``mh_acceptance'' plot separates burn-in from post-burn in by a grey vertical line. Post burn-in, the different core proportions plot in different colors. 41 | The ``num_nodes'' plot separates different core estimates by vertical lines (post burn-in only). 42 | The `tree_depths'' plot separates different core estimates by vertical lines (post burn-in only). 43 | } 44 | 45 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 46 | 47 | 48 | \examples{ 49 | \dontrun{ 50 | #generate Friedman data 51 | set.seed(11) 52 | n = 200 53 | p = 5 54 | X = data.frame(matrix(runif(n * p), ncol = p)) 55 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 56 | 57 | ##build BART regression model 58 | bart_machine = bartMachine(X, y) 59 | 60 | #plot convergence diagnostics 61 | plot_convergence_diagnostics(bart_machine) 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /bartMachine/man/plot_y_vs_yhat.Rd: -------------------------------------------------------------------------------- 1 | \name{plot_y_vs_yhat} 2 | \alias{plot_y_vs_yhat} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Plot the fitted Versus Actual Response 6 | } 7 | \description{ 8 | Generates a plot actual versus fitted values and corresponding credible intervals or prediction intervals for the fitted values. 9 | } 10 | \usage{ 11 | plot_y_vs_yhat(bart_machine, Xtest = NULL, ytest = NULL, 12 | credible_intervals = FALSE, prediction_intervals = FALSE, 13 | interval_confidence_level = 0.95) 14 | } 15 | %- maybe also 'usage' for other objects documented here. 16 | \arguments{ 17 | \item{bart_machine}{ 18 | An object of class ``bartMachine''. 19 | } 20 | \item{Xtest}{ 21 | Optional argument for test data. If included, BART computes fitted values at the rows of \code{Xtest}. Else, the fitted values from the training data are used. 22 | } 23 | \item{ytest}{ 24 | Optional argument for test data. Vector of observed values corresponding to the rows of \code{Xtest} to be plotted against the predictions for the rows of \code{Xtest}. 25 | } 26 | \item{credible_intervals}{ 27 | If TRUE, Bayesian credible intervals are computed using the quantiles of the posterior distribution of \eqn{\hat{f}(x)}. See \code{\link{calc_credible_intervals}} for details. 28 | } 29 | \item{prediction_intervals}{ 30 | If TRUE, Bayesian predictive intervals are computed using the a draw of from \eqn{\hat{f}(x)}. See \code{\link{calc_prediction_intervals}} for details. 31 | } 32 | \item{interval_confidence_level}{ 33 | Desired level of confidence for credible or prediction intervals. 34 | } 35 | } 36 | 37 | \value{ 38 | None. 39 | } 40 | 41 | \author{ 42 | Adam Kapelner and Justin Bleich 43 | } 44 | \note{ 45 | This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}. 46 | } 47 | 48 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 49 | 50 | \seealso{ 51 | \code{\link{bart_machine_get_posterior}}, \code{\link{calc_credible_intervals}}, \code{\link{calc_prediction_intervals}} 52 | } 53 | 54 | \examples{ 55 | \dontrun{ 56 | #generate linear data 57 | set.seed(11) 58 | n = 500 59 | p = 3 60 | X = data.frame(matrix(runif(n * p), ncol = p)) 61 | y = 3*X[ ,1] + 2*X[ ,2] +X[ ,3] + rnorm(n) 62 | 63 | ##build BART regression model 64 | bart_machine = bartMachine(X, y) 65 | 66 | ##generate plot 67 | plot_y_vs_yhat(bart_machine) 68 | 69 | #generate plot with prediction bands 70 | plot_y_vs_yhat(bart_machine, prediction_intervals = TRUE) 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /bartMachine/man/predict.bartMachine.Rd: -------------------------------------------------------------------------------- 1 | \name{predict.bartMachine} 2 | \alias{predict.bartMachine} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Make a prediction on data using a BART object 6 | } 7 | \description{ 8 | Makes a prediction on new data given a fitted BART model for regression or classification. 9 | } 10 | \usage{ 11 | \method{predict}{bartMachine}(object, new_data, type = "prob", prob_rule_class = NULL, verbose = TRUE, ...) 12 | } 13 | %- maybe also 'usage' for other objects documented here. 14 | \arguments{ 15 | \item{object}{ 16 | An object of class ``bartMachine''. 17 | } 18 | \item{new_data}{ 19 | A data frame where each row is an observation to predict. The column names 20 | should be the same as the column names of the training data. 21 | } 22 | \item{type}{ 23 | Only relevant if the bartMachine model is classification. The type can be ``prob'' which will 24 | return the estimate of \eqn{P(Y = 1)}(the ``positive'' class) or ``class'' which will return the best guess as to the 25 | class of the object, in the original label, based on if the probability estimate is greater 26 | than \code{prob_rule_class}. Default is ``prob.'' 27 | } 28 | \item{prob_rule_class}{ 29 | The rule to determine when the class estimate is \eqn{Y = 1} (the ``positive'' class) based on the probability estimate. This 30 | defaults to what was originally specified in the \code{bart_machine} object. 31 | } 32 | \item{verbose}{ 33 | Prints out prediction-related messages. Currently in use only for probability predictions to let the user know which class 34 | is being predicted. Default is \code{TRUE}. 35 | } 36 | \item{...}{ 37 | Parameters that are ignored. 38 | } 39 | } 40 | 41 | \value{ 42 | If regression, a numeric vector of \code{y_hat}, the best guess as to the response. If classification and \code{type = ``prob''}, 43 | a numeric vector of \code{p_hat}, the best guess as to the probability of the response class being the ''positive'' class. If classification and 44 | \code{type = ''class''}, a character vector of the best guess of the response's class labels. 45 | } 46 | 47 | \author{ 48 | Adam Kapelner and Justin Bleich 49 | } 50 | 51 | 52 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 53 | 54 | \seealso{ 55 | \code{\link{bart_predict_for_test_data}} 56 | } 57 | \examples{ 58 | #Regression example 59 | \dontrun{ 60 | #generate Friedman data 61 | set.seed(11) 62 | n = 200 63 | p = 5 64 | X = data.frame(matrix(runif(n * p), ncol = p)) 65 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 66 | 67 | ##build BART regression model 68 | bart_machine = bartMachine(X, y) 69 | 70 | ##make predictions on the training data 71 | y_hat = predict(bart_machine, X) 72 | 73 | #Classification example 74 | data(iris) 75 | iris2 = iris[51 : 150, ] #do not include the third type of flower for this example 76 | iris2$Species = factor(iris2$Species) 77 | bart_machine = bartMachine(iris2[ ,1:4], iris2$Species) 78 | 79 | ##make probability predictions on the training data 80 | p_hat = predict(bart_machine, X) 81 | 82 | ##make class predictions on test data 83 | y_hat_class = predict(bart_machine, X, type = "class") 84 | 85 | ##make class predictions on test data conservatively for ''versicolor'' 86 | y_hat_class_conservative = predict(bart_machine, X, type = "class", prob_rule_class = 0.9) 87 | } 88 | 89 | 90 | } 91 | -------------------------------------------------------------------------------- /bartMachine/man/predict_bartMachineArr.Rd: -------------------------------------------------------------------------------- 1 | \name{predict_bartMachineArr} 2 | \alias{predict_bartMachineArr} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Make a prediction on data using a BART array object 6 | } 7 | \description{ 8 | Makes a prediction on new data given an array of fitted BART model for 9 | regression or classification. If BART creates models that are variable, 10 | running many and averaging is a good strategy. It is well known that the 11 | Gibbs sampler gets locked into local modes at times. This is a way 12 | to average over many chains. 13 | } 14 | \usage{ 15 | predict_bartMachineArr(object, new_data, ...) 16 | } 17 | %- maybe also 'usage' for other objects documented here. 18 | \arguments{ 19 | \item{object}{ 20 | An object of class ``bartMachineArr''. 21 | } 22 | \item{new_data}{ 23 | A data frame where each row is an observation to predict. The column names 24 | should be the same as the column names of the training data. 25 | } 26 | \item{...}{ 27 | Not supported. Note that parameters \code{type} and \code{prob_rule_class} for 28 | \code{\link{predict.bartMachine}} are not supported. 29 | } 30 | } 31 | 32 | \value{ 33 | If regression, a numeric vector of \code{y_hat}, the best guess as to the response. If classification and \code{type = ``prob''}, 34 | a numeric vector of \code{p_hat}, the best guess as to the probability of the response class being the ''positive'' class. If classification and 35 | \code{type = ''class''}, a character vector of the best guess of the response's class labels. 36 | } 37 | 38 | \author{ 39 | Adam Kapelner 40 | } 41 | 42 | 43 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 44 | 45 | \seealso{ 46 | \code{\link{predict.bartMachine}} 47 | } 48 | \examples{ 49 | #Regression example 50 | \dontrun{ 51 | #generate Friedman data 52 | set.seed(11) 53 | n = 200 54 | p = 5 55 | X = data.frame(matrix(runif(n * p), ncol = p)) 56 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 57 | 58 | ##build BART regression model 59 | bart_machine = bartMachine(X, y) 60 | bart_machine_arr = bartMachineArr(bart_machine) 61 | 62 | ##make predictions on the training data 63 | y_hat = predict(bart_machine_arr, X) 64 | 65 | #Classification example 66 | data(iris) 67 | iris2 = iris[51 : 150, ] #do not include the third type of flower for this example 68 | iris2$Species = factor(iris2$Species) 69 | bart_machine = bartMachine(iris2[ ,1:4], iris2$Species) 70 | bart_machine_arr = bartMachineArr(bart_machine) 71 | 72 | ##make probability predictions on the training data 73 | p_hat = predict_bartMachineArr(bart_machine_arr, iris2[ ,1:4]) 74 | } 75 | 76 | 77 | } 78 | -------------------------------------------------------------------------------- /bartMachine/man/print.bartMachine.Rd: -------------------------------------------------------------------------------- 1 | \name{print.bartMachine} 2 | \alias{print.bartMachine} 3 | \title{ 4 | Summarizes information about a \code{bartMachine} object. 5 | } 6 | \description{ 7 | This is an alias for the \code{\link{summary.bartMachine}} function. See description in that section. 8 | } 9 | \usage{ 10 | \method{print}{bartMachine}(x, ...) 11 | } 12 | %- maybe also 'usage' for other objects documented here. 13 | \arguments{ 14 | \item{x}{ 15 | An object of class ``bartMachine''. 16 | } 17 | \item{...}{ 18 | Parameters that are ignored. 19 | } 20 | } 21 | 22 | \value{ 23 | None. 24 | } 25 | 26 | \author{ 27 | Adam Kapelner and Justin Bleich 28 | } 29 | 30 | 31 | 32 | \examples{ 33 | 34 | \dontrun{ 35 | #Regression example 36 | 37 | #generate Friedman data 38 | set.seed(11) 39 | n = 200 40 | p = 5 41 | X = data.frame(matrix(runif(n * p), ncol = p)) 42 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 43 | 44 | ##build BART regression model 45 | bart_machine = bartMachine(X, y) 46 | 47 | ##print out details 48 | print(bart_machine) 49 | 50 | ##Also, the default print works too 51 | bart_machine 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /bartMachine/man/rmse_by_num_trees.Rd: -------------------------------------------------------------------------------- 1 | \name{rmse_by_num_trees} 2 | \alias{rmse_by_num_trees} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Assess the Out-of-sample RMSE by Number of Trees 6 | } 7 | \description{ 8 | Assess out-of-sample RMSE of a BART model for varying numbers of trees in the sum-of-trees model. 9 | } 10 | \usage{ 11 | rmse_by_num_trees(bart_machine, tree_list = c(5, seq(10, 50, 10), 100, 150, 200), 12 | in_sample = FALSE, plot = TRUE, holdout_pctg = 0.3, num_replicates = 4, ...) 13 | } 14 | %- maybe also 'usage' for other objects documented here. 15 | \arguments{ 16 | \item{bart_machine}{ 17 | An object of class ``bartMachine''. 18 | } 19 | \item{tree_list}{ 20 | List of sizes for the sum-of-trees models. 21 | } 22 | \item{in_sample}{ 23 | If TRUE, the RMSE is computed on in-sample data rather than an out-of-sample holdout. 24 | } 25 | \item{plot}{ 26 | If TRUE, a plot of the RMSE by the number of trees in the ensemble is created. 27 | } 28 | \item{holdout_pctg}{ 29 | Percentage of the data to be treated as an out-of-sample holdout. 30 | } 31 | \item{num_replicates}{ 32 | Number of replicates to average the results over. Each replicate uses a randomly sampled holdout of the data, (which could have overlap). 33 | } 34 | \item{...}{ 35 | Other arguments to be passed to the plot function. 36 | } 37 | } 38 | 39 | \value{ 40 | Invisibly, returns the out-of-sample average RMSEs for each tree size. 41 | } 42 | 43 | \author{ 44 | Adam Kapelner and Justin Bleich 45 | } 46 | \note{ 47 | Since using a large number of trees can substantially increase computation time, this plot can help assess whether a smaller ensemble size is sufficient to obtain desirable predictive performance. 48 | This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}. 49 | } 50 | 51 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 52 | 53 | 54 | \examples{ 55 | \dontrun{ 56 | #generate Friedman data 57 | set.seed(11) 58 | n = 200 59 | p = 10 60 | X = data.frame(matrix(runif(n * p), ncol = p)) 61 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 62 | 63 | ##build BART regression model 64 | bart_machine = bartMachine(X, y, num_trees = 20) 65 | 66 | #explore RMSE by number of trees 67 | rmse_by_num_trees(bart_machine) 68 | } 69 | 70 | } 71 | -------------------------------------------------------------------------------- /bartMachine/man/set_bart_machine_num_cores.Rd: -------------------------------------------------------------------------------- 1 | \name{set_bart_machine_num_cores} 2 | \alias{set_bart_machine_num_cores} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Set the Number of Cores for BART 6 | } 7 | \description{ 8 | Sets the number of cores to be used for all parallelized BART functions. 9 | } 10 | \usage{ 11 | set_bart_machine_num_cores(num_cores) 12 | } 13 | %- maybe also 'usage' for other objects documented here. 14 | \arguments{ 15 | \item{num_cores}{ 16 | Number of cores to use. If the number of cores is more than 1, setting the seed during model construction 17 | cannot be deterministic. 18 | } 19 | } 20 | 21 | \value{ 22 | None. 23 | } 24 | 25 | \author{ 26 | Adam Kapelner and Justin Bleich 27 | } 28 | 29 | 30 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 31 | 32 | \seealso{ 33 | \code{\link{bart_machine_num_cores}} 34 | } 35 | \examples{ 36 | \dontrun{ 37 | #set all parallelized functions to use 4 cores 38 | set_bart_machine_num_cores(4) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /bartMachine/man/summary.bartMachine.Rd: -------------------------------------------------------------------------------- 1 | \name{summary.bartMachine} 2 | \alias{summary.bartMachine} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Summarizes information about a \code{bartMachine} object. 6 | } 7 | \description{ 8 | Provides a quick summary of the BART model. 9 | } 10 | \usage{ 11 | \method{summary}{bartMachine}(object, ...) 12 | } 13 | %- maybe also 'usage' for other objects documented here. 14 | \arguments{ 15 | \item{object}{ 16 | An object of class ``bartMachine''. 17 | } 18 | \item{...}{ 19 | Parameters that are ignored. 20 | } 21 | } 22 | \details{ 23 | Gives the version number of the \code{bartMachine} package used to build this \code{additiveBartMachine} object and if the object 24 | models either ``regression'' or ``classification.'' Gives the amount of training data and the dimension of feature space. Prints 25 | the amount of time it took to build the model, how many processor cores were used to during its construction, as well as the 26 | number of burn-in and posterior Gibbs samples were used. 27 | 28 | If the model is for regression, it prints the estimate of \eqn{\sigma^2} before the model was constructed as well as after so 29 | the user can inspect how much variance was explained. 30 | 31 | If the model was built using the \code{run_in_sample = TRUE} parameter in \code{\link{build_bart_machine}} and is for regression, the summary L1, 32 | L2, rmse, Pseudo-\eqn{R^2} are printed as well as the p-value for the tests of normality and zero-mean noise. If the model is for classification, a confusion matrix is printed. 33 | } 34 | \value{ 35 | None. 36 | } 37 | 38 | \author{ 39 | Adam Kapelner 40 | } 41 | 42 | \examples{ 43 | \dontrun{ 44 | #Regression example 45 | 46 | #generate Friedman data 47 | set.seed(11) 48 | n = 200 49 | p = 5 50 | X = data.frame(matrix(runif(n * p), ncol = p)) 51 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 52 | 53 | ##build BART regression model 54 | bart_machine = bartMachine(X, y) 55 | 56 | ##print out details 57 | summary(bart_machine) 58 | 59 | ##Also, the default print works too 60 | bart_machine 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /bartMachine/man/var_selection_by_permute_cv.Rd: -------------------------------------------------------------------------------- 1 | \name{var_selection_by_permute_cv} 2 | \alias{var_selection_by_permute_cv} 3 | %- Also NEED an '\alias' for EACH other topic documented here. 4 | \title{ 5 | Perform Variable Selection Using Cross-validation Procedure 6 | } 7 | \description{ 8 | Performs variable selection by cross-validating over the three threshold-based procedures outlined in Bleich et al. (2013) and selecting the single procedure that returns the lowest cross-validation RMSE. 9 | } 10 | \usage{ 11 | var_selection_by_permute_cv(bart_machine, k_folds = 5, folds_vec = NULL, 12 | num_reps_for_avg = 5, num_permute_samples = 100, 13 | num_trees_for_permute = 20, alpha = 0.05, num_trees_pred_cv = 50) 14 | } 15 | %- maybe also 'usage' for other objects documented here. 16 | \arguments{ 17 | \item{bart_machine}{ 18 | An object of class ``bartMachine''. 19 | } 20 | \item{k_folds}{ 21 | Number of folds to be used in cross-validation. 22 | } 23 | \item{folds_vec}{ 24 | An integer vector of indices specifying which fold each observation belongs to. 25 | } 26 | \item{num_reps_for_avg}{ 27 | Number of replicates to over over to for the BART model's variable inclusion proportions. 28 | } 29 | \item{num_permute_samples}{ 30 | Number of permutations of the response to be made to generate the ``null'' permutation distribution. 31 | } 32 | \item{num_trees_for_permute}{ 33 | Number of trees to use in the variable selection procedure. As with \cr \code{\link{investigate_var_importance}}, a small number of trees should be used to force variables to compete for entry into the model. Note that this number is used to estimate both the ``true'' and ``null'' variable inclusion proportions. 34 | } 35 | \item{alpha}{ 36 | Cut-off level for the thresholds. 37 | } 38 | \item{num_trees_pred_cv}{ 39 | Number of trees to use for prediction on the hold-out portion of each fold. Once variables have been selected using the training portion of each fold, a new model is built using only those variables with \code{num_trees_pred_cv} trees in the sum-of-trees model. Forecasts for the holdout sample are made using this model. A larger number of trees is recommended to exploit the full forecasting power of BART. 40 | } 41 | } 42 | \details{ 43 | See Bleich et al. (2013) for a complete description of the procedures outlined above as well as the corresponding vignette for a brief summary with examples. 44 | } 45 | \value{ 46 | Returns a list with the following components: 47 | 48 | \item{best_method}{The name of the best variable selection procedure, as chosen via cross-validation.} 49 | \item{important_vars_cv}{The variables chosen by the \code{best_method} above.} 50 | } 51 | \references{ 52 | J Bleich, A Kapelner, ST Jensen, and EI George. Variable Selection Inference for Bayesian 53 | Additive Regression Trees. ArXiv e-prints, 2013. 54 | 55 | Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning 56 | with Bayesian Additive Regression Trees. Journal of Statistical 57 | Software, 70(4), 1-40. doi:10.18637/jss.v070.i04 58 | } 59 | \author{ 60 | Adam Kapelner and Justin Bleich 61 | } 62 | \note{ 63 | This function can have substantial run-time. 64 | This function is parallelized by the number of cores set in \code{\link{set_bart_machine_num_cores}}. 65 | } 66 | 67 | %% ~Make other sections like Warning with \section{Warning }{....} ~ 68 | 69 | \seealso{ 70 | \code{\link{var_selection_by_permute}}, \code{\link{investigate_var_importance}} 71 | } 72 | \examples{ 73 | \dontrun{ 74 | #generate Friedman data 75 | set.seed(11) 76 | n = 150 77 | p = 100 ##95 useless predictors 78 | X = data.frame(matrix(runif(n * p), ncol = p)) 79 | y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n) 80 | 81 | ##build BART regression model (not actually used in variable selection) 82 | bart_machine = bartMachine(X, y) 83 | 84 | #variable selection via cross-validation 85 | var_sel_cv = var_selection_by_permute_cv(bart_machine, k_folds = 3) 86 | print(var_sel_cv$best_method) 87 | print(var_sel_cv$important_vars_cv) 88 | } 89 | 90 | } 91 | -------------------------------------------------------------------------------- /bartMachine/vignettes/bart_normality_heteroskedasticity_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/bart_normality_heteroskedasticity_2.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/convergence_diagnostics4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/convergence_diagnostics4.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/cov_test_body_style2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/cov_test_body_style2.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/cov_test_omnibus2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/cov_test_omnibus2.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/cov_test_top_10_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/cov_test_top_10_2.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/cov_test_width2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/cov_test_width2.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/covariate_test_age3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/covariate_test_age3.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/friedman_function_interactions2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/friedman_function_interactions2.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/glucose_partial_dependence2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/glucose_partial_dependence2.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/pdp_horsepower2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/pdp_horsepower2.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/pdp_stroke2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/pdp_stroke2.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/plot_y_vs_y_hat_cred_ints2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/plot_y_vs_y_hat_cred_ints2.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/plot_y_vs_y_hat_pred_ints2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/plot_y_vs_y_hat_pred_ints2.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/rmse_num_trees_3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/rmse_num_trees_3.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/speed_full4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/speed_full4.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/speed_zoomed4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/speed_zoomed4.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/var_imp_automobile_cc2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/var_imp_automobile_cc2.pdf -------------------------------------------------------------------------------- /bartMachine/vignettes/var_selection_plot2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachine/vignettes/var_selection_plot2.pdf -------------------------------------------------------------------------------- /bartMachineJARs/DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: bartMachineJARs 2 | Version: 1.2.1 3 | Title: bartMachine JARs 4 | Type: Package 5 | Date: 2022-09-13 6 | Author: Adam Kapelner and Justin Bleich (R package), see COPYRIGHTS file for the authors of the java libraries 7 | Maintainer: Adam Kapelner 8 | Description: These are bartMachine's Java dependency libraries. Note: this package has no functionality of its own and should not be installed as a standalone package without bartMachine. 9 | License: GPL-3 10 | Depends: R (>= 2.14.0), rJava (>= 0.9-8) 11 | SystemRequirements: Java (>=8.0) 12 | -------------------------------------------------------------------------------- /bartMachineJARs/NAMESPACE: -------------------------------------------------------------------------------- 1 | importFrom("rJava", ".jpackage") 2 | -------------------------------------------------------------------------------- /bartMachineJARs/R/onLoad.R: -------------------------------------------------------------------------------- 1 | .onLoad <- 2 | function(libname, pkgname) 3 | .jpackage(pkgname, lib.loc = libname) 4 | -------------------------------------------------------------------------------- /bartMachineJARs/inst/COPYRIGHTS: -------------------------------------------------------------------------------- 1 | The following open source Java libraries are included in this package: 2 | 3 | - Commons Math: The Apache Commons Mathematics Library under the Apache License 2.0, 4 | Version 2.1, March 2010, 5 | http://commons.apache.org/proper/commons-math/ 6 | 7 | - Trove under Lesser GNU Public License (LGPL) 2.1, 8 | Version 3.0.3, February 2012, 9 | http://trove.starlight-systems.com/ 10 | 11 | - Fastutil under Apache-2.0 license, 12 | Version 8.5.8, February 2022, 13 | https://fastutil.di.unimi.it/ 14 | 15 | License 16 | ======= 17 | 18 | You may obtain a copy of the Apache License, Version 2.0 at 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | 21 | You may obtain a copy of the Lesser GNU Public License (LGPL) 2.1 at 22 | http://www.gnu.org/licenses/lgpl-2.1.html -------------------------------------------------------------------------------- /bartMachineJARs/inst/java/commons-math-2.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachineJARs/inst/java/commons-math-2.1.jar -------------------------------------------------------------------------------- /bartMachineJARs/inst/java/fastutil-core-8.5.8.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachineJARs/inst/java/fastutil-core-8.5.8.jar -------------------------------------------------------------------------------- /bartMachineJARs/inst/java/trove-3.0.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bartMachineJARs/inst/java/trove-3.0.3.jar -------------------------------------------------------------------------------- /bartMachineJARs/java/README: -------------------------------------------------------------------------------- 1 | Code for these JARs can be found below: 2 | 3 | - Commons Math: The Apache Commons Mathematics Library Version 2.1, March 2010, 4 | http://archive.apache.org/dist/commons/math/source/commons-math-2.1-src.tar.gz 5 | 6 | - Trove Version 3.0.3, February 2012, 7 | https://bitbucket.org/robeden/trove/downloads/trove-3.0.3.tar.gz 8 | -------------------------------------------------------------------------------- /bart_package_paper/interaction_constraint_demo.R: -------------------------------------------------------------------------------- 1 | options(java.parameters = c("-Xmx20000m")) 2 | library(bartMachine) 3 | set_bart_machine_num_cores(10) 4 | 5 | 6 | cov_dgp = function(n, p){ 7 | data.frame(matrix(runif(n * p), ncol = p)) 8 | } 9 | 10 | response_function = function(X, sigma = 0.3){ 11 | # 10 * sin(pi * X[ ,1] * X[,2]) + 20 * (X[,3] -.5)^2 + 10 * X[, 4] + 5 * X[, 5] + rnorm(nrow(X), 0, 0.01) #Friedman 12 | # X[ ,1] * X[,2] + X[ ,3] * X[,4] + X[ ,5] + rnorm(nrow(X), 0, sigma) 13 | X[ ,1] + X[,2] + X[ ,3] + X[,4] + X[ ,5] + rnorm(nrow(X), 0, sigma) 14 | } 15 | 16 | SEED = 1984 17 | set.seed(SEED) 18 | ntrain = 500 19 | p = 5 20 | Xtrain = cov_dgp(ntrain, p) 21 | ytrain = response_function(Xtrain) 22 | ?bartMachine 23 | 24 | gam = bartMachine(Xtrain, ytrain, interaction_constraints = as.list(seq(1 : p))) 25 | additive_bart_machine = bartMachine(Xtrain, ytrain, interaction_constraints = list(c(1, 2), c(3, 4), 5)) 26 | bart_machine = bartMachine(Xtrain, ytrain) 27 | 28 | summary(gam) 29 | summary(additive_bart_machine) 30 | summary(bart_machine) 31 | 32 | 33 | ####now oos 34 | ntest = 1000 35 | Xtest = cov_dgp(ntest, p) 36 | 37 | y_hat_test_gam = predict(gam, Xtest) 38 | y_hat_test_additive = predict(additive_bart_machine, Xtest) 39 | y_hat_test = predict(bart_machine, Xtest) 40 | 41 | ytest = response_function(Xtest) 42 | sqrt(sum((y_hat_test_gam - ytest)^2) / ntest) 43 | sqrt(sum((y_hat_test_additive - ytest)^2) / ntest) 44 | sqrt(sum((y_hat_test - ytest)^2) / ntest) 45 | 46 | -------------------------------------------------------------------------------- /bart_package_paper/sec4.10-4.12.R: -------------------------------------------------------------------------------- 1 | options(java.parameters = "-Xmx2500m") 2 | library(bartMachine) 3 | set_bart_machine_num_cores(4) 4 | 5 | 6 | gen_friedman_data = function(n, p, sigma){ 7 | if (p < 5){ 8 | stop("p must be greater than or equal to 5") 9 | } 10 | X = matrix(runif(n * p), nrow = n, ncol = p) 11 | y = 10 * sin(pi * X[, 1] *X[, 2]) + 20 *(X[, 3] - .5)^2 + 10 * X[, 4] + 5 * X[, 5] + rnorm(n, 0, sigma) 12 | data.frame(y, X) 13 | } 14 | 15 | ##### section 4.10 16 | 17 | #set up prior 18 | p = 5 19 | p0 = 95 20 | prior = c(rep(5, times = p), rep(1, times = p0)) 21 | 22 | #make training and test data 23 | ntrain = 500 24 | sigma = 1 25 | fr_data = gen_friedman_data(ntrain, p + p0, sigma) 26 | y = fr_data$y 27 | X = fr_data[, 2 : (p + p0 + 1)] 28 | ntest = 500 29 | fr_data = gen_friedman_data(ntest, p + p0, sigma) 30 | Xtest = fr_data[, 2 : (p + p0 + 1)] 31 | ytest = fr_data$y 32 | 33 | #build uninformed and informed models 34 | bart_machine = bartMachine(X, y) 35 | bart_machine_informed = bartMachine(X, y, cov_prior_vec = prior, run_in_sample = FALSE) 36 | #test out of sample 37 | bart_predict_for_test_data(bart_machine, Xtest, ytest)$rmse 38 | bart_predict_for_test_data(bart_machine_informed, Xtest, ytest)$rmse 39 | 40 | ##### section 4.11 41 | 42 | fr_data = gen_friedman_data(500, 10, 1) 43 | y = fr_data$y 44 | X = fr_data[, 2 : 11] 45 | 46 | bart_machine = bartMachine(X, y) 47 | 48 | # Figure 10 49 | interaction_investigator(bart_machine, num_replicates_for_avg = 25, num_var_plot = 10, bottom_margin = 5) 50 | 51 | ##### section 4.12 52 | 53 | #bartMachine models can be saved and can persist across R sessions 54 | bart_machine = bartMachine(X, y, serialize = TRUE) 55 | save.image("bart_demo.RData") 56 | q("no") 57 | R 58 | options(java.parameters = "-Xmx7000m") 59 | library(bartMachine) 60 | load("bart_demo.RData") 61 | predict(bart_machine, X) 62 | 63 | #Demonstrate that serialiation can be very expensive 64 | options(java.parameters = "-Xmx7000m") 65 | library(bartMachine) 66 | fr_data = gen_friedman_data(4000, 1000, 1) 67 | y = fr_data$y 68 | X = fr_data[, 2 : 1001] 69 | bart_machine = bartMachine(X, y, serialize = TRUE, num_iterations_after_burn_in = 4000, num_trees = 100, run_in_sample = FALSE, mem_cache_for_speed = FALSE) 70 | save.image("bart_demo.RData") 71 | q("no") 72 | 73 | #demonstrate you cannot save a bartMachine model in an RData file 74 | #without using the serialize option 75 | options(java.parameters = "-Xmx6000m") 76 | library(bartMachine) 77 | bart_machine = bartMachine(X, y) 78 | save.image("bart_demo.RData") 79 | q("no") 80 | R 81 | options(java.parameters = "-Xmx6000m") 82 | library(bartMachine) 83 | load("bart_demo.RData") 84 | predict(bart_machine, X) 85 | -------------------------------------------------------------------------------- /bart_package_paper/sec5.R: -------------------------------------------------------------------------------- 1 | options(java.parameters = "-Xmx2500m") 2 | library(bartMachine) 3 | library(MASS) 4 | 5 | data(Pima.te) 6 | X = data.frame(Pima.te[, -8]) 7 | y = Pima.te[, 8] 8 | 9 | set_bart_machine_num_cores(4) 10 | 11 | bart_machine_cv = bartMachineCV(X, y) 12 | bart_machine_cv 13 | 14 | bart_machine = bartMachine(X, y, prob_rule_class = 0.3) 15 | bart_machine 16 | 17 | oos_stats = k_fold_cv(X, y, k_folds = 10) 18 | oos_stats$confusion_matrix 19 | 20 | round(predict(bart_machine_cv, X[1 : 2, ], type = "prob"), 3) 21 | predict(bart_machine_cv, X[1 : 2, ], type = "class") 22 | 23 | # Figure 11 24 | cov_importance_test(bart_machine_cv, covariates = c("age")) 25 | 26 | # Figure 12 27 | pd_plot(bart_machine_cv, j = "glu") 28 | 29 | round(calc_credible_intervals(bart_machine_cv, X[1 : 2, ]), 3) 30 | -------------------------------------------------------------------------------- /bart_package_paper/time_mat.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/bart_package_paper/time_mat.RData -------------------------------------------------------------------------------- /build.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | This script builds the bartMachine code into the bart_java.jar file which gets called from R in the R package 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /commons-math-2.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/commons-math-2.1.jar -------------------------------------------------------------------------------- /datasets/r_boston_tiny_with_missing.csv: -------------------------------------------------------------------------------- 1 | TOWNNO,LON,LAT,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,y 2 | 0,-70.955,42.255,0.00632,18,2.31,0,,6.575,65.2,4.09,1,296,15.3,396.9,4.98,24 3 | 1,-70.95,42.2875,0.02731,0,7.07,0,0.469,6.421,78.9,4.9671,2,242,17.8,396.9,9.14,21.6 4 | 1,-70.936,42.283,0.02729,,7.07,0,,7.185,61.1,,2,242,17.8,392.83,,34.7 5 | 2,-70.928,42.293,,,,0,,6.998,,,3,,18.7,,,33.4 6 | 2,,42.298,,0,,0,,7.147,54.2,,,222,18.7,396.9,,36.2 7 | 2,-70.9165,42.304,,0,,0,0.458,,58.7,,3,222,18.7,394.12,5.21,28.7 8 | 3,-70.936,42.297,,,7.87,,0.524,6.012,66.6,5.5605,5,311,,395.6,12.43,22.9 9 | 3,-70.9375,,,12.5,7.87,0,0.524,6.172,96.1,5.9505,5,311,15.2,,19.15,27.1 10 | ,-70.933,42.312,,,7.87,0,0.524,5.631,100,,5,311,15.2,,29.93,16.5 11 | 3,-70.929,42.316,0.17004,12.5,7.87,0,0.524,6.004,85.9,6.5921,5,311,15.2,,,18.9 12 | 3,-70.935,42.316,0.22489,12.5,7.87,0,0.524,,94.3,6.3467,5,311,15.2,392.52,20.45,15 13 | -------------------------------------------------------------------------------- /datasets/r_simple.csv: -------------------------------------------------------------------------------- 1 | x_1,y 2 | 8.857364673,9.603232396 3 | 17.75328233,18.82030956 4 | 29.79429595,30.93360107 5 | 30.33565907,30.71412372 6 | 48.92067772,50.44072685 7 | 54.95933525,56.7350879 8 | 59.54938415,61.04297103 9 | 62.27725963,62.50824406 10 | 89.61743563,91.14035798 11 | 98.82344913,101.1717397 12 | -------------------------------------------------------------------------------- /datasets/r_simpledata.csv: -------------------------------------------------------------------------------- 1 | 0,0,1 2 | 0,0,1 3 | 0,0,1 4 | 0,0,1 5 | 0,0,1 6 | 0,0,1 7 | 0,0,1 8 | 0,0,1 9 | 0,0,1 10 | 0,0,1 11 | 0,0,1 12 | 0,0,1 13 | 0,0,1 14 | 0,0,1 15 | 0,0,1 16 | 0,0,1 17 | 0,0,1 18 | 0,0,1 19 | 0,0,1 20 | 0,0,1 21 | 0,0,1 22 | 0,0,1 23 | 0,0,1 24 | 0,0,1 25 | 0,0,1 26 | 0,1,2 27 | 0,1,2 28 | 0,1,2 29 | 0,1,2 30 | 0,1,2 31 | 0,1,2 32 | 0,1,2 33 | 0,1,2 34 | 0,1,2 35 | 0,1,2 36 | 0,1,2 37 | 0,1,2 38 | 0,1,2 39 | 0,1,2 40 | 0,1,2 41 | 0,1,2 42 | 0,1,2 43 | 0,1,2 44 | 0,1,2 45 | 0,1,2 46 | 0,1,2 47 | 0,1,2 48 | 0,1,2 49 | 0,1,2 50 | 0,1,2 51 | 1,0,3 52 | 1,0,3 53 | 1,0,3 54 | 1,0,3 55 | 1,0,3 56 | 1,0,3 57 | 1,0,3 58 | 1,0,3 59 | 1,0,3 60 | 1,0,3 61 | 1,0,3 62 | 1,0,3 63 | 1,0,3 64 | 1,0,3 65 | 1,0,3 66 | 1,0,3 67 | 1,0,3 68 | 1,0,3 69 | 1,0,3 70 | 1,0,3 71 | 1,0,3 72 | 1,0,3 73 | 1,0,3 74 | 1,0,3 75 | 1,0,3 76 | 1,1,4 77 | 1,1,4 78 | 1,1,4 79 | 1,1,4 80 | 1,1,4 81 | 1,1,4 82 | 1,1,4 83 | 1,1,4 84 | 1,1,4 85 | 1,1,4 86 | 1,1,4 87 | 1,1,4 88 | 1,1,4 89 | 1,1,4 90 | 1,1,4 91 | 1,1,4 92 | 1,1,4 93 | 1,1,4 94 | 1,1,4 95 | 1,1,4 96 | 1,1,4 97 | 1,1,4 98 | 1,1,4 99 | 1,1,4 100 | 1,1,4 101 | -------------------------------------------------------------------------------- /datasets/r_stupiddata.csv: -------------------------------------------------------------------------------- 1 | 1,1,1,1 2 | 1,2,2,2 3 | 1,3,3,3 4 | 1,4,4,4 5 | 1,5,5,5 6 | 1,6,6,6 7 | 1,7,7,7 8 | 1,8,8,8 9 | 1,9,9,9 10 | 1,10,10,10 11 | 1,11,11,11 12 | 1,12,12,12 13 | 1,13,13,13 14 | 2,14,14,-1 15 | 2,15,15,-2 16 | 2,16,16,-3 17 | 2,17,17,-4 18 | 2,18,18,-5 19 | 2,19,19,-6 20 | 2,20,20,-7 21 | 2,21,21,-8 22 | 2,22,22,-9 23 | 2,23,23,-10 24 | 2,24,24,-11 25 | 2,25,25,-12 26 | 2,26,1,1 27 | 2,27,2,2 28 | 2,28,3,3 29 | 2,29,4,4 30 | 2,30,5,5 31 | 2,31,6,6 32 | 2,32,7,7 33 | 2,33,8,8 34 | 2,34,9,9 35 | 2,35,10,10 36 | 2,36,11,11 37 | 2,37,12,12 38 | 2,38,13,13 39 | 4,39,14,-1 40 | 4,40,15,-2 41 | 4,41,16,-3 42 | 4,42,17,-4 43 | 4,43,18,-5 44 | 4,44,19,-6 45 | 4,45,20,-7 46 | 4,46,21,-8 47 | 4,47,22,-9 48 | 4,48,23,-10 49 | 4,49,24,-11 50 | 4,50,25,-12 51 | -2000,51,1,1 52 | -2000,52,2,2 53 | -2000,53,3,3 54 | -2000,54,4,4 55 | -2000,55,5,5 56 | -2000,56,6,6 57 | -2000,57,7,7 58 | -2000,58,8,8 59 | -2000,59,9,9 60 | -2000,60,10,10 61 | -2000,61,11,11 62 | -2000,62,12,12 63 | -2000,63,13,13 64 | -4000,64,14,-1 65 | -4000,65,15,-2 66 | -4000,66,16,-3 67 | -4000,67,17,-4 68 | -4000,68,18,-5 69 | -4000,69,19,-6 70 | -4000,70,20,-7 71 | -4000,71,21,-8 72 | -4000,72,22,-9 73 | -4000,73,23,-10 74 | -4000,74,24,-11 75 | -4000,75,25,-12 76 | 100,76,1,1 77 | 100,77,2,2 78 | 100,78,3,3 79 | 100,79,4,4 80 | 100,80,5,5 81 | 100,81,6,6 82 | 100,82,7,7 83 | 100,83,8,8 84 | 100,84,9,9 85 | 100,85,10,10 86 | 100,86,11,11 87 | 100,87,12,12 88 | 100,88,13,13 89 | 200,89,14,-1 90 | 200,90,15,-2 91 | 200,91,16,-3 92 | 200,92,17,-4 93 | 200,93,18,-5 94 | 200,94,19,-6 95 | 200,95,20,-7 96 | 200,96,21,-8 97 | 200,97,22,-9 98 | 200,98,23,-10 99 | 200,99,24,-11 100 | 200,100,25,-12 101 | -------------------------------------------------------------------------------- /datasets/r_treemodel_low_n.csv: -------------------------------------------------------------------------------- 1 | "x_1","x_2","x_3","y" 2 | 31.5,3.1,25.7,48.8 3 | 92.9,79.8,43.9,50.3 4 | 6.7,56.8,6.3,9.7 5 | 52.6,38.7,5,51.9 6 | 2.3,99.2,65.3,28.2 7 | 28.5,57.6,99.4,30.8 8 | 88.8,61.6,8.8,49.6 9 | 15.9,38,49.3,30.7 10 | 17.7,87,13.4,30.7 11 | 53.2,4.9,12,50.9 12 | 39.2,1.2,26.6,51.8 13 | 36.8,78.5,25.6,49.1 14 | 52.6,43.7,12,50.6 15 | 38.9,53.3,72.7,50 16 | 50.9,72.9,54.1,51.6 17 | -------------------------------------------------------------------------------- /fastutil-core-8.5.8.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/fastutil-core-8.5.8.jar -------------------------------------------------------------------------------- /java_code_documentation/AlgorithmTesting/class-use/DataAnalysis.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class AlgorithmTesting.DataAnalysis 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
AlgorithmTesting.DataAnalysis

67 |
68 |
No usage of AlgorithmTesting.DataAnalysis
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /java_code_documentation/AlgorithmTesting/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | AlgorithmTesting 7 | 8 | 9 | 10 | 11 |

AlgorithmTesting

12 |
13 |

Classes

14 | 18 |
19 | 20 | 21 | -------------------------------------------------------------------------------- /java_code_documentation/AlgorithmTesting/package-use.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Package AlgorithmTesting 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Package
AlgorithmTesting

67 |
68 |
No usage of AlgorithmTesting
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /java_code_documentation/CustomLogging/class-use/StdOutErrLevel.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class CustomLogging.StdOutErrLevel 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
CustomLogging.StdOutErrLevel

67 |
68 |
No usage of CustomLogging.StdOutErrLevel
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /java_code_documentation/CustomLogging/class-use/SuperSimpleFormatter.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class CustomLogging.SuperSimpleFormatter 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
CustomLogging.SuperSimpleFormatter

67 |
68 |
No usage of CustomLogging.SuperSimpleFormatter
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /java_code_documentation/CustomLogging/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | CustomLogging 7 | 8 | 9 | 10 | 11 |

CustomLogging

12 |
13 |

Classes

14 | 19 |
20 | 21 | 22 | -------------------------------------------------------------------------------- /java_code_documentation/CustomLogging/package-use.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Package CustomLogging 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Package
CustomLogging

67 |
68 |
No usage of CustomLogging
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /java_code_documentation/OpenSourceExtensions/class-use/StatUtil.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class OpenSourceExtensions.StatUtil 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
OpenSourceExtensions.StatUtil

67 |
68 |
No usage of OpenSourceExtensions.StatUtil
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /java_code_documentation/OpenSourceExtensions/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | OpenSourceExtensions 7 | 8 | 9 | 10 | 11 |

OpenSourceExtensions

12 |
13 |

Classes

14 | 19 |
20 | 21 | 22 | -------------------------------------------------------------------------------- /java_code_documentation/allclasses-noframe.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | All Classes 7 | 8 | 9 | 10 | 11 |

All Classes

12 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /java_code_documentation/bartMachine/class-use/StatToolbox.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class bartMachine.StatToolbox 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
bartMachine.StatToolbox

67 |
68 |
No usage of bartMachine.StatToolbox
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /java_code_documentation/bartMachine/class-use/Tools.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class bartMachine.Tools 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
bartMachine.Tools

67 |
68 |
No usage of bartMachine.Tools
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /java_code_documentation/bartMachine/class-use/TreeArrayIllustration.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class bartMachine.TreeArrayIllustration 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
bartMachine.TreeArrayIllustration

67 |
68 |
No usage of bartMachine.TreeArrayIllustration
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /java_code_documentation/bartMachine/class-use/TreeIllustration.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class bartMachine.TreeIllustration 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
bartMachine.TreeIllustration

67 |
68 |
No usage of bartMachine.TreeIllustration
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /java_code_documentation/bartMachine/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | bartMachine 7 | 8 | 9 | 10 | 11 |

bartMachine

12 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /java_code_documentation/deprecated-list.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Deprecated List 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Deprecated API

67 |

Contents

68 |
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /java_code_documentation/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Generated Documentation (Untitled) 7 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | <noscript> 68 | <div>JavaScript is disabled on your browser.</div> 69 | </noscript> 70 | <h2>Frame Alert</h2> 71 | <p>This document is designed to be viewed using the frames feature. If you see this message, you are using a non-frame-capable web client. Link to <a href="overview-summary.html">Non-frame version</a>.</p> 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /java_code_documentation/overview-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Overview List 7 | 8 | 9 | 10 | 11 | 12 |
13 |

Packages

14 | 20 |
21 |

 

22 | 23 | 24 | -------------------------------------------------------------------------------- /java_code_documentation/package-list: -------------------------------------------------------------------------------- 1 | AlgorithmTesting 2 | CustomLogging 3 | OpenSourceExtensions 4 | bartMachine 5 | -------------------------------------------------------------------------------- /java_code_documentation/resources/background.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/java_code_documentation/resources/background.gif -------------------------------------------------------------------------------- /java_code_documentation/resources/tab.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/java_code_documentation/resources/tab.gif -------------------------------------------------------------------------------- /java_code_documentation/resources/titlebar.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/java_code_documentation/resources/titlebar.gif -------------------------------------------------------------------------------- /java_code_documentation/resources/titlebar_end.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/java_code_documentation/resources/titlebar_end.gif -------------------------------------------------------------------------------- /junit-4.10.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/junit-4.10.jar -------------------------------------------------------------------------------- /simple_example.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## try using built-in jcache 4 | options(java.parameters = "-Xmx1500m") 5 | library(bartMachine) 6 | n = 500 7 | x = 1 : n; y = x + rnorm(n) 8 | bart_machine = build_bart_machine(as.data.frame(x), y, serialize = TRUE) 9 | save.image("test_bart_machine.RData") 10 | q("no") 11 | #close R, open R 12 | R 13 | options(java.parameters = "-Xmx1500m") 14 | library(bartMachine) 15 | load("test_bart_machine.RData") 16 | x = 1 : n 17 | predict(bart_machine, as.data.frame(x)) 18 | 19 | 20 | ##how big is this stuff? 21 | data_names = names(bart_machine) 22 | sizes = matrix(NA, ncol = 1, nrow = length(data_names)) 23 | rownames(sizes) = data_names 24 | for (i in 1 : length(data_names)){ 25 | sizes[i, ] = object.size(bart_machine[[data_names[i]]]) / 1e6 26 | } 27 | t(t(sizes[order(-sizes), ])) 28 | q("no") 29 | 30 | ####ensure no more memory leak 31 | R 32 | options(java.parameters = "-Xmx1000m") 33 | library(bartMachine) 34 | x = 1 : 100; y = x + rnorm(100) 35 | for (i in 1 : 10000){ 36 | bart_machine = build_bart_machine(as.data.frame(x), y) 37 | } 38 | q("no") 39 | 40 | ## If it helps, this may 41 | 42 | #get some data 43 | R 44 | options(java.parameters = "-Xmx1000m") 45 | library(bartMachine) 46 | 47 | library(MASS) 48 | data(Boston) 49 | X = Boston 50 | y = X$medv 51 | X$medv = NULL 52 | 53 | Xtrain = X[1 : (nrow(X) / 2), ] 54 | ytrain = y[1 : (nrow(X) / 2)] 55 | Xtest = X[(nrow(X) / 2 + 1) : nrow(X), ] 56 | ytest = y[(nrow(X) / 2 + 1) : nrow(X)] 57 | 58 | set_bart_machine_num_cores(4) 59 | bart_machine = build_bart_machine(Xtrain, ytrain, 60 | num_trees = 200, 61 | num_burn_in = 300, 62 | num_iterations_after_burn_in = 1000, 63 | use_missing_data = TRUE, 64 | debug_log = TRUE, 65 | verbose = TRUE) 66 | bart_machine 67 | 68 | plot_y_vs_yhat(bart_machine) 69 | 70 | yhat = predict(bart_machine, Xtest) 71 | q("no") 72 | 73 | 74 | options(java.parameters = "-Xmx1500m") 75 | library(bartMachine) 76 | data("Pima.te", package = "MASS") 77 | X <- data.frame(Pima.te[, -8]) 78 | y <- Pima.te[, 8] 79 | bart_machine = bartMachine(X, y) 80 | bart_machine 81 | table(y, predict(bart_machine, X, type = "class")) 82 | 83 | raw_node_data = extract_raw_node_data(bart_machine, g = 37) 84 | raw_node_data[[17]] 85 | 86 | 87 | 88 | 89 | 90 | options(java.parameters = "-Xmx20000m") 91 | library(bartMachine) 92 | set_bart_machine_num_cores(10) 93 | set.seed(1) 94 | data(iris) 95 | iris2 = iris[51 : 150, ] #do not include the third type of flower for this example 96 | iris2$Species = factor(iris2$Species) 97 | X = iris2[ ,1:4] 98 | y = iris2$Species 99 | 100 | 101 | 102 | bart_machine = bartMachine(X, y, num_trees = 50, seed = 1) 103 | bart_machine 104 | ##make probability predictions on the training data 105 | p_hat = predict(bart_machine, iris2[ ,1:4]) 106 | p_hat -------------------------------------------------------------------------------- /src/AlgorithmTesting/DataSetupForCSVFile.java: -------------------------------------------------------------------------------- 1 | /* 2 | BART - Bayesian Additive Regressive Trees 3 | Software for Supervised Statistical Learning 4 | 5 | Copyright (C) 2012 Professor Ed George & Adam Kapelner, 6 | Dept of Statistics, The Wharton School of the University of Pennsylvania 7 | 8 | This program is free software; you can redistribute it and/or modify 9 | it under the terms of the GNU General Public License as published by 10 | the Free Software Foundation; either version 2 of the License, or 11 | (at your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, 14 | but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | GNU General Public License for more details: 17 | 18 | http://www.gnu.org/licenses/gpl-2.0.txt 19 | 20 | You should have received a copy of the GNU General Public License along 21 | with this program; if not, write to the Free Software Foundation, Inc., 22 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 23 | */ 24 | 25 | package AlgorithmTesting; 26 | 27 | import java.io.BufferedReader; 28 | import java.io.File; 29 | import java.io.FileReader; 30 | import java.io.IOException; 31 | import java.util.ArrayList; 32 | import java.util.HashSet; 33 | 34 | import bartMachine.Classifier; 35 | 36 | 37 | public class DataSetupForCSVFile { 38 | 39 | private final ArrayList X_y; 40 | private int K; 41 | private boolean file_has_header; 42 | // private int p; 43 | // private ArrayList feature_types; 44 | // private ArrayList feature_names; 45 | 46 | public DataSetupForCSVFile(File file, boolean header) { 47 | X_y = new ArrayList(); 48 | this.file_has_header = header; 49 | try { 50 | LoadDataIntoXyFormatAndFindFeatureNamesAndP(file); 51 | } catch (IOException e) { 52 | e.printStackTrace(); 53 | } 54 | extractNumClassesFromDataMatrix(); 55 | } 56 | 57 | private void extractNumClassesFromDataMatrix() { 58 | HashSet set_of_classes = new HashSet(); 59 | for (double[] record : X_y){ 60 | double k = record[record.length - 1]; 61 | // System.out.println("y = " + k); 62 | set_of_classes.add(k); //the last position is the response 63 | } 64 | K = set_of_classes.size(); 65 | // System.out.println("num classes: " + K); 66 | } 67 | 68 | private void LoadDataIntoXyFormatAndFindFeatureNamesAndP(File file) throws IOException{ 69 | //begin by iterating over the file 70 | BufferedReader in = new BufferedReader(new FileReader(file)); 71 | int line_num = 0; 72 | while (true){ 73 | String datum = in.readLine(); 74 | if (datum == null) { 75 | break; 76 | } 77 | String[] datums = datum.split(","); 78 | // p = datums.length - 1; 79 | 80 | 81 | if (line_num == 0 && file_has_header){ 82 | // feature_types = new ArrayList(p); 83 | // for (int i = 0; i < p; i++){ 84 | // feature_types.add(FeatureType.NUMBER); //default for now 85 | // } 86 | // feature_names = new ArrayList(p); 87 | // for (int i = 0; i < p; i++){ 88 | // feature_names.add(datums[i]); //default for now 89 | // System.out.println("feature " + (i + 1) + " " + datums[i]); 90 | // } 91 | } 92 | else { 93 | final double[] record = new double[datums.length]; 94 | for (int i = 0; i < datums.length; i++){ 95 | try { 96 | record[i] = Double.parseDouble(datums[i]); 97 | } catch(NumberFormatException e){ 98 | record[i] = Classifier.MISSING_VALUE; 99 | } 100 | } 101 | X_y.add(record); 102 | // System.out.println("record: " + Tools.StringJoin(record, ", ")); 103 | } 104 | line_num++; 105 | } 106 | in.close(); 107 | } 108 | 109 | 110 | public int getK() { 111 | return K; 112 | } 113 | 114 | 115 | public ArrayList getX_y() { 116 | return X_y; 117 | } 118 | 119 | } 120 | -------------------------------------------------------------------------------- /src/AlgorithmTesting/package-info.java: -------------------------------------------------------------------------------- 1 | package AlgorithmTesting; 2 | 3 | /** 4 | * This just contains files necessary for loading up datasets and running BART. 5 | * This should only be used for debugging since the preferred way of using this 6 | * BART implementation is via the R package bartMachine. Thus, the methods in 7 | * this package are not documented. 8 | */ 9 | -------------------------------------------------------------------------------- /src/CustomLogging/LoggingOutputStream.java: -------------------------------------------------------------------------------- 1 | package CustomLogging; 2 | 3 | import java.io.ByteArrayOutputStream; 4 | import java.io.IOException; 5 | import java.util.logging.Level; 6 | import java.util.logging.Logger; 7 | 8 | /** 9 | \* An OutputStream that writes contents to a Logger upon each call to flush() 10 | \*/ 11 | public class LoggingOutputStream extends ByteArrayOutputStream { 12 | 13 | private String lineSeparator; 14 | 15 | private Logger logger; 16 | private Level level; 17 | 18 | /** 19 | \* Constructor 20 | \* @param logger Logger to write to 21 | \* @param level Level at which to write the log message 22 | \*/ 23 | public LoggingOutputStream(Logger logger, Level level) { 24 | super(); 25 | this.logger = logger; 26 | this.level = level; 27 | lineSeparator = System.getProperty("line.separator"); 28 | } 29 | 30 | /** 31 | \* upon flush() write the existing contents of the OutputStream 32 | \* to the logger as a log record. 33 | \* @throws java.io.IOException in case of error 34 | \*/ 35 | public void flush() throws IOException { 36 | 37 | String record; 38 | synchronized(this) { 39 | super.flush(); 40 | record = this.toString(); 41 | super.reset(); 42 | 43 | if (record.length() == 0 || record.equals(lineSeparator)) { 44 | // avoid empty records 45 | return; 46 | } 47 | 48 | logger.logp(level, "", "", record); 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /src/CustomLogging/StdOutErrLevel.java: -------------------------------------------------------------------------------- 1 | package CustomLogging; 2 | 3 | import java.io.InvalidObjectException; 4 | import java.io.ObjectStreamException; 5 | import java.util.logging.Level; 6 | 7 | /** 8 | \* Class defining 2 new Logging levels, one for STDOUT, one for STDERR, 9 | \* used when multiplexing STDOUT and STDERR into the same rolling log file 10 | \* via the Java Logging APIs. 11 | \*/ 12 | public class StdOutErrLevel extends Level { 13 | private static final long serialVersionUID = -9122466300490214950L; 14 | 15 | /** 16 | \* Private constructor 17 | \*/ 18 | private StdOutErrLevel(String name, int value) { 19 | super(name, value); 20 | } 21 | /** 22 | \* Level for STDOUT activity. 23 | \*/ 24 | public static Level STDOUT = 25 | new StdOutErrLevel("STDOUT", Level.INFO.intValue()+53); 26 | /** 27 | \* Level for STDERR activity 28 | \*/ 29 | public static Level STDERR = 30 | new StdOutErrLevel("STDERR", Level.INFO.intValue()+54); 31 | 32 | /** 33 | \* Method to avoid creating duplicate instances when deserializing the 34 | \* object. 35 | \* @return the singleton instance of this Level value in this 36 | \* classloader 37 | \* @throws ObjectStreamException If unable to deserialize 38 | \*/ 39 | protected Object readResolve() 40 | throws ObjectStreamException { 41 | if (this.intValue() == STDOUT.intValue()) 42 | return STDOUT; 43 | if (this.intValue() == STDERR.intValue()) 44 | return STDERR; 45 | throw new InvalidObjectException("Unknown instance :" + this); 46 | } 47 | 48 | } -------------------------------------------------------------------------------- /src/CustomLogging/SuperSimpleFormatter.java: -------------------------------------------------------------------------------- 1 | package CustomLogging; 2 | 3 | import java.io.*; 4 | import java.util.logging.*; 5 | 6 | /** 7 | * Print a brief summary of the LogRecord in a human readable 8 | * format. The summary will typically be 1 or 2 lines. 9 | * 10 | * @version %I%, %G% 11 | * @since 1.4 12 | */ 13 | 14 | public class SuperSimpleFormatter extends Formatter { 15 | 16 | // Date dat = new Date(); 17 | // private final static String format = "{0,date} {0,time}"; 18 | // private MessageFormat formatter; 19 | // 20 | // private Object args[] = new Object[1]; 21 | 22 | // Line separator string. This is the value of the line.separator 23 | // property at the moment that the SimpleFormatter was created. 24 | private String lineSeparator = "\n"; 25 | 26 | /** 27 | * Format the given LogRecord. 28 | * @param record the log record to be formatted. 29 | * @return a formatted log record 30 | */ 31 | public synchronized String format(LogRecord record) { 32 | StringBuffer sb = new StringBuffer(); 33 | // Minimize memory allocations here. 34 | // dat.setTime(record.getMillis()); 35 | // args[0] = dat; 36 | // StringBuffer text = new StringBuffer(); 37 | // if (formatter == null) { 38 | // formatter = new MessageFormat(format); 39 | // } 40 | // formatter.format(args, text, null); 41 | // sb.append(text); 42 | // sb.append(" "); 43 | if (record.getSourceClassName() != null) { 44 | sb.append(record.getSourceClassName()); 45 | } else { 46 | sb.append(record.getLoggerName()); 47 | } 48 | if (record.getSourceMethodName() != null) { 49 | sb.append(" "); 50 | sb.append(record.getSourceMethodName()); 51 | } 52 | sb.append(lineSeparator); 53 | String message = formatMessage(record); 54 | if (record.getLevel().getLocalizedName().equals("STDERR")){ 55 | sb.append("ERROR: "); 56 | } 57 | sb.append(message); 58 | // sb.append(lineSeparator); 59 | if (record.getThrown() != null) { 60 | try { 61 | StringWriter sw = new StringWriter(); 62 | PrintWriter pw = new PrintWriter(sw); 63 | record.getThrown().printStackTrace(pw); 64 | pw.close(); 65 | sb.append(sw.toString()); 66 | } catch (Exception ex) { 67 | } 68 | } 69 | return sb.toString(); 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/CustomLogging/package-info.java: -------------------------------------------------------------------------------- 1 | /** 2 | * This is for debugging only, please ignore. Thus, 3 | * these methods are not documented. 4 | */ 5 | package CustomLogging; -------------------------------------------------------------------------------- /src/OpenSourceExtensions/UnorderedPair.java: -------------------------------------------------------------------------------- 1 | package OpenSourceExtensions; 2 | 3 | /** 4 | * An unordered immutable pair of elements. Unordered means the pair (a,b) equals the 5 | * pair (b,a). An element of a pair cannot be null. 6 | *

7 | * The class implements {@link Comparable}. Pairs are compared by their smallest elements, and if 8 | * those are equal, they are compared by their maximum elements. 9 | *

10 | * To work correctly, the underlying element type must implement {@link Object#equals} and 11 | * {@link Comparable#compareTo} in a consistent fashion. 12 | */ 13 | public final class UnorderedPair> implements Comparable> { 14 | private final E first; 15 | private final E second; 16 | 17 | /** 18 | * Creates an unordered pair of the specified elements. The order of the arguments is irrelevant, 19 | * so the first argument is not guaranteed to be returned by {@link #getFirst()}, for example. 20 | * @param a one element of the pair. Must not be null. 21 | * @param b one element of the pair. Must not be null. May be the same as a. 22 | */ 23 | public UnorderedPair(E a, E b) { 24 | if (a.compareTo(b) < 0) { 25 | this.first = a; 26 | this.second = b; 27 | } else { 28 | this.first = b; 29 | this.second = a; 30 | } 31 | } 32 | 33 | /** 34 | * Gets the smallest element of the pair (according to its {@link Comparable} implementation). 35 | * @return an element of the pair. null is never returned. 36 | */ 37 | public E getFirst() { 38 | return first; 39 | } 40 | 41 | /** 42 | * Gets the largest element of the pair (according to its {@link Comparable} implementation). 43 | * @return an element of the pair. null is never returned. 44 | */ 45 | public E getSecond() { 46 | return second; 47 | } 48 | 49 | @Override 50 | public int hashCode() { 51 | return 31 * first.hashCode() + 173 * second.hashCode(); 52 | } 53 | 54 | @Override 55 | public boolean equals(Object obj) { 56 | if (this == obj) 57 | return true; 58 | if (obj == null) 59 | return false; 60 | if (getClass() != obj.getClass()) 61 | return false; 62 | UnorderedPair other = (UnorderedPair) obj; 63 | if (!first.equals(other.first)) 64 | return false; 65 | if (!second.equals(other.second)) 66 | return false; 67 | return true; 68 | } 69 | 70 | public int compareTo(UnorderedPair o) { 71 | int firstCmp = first.compareTo(o.first); 72 | if (firstCmp != 0) 73 | return firstCmp; 74 | return second.compareTo(o.second); 75 | } 76 | 77 | @Override 78 | public String toString() { 79 | return "(" + first + "," + second + ")"; 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/OpenSourceExtensions/package-info.java: -------------------------------------------------------------------------------- 1 | /** 2 | * This is the code for BART that makes use of open source libraries that 3 | * were needed to be modified slightly. Only the code that is changed is 4 | * documented. 5 | */ 6 | package OpenSourceExtensions; -------------------------------------------------------------------------------- /src/bartMachine/bartMachineClassificationMultThread.java: -------------------------------------------------------------------------------- 1 | package bartMachine; 2 | 3 | import java.io.Serializable; 4 | 5 | import OpenSourceExtensions.StatUtil; 6 | 7 | /** 8 | * This class handles the parallelization of many Gibbs chains over many CPU cores 9 | * to create one BART regression model. It also handles all operations on the completed model. 10 | * @author Adam Kapelner and Justin Bleich 11 | * 12 | */ 13 | @SuppressWarnings("serial") 14 | public class bartMachineClassificationMultThread extends bartMachineRegressionMultThread implements Serializable{ 15 | 16 | /** The default value of the classification_rule */ 17 | private static double DEFAULT_CLASSIFICATION_RULE = 0.5; 18 | /** The value of the classification rule which if the probability estimate of Y = 1 is greater than, we predict 1 */ 19 | private double classification_rule; 20 | 21 | /** Set up an array of binary classification BARTs with length equal to num_cores, the number of CPU cores requested */ 22 | protected void SetupBARTModels() { 23 | bart_gibbs_chain_threads = new bartMachineClassification[num_cores]; 24 | for (int t = 0; t < num_cores; t++){ 25 | SetupBartModel(new bartMachineClassification(), t); 26 | } 27 | classification_rule = DEFAULT_CLASSIFICATION_RULE; 28 | } 29 | 30 | /** 31 | * Predicts the best guess of the class for an observation 32 | * 33 | * @param record The record who's class we wish to predict 34 | * @param num_cores_evaluate The number of CPU cores to use during this operation 35 | * @return The best guess of the class based on the probability estimate evaluated against the {@link classification_rule} 36 | */ 37 | public double Evaluate(double[] record, int num_cores_evaluate) { 38 | return EvaluateViaSampAvg(record, num_cores_evaluate) > classification_rule ? 1 : 0; 39 | } 40 | 41 | /** 42 | * This returns the Gibbs sample predictions for all trees and all posterior samples. 43 | * This differs from the parent implementation because we convert the response value to 44 | * a probability estimate using the normal CDF. 45 | * 46 | * @param data The data for which to generate predictions 47 | * @param num_cores_evaluate The number of CPU cores to use during this operation 48 | * @return The predictions as a vector of size number of posterior samples of vectors of size number of trees 49 | */ 50 | protected double[][] getGibbsSamplesForPrediction(double[][] data, int num_cores_evaluate){ 51 | double[][] y_gibbs_samples = super.getGibbsSamplesForPrediction(data, num_cores_evaluate); 52 | double[][] y_gibbs_samples_probs = new double[y_gibbs_samples.length][y_gibbs_samples[0].length]; 53 | for (int g = 0; g < y_gibbs_samples.length; g++){ 54 | for (int i = 0; i < y_gibbs_samples[0].length; i++){ 55 | y_gibbs_samples_probs[g][i] = StatUtil.normal_cdf(y_gibbs_samples[g][i]); 56 | } 57 | } 58 | return y_gibbs_samples_probs; 59 | } 60 | 61 | public void setClassificationRule(double classification_rule) { 62 | this.classification_rule = classification_rule; 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /src/bartMachine/bartMachineRegression.java: -------------------------------------------------------------------------------- 1 | package bartMachine; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * The class that is instantiated to build a Regression BART model 7 | * 8 | * @author Adam Kapelner and Justin Bleich 9 | * 10 | */ 11 | @SuppressWarnings("serial") 12 | public class bartMachineRegression extends bartMachine_i_prior_cov_spec implements Serializable{ 13 | 14 | /** 15 | * Constructs the BART classifier for regression. 16 | */ 17 | public bartMachineRegression() { 18 | super(); 19 | } 20 | 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/bartMachine/bartMachine_a_base.java: -------------------------------------------------------------------------------- 1 | package bartMachine; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * The base class for any BART implementation. Contains 7 | * mostly instance variables and settings for the algorithm 8 | * 9 | * @author Adam Kapelner and Justin Bleich 10 | */ 11 | @SuppressWarnings("serial") 12 | public abstract class bartMachine_a_base extends Classifier implements Serializable { 13 | 14 | /** all Gibbs samples for burn-in and post burn-in where each entry is a vector of pointers to the num_trees trees in the sum-of-trees model */ 15 | protected bartMachineTreeNode[][] gibbs_samples_of_bart_trees; 16 | /** Gibbs samples post burn-in where each entry is a vector of pointers to the num_trees trees in the sum-of-trees model */ 17 | protected bartMachineTreeNode[][] gibbs_samples_of_bart_trees_after_burn_in; 18 | /** Gibbs samples for burn-in and post burn-in of the variances */ 19 | protected double[] gibbs_samples_of_sigsq; 20 | /** Gibbs samples for post burn-in of the variances */ 21 | protected double[] gibbs_samples_of_sigsq_after_burn_in; 22 | /** a record of whether each Gibbs sample accepted or rejected the MH step within each of the num_trees trees */ 23 | protected boolean[][] accept_reject_mh; 24 | /** a record of the proposal of each Gibbs sample within each of the m trees: G, P or C for "grow", "prune", "change" */ 25 | protected char[][] accept_reject_mh_steps; 26 | 27 | /** the number of trees in our sum-of-trees model */ 28 | protected int num_trees; 29 | /** how many Gibbs samples we burn-in for */ 30 | protected int num_gibbs_burn_in; 31 | /** how many total Gibbs samples in a BART model creation */ 32 | protected int num_gibbs_total_iterations; 33 | 34 | /** the current thread being used to run this Gibbs sampler */ 35 | protected int threadNum; 36 | /** how many CPU cores to use during model creation */ 37 | protected int num_cores; 38 | /** 39 | * whether or not we use the memory cache feature 40 | * 41 | * @see Section 3.1 of Kapelner, A and Bleich, J. bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013 42 | */ 43 | protected boolean mem_cache_for_speed; 44 | /** saves indices in nodes (useful for computing weights) */ 45 | protected boolean flush_indices_to_save_ram; 46 | /** should we print stuff out to screen? */ 47 | protected boolean verbose = true; 48 | 49 | 50 | 51 | /** Remove unnecessary data from the Gibbs chain to conserve RAM */ 52 | protected void FlushData() { 53 | for (bartMachineTreeNode[] bart_trees : gibbs_samples_of_bart_trees){ 54 | FlushDataForSample(bart_trees); 55 | } 56 | } 57 | 58 | /** Remove unnecessary data from an individual Gibbs sample */ 59 | protected void FlushDataForSample(bartMachineTreeNode[] bart_trees) { 60 | for (bartMachineTreeNode tree : bart_trees){ 61 | tree.flushNodeData(); 62 | } 63 | } 64 | 65 | /** Must be implemented, but does nothing */ 66 | public void StopBuilding(){} 67 | 68 | public void setThreadNum(int threadNum) { 69 | this.threadNum = threadNum; 70 | } 71 | 72 | public void setVerbose(boolean verbose){ 73 | this.verbose = verbose; 74 | } 75 | 76 | public void setTotalNumThreads(int num_cores) { 77 | this.num_cores = num_cores; 78 | } 79 | 80 | public void setMemCacheForSpeed(boolean mem_cache_for_speed){ 81 | this.mem_cache_for_speed = mem_cache_for_speed; 82 | } 83 | 84 | public void setFlushIndicesToSaveRAM(boolean flush_indices_to_save_ram) { 85 | this.flush_indices_to_save_ram = flush_indices_to_save_ram; 86 | } 87 | 88 | public void setNumTrees(int m){ 89 | this.num_trees = m; 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/bartMachine/bartMachine_c_debug.java: -------------------------------------------------------------------------------- 1 | package bartMachine; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * This portion of the code used to have many debug functions. These have 7 | * been removed during the tidy up for release. 8 | * 9 | * @author Adam Kapelner and Justin Bleich 10 | */ 11 | @SuppressWarnings("serial") 12 | public abstract class bartMachine_c_debug extends bartMachine_b_hyperparams implements Serializable{ 13 | 14 | /** should we create illustrations of the trees and save the images to the debug directory? */ 15 | protected boolean tree_illust = false; 16 | 17 | /** the hook that gets called to save the tree illustrations when the Gibbs sampler begins */ 18 | protected void InitTreeIllustrations() { 19 | bartMachineTreeNode[] initial_trees = gibbs_samples_of_bart_trees[0]; 20 | TreeArrayIllustration tree_array_illustration = new TreeArrayIllustration(0, unique_name); 21 | 22 | for (bartMachineTreeNode tree : initial_trees){ 23 | tree_array_illustration.AddTree(tree); 24 | tree_array_illustration.addLikelihood(0); 25 | } 26 | tree_array_illustration.CreateIllustrationAndSaveImage(); 27 | } 28 | 29 | /** the hook that gets called to save the tree illustrations for each Gibbs sample */ 30 | protected void illustrate(TreeArrayIllustration tree_array_illustration) { 31 | if (tree_illust){ // 32 | tree_array_illustration.CreateIllustrationAndSaveImage(); 33 | } 34 | } 35 | 36 | /** 37 | * Get the untransformed samples of the sigsqs from the Gibbs chaing 38 | * 39 | * @return The vector of untransformed variances over all the Gibbs samples 40 | */ 41 | public double[] getGibbsSamplesSigsqs(){ 42 | double[] sigsqs_to_export = new double[gibbs_samples_of_sigsq.length]; 43 | for (int n_g = 0; n_g < gibbs_samples_of_sigsq.length; n_g++){ 44 | sigsqs_to_export[n_g] = un_transform_sigsq(gibbs_samples_of_sigsq[n_g]); 45 | } 46 | return sigsqs_to_export; 47 | } 48 | 49 | /** 50 | * Queries the depths of the num_trees trees between a range of Gibbs samples 51 | * 52 | * @param n_g_i The Gibbs sample number to start querying 53 | * @param n_g_f The Gibbs sample number (+1) to stop querying 54 | * @return The depths of all num_trees trees for each Gibbs sample specified 55 | */ 56 | public int[][] getDepthsForTrees(int n_g_i, int n_g_f){ 57 | int[][] all_depths = new int[n_g_f - n_g_i][num_trees]; 58 | for (int g = n_g_i; g < n_g_f; g++){ 59 | for (int t = 0; t < num_trees; t++){ 60 | all_depths[g - n_g_i][t] = gibbs_samples_of_bart_trees[g][t].deepestNode(); 61 | } 62 | } 63 | return all_depths; 64 | } 65 | 66 | /** 67 | * Queries the number of nodes (terminal and non-terminal) in the num_trees trees between a range of Gibbs samples 68 | * 69 | * @param n_g_i The Gibbs sample number to start querying 70 | * @param n_g_f The Gibbs sample number (+1) to stop querying 71 | * @return The number of nodes of all num_trees trees for each Gibbs sample specified 72 | */ 73 | public int[][] getNumNodesAndLeavesForTrees(int n_g_i, int n_g_f){ 74 | int[][] all_new_nodes = new int[n_g_f - n_g_i][num_trees]; 75 | for (int g = n_g_i; g < n_g_f; g++){ 76 | for (int t = 0; t < num_trees; t++){ 77 | all_new_nodes[g - n_g_i][t] = gibbs_samples_of_bart_trees[g][t].numNodesAndLeaves(); 78 | } 79 | } 80 | return all_new_nodes; 81 | } 82 | 83 | 84 | } 85 | -------------------------------------------------------------------------------- /src/bartMachine/bartMachine_d_init.java: -------------------------------------------------------------------------------- 1 | package bartMachine; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * This portion of the code initializes the Gibbs sampler 7 | * 8 | * @author Adam Kapelner and Justin Bleich 9 | */ 10 | @SuppressWarnings("serial") 11 | public abstract class bartMachine_d_init extends bartMachine_c_debug implements Serializable{ 12 | 13 | /** during debugging, we may want to fix sigsq */ 14 | protected transient double fixed_sigsq; 15 | /** the number of the current Gibbs sample */ 16 | protected int gibbs_sample_num; 17 | /** cached current sum of residuals vector */ 18 | protected transient double[] sum_resids_vec; 19 | 20 | /** Initializes the Gibbs sampler setting all zero entries and moves the counter to the first sample */ 21 | protected void SetupGibbsSampling(){ 22 | InitGibbsSamplingData(); 23 | InitizializeSigsq(); 24 | InitializeTrees(); 25 | InitializeMus(); 26 | if (tree_illust){ 27 | InitTreeIllustrations(); 28 | } 29 | //the zeroth gibbs sample is the initialization we just did; now we're onto the first in the chain 30 | gibbs_sample_num = 1; 31 | 32 | sum_resids_vec = new double[n]; 33 | } 34 | 35 | /** Initializes the vectors that hold information across the Gibbs sampler */ 36 | protected void InitGibbsSamplingData(){ 37 | //now initialize the gibbs sampler array for trees and error variances 38 | gibbs_samples_of_bart_trees = new bartMachineTreeNode[num_gibbs_total_iterations + 1][num_trees]; 39 | gibbs_samples_of_bart_trees_after_burn_in = new bartMachineTreeNode[num_gibbs_total_iterations - num_gibbs_burn_in + 1][num_trees]; 40 | gibbs_samples_of_sigsq = new double[num_gibbs_total_iterations + 1]; 41 | gibbs_samples_of_sigsq_after_burn_in = new double[num_gibbs_total_iterations - num_gibbs_burn_in]; 42 | 43 | accept_reject_mh = new boolean[num_gibbs_total_iterations + 1][num_trees]; 44 | accept_reject_mh_steps = new char[num_gibbs_total_iterations + 1][num_trees]; 45 | } 46 | 47 | /** Initializes the tree structures in the zeroth Gibbs sample to be merely stumps */ 48 | protected void InitializeTrees() { 49 | //create the array of trees for the zeroth gibbs sample 50 | bartMachineTreeNode[] bart_trees = new bartMachineTreeNode[num_trees]; 51 | for (int i = 0; i < num_trees; i++){ 52 | bartMachineTreeNode stump = new bartMachineTreeNode(this); 53 | stump.setStumpData(X_y, y_trans, p); 54 | bart_trees[i] = stump; 55 | } 56 | gibbs_samples_of_bart_trees[0] = bart_trees; 57 | } 58 | 59 | 60 | /** Initializes the leaf structure (the mean predictions) by setting them to zero (in the transformed scale, this is the center of the range) */ 61 | protected void InitializeMus() { 62 | for (bartMachineTreeNode stump : gibbs_samples_of_bart_trees[0]){ 63 | stump.y_pred = 0; 64 | } 65 | } 66 | 67 | /** Initializes the first variance value by drawing from the prior */ 68 | protected void InitizializeSigsq() { 69 | gibbs_samples_of_sigsq[0] = StatToolbox.sample_from_inv_gamma(hyper_nu / 2, 2 / (hyper_nu * hyper_lambda)); 70 | } 71 | 72 | /** this is the number of posterior Gibbs samples afte burn-in (thinning was never implemented) */ 73 | public int numSamplesAfterBurningAndThinning(){ 74 | return num_gibbs_total_iterations - num_gibbs_burn_in; 75 | } 76 | 77 | public void setNumGibbsBurnIn(int num_gibbs_burn_in){ 78 | this.num_gibbs_burn_in = num_gibbs_burn_in; 79 | } 80 | 81 | public void setNumGibbsTotalIterations(int num_gibbs_total_iterations){ 82 | this.num_gibbs_total_iterations = num_gibbs_total_iterations; 83 | } 84 | 85 | public void setSigsq(double fixed_sigsq){ 86 | this.fixed_sigsq = fixed_sigsq; 87 | } 88 | 89 | public boolean[][] getAcceptRejectMH(){ 90 | return accept_reject_mh; 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/bartMachine/bartMachine_i_prior_cov_spec.java: -------------------------------------------------------------------------------- 1 | package bartMachine; 2 | 3 | import java.io.Serializable; 4 | 5 | import gnu.trove.list.array.TIntArrayList; 6 | 7 | /** 8 | * This portion of the code implements the informed prior information on covariates feature. 9 | * 10 | * @author Adam Kapelner and Justin Bleich 11 | * @see Section 4.10 of Kapelner, A and Bleich, J. bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013 12 | */ 13 | @SuppressWarnings("serial") 14 | public class bartMachine_i_prior_cov_spec extends bartMachine_h_eval implements Serializable{ 15 | 16 | /** Do we use this feature in this BART model? */ 17 | protected boolean use_prior_cov_spec; 18 | /** This is a probability vector which is the prior on which covariates to split instead of the uniform discrete distribution by default */ 19 | protected double[] cov_split_prior; 20 | 21 | 22 | /** 23 | * Pick one predictor from a set of valid predictors that can be part of a split rule at a node 24 | * while accounting for the covariate prior. 25 | * 26 | * @param node The node of interest 27 | * @return The index of the column to split on 28 | */ 29 | private int pickRandomPredictorThatCanBeAssignedF1(bartMachineTreeNode node){ 30 | TIntArrayList predictors = node.predictorsThatCouldBeUsedToSplitAtNode(); 31 | //get probs of split prior based on predictors that can be used and weight it accordingly 32 | double[] weighted_cov_split_prior_subset = getWeightedCovSplitPriorSubset(predictors); 33 | //choose predictor based on random prior value 34 | return StatToolbox.multinomial_sample(predictors, weighted_cov_split_prior_subset); 35 | } 36 | 37 | /** 38 | * The prior-adjusted number of covariates available to be split at this node 39 | * 40 | * @param node The node of interest 41 | * @return The prior-adjusted number of covariates that can be split 42 | */ 43 | private double pAdjF1(bartMachineTreeNode node) { 44 | if (node.padj == null){ 45 | node.padj = node.predictorsThatCouldBeUsedToSplitAtNode().size(); 46 | } 47 | if (node.padj == 0){ 48 | return 0; 49 | } 50 | if (node.isLeaf){ 51 | return node.padj; 52 | } 53 | //pull out weighted cov split prior subset vector 54 | TIntArrayList predictors = node.predictorsThatCouldBeUsedToSplitAtNode(); 55 | //get probs of split prior based on predictors that can be used and weight it accordingly 56 | double[] weighted_cov_split_prior_subset = getWeightedCovSplitPriorSubset(predictors); 57 | 58 | //find index inside predictor vector 59 | int index = bartMachineTreeNode.BAD_FLAG_int; 60 | for (int i = 0; i < predictors.size(); i++){ 61 | if (predictors.get(i) == node.splitAttributeM){ 62 | index = i; 63 | break; 64 | } 65 | } 66 | 67 | //return inverse probability 68 | return 1 / weighted_cov_split_prior_subset[index]; 69 | } 70 | 71 | /** 72 | * Given a set of valid predictors return the probability vector that corresponds to the 73 | * elements of cov_split_prior re-normalized because some entries may be deleted 74 | * 75 | * @param predictors The indices of the valid covariates 76 | * @return The updated and renormalized prior probability vector on the covariates to split 77 | */ 78 | private double[] getWeightedCovSplitPriorSubset(TIntArrayList predictors) { 79 | double[] weighted_cov_split_prior_subset = new double[predictors.size()]; 80 | for (int i = 0; i < predictors.size(); i++){ 81 | weighted_cov_split_prior_subset[i] = cov_split_prior[predictors.get(i)]; 82 | } 83 | Tools.normalize_array(weighted_cov_split_prior_subset); 84 | return weighted_cov_split_prior_subset; 85 | } 86 | 87 | public void setCovSplitPrior(double[] cov_split_prior) { 88 | this.cov_split_prior = cov_split_prior; 89 | //if we're setting the vector, we're using this feature 90 | use_prior_cov_spec = true; 91 | } 92 | 93 | /////////////nothing but scaffold code below, do not alter! 94 | 95 | public int pickRandomPredictorThatCanBeAssigned(bartMachineTreeNode node){ 96 | if (use_prior_cov_spec){ 97 | return pickRandomPredictorThatCanBeAssignedF1(node); 98 | } 99 | return super.pickRandomPredictorThatCanBeAssigned(node); 100 | } 101 | 102 | public double pAdj(bartMachineTreeNode node){ 103 | if (use_prior_cov_spec){ 104 | return pAdjF1(node); 105 | } 106 | return super.pAdj(node); 107 | } 108 | 109 | } 110 | -------------------------------------------------------------------------------- /src/bartMachine/package-info.java: -------------------------------------------------------------------------------- 1 | /** 2 | * The code for the implementation of the BART algorithm as described in detail in: 3 | * 4 | * @see Kapelner, A and Bleich, J. bartMachine: A Powerful Tool for Machine Learning in R, ArXiv e-prints, 2013 5 | * @since 1.0 6 | * @author Adam Kapelner and Justin Bleich 7 | */ 8 | package bartMachine; -------------------------------------------------------------------------------- /trove-3.0.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kapelner/bartMachine/6d1a223a2f37e34adda861268bc181ca29c27215/trove-3.0.3.jar --------------------------------------------------------------------------------