├── LICENSE ├── README.md ├── data └── README.txt ├── matlab_blend ├── do_blend.m └── wexpreg.m ├── model_output └── README.txt ├── python_train_predict ├── higgsml_functions.py ├── make_predictions.py ├── start_rgf_models_7f_s.py ├── start_rgf_models_7f_s_exp.py ├── start_rgf_models_7f_w.py ├── start_rgf_models_7f_w_exp.py ├── start_rgf_models_7f_wl.py └── start_rgf_models_7f_wl_exp.py ├── rgf1.2 ├── BUILDLOG ├── COPYING ├── README ├── bin │ ├── rgf │ └── rgf.exe.file ├── makefile ├── proj_vc2010 │ └── rgf │ │ ├── rgf.sln │ │ └── rgf.vcxproj ├── rgf1.2-guide.pdf ├── src │ ├── com │ │ ├── AzBmat.hpp │ │ ├── AzDmat.cpp │ │ ├── AzDmat.hpp │ │ ├── AzException.hpp │ │ ├── AzHelp.hpp │ │ ├── AzIntPool.cpp │ │ ├── AzIntPool.hpp │ │ ├── AzLoss.cpp │ │ ├── AzLoss.hpp │ │ ├── AzMemTempl.hpp │ │ ├── AzOut.hpp │ │ ├── AzParam.cpp │ │ ├── AzParam.hpp │ │ ├── AzPerfResult.hpp │ │ ├── AzPrint.hpp │ │ ├── AzReadOnlyMatrix.hpp │ │ ├── AzSmat.cpp │ │ ├── AzSmat.hpp │ │ ├── AzStrArray.hpp │ │ ├── AzStrPool.cpp │ │ ├── AzStrPool.hpp │ │ ├── AzSvDataS.cpp │ │ ├── AzSvDataS.hpp │ │ ├── AzSvFeatInfo.hpp │ │ ├── AzSvFeatInfoClone.hpp │ │ ├── AzTaskTools.cpp │ │ ├── AzTaskTools.hpp │ │ ├── AzTimer.hpp │ │ ├── AzTools.cpp │ │ ├── AzTools.hpp │ │ ├── AzUtil.cpp │ │ └── AzUtil.hpp │ └── tet │ │ ├── AzDataForTrTree.hpp │ │ ├── AzFindSplit.cpp │ │ ├── AzFindSplit.hpp │ │ ├── AzFsinfo.hpp │ │ ├── AzOptOnTree.cpp │ │ ├── AzOptOnTree.hpp │ │ ├── AzOptOnTree_TreeReg.cpp │ │ ├── AzOptOnTree_TreeReg.hpp │ │ ├── AzOptimizerT.hpp │ │ ├── AzRegDepth.hpp │ │ ├── AzReg_TreeReg.hpp │ │ ├── AzReg_TreeRegArr.hpp │ │ ├── AzReg_TreeRegArrImp.hpp │ │ ├── AzReg_TsrOpt.cpp │ │ ├── AzReg_TsrOpt.hpp │ │ ├── AzReg_TsrSib.cpp │ │ ├── AzReg_TsrSib.hpp │ │ ├── AzReg_Tsrbase.cpp │ │ ├── AzReg_Tsrbase.hpp │ │ ├── AzRgfTrainerSel.hpp │ │ ├── AzRgfTree.cpp │ │ ├── AzRgfTree.hpp │ │ ├── AzRgfTreeEnsImp.hpp │ │ ├── AzRgfTreeEnsemble.hpp │ │ ├── AzRgf_FindSplit.hpp │ │ ├── AzRgf_FindSplit_Dflt.cpp │ │ ├── AzRgf_FindSplit_Dflt.hpp │ │ ├── AzRgf_FindSplit_TreeReg.cpp │ │ ├── AzRgf_FindSplit_TreeReg.hpp │ │ ├── AzRgf_Optimizer.hpp │ │ ├── AzRgf_Optimizer_Dflt.cpp │ │ ├── AzRgf_Optimizer_Dflt.hpp │ │ ├── AzRgf_Optimizer_TreeReg.hpp │ │ ├── AzRgf_kw.hpp │ │ ├── AzRgforest.cpp │ │ ├── AzRgforest.hpp │ │ ├── AzRgforest_TreeReg.hpp │ │ ├── AzSortedFeat.cpp │ │ ├── AzSortedFeat.hpp │ │ ├── AzTET_Eval.hpp │ │ ├── AzTET_Eval_Dflt.hpp │ │ ├── AzTETmain.cpp │ │ ├── AzTETmain.hpp │ │ ├── AzTETmain_kw.hpp │ │ ├── AzTETproc.cpp │ │ ├── AzTETproc.hpp │ │ ├── AzTETrainer.hpp │ │ ├── AzTETselector.hpp │ │ ├── AzTE_ModelInfo.hpp │ │ ├── AzTrTree.cpp │ │ ├── AzTrTree.hpp │ │ ├── AzTrTreeEnsemble.hpp │ │ ├── AzTrTreeEnsemble_ReadOnly.hpp │ │ ├── AzTrTreeFeat.cpp │ │ ├── AzTrTreeFeat.hpp │ │ ├── AzTrTreeNode.hpp │ │ ├── AzTrTree_ReadOnly.hpp │ │ ├── AzTrTsplit.hpp │ │ ├── AzTrTtarget.hpp │ │ ├── AzTree.cpp │ │ ├── AzTree.hpp │ │ ├── AzTreeEnsemble.cpp │ │ ├── AzTreeEnsemble.hpp │ │ ├── AzTreeNodes.hpp │ │ ├── AzTreeRule.hpp │ │ └── driv_rgf.cpp └── test │ ├── call_exe.pl │ ├── output │ ├── sample.model-01 │ ├── sample.model-02 │ ├── sample.model-03 │ ├── sample.model-04 │ └── sample.model-05 │ └── sample │ ├── predict.inp │ ├── regress.test.x │ ├── regress.test.y │ ├── regress.train.x │ ├── regress.train.y │ ├── regress_train_test.inp │ ├── test.data.x │ ├── test.data.y │ ├── train.data.sparse.x │ ├── train.data.x │ ├── train.data.y │ ├── train.inp │ ├── train_predict.inp │ └── train_test.inp └── slides_Cambridge_Nov_14.pdf /data/README.txt: -------------------------------------------------------------------------------- 1 | The data files "training.csv" and "test.csv" go here. 2 | See http://www.kaggle.com/c/higgs-boson/data. -------------------------------------------------------------------------------- /matlab_blend/do_blend.m: -------------------------------------------------------------------------------- 1 | % This performs non-negative linear blending of several RGF models for the 2 | % HiggsML Challenge 3 | % Run this in Matlab 2014 for support of data tables, or update the 4 | % 'readtable' statement belows when using an older version 5 | % Author: Tim Salimans 6 | 7 | % load all the necessary data 8 | model_dir = [pwd '/../model_output']; 9 | data_dir = [pwd '/../data']; 10 | traindat = readtable([data_dir '/training.csv']); 11 | trainsel = traindat.DER_mass_MMC>0; 12 | raw_weights = traindat.Weight(trainsel); 13 | target = strcmp(traindat.Label(trainsel),'s'); 14 | train_weights = raw_weights; 15 | train_weights(target) = train_weights(target)/mean(train_weights(target)); 16 | train_weights(~target) = train_weights(~target)/mean(train_weights(~target)); 17 | 18 | % load cross validated predictions 19 | n = length(target); 20 | cvind = mod(0:(n-1),7); 21 | xw_cv = zeros(n,70); 22 | xs_cv = zeros(n,70); 23 | xwl_cv = zeros(n,70); 24 | xw_exp_cv = zeros(n,70); 25 | xs_exp_cv = zeros(n,70); 26 | xwl_exp_cv = zeros(n,70); 27 | for i=1:70 28 | if i<10 29 | nrs = ['0' int2str(i)]; 30 | else 31 | nrs = int2str(i); 32 | end 33 | for j=0:6 34 | fn = [model_dir '/w_cv' int2str(j) '_output/m-' nrs '.pred']; 35 | if exist(fn,'file') 36 | preds = csvread(fn); 37 | np = length(preds)/2; 38 | xw_cv(cvind==j,i) = preds(1:np)-preds((np+1):end); 39 | end 40 | fn = [model_dir '/s_cv' int2str(j) '_output/m-' nrs '.pred']; 41 | if exist(fn,'file') 42 | xs_cv(cvind==j,i) = csvread(fn); 43 | end 44 | fn = [model_dir '/wl_cv' int2str(j) '_output/m-' nrs '.pred']; 45 | if exist(fn,'file') 46 | xwl_cv(cvind==j,i) = csvread(fn); 47 | end 48 | 49 | fn = [model_dir '/w_exp_cv' int2str(j) '_output/m-' nrs '.pred']; 50 | if exist(fn,'file') 51 | preds = csvread(fn); 52 | np = length(preds)/2; 53 | xw_exp_cv(cvind==j,i) = preds(1:np)-preds((np+1):end); 54 | end 55 | fn = [model_dir '/s_exp_cv' int2str(j) '_output/m-' nrs '.pred']; 56 | if exist(fn,'file') 57 | xs_exp_cv(cvind==j,i) = csvread(fn); 58 | end 59 | fn = [model_dir '/wl_exp_cv' int2str(j) '_output/m-' nrs '.pred']; 60 | if exist(fn,'file') 61 | xwl_exp_cv(cvind==j,i) = csvread(fn); 62 | end 63 | end 64 | end 65 | 66 | % load predictions on the test set 67 | fn = [model_dir '/wl_exp_full_output/m-01.pred']; 68 | preds = csvread(fn); 69 | ntest = length(preds); 70 | xw_test = zeros(ntest,70); 71 | xs_test = zeros(ntest,70); 72 | xwl_test = zeros(ntest,70); 73 | xw_exp_test = zeros(ntest,70); 74 | xs_exp_test = zeros(ntest,70); 75 | xwl_exp_test = zeros(ntest,70); 76 | for i=1:70 77 | if i<10 78 | nrs = ['0' int2str(i)]; 79 | else 80 | nrs = int2str(i); 81 | end 82 | 83 | fn = [model_dir '/w_full_output/m-' nrs '.pred']; 84 | if exist(fn,'file') 85 | preds = csvread(fn); 86 | np = length(preds)/2; 87 | xw_test(:,i) = preds(1:np)-preds((np+1):end); 88 | end 89 | fn = [model_dir '/s_full_output/m-' nrs '.pred']; 90 | if exist(fn,'file') 91 | xs_test(:,i) = csvread(fn); 92 | end 93 | fn = [model_dir '/wl_full_output/m-' nrs '.pred']; 94 | if exist(fn,'file') 95 | xwl_test(:,i) = csvread(fn); 96 | end 97 | 98 | fn = [model_dir '/w_exp_full_output/m-' nrs '.pred']; 99 | if exist(fn,'file') 100 | preds = csvread(fn); 101 | np = length(preds)/2; 102 | xw_exp_test(:,i) = preds(1:np)-preds((np+1):end); 103 | end 104 | fn = [model_dir '/s_exp_full_output/m-' nrs '.pred']; 105 | if exist(fn,'file') 106 | xs_exp_test(:,i) = csvread(fn); 107 | end 108 | fn = [model_dir '/wl_exp_full_output/m-' nrs '.pred']; 109 | if exist(fn,'file') 110 | xwl_exp_test(:,i) = csvread(fn); 111 | end 112 | end 113 | 114 | % combine all the predictions, and only keep the complete ones 115 | x_all = [ones(n,1) xw_cv xs_cv xwl_cv xw_exp_cv xs_exp_cv xwl_exp_cv]; 116 | x_full = [ones(ntest,1) xw_test xs_test xwl_test xw_exp_test xs_exp_test xwl_exp_test]; 117 | sel = ((sum(x_all==0)+sum(x_full==0))==0); 118 | x_all = x_all(:,sel); 119 | x_full = x_full(:,sel); 120 | k = size(x_all,2); 121 | 122 | % perform non-negative blending with exponential loss 123 | obj = @(b)wexpreg(b,double(target),x_all,train_weights); 124 | opt = optimset('GradObj','on','Hessian','on','Algorithm','trust-region-reflective','Display','iter','TolFun',1e-15); 125 | b = fmincon(obj,zeros(k,1),[],[],[],[],[-Inf; zeros(k-1,1)],[],[],opt); 126 | 127 | % define percentage cutoff 128 | p = 0.175; 129 | 130 | % write predictions 131 | testdat = readtable([data_dir '/test.csv']); 132 | testsel = testdat.DER_mass_MMC>0; 133 | ftest = x_full*b; 134 | rank_order = zeros(length(testsel),1); 135 | rank_order(~testsel) = 1:(length(testsel)-length(ftest)); 136 | [~,ind] = sort(ftest); 137 | rank_order(testsel) = ind + (length(testsel)-length(ftest)); 138 | pred_s = false(length(testsel),1); 139 | pred_s(testsel) = ftest>=quantile(ftest,1-p); 140 | pred_strings = repmat('b',length(pred_s),1); 141 | pred_strings(pred_s) = 's'; 142 | to_predict = table(testdat.EventId,rank_order,cellstr(pred_strings),'VariableNames',{'EventId' 'RankOrder' 'Class'}); 143 | writetable(to_predict,['preds_' strrep(num2str(p),'.','_') '.csv']); 144 | -------------------------------------------------------------------------------- /matlab_blend/wexpreg.m: -------------------------------------------------------------------------------- 1 | % linear classification with exponential loss 2 | function [nllh,grad,hess] = wexpreg(b,y,x,w) 3 | 4 | f = x*b; 5 | pos = (w.*y).*exp(-f); 6 | neg = (w.*(1-y)).*exp(f); 7 | nllh = sum(pos)+sum(neg); 8 | 9 | if nargout>=2 10 | df = neg-pos; 11 | grad = x'*df; 12 | end 13 | 14 | if nargout==3 15 | dh = pos+neg; 16 | hess = x'*bsxfun(@times,x,dh); 17 | end -------------------------------------------------------------------------------- /model_output/README.txt: -------------------------------------------------------------------------------- 1 | This folder holds the model output. The model predictions take over 3GB to store and many CPU days to create. 2 | The complete contents of this folder can be downloaded from: https://drive.google.com/file/d/0B4Zly9eEgwFsbUx5cm15UHpJZTg/edit?usp=sharing -------------------------------------------------------------------------------- /python_train_predict/make_predictions.py: -------------------------------------------------------------------------------- 1 | from higgsml_functions import * 2 | 3 | # define the directory to store everything temporary 4 | temp_dir = 'temp' 5 | myMakeDir(temp_dir) 6 | 7 | # define the directory to find the models in 8 | save_dir = '../model_output' 9 | 10 | # load test data & process 11 | test_dat = loadAndProcessData("../data/test.csv") 12 | 13 | # make predictions for each model in 'save_dir' 14 | makePredictions(test_dat,temp_dir,save_dir) 15 | -------------------------------------------------------------------------------- /python_train_predict/start_rgf_models_7f_s.py: -------------------------------------------------------------------------------- 1 | from higgsml_functions import * 2 | 3 | # define the directory to store everything temporary 4 | temp_dir = 'temp' 5 | myMakeDir(temp_dir) 6 | 7 | # define the directory to store the models 8 | save_dir = '../model_output' 9 | myMakeDir(save_dir) 10 | 11 | # load training data & process 12 | train_dat,train_target,train_weights = loadAndProcessData("../data/training.csv") 13 | 14 | # parameters for model 15 | param = {} 16 | param['algorithm'] = 'RGF' 17 | param['reg_L2'] = 0.1 18 | param['reg_sL2'] = 0.001 19 | param['loss'] = 'Log' 20 | param['test_interval'] = 1000 21 | param['max_leaf_forest'] = 50000 22 | 23 | # start training / cross validation 24 | startTraining(train_dat,train_target,None,'s',temp_dir,save_dir,param) 25 | -------------------------------------------------------------------------------- /python_train_predict/start_rgf_models_7f_s_exp.py: -------------------------------------------------------------------------------- 1 | from higgsml_functions import * 2 | 3 | # define the directory to store everything temporary 4 | temp_dir = 'temp' 5 | myMakeDir(temp_dir) 6 | 7 | # define the directory to store the models 8 | save_dir = '../model_output' 9 | myMakeDir(save_dir) 10 | 11 | # load training data & process 12 | train_dat,train_target,train_weights = loadAndProcessData("../data/training.csv") 13 | 14 | # parameters for model 15 | param = {} 16 | param['algorithm'] = 'RGF' 17 | param['reg_L2'] = 0.1 18 | param['reg_sL2'] = 0.001 19 | param['loss'] = 'Expo' 20 | param['test_interval'] = 1000 21 | param['max_leaf_forest'] = 70000 22 | 23 | # start training / cross validation 24 | startTraining(train_dat,train_target,None,'s_exp',temp_dir,save_dir,param) 25 | 26 | -------------------------------------------------------------------------------- /python_train_predict/start_rgf_models_7f_w.py: -------------------------------------------------------------------------------- 1 | from higgsml_functions import * 2 | 3 | # define the directory to store everything temporary 4 | temp_dir = 'temp' 5 | myMakeDir(temp_dir) 6 | 7 | # define the directory to store the models 8 | save_dir = '../model_output' 9 | myMakeDir(save_dir) 10 | 11 | # load training data & process 12 | train_dat,train_target,train_weights = loadAndProcessData("../data/training.csv") 13 | 14 | # parameters for model 15 | param = {} 16 | param['algorithm'] = 'RGF' 17 | param['reg_L2'] = 0.1 18 | param['reg_sL2'] = 0.001 19 | param['loss'] = 'Log' 20 | param['test_interval'] = 1000 21 | param['max_leaf_forest'] = 50000 22 | 23 | # start training / cross validation 24 | startTraining(train_dat,train_target,train_weights,'w',temp_dir,save_dir,param,predict_weights=True) 25 | -------------------------------------------------------------------------------- /python_train_predict/start_rgf_models_7f_w_exp.py: -------------------------------------------------------------------------------- 1 | from higgsml_functions import * 2 | 3 | # define the directory to store everything temporary 4 | temp_dir = 'temp' 5 | myMakeDir(temp_dir) 6 | 7 | # define the directory to store the models 8 | save_dir = '../model_output' 9 | myMakeDir(save_dir) 10 | 11 | # load training data & process 12 | train_dat,train_target,train_weights = loadAndProcessData("../data/training.csv") 13 | 14 | # parameters for model 15 | param = {} 16 | param['algorithm'] = 'RGF' 17 | param['reg_L2'] = 0.1 18 | param['reg_sL2'] = 0.001 19 | param['loss'] = 'Expo' 20 | param['test_interval'] = 1000 21 | param['max_leaf_forest'] = 70000 22 | 23 | # start training / cross validation 24 | startTraining(train_dat,train_target,train_weights,'w_exp',temp_dir,save_dir,param,predict_weights=True) 25 | 26 | -------------------------------------------------------------------------------- /python_train_predict/start_rgf_models_7f_wl.py: -------------------------------------------------------------------------------- 1 | from higgsml_functions import * 2 | 3 | # define the directory to store everything temporary 4 | temp_dir = 'temp' 5 | myMakeDir(temp_dir) 6 | 7 | # define the directory to store the models 8 | save_dir = '../model_output' 9 | myMakeDir(save_dir) 10 | 11 | # load training data & process 12 | train_dat,train_target,train_weights = loadAndProcessData("../data/training.csv") 13 | 14 | # parameters for model 15 | param = {} 16 | param['algorithm'] = 'RGF' 17 | param['reg_L2'] = 0.1 18 | param['reg_sL2'] = 0.001 19 | param['loss'] = 'Log' 20 | param['test_interval'] = 1000 21 | param['max_leaf_forest'] = 50000 22 | 23 | # start training / cross validation 24 | startTraining(train_dat,train_target,train_weights,'wl',temp_dir,save_dir,param) 25 | 26 | -------------------------------------------------------------------------------- /python_train_predict/start_rgf_models_7f_wl_exp.py: -------------------------------------------------------------------------------- 1 | from higgsml_functions import * 2 | 3 | # define the directory to store everything temporary 4 | temp_dir = 'temp' 5 | myMakeDir(temp_dir) 6 | 7 | # define the directory to store the models 8 | save_dir = '../model_output' 9 | myMakeDir(save_dir) 10 | 11 | # load training data & process 12 | train_dat,train_target,train_weights = loadAndProcessData("../data/training.csv") 13 | 14 | # parameters for model 15 | param = {} 16 | param['algorithm'] = 'RGF' 17 | param['reg_L2'] = 0.1 18 | param['reg_sL2'] = 0.001 19 | param['loss'] = 'Expo' 20 | param['test_interval'] = 1000 21 | param['max_leaf_forest'] = 34000 22 | 23 | # start training / cross validation 24 | startTraining(train_dat,train_target,train_weights,'wl_exp',temp_dir,save_dir,param) 25 | 26 | -------------------------------------------------------------------------------- /rgf1.2/BUILDLOG: -------------------------------------------------------------------------------- 1 | 09/24/2012: Version 1.2 first release. 2 | 10/15/2012: To remove a compile error with g++ on OS X. 3 | -------------------------------------------------------------------------------- /rgf1.2/README: -------------------------------------------------------------------------------- 1 | ************************************************************************ 2 | 3 | README 4 | 5 | RGF Version 1.2 6 | September 2012 7 | 8 | C++ programs for regularized greedy forest 9 | 10 | ************************************************************************ 11 | 12 | Contents: 13 | --------- 14 | 15 | 1. Introduction 16 | 1.1. System Requirements 17 | 2. Download and Installation 18 | 3. Creating the Executable 19 | 3.1. Windows 20 | 3.2. Unix-like Systems 21 | 3.3. [Optional] Endianness Consideration 22 | 4. Documentation 23 | 5. Contact 24 | 6. Copyright 25 | 7. References 26 | 27 | --------------- 28 | 1. Introduction 29 | This software package provides implementation of regularized greedy forest 30 | (RGF) described in [1]. 31 | 32 | 1.1 System Requirement 33 | Executables are provided only for some versions of Windows (see Section 3 for 34 | detail). If provided executables do not work for your environment, you need 35 | to compile C++ code. 36 | 37 | To use the provided tools and to go through the examples in the user guide, 38 | Perl is required. 39 | 40 | ----------- 41 | 2. Download and Installation 42 | The package is provided as a zip file: 43 | 44 | rgf1.2.zip 45 | 46 | Download and extract the content. 47 | 48 | The top directory of the extracted content is "rgf1.2". Below all the 49 | path expressions are relative to "rgf1.2". 50 | 51 | -------------------------- 52 | 3 Creating the Executable 53 | 54 | To go through the examples in the user guide, your executable needs to be 55 | at the "bin" directory. Otherwise, your executable can be anywhere you like. 56 | 57 | ----------- 58 | 3.1 Windows 59 | A 64-bit executable is provided with the filename "rgf.exe.file" at the 60 | "bin" directory. It was tested on Windows 7 with the latest service pack as 61 | of 9/17/2012. You can either rename it to "rgf.exe" and use it, or you can 62 | rebuild it by yourself using the provided solution file for MS Visual C++ 63 | 2010 Express: "proj_vc2010\rgf\rgf.sln". 64 | 65 | --------------------- 66 | 3.2 Unix-like Systems 67 | You need to build your executable from the source code. A make file 68 | "makefile" is provided at the top directory. It is configured to use 69 | "g++" and always compile everything. You may need to customize "makefile" 70 | for your environment. 71 | 72 | To build the executable, change the current directory to the top 73 | directory "rgf1.2" and enter in the command line "make". Check the 74 | "bin" directory to make sure that your new executable "rgf" is there. 75 | 76 | ---------------------------------------- 77 | 3.3 [Optional] Endianness Consideration 78 | The models obtained by RGF training can be saved to files. 79 | The model files are essentially snap shots of memory that include 80 | numerical values. Therefore, the model files are sensitive to 81 | "endianness" of the environments. For this reason, if you wish to 82 | share model files among environments of different endianness, you need 83 | to follow the instructions below. Otherwise, you can skip this section. 84 | 85 | To share model files among the environments of different endianness, 86 | build your executable for the environment with big-endian with the 87 | compile option: 88 | 89 | /D_AZ_BIG_ENDIAN_ 90 | 91 | By doing so, the executable in your big-endian environment swaps the 92 | byte order of numerical values before writing and after reading the 93 | model files. 94 | 95 | ---------------- 96 | 4. Documentation 97 | rgf1.2-guide.pdf "Regularized Greedy Forest Version 1.2: User Guide" is included. 98 | 99 | ---------- 100 | 5. Contact 101 | riejohnson@gmail.com 102 | 103 | ------------ 104 | 6. Copyright 105 | RGF Version 1.2 is distributed under the GNU public license. Please read 106 | the file COPYING. 107 | 108 | ------------- 109 | 7. References 110 | 111 | [1] Rie Johnson and Tong Zhang. Learning nonlinear functions using 112 | regularized greedy forest. Technical report: arXiv:1109.0887v5, 2011. 113 | -------------------------------------------------------------------------------- /rgf1.2/bin/rgf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/bin/rgf -------------------------------------------------------------------------------- /rgf1.2/bin/rgf.exe.file: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/bin/rgf.exe.file -------------------------------------------------------------------------------- /rgf1.2/makefile: -------------------------------------------------------------------------------- 1 | 2 | BIN_NAME = rgf 3 | BIN_DIR = bin 4 | TARGET = $(BIN_DIR)/$(BIN_NAME) 5 | CFLAGS = -Isrc/com -Isrc/tet_tools -O2 6 | 7 | CPP_FILES= \ 8 | src/tet/driv_rgf.cpp \ 9 | src/com/AzDmat.cpp \ 10 | src/tet/AzFindSplit.cpp \ 11 | src/com/AzIntPool.cpp \ 12 | src/com/AzLoss.cpp \ 13 | src/tet/AzOptOnTree_TreeReg.cpp \ 14 | src/tet/AzOptOnTree.cpp \ 15 | src/com/AzParam.cpp \ 16 | src/tet/AzReg_Tsrbase.cpp \ 17 | src/tet/AzReg_TsrOpt.cpp \ 18 | src/tet/AzReg_TsrSib.cpp \ 19 | src/tet/AzRgf_FindSplit_Dflt.cpp \ 20 | src/tet/AzRgf_FindSplit_TreeReg.cpp \ 21 | src/tet/AzRgf_Optimizer_Dflt.cpp \ 22 | src/tet/AzRgforest.cpp \ 23 | src/tet/AzRgfTree.cpp \ 24 | src/com/AzSmat.cpp \ 25 | src/tet/AzSortedFeat.cpp \ 26 | src/com/AzStrPool.cpp \ 27 | src/com/AzSvDataS.cpp \ 28 | src/com/AzTaskTools.cpp \ 29 | src/tet/AzTETmain.cpp \ 30 | src/tet/AzTETproc.cpp \ 31 | src/com/AzTools.cpp \ 32 | src/tet/AzTree.cpp \ 33 | src/tet/AzTreeEnsemble.cpp \ 34 | src/tet/AzTrTree.cpp \ 35 | src/tet/AzTrTreeFeat.cpp \ 36 | src/com/AzUtil.cpp 37 | 38 | #$(TARGET): $(CPP_FILES) 39 | all: 40 | /bin/rm -f $(TARGET) 41 | g++ $(CPP_FILES) $(CFLAGS) -o $(TARGET) 42 | 43 | clean: 44 | /bin/rm -f $(TARGET) 45 | -------------------------------------------------------------------------------- /rgf1.2/proj_vc2010/rgf/rgf.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 11.00 3 | # Visual C++ Express 2010 4 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "rgf", "rgf.vcxproj", "{27341E76-36FB-4CA4-848C-1782869FE39B}" 5 | EndProject 6 | Global 7 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 8 | Debug|Win32 = Debug|Win32 9 | Debug|x64 = Debug|x64 10 | Release|Win32 = Release|Win32 11 | Release|x64 = Release|x64 12 | EndGlobalSection 13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 14 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Debug|Win32.ActiveCfg = Release|Win32 15 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Debug|Win32.Build.0 = Release|Win32 16 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Debug|x64.ActiveCfg = Release|x64 17 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Debug|x64.Build.0 = Release|x64 18 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Release|Win32.ActiveCfg = Release|Win32 19 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Release|Win32.Build.0 = Release|Win32 20 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Release|x64.ActiveCfg = Release|x64 21 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Release|x64.Build.0 = Release|x64 22 | EndGlobalSection 23 | GlobalSection(SolutionProperties) = preSolution 24 | HideSolutionNode = FALSE 25 | EndGlobalSection 26 | EndGlobal 27 | -------------------------------------------------------------------------------- /rgf1.2/rgf1.2-guide.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/rgf1.2-guide.pdf -------------------------------------------------------------------------------- /rgf1.2/src/com/AzBmat.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzBmat.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_BMAT_HPP_ 20 | #define _AZ_BMAT_HPP_ 21 | #include "AzUtil.hpp" 22 | 23 | //! binary matrix 24 | class AzBmat { 25 | protected: 26 | int row_num; 27 | AzDataArray a; 28 | 29 | public: 30 | AzBmat() : row_num(0) {} 31 | AzBmat(int inp_row_num, int inp_col_num) : row_num(0) { 32 | resize(inp_col_num); 33 | } 34 | AzBmat(const AzBmat *inp) : row_num(0) { 35 | set(inp); 36 | } 37 | inline void set(const AzBmat *inp) { 38 | row_num = inp->row_num; 39 | a.reset(&inp->a); 40 | } 41 | AzBmat(const AzBmat &inp) : row_num(0) { 42 | set(&inp); 43 | } 44 | AzBmat & operator =(const AzBmat &inp) { 45 | if (this == &inp) return *this; 46 | set(&inp); 47 | return *this; 48 | } 49 | inline void reform(int inp_row_num, int inp_col_num) { 50 | reset(); 51 | row_num = inp_row_num; 52 | resize(inp_col_num); 53 | } 54 | inline void resize(int new_col_num) { 55 | a.resz(new_col_num); 56 | } 57 | 58 | inline void reset() { 59 | row_num = 0; 60 | a.reset(); 61 | } 62 | inline int rowNum() const { 63 | return row_num; 64 | } 65 | inline int colNum() const { 66 | return a.cursor(); 67 | } 68 | 69 | inline const AzIntArr *on_rows(int col) const { 70 | return a.point(col); 71 | } 72 | inline void clear(int fx) { 73 | a.point_u(fx)->reset(); 74 | } 75 | 76 | inline void load(int col, const AzIntArr *ia_on_rows) { 77 | if (ia_on_rows == NULL || ia_on_rows->size() <= 0) return; 78 | 79 | if (ia_on_rows->min() < 0 || 80 | ia_on_rows->max() >= row_num) { 81 | throw new AzException("AzBmat::load", "wrong row#"); 82 | } 83 | a.point_u(col)->reset(ia_on_rows); 84 | } 85 | }; 86 | #endif 87 | 88 | 89 | -------------------------------------------------------------------------------- /rgf1.2/src/com/AzException.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzException.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_EXCEPTION_HPP_ 20 | #define _AZ_EXCEPTION_HPP_ 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | using namespace std; 28 | 29 | enum AzRetCode { 30 | AzNormal=0, 31 | AzAllocError=10, 32 | AzFileIOError=20, 33 | AzInputError=30, 34 | AzInputMissing=31, 35 | AzInputNotValid=32, 36 | AzConflict=100, /* all others */ 37 | }; 38 | 39 | /*-----------------------------------------------------*/ 40 | class AzException { 41 | public: 42 | AzException(const char *string1, 43 | const char *string2, 44 | const char *string3=NULL) 45 | { 46 | reset(AzConflict, string1, string2, string3); 47 | } 48 | 49 | AzException(AzRetCode retcode, 50 | const char *string1, 51 | const char *string2, 52 | const char *string3=NULL) 53 | { 54 | reset(retcode, string1, string2, string3); 55 | } 56 | 57 | template 58 | AzException(AzRetCode retcode, 59 | const char *string1, 60 | const char *string2, 61 | const char *string3, 62 | T anything) 63 | { 64 | reset(retcode, string1, string2, string3); 65 | s3 << "; " << anything; 66 | } 67 | 68 | void reset(AzRetCode retcode, 69 | const char *str1, 70 | const char *str2, 71 | const char *str3) 72 | { 73 | this->retcode = retcode; 74 | if (str1 != NULL) s1 << str1; 75 | if (str2 != NULL) s2 << str2; 76 | if (str3 != NULL) s3 << str3; 77 | } 78 | 79 | AzRetCode getReturnCode() { 80 | return retcode; 81 | } 82 | 83 | string getMessage() 84 | { 85 | if (retcode == AzNormal) { 86 | 87 | } 88 | else if (retcode == AzAllocError) { 89 | message << "!Memory alloc error!"; 90 | } 91 | else if (retcode == AzFileIOError) { 92 | message << "!File I/O error!"; 93 | } 94 | else if (retcode == AzInputError) { 95 | message << "!Input error!"; 96 | } 97 | else if (retcode == AzInputMissing) { 98 | message << "!Missing input!"; 99 | } 100 | else if (retcode == AzInputNotValid) { 101 | message << "!Input value is not valid!"; 102 | } 103 | else if (retcode == AzConflict) { 104 | message << "Conflict"; 105 | } 106 | else { 107 | message << "Unknown error"; 108 | } 109 | 110 | message << ": "; 111 | if (s1.str().find("Az") == 0) { 112 | message << "(Detected in " << s1.str() << ") " << endl; 113 | } 114 | else { 115 | message << s1.str() << " "; 116 | } 117 | message << s2.str(); 118 | if (s3.str().length() > 0) { 119 | message << " " << s3.str(); 120 | } 121 | message << endl; 122 | return message.str(); 123 | } 124 | 125 | protected: 126 | AzRetCode retcode; 127 | 128 | stringstream s1, s2, s3; 129 | stringstream message; 130 | }; 131 | 132 | #endif 133 | -------------------------------------------------------------------------------- /rgf1.2/src/com/AzIntPool.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzIntPool.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_INT_POOL_HPP_ 20 | #define _AZ_INT_POOL_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | 24 | class AzIpEnt { 25 | public: 26 | int offs; 27 | const int *ints; 28 | int num; 29 | int count; 30 | int value; 31 | AzIpEnt() { 32 | offs = 0; 33 | ints = NULL; 34 | num = count = 0; 35 | value = -1; 36 | } 37 | }; 38 | 39 | //! Store integer arrays. Searchable after committed. 40 | class AzIntPool { 41 | protected: 42 | AzIpEnt *ent; 43 | AzBaseArray a_ent; 44 | int ent_num; 45 | 46 | AzBaseArray a_data; 47 | int *data; 48 | int data_num; 49 | 50 | bool isCommitted; 51 | 52 | public: 53 | AzIntPool() : ent(NULL), ent_num(0), data(NULL), data_num(0), isCommitted(true) {} 54 | AzIntPool(AzFile *file) 55 | : ent(NULL), ent_num(0), data(NULL), data_num(0), isCommitted(true) { 56 | _read(file); 57 | } 58 | AzIntPool(const AzIntPool *inp) /* copy */ 59 | : ent(NULL), ent_num(0), data(NULL), data_num(0), isCommitted(true) { 60 | reset(inp); 61 | } 62 | ~AzIntPool() {} 63 | 64 | void reset() { 65 | a_ent.free(&ent); ent_num = 0; 66 | a_data.free(&data); data_num = 0; 67 | isCommitted = true; 68 | } 69 | void reset(const AzIntPool *ip); 70 | 71 | void read(AzFile *file) { 72 | reset(); 73 | _read(file); 74 | } 75 | 76 | int write(AzFile *file); 77 | 78 | inline void update(int ex, 79 | const AzIntArr *iq, 80 | int count=1, 81 | int val=-1) { 82 | update(ex, iq->point(), iq->size(), count, val); 83 | } 84 | void update(int ex, 85 | const int *ints, 86 | int ints_num, 87 | int count=1, 88 | int val=-1); 89 | 90 | int put(const int *ints, int ints_num, 91 | int count=1, 92 | int value=-1); 93 | int put(const AzIntArr *intq, int 94 | count=1) { 95 | return this->put(intq->point(), intq->size(), count); 96 | } 97 | 98 | /*--- put with a value ---*/ 99 | int putv(const AzIntArr *intq, 100 | int value) { 101 | int count = 1; 102 | return this->put(intq->point(), intq->size(), count, value); 103 | } 104 | 105 | inline int getValue(int idx) const { 106 | checkRange(idx, "AzIntPool::getValue"); 107 | return ent[idx].value; 108 | } 109 | 110 | void commit(); 111 | int size() const {return ent_num;} 112 | const int *point(int ent_no, int *out_ints_num=NULL) const; 113 | int getCount(int ent_no) const; 114 | void setCount(int ent_no, int new_count); 115 | int find(const int *ints, int ints_num) const; 116 | int find(const AzIntArr *intq) const { 117 | return find(intq->point(), intq->size()); 118 | } 119 | 120 | void get(int ent_no, AzIntArr *intq) const 121 | { 122 | int num; 123 | const int *ints = point(ent_no, &num); 124 | intq->concat(ints, num); 125 | } 126 | 127 | inline void erase(int ent_no) { 128 | shorten(ent_no, 0); 129 | } 130 | void shorten(int ent_no, int new_len); 131 | 132 | void dump(const AzOut &, const char *header) const; 133 | 134 | bool isThisCommitted() const { 135 | return isCommitted; 136 | } 137 | 138 | inline void clear() { 139 | reset(); 140 | } 141 | 142 | void concat(const AzIntPool *inp) { 143 | int num = inp->size(); 144 | int ix; 145 | for (ix = 0; ix < num; ++ix) { 146 | int len; 147 | const int *ints = inp->point(ix, &len); 148 | int count = inp->getCount(ix); 149 | int value = inp->getValue(ix); 150 | put(ints, len, count, value); 151 | } 152 | } 153 | 154 | /*--- prohibit assign operator ---*/ 155 | AzIntPool(const AzIntPool &) { 156 | throw new AzException("AzIntPool =", "no support"); 157 | } 158 | AzIntPool & operator =(const AzIntPool &inp) { 159 | if (this == &inp) return *this; 160 | throw new AzException("AzIntPool =", "no support"); 161 | } 162 | 163 | 164 | protected: 165 | void _swap(); 166 | void _read(AzFile *file); 167 | 168 | int inc_ent(); 169 | int inc_data(int min_inc); 170 | 171 | inline void checkRange(int ent_no, const char *eyec) const 172 | { 173 | if (ent_no < 0 || ent_no >= ent_num) { 174 | throw new AzException(eyec, "out of range"); 175 | } 176 | } 177 | }; 178 | 179 | #endif 180 | -------------------------------------------------------------------------------- /rgf1.2/src/com/AzOut.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzOut.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_OUT_HPP_ 20 | #define _AZ_OUT_HPP_ 21 | 22 | class AzOut { 23 | protected: 24 | bool isActive; 25 | int level; 26 | public: 27 | ostream *o; 28 | 29 | inline AzOut() : o(NULL), isActive(true), level(0) {} 30 | inline AzOut(ostream *o_ptr) : isActive(true), level(0) { 31 | o = o_ptr; 32 | } 33 | inline void reset(ostream *o_ptr) { 34 | o = o_ptr; 35 | activate(); 36 | } 37 | 38 | inline void deactivate() { 39 | isActive = false; 40 | } 41 | inline void activate() { 42 | isActive = true; 43 | } 44 | inline void setStdout() { 45 | o = &cout; 46 | activate(); 47 | } 48 | inline void setStderr() { 49 | o = &cerr; 50 | activate(); 51 | } 52 | inline bool isNull() const { 53 | if (!isActive) return true; 54 | if (o == NULL) return true; 55 | return false; 56 | } 57 | inline void flush() const { 58 | if (o != NULL) o->flush(); 59 | } 60 | inline void setLevel(int inp_level) { 61 | level = inp_level; 62 | } 63 | inline int getLevel() const { 64 | return level; 65 | } 66 | }; 67 | #endif 68 | -------------------------------------------------------------------------------- /rgf1.2/src/com/AzParam.cpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzParam.cpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #include "AzParam.hpp" 20 | 21 | /*-------------------------------------------------------------*/ 22 | void AzParam::check(const AzOut &out, 23 | AzBytArr *s_unused) 24 | { 25 | if (param == NULL) return; 26 | 27 | AzStrPool sp_unused, sp_kw; 28 | analyze(&sp_unused, &sp_kw); 29 | 30 | bool error = false; 31 | 32 | int ix; 33 | for (ix = 0; ix < sp_unused.size(); ++ix) { 34 | if (sp_unused.getLen(ix) <= 0) continue; 35 | 36 | /*--- this parameter wasn't used by anyone. ---*/ 37 | if (s_unused != NULL) { 38 | /*--- may be used by someone else ---*/ 39 | if (s_unused->length() > 0) s_unused->concat(dlm); 40 | s_unused->concat(sp_unused.c_str(ix)); 41 | } 42 | else { 43 | /*--- it could be a typo ---*/ 44 | AzBytArr s; 45 | s.concat("!Warning! Unknown parameter: \""); 46 | s.concat(sp_unused.c_str(ix)); s.concat("\""); 47 | AzPrint::writeln(out, s); 48 | } 49 | } 50 | 51 | AzBytArr s_err; 52 | int kw_num = sp_kw.size(); 53 | sp_kw.commit(); 54 | if (sp_kw.size() != kw_num) { 55 | for (ix = 0; ix < sp_kw.size(); ++ix) { 56 | int count = sp_kw.getCount(ix); 57 | if (count > 1) { 58 | /*--- this keyword appears more than once. ---*/ 59 | if (s_err.length() <= 0) s_err.newline(); 60 | s_err.concat(" Duplicated keyword: "); 61 | s_err.concat(sp_kw.c_str(ix)); 62 | s_err.newline(); 63 | error = true; 64 | } 65 | } 66 | } 67 | 68 | if (error) { 69 | throw new AzException(AzInputError, "AzParam::check", s_err.c_str()); 70 | } 71 | } 72 | 73 | /*-------------------------------------------------------------*/ 74 | void AzParam::analyze(AzStrPool *sp_unused, 75 | AzStrPool *sp_kw) 76 | { 77 | if (param == NULL) return; 78 | sp_used_kw.commit(); 79 | 80 | AzStrPool sp_kwval; 81 | AzTools::getStrings(param, dlm, &sp_kwval); 82 | int ix; 83 | for (ix = 0; ix < sp_kwval.size(); ++ix) { 84 | int len; 85 | const AzByte *ptr = sp_kwval.point(ix, &len); 86 | if (len <= 0) continue; 87 | 88 | AzBytArr s_kw; 89 | AzTools::getString(&ptr, ptr+len, kwval_dlm, &s_kw); 90 | if (s_kw.length() < len) s_kw.concat(kwval_dlm); 91 | if (sp_used_kw.find(&s_kw) < 0) { 92 | /*--- this parameter wasn't used by anyone. ---*/ 93 | sp_unused->put(sp_kwval.c_str(ix)); 94 | } 95 | if (sp_kw != NULL) sp_kw->put(&s_kw); 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /rgf1.2/src/com/AzParam.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzParam.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_PARAM_HPP_ 20 | #define _AZ_PARAM_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzLoss.hpp" 24 | #include "AzStrPool.hpp" 25 | #include "AzPrint.hpp" 26 | #include "AzTools.hpp" 27 | 28 | //! Parse parameters 29 | class AzParam { 30 | protected: 31 | const char *param; 32 | char dlm, kwval_dlm; 33 | AzStrPool sp_used_kw; 34 | bool doCheck; /* check unknown/duplicated keywords */ 35 | public: 36 | AzParam(const char *inp_param, 37 | bool inp_doCheck=true, 38 | char inp_dlm=',', 39 | char inp_kwval_dlm='=') : sp_used_kw(100, 30) 40 | { 41 | param = inp_param; 42 | doCheck = inp_doCheck; 43 | dlm = inp_dlm; 44 | kwval_dlm = inp_kwval_dlm; 45 | } 46 | 47 | inline const char *c_str() const { return param; } 48 | inline void swOn(bool *swch, const char *kw, 49 | bool doCheckKw=true) { 50 | if (param == NULL) return; 51 | if (doCheckKw) { 52 | if (strstr(kw, "Dont") == kw || 53 | (strstr(kw, "No") == kw && strstr(kw, "Normalize") == NULL)) { 54 | throw new AzException("AzParam::swOn", 55 | "On-kw shouldn't begin with \"Dont\" or \"No\"", kw); 56 | } 57 | } 58 | const char *ptr = pointAfterKw(param, kw); 59 | if (ptr != NULL && 60 | (*ptr == '\0' || *ptr == dlm)) { 61 | *swch = true; 62 | } 63 | if (doCheck) sp_used_kw.put(kw); 64 | } 65 | inline void swOff(bool *swch, const char *kw, 66 | bool doCheckKw=true) { 67 | if (param == NULL) return; 68 | if (doCheckKw) { 69 | if (strstr(kw, "Dont") != kw && 70 | (strstr(kw, "No") != kw && strstr(kw, "Normalize") == NULL)) { 71 | throw new AzException("AzParam::swOff", 72 | "Off-kw should start with \"dont\" or \"No\"", kw); 73 | } 74 | } 75 | const char *ptr = pointAfterKw(param, kw); 76 | if (ptr != NULL && 77 | (*ptr == '\0' || *ptr == dlm) ) { 78 | *swch = false; 79 | } 80 | if (doCheck) sp_used_kw.put(kw); 81 | } 82 | inline void vStr(const char *kw, AzBytArr *s) { 83 | if (param == NULL) return; 84 | const char *bp = pointAfterKw(param, kw); 85 | if (bp == NULL) return; 86 | const char *ep = pointAt(bp, dlm); 87 | s->reset(); 88 | s->concat(bp, Az64::ptr_diff(ep-bp, "AzParam::vStr")); 89 | if (doCheck) sp_used_kw.put(kw); 90 | } 91 | 92 | inline void vFloat(const char *kw, double *out_value) { 93 | if (param == NULL) return; 94 | const char *ptr = pointAfterKw(param, kw); 95 | if (ptr == NULL) return; 96 | *out_value = atof(ptr); 97 | if (doCheck) sp_used_kw.put(kw); 98 | } 99 | inline void vInt(const char *kw, int *out_value) { 100 | if (param == NULL) return; 101 | const char *ptr = pointAfterKw(param, kw); 102 | if (ptr == NULL) return; 103 | *out_value = atol(ptr); 104 | if (doCheck) sp_used_kw.put(kw); 105 | } 106 | 107 | inline void vLoss(const char *kw, AzLossType *out_loss) { 108 | if (param == NULL) return; 109 | AzBytArr s; 110 | vStr(kw, &s); 111 | if (s.length() > 0) { 112 | *out_loss = AzLoss::lossType(s.c_str()); 113 | } 114 | if (doCheck) sp_used_kw.put(kw); 115 | } 116 | 117 | void check(const AzOut &out, AzBytArr *s_unused_param=NULL); 118 | 119 | protected: 120 | inline const char *pointAfterKw(const char *inp_inp, const char *kw) const { 121 | const char *ptr = NULL; 122 | const char *inp = inp_inp; 123 | for ( ; ; ) { 124 | ptr = strstr(inp, kw); 125 | if (ptr == NULL) return NULL; 126 | if (ptr == inp || *(ptr-1) == dlm) { 127 | break; 128 | } 129 | inp = ptr + strlen(kw); 130 | } 131 | ptr += strlen(kw); 132 | return ptr; 133 | } 134 | inline static const char *pointAt(const char *inp, const char *kw) { 135 | const char *ptr = strstr(inp, kw); 136 | if (ptr == NULL) return inp + strlen(inp); 137 | return ptr; 138 | } 139 | inline static const char *pointAt(const char *inp, char ch) { 140 | const char *ptr = strchr(inp, ch); 141 | if (ptr == NULL) return inp + strlen(inp); 142 | return ptr; 143 | } 144 | 145 | void analyze(AzStrPool *sp_unused, 146 | AzStrPool *sp_kw); 147 | }; 148 | #endif 149 | 150 | -------------------------------------------------------------------------------- /rgf1.2/src/com/AzPerfResult.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzPerfResult.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_PERF_RESULT_HPP_ 20 | #define _AZ_PERF_RESULT_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | 24 | #define AzOtherCat "_X_" 25 | 26 | /*--------------------------------------------------*/ 27 | enum AzPerfType { 28 | AzPerfType_Acc = 0, 29 | AzPerfType_RMSE = 1, 30 | }; 31 | #define AzPerfType_Num 2 32 | static const char *perf_str[AzPerfType_Num] = { 33 | "acc", "rmse", 34 | }; 35 | 36 | /*--------------------------------------------------*/ 37 | class AzPerfResult { 38 | public: 39 | AzPerfResult() { 40 | p=r=f=acc=breakEven_f=breakEven_acc=rmse=loss=-1; 41 | } 42 | double p, r, f, acc, breakEven_f, breakEven_acc, rmse, loss; 43 | inline void put(double inp_p, double inp_r, double inp_f, double inp_acc, 44 | double inp_be_f, double inp_be_acc, 45 | double inp_rmse, double inp_loss) { 46 | p = inp_p; 47 | r = inp_r; 48 | f = inp_f; 49 | acc = inp_acc; 50 | breakEven_f = inp_be_f; 51 | breakEven_acc = inp_be_acc; 52 | rmse = inp_rmse; 53 | loss=inp_loss; 54 | } 55 | double getPerf(AzPerfType p_type) { 56 | if (p_type == AzPerfType_Acc) return acc; 57 | if (p_type == AzPerfType_RMSE) return rmse; 58 | return -1; 59 | } 60 | static const char *getPerfStr(AzPerfType p_type) { 61 | if (p_type < 0 || 62 | p_type >= AzPerfType_Num) return "???"; 63 | return perf_str[p_type]; 64 | } 65 | 66 | static double isBetter(AzPerfType p_type, 67 | double p, double comp_p) { 68 | /*--- negative means unset ---*/ 69 | if (p < 0) return false; 70 | if (comp_p < 0) return true; 71 | 72 | if (p_type == AzPerfType_RMSE) { 73 | if (p < comp_p) return true; 74 | } 75 | else { 76 | if (p > comp_p) return true; 77 | } 78 | return false; 79 | } 80 | void zeroOut() { 81 | p=r=f=acc=breakEven_f=breakEven_acc=rmse=loss=0; 82 | } 83 | void add(const AzPerfResult *inp) { 84 | p+=inp->p; 85 | r+=inp->r; 86 | f+=inp->f; 87 | acc+=inp->acc; 88 | breakEven_f=inp->breakEven_f; 89 | breakEven_acc=inp->breakEven_acc; 90 | rmse+=inp->rmse; 91 | loss+=inp->loss; 92 | } 93 | void multiply(double val) { 94 | p*=val; 95 | r*=val; 96 | f*=val; 97 | acc*=val; 98 | breakEven_f*=val; 99 | breakEven_acc*=val; 100 | rmse*=val; 101 | loss*=val; 102 | } 103 | }; 104 | #endif 105 | -------------------------------------------------------------------------------- /rgf1.2/src/com/AzReadOnlyMatrix.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzReadOnlyMatrix.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_READONLY_MATRIX_HPP_ 20 | #define _AZ_READONLY_MATRIX_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzStrArray.hpp" 24 | 25 | //! Abstract class: interface for read-only vectors. 26 | class AzReadOnlyVector { 27 | public: 28 | virtual int rowNum() const = 0; 29 | virtual double get(int row_no) const = 0; 30 | virtual int next(AzCursor &cursor, double &out_val) const = 0; 31 | virtual bool isZero() const = 0; 32 | virtual void dump(const AzOut &out, const char *header, 33 | const AzStrArray *sp_row = NULL, 34 | int cut_num = -1) const = 0; 35 | virtual double selfInnerProduct() const = 0; 36 | virtual int nonZeroRowNum() const = 0; 37 | virtual double sum() const = 0; 38 | 39 | /*--------------------------------------*/ 40 | virtual void writeText(const char *fn, int digits) const { 41 | AzIntArr ia; 42 | ia.range(0, rowNum()); 43 | writeText(fn, &ia, digits); 44 | } 45 | 46 | /*--------------------------------------*/ 47 | virtual void writeText(const char *fn, const AzIntArr *ia, int digits) const { 48 | AzFile file(fn); 49 | file.open("wb"); 50 | AzBytArr s; 51 | int ix; 52 | for (ix = 0; ix < ia->size(); ++ix) { 53 | int row = ia->get(ix); 54 | double val = get(row); 55 | s.cn(val, digits); 56 | s.nl(); 57 | } 58 | s.writeText(&file); 59 | file.close(true); 60 | } 61 | 62 | /*--------------------------------------*/ 63 | virtual void to_sparse(AzBytArr *s, int digits) const { 64 | AzCursor cur; 65 | for ( ; ; ) { 66 | double val; 67 | int row = next(cur, val); 68 | if (row < 0) break; 69 | if (s->length() > 0) { 70 | s->c(' '); 71 | } 72 | s->cn(row); 73 | if (val != 1) { 74 | s->c(':'); s->cn(val, digits); 75 | } 76 | } 77 | } 78 | 79 | /*--------------------------------------*/ 80 | virtual void to_dense(AzBytArr *s, int digits) const { 81 | int row; 82 | for (row = 0; row < rowNum(); ++row) { 83 | double val = get(row); 84 | if (row > 0) { 85 | s->c(" "); 86 | } 87 | s->cn(val, digits); 88 | } 89 | } 90 | }; 91 | 92 | //! Abstract class: interface for read-only matrices. 93 | class AzReadOnlyMatrix { 94 | public: 95 | virtual int rowNum() const = 0; 96 | virtual int colNum() const = 0; 97 | 98 | virtual void destroy() = 0; 99 | #if 0 100 | virtual void destroy(int col) = 0; 101 | #endif 102 | 103 | virtual double get(int row_no, int col_no) const = 0; 104 | 105 | virtual const AzReadOnlyVector *col(int col_no) const = 0; 106 | 107 | virtual bool isZero() const = 0; 108 | virtual bool isZero(int col) const = 0; 109 | 110 | virtual void dump(const AzOut &out, const char *header, 111 | const AzStrArray *sp_row = NULL, const AzStrArray *sp_col = NULL, 112 | int cut_num = -1) const = 0; 113 | 114 | /*--------------------------------------*/ 115 | virtual void writeText(const char *fn, int digits, 116 | bool doSparse=false) const { 117 | AzIntArr ia; 118 | ia.range(0, colNum()); 119 | writeText(fn, &ia, digits, doSparse); 120 | } 121 | 122 | /*--------------------------------------*/ 123 | virtual void writeText(const char *fn, const AzIntArr *ia, 124 | int digits, 125 | bool doSparse=false) const { 126 | if (doSparse) { 127 | writeText_sparse(fn, ia, digits); 128 | return; 129 | } 130 | AzFile file(fn); 131 | file.open("wb"); 132 | int num; 133 | const int *cxs = ia->point(&num); 134 | int ix; 135 | for (ix = 0; ix < num; ++ix) { 136 | int cx = cxs[ix]; 137 | AzBytArr s; 138 | col(cx)->to_dense(&s, digits); 139 | s.nl(); 140 | s.writeText(&file); 141 | } 142 | file.close(true); 143 | } 144 | 145 | /*--------------------------------------*/ 146 | virtual void writeText_sparse(const char *fn, const AzIntArr *ia, int digits) const { 147 | AzFile file(fn); 148 | file.open("wb"); 149 | AzBytArr s_header("sparse "); s_header.cn(rowNum()); s_header.nl(); 150 | s_header.writeText(&file); 151 | 152 | int num; 153 | const int *cxs = ia->point(&num); 154 | int ix; 155 | for (ix = 0; ix < num; ++ix) { 156 | int cx = cxs[ix]; 157 | AzBytArr s; 158 | col(cx)->to_sparse(&s, digits); 159 | s.nl(); 160 | s.writeText(&file); 161 | } 162 | file.close(true); 163 | } 164 | }; 165 | #endif 166 | -------------------------------------------------------------------------------- /rgf1.2/src/com/AzStrArray.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzStrArray.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_STR_ARRAY_HPP_ 20 | #define _AZ_STR_ARRAY_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | 24 | class AzStrArray { 25 | public: 26 | virtual int size() const = 0; 27 | virtual const char *c_str(int no) const = 0; 28 | void get(int no, AzBytArr *byteq) const { 29 | byteq->reset(); 30 | byteq->concat(c_str(no)); 31 | } 32 | 33 | virtual bool isSame(const AzStrArray *inp) const { 34 | if (size() != inp->size()) { 35 | return false; 36 | } 37 | int ix; 38 | for (ix = 0; ix < size(); ++ix) { 39 | AzBytArr s0; 40 | get(ix, &s0); 41 | if (s0.compare(inp->c_str(ix)) != 0) { 42 | return false; 43 | } 44 | } 45 | return true; 46 | } 47 | 48 | /*---*/ 49 | virtual void writeText(const char *fn) const { 50 | AzFile file(fn); 51 | file.open("wb"); 52 | int ix; 53 | for (ix = 0; ix < size(); ++ix) { 54 | AzBytArr s; 55 | get(ix, &s); 56 | s.nl(); 57 | s.writeText(&file); 58 | } 59 | file.close(true); 60 | } 61 | }; 62 | 63 | #endif 64 | 65 | -------------------------------------------------------------------------------- /rgf1.2/src/com/AzSvFeatInfo.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzSvFeatInfo.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_SV_FEAT_INFO_HPP_ 20 | #define _AZ_SV_FEAT_INFO_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzStrPool.hpp" 24 | #include "AzPrint.hpp" 25 | 26 | //! Abstract class: interfact to access feature descriptions. 27 | class AzSvFeatInfo { 28 | public: 29 | // Concatenate feature description to str_desc. 30 | virtual void concatDesc(int ex, //!< feature id 31 | AzBytArr *str_desc) const = 0; 32 | 33 | //! Return number of features. 34 | virtual int featNum() const = 0; /* # of features */ 35 | 36 | void desc(int ex, AzBytArr *str_desc) const { 37 | str_desc->reset(); 38 | concatDesc(ex, str_desc); 39 | } 40 | 41 | int desc2fno(const char *fnm) const { 42 | int fx; 43 | for (fx = 0; fx < featNum(); ++fx) { 44 | AzBytArr s; 45 | desc(fx, &s); 46 | if (s.compare(fnm) == 0) { 47 | return fx; 48 | } 49 | } 50 | return -1; 51 | } 52 | 53 | void show(const AzOut &out, const AzIntArr *ia_fxs) const { 54 | int ix; 55 | for (ix = 0; ix < ia_fxs->size(); ++ix) { 56 | int fx = ia_fxs->get(ix); 57 | AzBytArr s("???"); 58 | if (fx>=0 && fxreset(); 80 | if (sp_kw->size()==0) return; 81 | int fx; 82 | for (fx = 0; fx < featNum(); ++fx) { 83 | AzBytArr s; 84 | desc(fx, &s); 85 | int ix; 86 | for (ix = 0; ix < sp_kw->size(); ++ix) { 87 | if (s.beginsWith(sp_kw->c_str(ix))) { 88 | ia_fxs->put(fx); 89 | break; 90 | } 91 | } 92 | } 93 | } 94 | 95 | void contains(const AzStrArray *sp_kw, 96 | AzIntArr *ia_fxs) const { 97 | ia_fxs->reset(); 98 | if (sp_kw->size()==0) return; 99 | int fx; 100 | for (fx = 0; fx < featNum(); ++fx) { 101 | AzBytArr s; 102 | desc(fx, &s); 103 | int ix; 104 | for (ix = 0; ix < sp_kw->size(); ++ix) { 105 | if (s.contains(sp_kw->c_str(ix))) { 106 | ia_fxs->put(fx); 107 | break; 108 | } 109 | } 110 | } 111 | } 112 | 113 | int equals(const char *kw) const { 114 | int fx; 115 | for (fx = 0; fx < featNum(); ++fx) { 116 | AzBytArr s; 117 | desc(fx, &s); 118 | if (s.compare(kw) == 0) { 119 | return fx; 120 | } 121 | } 122 | return -1; 123 | } 124 | }; 125 | 126 | #endif 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /rgf1.2/src/com/AzSvFeatInfoClone.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzSvFeatInfoClone.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_SV_FEAT_INFO_CLONE_HPP_ 20 | #define _AZ_SV_FEAT_INFO_CLONE_HPP_ 21 | 22 | #include "AzSvFeatInfo.hpp" 23 | #include "AzStrArray.hpp" 24 | 25 | class AzSvFeatInfoClone : /* implements */ public virtual AzSvFeatInfo, 26 | /* implements */ public virtual AzStrArray 27 | { 28 | protected: 29 | AzDataPool arr_desc; 30 | 31 | public: 32 | AzSvFeatInfoClone() {} 33 | AzSvFeatInfoClone(const AzSvFeatInfo *inp) { 34 | reset(inp); 35 | } 36 | AzSvFeatInfoClone(const AzStrArray *inp) { 37 | reset(inp); 38 | } 39 | inline int featNum() const { 40 | return arr_desc.size(); 41 | } 42 | inline void concatDesc(int fx, AzBytArr *desc) const { 43 | if (fx < 0 || fx >= featNum()) { 44 | desc->c("?"); desc->cn(fx); desc->c("?"); 45 | return; 46 | } 47 | desc->concat(arr_desc.point(fx)); 48 | } 49 | void reset(const AzSvFeatInfo *inp) { 50 | int f_num = inp->featNum(); 51 | arr_desc.reset(); 52 | int fx; 53 | for (fx = 0; fx < f_num; ++fx) { 54 | AzBytArr *ptr = arr_desc.new_slot(); 55 | inp->desc(fx, ptr); 56 | } 57 | } 58 | void reset(const AzStrArray *inp) { 59 | int f_num = inp->size(); 60 | arr_desc.reset(); 61 | int fx; 62 | for (fx = 0; fx < f_num; ++fx) { 63 | AzBytArr *ptr = arr_desc.new_slot(); 64 | ptr->reset(inp->c_str(fx)); 65 | } 66 | } 67 | void reset(int inp_f_num) { 68 | arr_desc.reset(); 69 | int fx; 70 | for (fx = 0; fx < inp_f_num; ++fx) { 71 | AzBytArr s("F"); 72 | s.cn(fx, 3, true); /* width=3, fillWithZero */ 73 | arr_desc.new_slot()->reset(&s); 74 | } 75 | } 76 | 77 | void append(const AzSvFeatInfo *inp) { 78 | int fx; 79 | for (fx = 0; fx < inp->featNum(); ++fx) { 80 | inp->desc(fx, arr_desc.new_slot()); 81 | } 82 | } 83 | 84 | /*--- to implement AzStrArray ---*/ 85 | int size() const { return featNum(); } 86 | const char *c_str(int fx) const { 87 | if (fx < 0 || fx >= featNum()) { 88 | return "???"; 89 | } 90 | return arr_desc.point(fx)->c_str(); 91 | } 92 | }; 93 | #endif 94 | -------------------------------------------------------------------------------- /rgf1.2/src/com/AzTaskTools.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTaskTools.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TASK_TOOLS_HPP_ 20 | #define _AZ_TASK_TOOLS_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzSmat.hpp" 24 | #include "AzSvFeatInfo.hpp" 25 | #include "AzLoss.hpp" 26 | #include "AzPerfResult.hpp" 27 | #include "AzPrint.hpp" 28 | 29 | /* Some functions are overlapping with the legacy code AzCatProc */ 30 | /*--------------------------------------------------*/ 31 | //! Tools related to classification or regression tasks. 32 | class AzTaskTools 33 | { 34 | public: 35 | static double analyzeLoss(AzLossType loss_type, 36 | const AzDvect *v_p, 37 | const AzDvect *v_y, 38 | const AzIntArr *inp_ia_dx, 39 | double p_coeff); 40 | 41 | static void showDist(const AzStrArray *sp_cat, 42 | const AzIntArr *ia_cat, 43 | const char *header, 44 | const AzOut &out); 45 | 46 | static double eval_breakEven( 47 | const AzIntArr *ia_gold, 48 | const AzDvect *v_pval, 49 | const AzStrArray *sp_cat, 50 | const char *eyecatcher, 51 | double *out_best_f=NULL, 52 | double *out_best_acc=NULL); 53 | 54 | static AzPerfResult eval(const AzDvect *v_p, 55 | const AzDvect *v_y, /* assume y in {+1,-1} */ 56 | AzLossType loss_type) { 57 | AzOut null_out; 58 | AzPerfResult res; 59 | eval("", loss_type, NULL, NULL, v_p, v_y, "", null_out, &res); 60 | return res; 61 | } 62 | static void eval(const AzDvect *v_p, 63 | const AzDvect *v_y, /* assume y in {+1,-1} */ 64 | AzPerfResult *result) { 65 | AzOut null_out; 66 | eval("", AzLoss_None, NULL, NULL, v_p, v_y, "", null_out, result); 67 | } 68 | static void eval(const AzDvect *v_p, 69 | const AzDvect *v_y, /* assume y in {+1,-1} */ 70 | const AzOut &test_out, 71 | AzPerfResult *result=NULL) { 72 | AzOut null_out; 73 | eval("", AzLoss_None, NULL, NULL, v_p, v_y, "", test_out, result); 74 | } 75 | static void eval(const char *ite_str, 76 | AzLossType loss_type, 77 | const AzIntArr *ia_dx, 78 | const double p_coeff[2], 79 | const AzDvect *v_test_pval, 80 | const AzDvect *v_test_yval, /* assume y in {+1,-1} */ 81 | const char *tt_eyec, 82 | const AzOut &test_out, 83 | AzPerfResult *result=NULL); 84 | 85 | static int genY(const AzIntArr *ia_cat, 86 | int focus_cat, 87 | double y_posi_val, 88 | double y_nega_val, 89 | AzDvect *v_yval); /* output */ 90 | 91 | /*--- for displaying the weights, feature names, etc. ---*/ 92 | static void dumpWeights(const AzOut &out, 93 | const AzDvect *v_w, 94 | const char *name, 95 | const AzSvFeatInfo *feat, 96 | int print_max, 97 | bool changeLine); 98 | static void printPR(AzPrint &o, 99 | int ok, 100 | int t, 101 | int g); 102 | protected: 103 | static void formatWeight(const AzSvFeatInfo *feat, 104 | int ex, 105 | double val, 106 | AzBytArr *str_out); 107 | }; 108 | #endif 109 | -------------------------------------------------------------------------------- /rgf1.2/src/com/AzTimer.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTimer.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TIMER_HPP_ 20 | #define _AZ_TIMER_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | 24 | class AzTimer { 25 | public: 26 | int chk; /* next check point */ 27 | int inc; /* increment: negative means no checking */ 28 | 29 | AzTimer() : chk(-1), inc(-1) {} 30 | ~AzTimer() {} 31 | 32 | inline void reset(int inp_inc) { 33 | chk = -1; 34 | inc = inp_inc; 35 | if (inc > 0) { 36 | chk = inc; 37 | } 38 | } 39 | 40 | inline bool ringing(bool isRinging, int inp) { /* timer is ringing */ 41 | if (isRinging) return true; 42 | 43 | if (chk > 0 && inp >= chk) { 44 | while(chk <= inp) { 45 | chk += inc; /* next check point */ 46 | } 47 | return true; 48 | } 49 | return false; 50 | } 51 | 52 | inline bool reachedMax(int inp, 53 | const char *msg, 54 | const AzOut &out) const { 55 | bool yes_no = reachedMax(inp); 56 | if (yes_no) { 57 | AzTimeLog::print(msg, " reached max", out); 58 | } 59 | return yes_no; 60 | } 61 | inline bool reachedMax(int inp) const { 62 | if (chk > 0 && inp >= chk) return true; 63 | else return false; 64 | } 65 | }; 66 | 67 | #endif 68 | 69 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzFindSplit.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzFindSplit.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_FIND_SPLIT_HPP_ 20 | #define _AZ_FIND_SPLIT_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzDataForTrTree.hpp" 24 | #include "AzTrTtarget.hpp" 25 | #include "AzTrTsplit.hpp" 26 | #include "AzTrTree.hpp" 27 | 28 | class Az_forFindSplit { 29 | public: 30 | double wy_sum, w_sum; 31 | Az_forFindSplit() : wy_sum(0), w_sum(0) {} 32 | void reset() { 33 | wy_sum = w_sum = 0; 34 | } 35 | }; 36 | 37 | //! Abstract class: provides building blocks for node split search. 38 | /*------------------------------------------*/ 39 | class AzFindSplit 40 | { 41 | protected: 42 | const AzTrTtarget *target; 43 | const AzDataForTrTree *data; 44 | const AzTrTree_ReadOnly *tree; 45 | int min_size; 46 | 47 | AzIntArr ia_feats; 48 | const AzIntArr *ia_fx; 49 | 50 | public: 51 | AzFindSplit() : target(NULL), data(NULL), tree(NULL), ia_fx(NULL), 52 | min_size(-1) {} 53 | ~AzFindSplit() {} 54 | void reset() { 55 | target = NULL; 56 | data = NULL; 57 | tree = NULL; 58 | min_size = -1; 59 | } 60 | 61 | void _begin(const AzTrTree_ReadOnly *inp_tree, 62 | const AzDataForTrTree *inp_data, 63 | const AzTrTtarget *inp_target, 64 | int inp_min_size); 65 | void _end() { 66 | reset(); 67 | } 68 | 69 | //---------------------------------------------------------------- 70 | // void findBestSplit(const AzTrTtarget *tar, 71 | // const AzIntArr *ia_dx, 72 | // ... parameters ... 73 | // AzTrTsplit *best_split); /* output */ 74 | //---------------------------------------------------------------- 75 | 76 | virtual void _pickFeats(int pick_num, int f_num); 77 | 78 | protected: 79 | /*----------------------------------------------------------------*/ 80 | virtual double getBestGain(double w_sum, 81 | double wy_sum, 82 | double *out_best_p) /* must not be null */ 83 | const = 0; 84 | virtual double evalSplit(const Az_forFindSplit i[2], 85 | double bestP[2]) /* output */ 86 | const; 87 | /*----------------------------------------------------------------*/ 88 | 89 | void _findBestSplit(int nx, 90 | /*--- output ---*/ 91 | AzTrTsplit *best_split); 92 | void loop(AzTrTsplit *best_split, 93 | int fx, /* feature# */ 94 | const AzSortedFeat *sorted, 95 | int dxs_num, 96 | const Az_forFindSplit *total); 97 | }; 98 | 99 | #endif 100 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzOptOnTree_TreeReg.cpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzOptOnTree_TreeReg.cpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #include "AzOptOnTree_TreeReg.hpp" 20 | #include "AzPrint.hpp" 21 | 22 | /*--------------------------------------------------------*/ 23 | void 24 | AzOptOnTree_TreeReg::optimize(AzRgfTreeEnsemble *inp_rgf_ens, 25 | const AzTrTreeFeat *inp_tree_feat, 26 | int ite_num, 27 | double lam, 28 | double sig) 29 | { 30 | ens = inp_rgf_ens; 31 | tree_feat = inp_tree_feat; 32 | rgf_ens = inp_rgf_ens; 33 | 34 | synchronize(); 35 | updateTreeWeights(rgf_ens); 36 | 37 | int tree_num = ens->size(); 38 | if (reg_arr->size() < tree_num) { 39 | throw new AzException("AzOptOnTree_TreeReg::optimize", 40 | "max #tree has changed??"); 41 | } 42 | int tx; 43 | for (tx = 0; tx < tree_num; ++tx) { 44 | AzReg_TreeReg *reg = reg_arr->reg(tx); 45 | reg->reset(ens->tree(tx), reg_depth); 46 | } 47 | iterate(ite_num, lam, sig); 48 | 49 | ens = NULL; 50 | tree_feat = NULL; 51 | rgf_ens = NULL; 52 | } 53 | 54 | /*--------------------------------------------------------*/ 55 | void AzOptOnTree_TreeReg::update_with_features( 56 | double nlam, 57 | double nsig, 58 | double py_avg, 59 | AzRgf_forDelta *for_delta) /* updated */ 60 | { 61 | int tree_num = ens->size(); 62 | int tx; 63 | for (tx = 0; tx < tree_num; ++tx) { 64 | ens->tree_u(tx)->restoreDataIndexes(); 65 | AzReg_TreeReg *reg = reg_arr->reg(tx); 66 | reg->clearFocusNode(); 67 | 68 | AzIIarr iia_nx_fx; 69 | tree_feat->featIds(tx, &iia_nx_fx); 70 | int num = iia_nx_fx.size(); 71 | AzIIFarr iifa_nx_fx_delta; 72 | int ix; 73 | for (ix = 0; ix < num; ++ix) { 74 | int nx, fx; 75 | iia_nx_fx.get(ix, &nx, &fx); 76 | 77 | double delta = bestDelta(nx, fx, reg, nlam, nsig, py_avg, for_delta); 78 | update_weight(nx, fx, delta, reg); 79 | } 80 | ens->tree_u(tx)->releaseDataIndexes(); 81 | } 82 | } 83 | /*--------------------------------------------------------*/ 84 | void AzOptOnTree_TreeReg::update_weight(int nx, 85 | int fx, 86 | double delta, 87 | AzReg_TreeReg *reg) 88 | { 89 | double new_w = v_w.get(fx) + delta; 90 | v_w.set(fx, new_w); 91 | 92 | int dxs_num; 93 | const int *dxs = data_points(fx, &dxs_num); 94 | updatePred(dxs, dxs_num, delta, &v_p); 95 | 96 | /*--- update the weight in the ensemble ---*/ 97 | const AzTrTreeFeatInfo *fp = tree_feat->featInfo(fx); 98 | rgf_ens->tree_u(fp->tx)->setWeight(fp->nx, new_w); 99 | reg->changeWeight(nx, delta); 100 | } 101 | 102 | /*--------------------------------------------------------*/ 103 | double AzOptOnTree_TreeReg::bestDelta( 104 | int nx, 105 | int fx, 106 | AzReg_TreeReg *reg, 107 | double nlam, 108 | double nsig, 109 | double py_avg, 110 | AzRgf_forDelta *for_delta) /* updated */ 111 | const 112 | { 113 | const char *eyec = "AzOptOnTree_TI::bestDelta"; 114 | 115 | double w = v_w.get(fx); 116 | int dxs_num; 117 | const int *dxs = data_points(fx, &dxs_num); 118 | if (dxs_num <= 0) { 119 | throw new AzException(eyec, "no data indexes"); 120 | } 121 | 122 | const double *fixed_dw = NULL; 123 | if (!AzDvect::isNull(&v_fixed_dw)) fixed_dw = v_fixed_dw.point(); 124 | const double *p = v_p.point(); 125 | const double *y = v_y.point(); 126 | double nega_dL = 0, ddL= 0; 127 | if (fixed_dw == NULL) { 128 | AzLoss::sum_deriv(loss_type, dxs, dxs_num, p, y, py_avg, 129 | nega_dL, ddL); 130 | } 131 | else { 132 | AzLoss::sum_deriv_weighted(loss_type, dxs, dxs_num, p, y, fixed_dw, py_avg, 133 | nega_dL, ddL); 134 | } 135 | 136 | double dR, ddR; 137 | reg->penalty_deriv(nx, &dR, &ddR); 138 | 139 | double dd = ddL + nlam*ddR; 140 | if (dd == 0) dd = 1; 141 | double delta = (nega_dL-nlam*dR)*eta/dd; 142 | for_delta->check_delta(&delta, max_delta); 143 | 144 | return delta; 145 | } 146 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzOptOnTree_TreeReg.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzOptOnTree_TreeReg.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_OPT_ON_TREE_TREE_REG_HPP_ 20 | #define _AZ_OPT_ON_TREE_TREE_REG_HPP_ 21 | 22 | #include "AzOptOnTree.hpp" 23 | #include "AzReg_TreeRegArr.hpp" 24 | 25 | //! coordinate descent with regulatization using tree structure 26 | /*--------------------------------------------------------*/ 27 | class AzOptOnTree_TreeReg : /* extends */ public virtual AzOptOnTree 28 | { 29 | protected: 30 | AzRgfTreeEnsemble *rgf_ens; 31 | AzReg_TreeRegArr *reg_arr; 32 | 33 | public: 34 | AzOptOnTree_TreeReg() : rgf_ens(NULL), reg_arr(NULL) {} 35 | void reset(AzReg_TreeRegArr *inp_reg_arr) { 36 | reg_arr = inp_reg_arr; 37 | } 38 | 39 | virtual void optimize(AzRgfTreeEnsemble *ens, /* weights are updated */ 40 | const AzTrTreeFeat *tree_feat, 41 | int inp_ite_num=-1, 42 | double lam=-1, 43 | double sig=-1); 44 | 45 | /*--- ---*/ 46 | virtual void reset(const AzOptOnTree_TreeReg *inp) { 47 | AzOptOnTree::reset(inp); 48 | rgf_ens = inp->rgf_ens; 49 | reg_arr = inp->reg_arr; 50 | } 51 | 52 | protected: 53 | //! override 54 | virtual void update_with_features(double nlam, double nsig, double py_avg, 55 | AzRgf_forDelta *for_delta); 56 | 57 | virtual void update_weight(int nx, 58 | int fx, 59 | double delta, 60 | AzReg_TreeReg *reg); 61 | virtual double bestDelta( 62 | int nx, 63 | int fx, 64 | AzReg_TreeReg *reg, 65 | double nlam, 66 | double nsig, 67 | double py_avg, 68 | AzRgf_forDelta *for_delta) /* updated */ 69 | const; 70 | }; 71 | 72 | #endif 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzOptimizerT.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzOptimizerT.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_OPTIMIZER_T_HPP_ 20 | #define _AZ_OPTIMIZER_T_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzSmat.hpp" 24 | #include "AzDmat.hpp" 25 | #include "AzBmat.hpp" 26 | #include "AzLoss.hpp" 27 | #include "AzTrTreeFeat.hpp" 28 | #include "AzTrTreeEnsemble_ReadOnly.hpp" 29 | #include "AzRgfTreeEnsemble.hpp" 30 | #include "AzRegDepth.hpp" 31 | #include "AzParam.hpp" 32 | 33 | //! coordinate descent for weight optimization. 34 | /*--------------------------------------------------------*/ 35 | class AzOptimizerT 36 | { 37 | public: 38 | virtual void reset(AzLossType loss_type, 39 | const AzDvect *v_y, 40 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */ 41 | const AzRegDepth *reg_depth, 42 | AzParam ¶m, 43 | bool beVerbose, 44 | const AzOut out_req, 45 | /*--- for warm start ---*/ 46 | const AzTrTreeEnsemble_ReadOnly *ens=NULL, 47 | const AzTrTreeFeat *tree_feat=NULL, 48 | const AzDvect *inp_v_p=NULL) = 0; 49 | 50 | virtual void copyPred_to(AzDvect *out_v_p) const = 0; 51 | 52 | virtual void resetPred(const AzBmat *m_tran, 53 | AzDvect *v_p) /* output */ 54 | const = 0; 55 | virtual void optimize(AzRgfTreeEnsemble *ens, 56 | const AzTrTreeFeat *tree_feat, 57 | int inp_ite_num=-1, 58 | double lam=-1, 59 | double sig=-1) = 0; 60 | virtual void optimize(AzRgfTreeEnsemble *ens, 61 | const AzTrTreeFeat *tree_feat, 62 | bool doRefreshP, 63 | int inp_ite_num=-1, 64 | double lam=-1, 65 | double sig=-1) { 66 | throw new AzException("AzOptimizerT::optimize(...,doRefreshP,...)", "No support"); 67 | } 68 | virtual const AzDvect *weights() const = 0; 69 | virtual double constant() const = 0; 70 | virtual void printHelp(AzHelp &h) const = 0; 71 | }; 72 | 73 | #endif 74 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzRegDepth.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRegDepth.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_REG_DEPTH_HPP_ 20 | #define _AZ_REG_DEPTH_HPP_ 21 | 22 | #include "AzRgf_kw.hpp" 23 | #include "AzUtil.hpp" 24 | #include "AzParam.hpp" 25 | #include "AzHelp.hpp" 26 | #include "AzRegDepth.hpp" 27 | 28 | #define depth_base_dflt 1 29 | /* #define depth_base_min_penalty_dflt 2 */ 30 | 31 | //! Regularizer using node depth. 32 | class AzRegDepth { 33 | protected: 34 | double depth_base; 35 | AzDvect v_dep2pow; /* to avoid repetitive calls of pow() */ 36 | const double *dep2pow; 37 | 38 | public: 39 | AzRegDepth() : depth_base(depth_base_dflt), dep2pow(NULL) {} 40 | 41 | virtual void set_default_for_min_penalty() { 42 | /* depth_base = depth_base_min_penalty_dflt; */ 43 | } 44 | 45 | virtual inline 46 | void check_if_nonincreasing(const char *who) const { 47 | if (depth_base < 1) { 48 | AzBytArr s(kw_depth_base); s.c(" must be no smaller than 1 for "); 49 | s.c(who); s.c("."); 50 | throw new AzException(AzInputNotValid, "AzRegDepth::check_if_nonincreasing", 51 | s.c_str()); 52 | } 53 | } 54 | 55 | virtual inline 56 | double apply(double val, int dep) const { 57 | if (depth_base == 1) return val; 58 | if (dep >= 0 && dep < v_dep2pow.rowNum()) { 59 | return val * dep2pow[dep]; 60 | } 61 | else { 62 | return val * pow(depth_base, (double)dep); 63 | } 64 | } 65 | 66 | virtual void reset(AzParam ¶m, 67 | const AzOut &out) { 68 | resetParam(param); 69 | if (depth_base <= 0) { 70 | throw new AzException(AzInputNotValid, "AzRegDepth::reset", 71 | kw_depth_base, "must be no smaller than 1."); 72 | } 73 | if (depth_base < 1) { 74 | AzBytArr s("!Warning! "); s.c(kw_depth_base); s.c(" should be no smaller than 1."); 75 | AzPrint::writeln(out, s); 76 | } 77 | 78 | printParam(out); 79 | } 80 | virtual void printHelp(AzHelp &h) const { 81 | h.begin(Azforest_config, "AzRegDepth", "Regularization on node depth"); 82 | h.item(kw_depth_base, help_depth_base, depth_base); 83 | h.end(); 84 | } 85 | virtual void printParam(const AzOut &out) const { 86 | if (out.isNull()) return; 87 | if (depth_base != 1) { 88 | AzPrint o(out); 89 | o.ppBegin("AzRegDepth", "Reg. on depth", ", "); 90 | o.printV(kw_depth_base, depth_base); 91 | o.ppEnd(); 92 | } 93 | } 94 | 95 | protected: 96 | virtual void resetParam(AzParam &p) { 97 | bool doCheck = false; 98 | p.vFloat(kw_depth_base, &depth_base); 99 | 100 | /*--- ---*/ 101 | v_dep2pow.reform(50); 102 | int dep; 103 | for (dep = 0; dep < v_dep2pow.rowNum(); ++dep) { 104 | v_dep2pow.set(dep, pow(depth_base, (double)dep)); 105 | } 106 | dep2pow = v_dep2pow.point(); 107 | } 108 | }; 109 | #endif 110 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzReg_TreeReg.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzReg_TreeReg.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_REG_TREE_REG_HPP_ 20 | #define _AZ_REG_TREE_REG_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzDmat.hpp" 24 | #include "AzTrTree_ReadOnly.hpp" 25 | #include "AzRegDepth.hpp" 26 | #include "AzParam.hpp" 27 | 28 | class AzReg_TreeRegShared { 29 | public: 30 | virtual AzDmat *share() = 0; 31 | virtual bool create(const AzTrTree_ReadOnly *tree, const AzDmat *info) = 0; 32 | virtual AzDmat *share(const AzTrTree_ReadOnly *tree) = 0; 33 | }; 34 | 35 | //! Default implementation of AzReg_TreeRegShared 36 | class AzReg_TreeRegShared_Dflt : /* implements */ public virtual AzReg_TreeRegShared 37 | { 38 | protected: 39 | AzDmat m_by_alltree; /* info not specific to individual tree */ 40 | 41 | public: 42 | /*--- override these to store tree-specific info ---*/ 43 | virtual AzDmat *share(const AzTrTree_ReadOnly *tree) { return NULL; } 44 | virtual bool create(const AzTrTree_ReadOnly *tree, const AzDmat *) { return false; } 45 | /*----------------------------------------------------*/ 46 | virtual AzDmat *share() { 47 | return &m_by_alltree; 48 | } 49 | }; 50 | 51 | //! Abstract class: interface to tree-structured regularizer 52 | class AzReg_TreeReg { 53 | public: 54 | virtual void set_shared(AzReg_TreeRegShared *shared) {} 55 | virtual void check_reg_depth(const AzRegDepth *) const {} 56 | 57 | virtual void reset(const AzTrTree_ReadOnly *inp_tree, 58 | const AzRegDepth *inp_reg_depth) = 0; 59 | 60 | virtual void penalty_deriv(int nx, double *dr, 61 | double *ddr) = 0; 62 | 63 | virtual void changeWeight(int nx, double w_diff) = 0; 64 | 65 | virtual void clearFocusNode() = 0; 66 | 67 | /*--- for node split ---*/ 68 | //! called by AzRgf_FindSplit_TR::begin 69 | virtual void reset_forNewLeaf(const AzTrTree_ReadOnly *t, 70 | const AzRegDepth *rdep) = 0; 71 | 72 | //! called by AzRgf_FindSplit_TR::findSplit 73 | virtual void reset_forNewLeaf(int f_nx, 74 | const AzTrTree_ReadOnly *t, 75 | const AzRegDepth *rdep) = 0; 76 | 77 | virtual double penalty_diff(const double leaf_w_delta[2]) const = 0; 78 | virtual void penalty_deriv(double *dr, 79 | double *ddr) const = 0; 80 | 81 | /*--- for maintenance ---*/ 82 | virtual void show(const AzOut &out, 83 | const char *header) const = 0; 84 | virtual double penalty() const { 85 | return -1; 86 | } 87 | 88 | /*---------------------------------------------------------*/ 89 | virtual void resetParam(AzParam ¶m) = 0; 90 | virtual void printParam(const AzOut &out) const = 0; 91 | virtual void printHelp(AzHelp &h) const = 0; 92 | 93 | virtual const char *signature() const = 0; 94 | virtual const char *description() const = 0; 95 | }; 96 | #endif 97 | 98 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzReg_TreeRegArr.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzReg_TreeRegArr.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_REG_TREE_REG_ARR_HPP_ 20 | #define _AZ_REG_TREE_REG_ARR_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzReg_TreeReg.hpp" 24 | 25 | class AzReg_TreeRegArr 26 | { 27 | public: 28 | virtual void reset(int tree_num) = 0; 29 | virtual AzReg_TreeReg *reg(int tx) = 0; 30 | virtual AzReg_TreeReg *reg_forNewLeaf(int tx) = 0; 31 | virtual int size() const = 0; 32 | }; 33 | #endif 34 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzReg_TreeRegArrImp.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzReg_TreeRegArrImp.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_REG_TREE_REG_ARR_IMP_HPP_ 20 | #define _AZ_REG_TREE_REG_ARR_IMP_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzRgfTreeEnsemble.hpp" 24 | #include "AzReg_TreeRegArr.hpp" 25 | 26 | template 27 | class AzReg_TreeRegArrImp : /* implements */ public virtual AzReg_TreeRegArr 28 | { 29 | protected: 30 | AzPtrPool areg; 31 | T template_reg; 32 | T temporary_reg; 33 | AzReg_TreeRegShared_Dflt shared; 34 | 35 | public: 36 | T *tmpl_u() { return &template_reg; } 37 | const T *tmpl() const { return &template_reg; } 38 | 39 | inline int size() const { return areg.size(); } 40 | void reset(int tree_num) { 41 | areg.reset(); 42 | int tx; 43 | for (tx = 0; tx < tree_num; ++tx) { 44 | T *reg = areg.new_slot(); 45 | reg->copyParam_from(&template_reg); 46 | reg->set_shared(&shared); 47 | } 48 | temporary_reg.set_shared(&shared); 49 | } 50 | inline AzReg_TreeReg *reg(int tx) { 51 | return areg.point_u(tx); 52 | } 53 | inline AzReg_TreeReg *reg_forNewLeaf(int tx) { 54 | int t_num = areg.size(); 55 | if (tx < t_num) return areg.point_u(tx); 56 | 57 | temporary_reg.copyParam_from(&template_reg); 58 | return &temporary_reg; /* should be root-only tree */ 59 | } 60 | }; 61 | #endif 62 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzReg_TsrOpt.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzReg_TsrOpt.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_REG_TSROPT_HPP_ 20 | #define _AZ_REG_TSROPT_HPP_ 21 | 22 | #include "AzRgf_kw.hpp" 23 | #include "AzReg_Tsrbase.hpp" 24 | 25 | //! 26 | class AzReg_TsrOpt : /* extends */ public virtual AzReg_Tsrbase 27 | { 28 | protected: 29 | int reg_ite_num; 30 | double curr_penalty; 31 | 32 | AzDmat *m_coeff; /* shared with the regularizers for other trees */ 33 | /* owned by AzReg_TreeRegShared */ 34 | 35 | //! tree structure 36 | AzIIarr iia_le_gt; 37 | 38 | public: 39 | AzReg_TsrOpt() : reg_ite_num(reg_ite_num_dflt), curr_penalty(0), m_coeff(NULL) {} 40 | 41 | virtual void copyParam_from(const AzReg_TsrOpt *inp) { 42 | AzReg_Tsrbase::copyParam_from(inp); 43 | reg_ite_num = inp->reg_ite_num; 44 | } 45 | 46 | void set_shared(AzReg_TreeRegShared *ptr) { 47 | m_coeff = ptr->share(); 48 | } 49 | virtual void check_reg_depth(const AzRegDepth *rd) const { 50 | if (rd == NULL) return; 51 | rd->check_if_nonincreasing("min-penalty regularizers"); 52 | } 53 | 54 | /*---------------------------------------------------------*/ 55 | virtual void _reset(const AzTrTree_ReadOnly *inp_tree, 56 | const AzRegDepth *inp_reg_depth); 57 | /*---------------------------------------------------------*/ 58 | 59 | /*--- for node split ---*/ 60 | /*---------------------------------------------------------*/ 61 | /*--- called by AzRgf_FindSplit_TR::begin for each tree ---*/ 62 | //! set current penalty 63 | virtual void _reset_forNewLeaf(const AzTrTree_ReadOnly *t, 64 | const AzRegDepth *rdep); 65 | /*--- called by AzRgf_FindSplit_TR::findSplit for each node ---*/ 66 | virtual double _reset_forNewLeaf(int f_nx, 67 | const AzTrTree_ReadOnly *t, 68 | const AzRegDepth *rdep); 69 | /*---------------------------------------------------------*/ 70 | 71 | /*--- for maintenance ---*/ 72 | virtual void show(const AzOut &out, 73 | const char *header) const { 74 | show(out, header, reg_depth, focus_nx, &av_dbar, tree, &v_bar); 75 | } 76 | 77 | /*---------------------------------------------------------*/ 78 | virtual void resetParam(AzParam ¶m); 79 | virtual void printParam(const AzOut &out) const; 80 | virtual void printHelp(AzHelp &h) const; 81 | 82 | virtual inline const char *signature() const { 83 | return "-___-_RGF_TsrOpt_"; 84 | } 85 | virtual inline const char *description() const { 86 | return "RGF w/min-penalty regularization"; 87 | } 88 | /*---------------------------------------------------------*/ 89 | 90 | /*--- static tools ---*/ 91 | static void show(const AzOut &out, 92 | const char *header, 93 | const AzRegDepth *reg_depth, 94 | int focus_nx, 95 | const AzDataArray *av_dbar, 96 | const AzTrTree_ReadOnly *tree, 97 | const AzDvect *v_bar); 98 | static void _propagate(int ite_num, 99 | const AzTrTree_ReadOnly *tree, 100 | int split_nx, 101 | const double new_leaf_w[2], 102 | const AzIntArr *ia_nonleaf, 103 | const AzRegDepth *reg_depth, 104 | AzDmat *m_coeff, /* inout */ 105 | AzDvect *v); /* output */ 106 | 107 | static void _show(int focus_nx, 108 | const AzDataArray *av_dbar, 109 | const AzTrTree_ReadOnly *tree, 110 | const AzDvect *v_bar, 111 | int nx, 112 | const AzDmat *m_coeff, 113 | const AzOut &out); 114 | 115 | static void setCoeff(const AzRegDepth *reg_depth, int depth, double coeff[4]); 116 | static void setCoeff(const AzRegDepth *reg_depth, 117 | const AzTrTree_ReadOnly *tree, 118 | AzDmat *m_coeff); 119 | 120 | protected: 121 | void resetTreeStructure(); 122 | bool isSameTreeStructure() const; 123 | void storeTreeStructure(); 124 | 125 | void update_v(); 126 | void update_dv(); 127 | void reset_v_dv() { 128 | av_dv.reset(); 129 | v_v.reset(); 130 | v_dv2_sum.reset(); 131 | } 132 | 133 | /*--- compute "bar" (auxiliary variables) iteratively ---*/ 134 | virtual void reset_bar(int split_nx, 135 | const AzIntArr *ia_leaf, 136 | const AzIntArr *ia_nonleaf); 137 | /*--- compute "bar"'s derivatives iteratively ---*/ 138 | virtual void deriv(int base_nx, /* derivative w.r.t. this node's weight */ 139 | const AzIntArr *ia_nonleaf, 140 | AzDvect *v_dbar); 141 | 142 | }; 143 | #endif 144 | 145 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzReg_TsrSib.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzReg_TsrSib.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_REG_TSRSIB_HPP_ 20 | #define _AZ_REG_TSRSIB_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzDmat.hpp" 24 | #include "AzTrTree_ReadOnly.hpp" 25 | #include "AzRegDepth.hpp" 26 | #include "AzReg_TreeReg.hpp" 27 | 28 | //! 29 | class AzReg_TsrSib : /* implements */ public virtual AzReg_TreeReg { 30 | protected: 31 | const AzTrTree_ReadOnly *tree; 32 | 33 | int focus_nx; 34 | const AzRegDepth *reg_depth; 35 | bool forNewLeaf; 36 | 37 | AzDataArray av_dv; 38 | AzDvect v_v; 39 | 40 | double vdv_sum, dv2_sum, dr, ddr; 41 | double newleaf_dep_factor; 42 | 43 | virtual void reset_values() { 44 | vdv_sum = dv2_sum = dr = ddr = 0; 45 | newleaf_dep_factor = 1; 46 | } 47 | 48 | public: 49 | AzReg_TsrSib() 50 | : tree(NULL), forNewLeaf(false), focus_nx(-1), 51 | reg_depth(NULL), newleaf_dep_factor(1), 52 | vdv_sum(0), dv2_sum(0), dr(0), ddr(0) {} 53 | 54 | void copyParam_from(const AzReg_TsrSib *inp) {} 55 | 56 | /*---------------------------------------------------------*/ 57 | virtual void reset(const AzTrTree_ReadOnly *inp_tree, 58 | const AzRegDepth *inp_reg_depth); 59 | /*---------------------------------------------------------*/ 60 | 61 | virtual void penalty_deriv(int nx, double *dr, 62 | double *ddr); 63 | 64 | virtual void changeWeight(int nx, double w_diff); 65 | 66 | inline void clearFocusNode() { 67 | focus_nx = -1; 68 | } 69 | 70 | /*--- for node split ---*/ 71 | /*---------------------------------------------------------*/ 72 | /*--- called by AzRgf_FindSplit_TR::begin ---*/ 73 | //! set current penalty 74 | virtual void reset_forNewLeaf(const AzTrTree_ReadOnly *t, 75 | const AzRegDepth *rdep); 76 | /*--- called by AzRgf_FindSplit_TR::findSplit ---*/ 77 | virtual void reset_forNewLeaf(int f_nx, 78 | const AzTrTree_ReadOnly *t, 79 | const AzRegDepth *rdep); 80 | /*---------------------------------------------------------*/ 81 | 82 | virtual double penalty_diff(const double leaf_w_delta[2]) const; 83 | virtual void penalty_deriv(double *dr, 84 | double *ddr) const; 85 | 86 | /*--- for maintenance ---*/ 87 | virtual void show(const AzOut &out, 88 | const char *header) const { 89 | 90 | } 91 | 92 | /*---------------------------------------------------------*/ 93 | virtual void resetParam(AzParam ¶m) {} 94 | virtual void printParam(const AzOut &out) const {} 95 | virtual void printHelp(AzHelp &h) const {} 96 | 97 | virtual inline const char *signature() const { 98 | return "-___-_RGF_TsrSib_"; 99 | } 100 | virtual inline const char *description() const { 101 | return "RGF w/min-penalty regularization w/sum-to-zero sibling constraints"; 102 | } 103 | /*---------------------------------------------------------*/ 104 | 105 | protected: 106 | void checkLeaf(const char *msg) const { 107 | if (!forNewLeaf || focus_nx < 0) { 108 | throw new AzException("AzReg_TsrSib::checkLeaf", msg); 109 | } 110 | } 111 | virtual void deriv_v(const AzTrTree_ReadOnly *tree, 112 | int leaf_nx, 113 | bool forNewLeaf, 114 | /* output */ 115 | AzSvect *v_dv, 116 | /* inout */ 117 | AzDvect *v_v) const; 118 | 119 | inline int get_newleaf_depth() const { 120 | return tree->node(focus_nx)->depth + 1; 121 | } 122 | virtual void update(); 123 | }; 124 | #endif 125 | 126 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzRgfTrainerSel.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgfTrainerSel.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_RGF_TRAINER_SEL_HPP_ 20 | #define _AZ_RGF_TRAINER_SEL_HPP_ 21 | 22 | #include "AzRgforest.hpp" 23 | #include "AzRgforest_TreeReg.hpp" 24 | #include "AzReg_TsrOpt.hpp" 25 | #include "AzReg_TsrSib.hpp" 26 | 27 | #include "AzTETselector.hpp" 28 | #include "AzPrint.hpp" 29 | 30 | //! Training algorithm selector. 31 | class AzRgfTrainerSel : /* implements */ public virtual AzTETselector { 32 | protected: 33 | AzRgforest rgf; 34 | AzRgforest_TreeReg rgf_sib; 35 | AzRgforest_TreeReg rgf_opt; 36 | 37 | #define kw_rgf "RGF" 38 | #define kw_rgf_sib "RGF_Sib" 39 | #define kw_rgf_opt "RGF_Opt" 40 | 41 | AzStrPool sp_name; 42 | AzDataArray alg; 43 | 44 | virtual void reset() { 45 | int id = 0; 46 | sp_name.putv(kw_rgf, id++); *alg.new_slot() = &rgf; 47 | sp_name.putv(kw_rgf_sib, id++); *alg.new_slot() = &rgf_sib; 48 | sp_name.putv(kw_rgf_opt, id++); *alg.new_slot() = &rgf_opt; 49 | sp_name.commit(); 50 | } 51 | 52 | public: 53 | AzRgfTrainerSel() { 54 | reset(); 55 | } 56 | 57 | virtual const char *dflt_name() const { 58 | return kw_rgf; 59 | } 60 | virtual const char *another_name() const { 61 | return kw_rgf_sib; 62 | } 63 | const AzStrArray *names() const { 64 | return &sp_name; 65 | } 66 | 67 | virtual void printOptions(const char *dlm, AzBytArr *s) const { 68 | int ix; 69 | for (ix = 0; ix < sp_name.size(); ++ix) { 70 | if (ix > 0) s->c(dlm); 71 | s->c(sp_name.c_str(ix)); 72 | } 73 | } 74 | 75 | virtual void printHelp(AzHelp &h) const { 76 | h.begin("", "", ""); 77 | int ix; 78 | for (ix = 0; ix < sp_name.size(); ++ix) { 79 | int id = sp_name.getValue(ix); 80 | const AzTETrainer *trainer = *alg.point(id); 81 | h.item(sp_name.c_str(ix), trainer->description()); 82 | } 83 | h.end(); 84 | } 85 | 86 | virtual AzTETrainer *select(const char *alg_name, //!< name of algorithm. 87 | //! if true, don't throw exception at error. 88 | bool dontThrow=false) const 89 | { 90 | AzTETrainer *trainer = NULL; 91 | 92 | int ex = sp_name.find(alg_name); 93 | if (ex < 0 && !dontThrow) { 94 | throw new AzException(AzInputNotValid, "algorithm name", alg_name); 95 | } 96 | if (ex >= 0) { 97 | int id = sp_name.getValue(ex); 98 | trainer = *alg.point(id); 99 | } 100 | return trainer; 101 | } 102 | }; 103 | 104 | #endif 105 | 106 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzRgfTreeEnsImp.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgfTreeEnsImp.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_RGF_TREE_ENS_IMP_HPP_ 20 | #define _AZ_RGF_TREE_ENS_IMP_HPP_ 21 | 22 | #include "AzRgfTreeEnsemble.hpp" 23 | #include "AzTrTreeEnsemble.hpp" 24 | 25 | //! implement AzRgfTreeEnsemble. T must be AzRgfTree or its extension. 26 | template 27 | class AzRgfTreeEnsImp : /*implements */public virtual AzRgfTreeEnsemble 28 | { 29 | protected: 30 | AzTrTreeEnsemble ens; 31 | 32 | public: 33 | AzRgfTreeEnsImp() {} 34 | ~AzRgfTreeEnsImp() {} 35 | inline bool usingTempFile() const { 36 | return ens.usingTempFile(); 37 | } 38 | inline void reset() { 39 | ens.reset(); 40 | } 41 | inline const char *param_c_str() const { 42 | return ens.param_c_str(); 43 | } 44 | 45 | inline double constant() const { 46 | return ens.constant(); 47 | } 48 | inline int orgdim() const { 49 | return ens.orgdim(); 50 | } 51 | inline void set_constant(double inp) { 52 | ens.set_constant(inp); 53 | } 54 | inline AzRgfTree *new_tree(int *out_tx=NULL) { 55 | return ens.new_tree(out_tx); 56 | } 57 | 58 | inline const AzRgfTree *tree(int tx) const { 59 | return ens.tree(tx); 60 | } 61 | inline AzRgfTree *tree_u(int tx) const { 62 | return ens.tree_u(tx); 63 | } 64 | 65 | inline T *rawtree_u(int tx) const { 66 | return ens.tree_u(tx); 67 | } 68 | 69 | inline int leafNum() const { 70 | return ens.leafNum(); 71 | } 72 | inline int leafNum(int tx0, int tx1) const { 73 | return ens.leafNum(tx0, tx1); 74 | } 75 | 76 | inline int lastIndex() const { 77 | return ens.lastIndex(); 78 | } 79 | inline int nextIndex() const { /* next slot */ 80 | return ens.nextIndex(); 81 | } 82 | 83 | inline int size() const { 84 | return ens.size(); 85 | } 86 | inline int max_size() const { 87 | return ens.max_size(); 88 | } 89 | inline bool isFull() const { 90 | return ens.isFull(); 91 | } 92 | inline void printHelp(AzHelp &h) const { 93 | ens.printHelp(h); 94 | } 95 | inline void copy_to(AzTreeEnsemble *out_ens, 96 | const char *config, const char *sign) const { 97 | ens.copy_to(out_ens, config, sign); 98 | } 99 | inline void copy_nodes_from(const AzTrTreeEnsemble_ReadOnly *inp) { 100 | ens.copy_nodes_from(inp); 101 | } 102 | inline void show(const AzSvFeatInfo *feat, 103 | const AzOut &out, const char *header="") const { 104 | ens.show(feat, out, header); 105 | } 106 | 107 | inline virtual void cold_start(AzParam ¶m, 108 | const AzBytArr *s_temp_prefix, 109 | int data_num, 110 | const AzOut &out, 111 | int tree_num_max, 112 | int inp_org_dim) { 113 | ens.cold_start(param, s_temp_prefix, data_num, 114 | out, tree_num_max, inp_org_dim); 115 | } 116 | inline virtual void warm_start(const AzTreeEnsemble *inp_ens, 117 | const AzDataForTrTree *data, 118 | AzParam ¶m, 119 | const AzBytArr *s_temp_prefix, 120 | const AzOut &out, 121 | int max_t_num, 122 | int search_t_num, /* to release work areas for the fixed trees */ 123 | AzDvect *v_p, /* inout */ 124 | const AzIntArr *inp_ia_tr_dx=NULL) { 125 | ens.warm_start(inp_ens, data, param, s_temp_prefix, out, max_t_num, search_t_num, 126 | v_p, inp_ia_tr_dx); 127 | } 128 | }; 129 | #endif 130 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzRgfTreeEnsemble.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgfTreeEnsemble.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_RGF_TREE_ENSEMBLE_HPP_ 20 | #define _AZ_RGF_TREE_ENSEMBLE_HPP_ 21 | 22 | #include "AzTrTreeEnsemble_ReadOnly.hpp" 23 | #include "AzRgfTree.hpp" 24 | #include "AzParam.hpp" 25 | #include "AzHelp.hpp" 26 | 27 | //! Abstract class: interface for ensemble of RGF-trees. 28 | /** 29 | * implemented by AzRgfTreeEnsImp 30 | **/ 31 | class AzRgfTreeEnsemble : /* extends */ public virtual AzTrTreeEnsemble_ReadOnly 32 | { 33 | public: 34 | virtual void set_constant(double inp) = 0; 35 | virtual AzRgfTree *new_tree(int *out_tx=NULL) = 0; 36 | virtual AzRgfTree *tree_u(int tx) const = 0; 37 | 38 | virtual int nextIndex() const = 0; 39 | virtual bool isFull() const = 0; 40 | 41 | virtual void copy_nodes_from(const AzTrTreeEnsemble_ReadOnly *inp) = 0; 42 | virtual void printHelp(AzHelp &h) const = 0; 43 | 44 | virtual void cold_start(AzParam ¶m, 45 | const AzBytArr *s_temp_prefix, /* may be NULL */ 46 | int data_num, 47 | const AzOut &out, 48 | int tree_num_max, 49 | int inp_org_dim) = 0; 50 | virtual void warm_start(const AzTreeEnsemble *inp_ens, 51 | const AzDataForTrTree *data, 52 | AzParam ¶m, 53 | const AzBytArr *s_temp_prefix, /* may be NULL */ 54 | const AzOut &out, 55 | int max_t_num, 56 | int search_t_num, 57 | AzDvect *v_p, /* inout */ 58 | const AzIntArr *inp_ia_tr_dx=NULL) = 0; 59 | }; 60 | #endif 61 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzRgf_FindSplit.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_FindSplit.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_RGF_FIND_SPLIT_HPP_ 20 | #define _AZ_RGF_FIND_SPLIT_HPP_ 21 | 22 | #include "AzTrTree_ReadOnly.hpp" 23 | #include "AzTrTsplit.hpp" 24 | #include "AzTrTtarget.hpp" 25 | #include "AzRgf_kw.hpp" 26 | #include "AzRegDepth.hpp" 27 | #include "AzParam.hpp" 28 | #include "AzFsinfo.hpp" 29 | #include "AzHelp.hpp" 30 | 31 | class AzRgf_FindSplit_input { 32 | public: 33 | int tx; 34 | const AzDataForTrTree *data; 35 | const AzTrTtarget *target; 36 | double lam_scale; /*!< for numerical stability of exp loss */ 37 | double nn; /* sum of data point weights if weighted */ 38 | 39 | AzRgf_FindSplit_input(int inp_tx, 40 | const AzDataForTrTree *inp_data, 41 | const AzTrTtarget *inp_target, 42 | double inp_lam_scale, 43 | double inp_nn) { 44 | tx = inp_tx; 45 | data = inp_data; 46 | target = inp_target; 47 | lam_scale = inp_lam_scale; 48 | nn = (double)inp_nn; 49 | } 50 | }; 51 | 52 | /*--------------------------------------------------------*/ 53 | //! Abstract class: interface for node split search for RGF. 54 | /** 55 | * Implemented by AzRgf_FindSplit_Dflt, AzRgf_FindSplit_TreeReg 56 | **/ 57 | class AzRgf_FindSplit { 58 | public: 59 | virtual void reset(AzParam ¶m, 60 | const AzRegDepth *reg_depth, 61 | const AzOut &out) = 0; 62 | 63 | virtual void begin(const AzTrTree_ReadOnly *tree, 64 | const AzRgf_FindSplit_input &inp, 65 | int min_size) 66 | = 0; 67 | virtual void begin(const AzTrTree_ReadOnly *tree, 68 | const AzRgf_FindSplit_input &inp, 69 | int min_size, 70 | AzFsinfoOnTree *fot) { /* added for TreeRegFast */ 71 | throw new AzException("AzRgf_FindSplit::begin(...fot)", 72 | "no appropriate override"); 73 | } 74 | 75 | virtual void pickFeats(int f_num, int data_num) = 0; 76 | 77 | virtual void end() = 0; 78 | virtual 79 | void findSplit(int nx, //!< node id 80 | /*--- output ---*/ 81 | AzTrTsplit *best_split) = 0; 82 | 83 | virtual void printParam(const AzOut &out) const = 0; 84 | virtual void printHelp(AzHelp &h) const = 0; 85 | }; 86 | #endif 87 | 88 | 89 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzRgf_FindSplit_Dflt.cpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_FindSplit_Dflt.cpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #include "AzRgf_FindSplit_Dflt.hpp" 20 | #include "AzHelp.hpp" 21 | #include "AzRgf_kw.hpp" 22 | 23 | /*--------------------------------------------------------*/ 24 | void AzRgf_FindSplit_Dflt::begin( 25 | const AzTrTree_ReadOnly *inp_tree, 26 | const AzRgf_FindSplit_input &inp, /* tx is not used */ 27 | int inp_min_size) 28 | { 29 | AzFindSplit::_begin(inp_tree, inp.data, inp.target, inp_min_size); 30 | 31 | nlam = inp.nn*lambda; 32 | nsig = inp.nn*sigma; 33 | if (inp.lam_scale != 1) { /* for numerical stability for expo loss */ 34 | nlam *= inp.lam_scale; 35 | nsig *= inp.lam_scale; 36 | } 37 | 38 | doUseInternalNodes = tree->usingInternalNodes(); 39 | } 40 | 41 | /*--------------------------------------------------------*/ 42 | double AzRgf_FindSplit_Dflt::getBestGain(double wsum, /* some of data weights */ 43 | double wrsum, /* weighted sum of residual */ 44 | double *best_q) const 45 | { 46 | double p = p_node->weight; /* parent's weight */ 47 | double gain = 0; 48 | double q = 0; 49 | 50 | if (doUseInternalNodes) { 51 | q = wrsum/(wsum+c_nlam); 52 | gain = q*q*(wsum+c_nlam); /* n*gain */ 53 | } 54 | else if (nsig <= 0) { /* L2 only */ 55 | q = (wrsum-c_nlam*p)/(wsum+c_nlam); 56 | gain = q*q*(wsum+c_nlam)+(p_nlam-2*c_nlam)*p*p/2; /* "/2" for two child nodes */ 57 | /* n*gain */ 58 | q += p; 59 | } 60 | else { /* L1 and L2; not tested after code change */ 61 | double _wysum = wrsum + wsum*p; 62 | if (_wysum > c_nsig) q = (_wysum-c_nsig)/(wsum+c_nlam); 63 | else if (_wysum < -c_nsig) q = (_wysum+c_nsig)/(wsum+c_nlam); 64 | else q = 0; 65 | double org_losshat = -2*p*_wysum+p*p*(wsum+p_nlam)+2*p_nsig*fabs(p); 66 | double new_losshat = -q*q*(wsum+c_nlam); 67 | gain = org_losshat - new_losshat; 68 | } 69 | 70 | *best_q = q; 71 | return gain; 72 | } 73 | 74 | /*--------------------------------------------------------*/ 75 | /*--------------------------------------------------------*/ 76 | void AzRgf_FindSplit_Dflt::resetParam(AzParam &p) 77 | { 78 | /*--- reg param shared with optimizer ---*/ 79 | p.vFloat(kw_lambda, &lambda); 80 | p.vFloat(kw_sigma, &sigma); 81 | 82 | /*--- override ... ---*/ 83 | p.vFloat(kw_s_lambda, &lambda); 84 | p.vFloat(kw_s_sigma, &sigma); 85 | 86 | if (lambda < 0) { 87 | throw new AzException(AzInputMissing, "AzRgf_FindSplit_Dflt", 88 | kw_lambda, "must be non-negative"); 89 | } 90 | if (sigma < 0) { 91 | throw new AzException(AzInputNotValid, "AzRgf_FindSplit_Dflt", 92 | kw_sigma, "must be non-negative"); 93 | } 94 | } 95 | 96 | /*--------------------------------------------------------*/ 97 | void AzRgf_FindSplit_Dflt::printParam(const AzOut &out) const 98 | { 99 | if (out.isNull()) return; 100 | 101 | AzPrint o(out); 102 | o.reset_options(); 103 | o.set_precision(5); 104 | o.ppBegin("AzRgf_FindSplit_Dflt", "Node split", ", "); 105 | o.printV(kw_lambda, lambda); 106 | o.printV_posiOnly(kw_sigma, sigma); 107 | o.ppEnd(); 108 | } 109 | 110 | /*--------------------------------------------------------*/ 111 | void AzRgf_FindSplit_Dflt::printHelp(AzHelp &h) const 112 | { 113 | h.begin(Azsplit_config, "AzRgf_FindSplit_Dflt", "Regularization at node split"); 114 | h.item_required(kw_lambda, help_lambda); 115 | h.item_experimental(kw_sigma, help_sigma, sigma_dflt); 116 | h.item(kw_s_lambda, help_s_lambda); 117 | h.item_experimental(kw_s_sigma, help_s_sigma); 118 | h.end(); 119 | } 120 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzRgf_FindSplit_Dflt.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_FindSplit_Dflt.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_RGF_FIND_SPLIT_DFLT_HPP_ 20 | #define _AZ_RGF_FIND_SPLIT_DFLT_HPP_ 21 | 22 | #include "AzFindSplit.hpp" 23 | #include "AzTrTree_ReadOnly.hpp" 24 | #include "AzRgf_FindSplit.hpp" 25 | #include "AzRegDepth.hpp" 26 | #include "AzParam.hpp" 27 | 28 | //! Node split search for RGF. L2 regularization. 29 | /*--------------------------------------------------------*/ 30 | class AzRgf_FindSplit_Dflt : /* extends */ public virtual AzFindSplit, 31 | /* implements */ public virtual AzRgf_FindSplit 32 | { 33 | protected: 34 | double lambda, sigma; 35 | const AzRegDepth *reg_depth; 36 | 37 | double nlam, nsig; 38 | double p_nlam; //!< L2 reg param for parent (node to be split) 39 | double c_nlam; //!< L2 reg param for child (new node after split) 40 | double p_nsig, c_nsig; 41 | bool doUseInternalNodes; 42 | const AzTrTreeNode *p_node; //!< parent node (node to be split) 43 | 44 | public: 45 | AzRgf_FindSplit_Dflt() : reg_depth(NULL), 46 | lambda(-1), sigma(sigma_dflt), 47 | doUseInternalNodes(false), nlam(0), nsig(0), 48 | p_nlam(0), c_nlam(0), p_nsig(0), c_nsig(0), 49 | p_node(NULL) {} 50 | virtual void begin(const AzTrTree_ReadOnly *tree, 51 | const AzRgf_FindSplit_input &inp, 52 | int inp_min_size); 53 | virtual void end() { 54 | _end(); 55 | } 56 | virtual inline 57 | void findSplit(int nx, 58 | /*--- output ---*/ 59 | AzTrTsplit *best_split) { 60 | p_node = tree->node(nx); 61 | p_nlam = reg_depth->apply(nlam, p_node->depth); 62 | c_nlam = reg_depth->apply(nlam, p_node->depth+1); 63 | p_nsig = reg_depth->apply(nsig, p_node->depth); 64 | c_nsig = reg_depth->apply(nsig, p_node->depth+1); 65 | AzFindSplit::_findBestSplit(nx, best_split); 66 | } 67 | virtual void reset(AzParam ¶m, 68 | const AzRegDepth *inp_reg_depth, 69 | const AzOut &out) 70 | { 71 | reg_depth = inp_reg_depth; 72 | if (reg_depth == NULL) throw new AzException("AzRgf_FindSplit_Dflt", 73 | "null reg_depth"); 74 | resetParam(param); 75 | printParam(out); 76 | } 77 | 78 | virtual void pickFeats(int pick_num, int f_num) { 79 | AzFindSplit::_pickFeats(pick_num, f_num); 80 | } 81 | 82 | virtual void printParam(const AzOut &out) const; 83 | virtual void printHelp(AzHelp &h) const; 84 | 85 | protected: 86 | virtual void resetParam(AzParam ¶m); 87 | virtual double getBestGain(double wsum, 88 | double wysum, 89 | double *best_q) const; 90 | }; 91 | #endif 92 | 93 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzRgf_FindSplit_TreeReg.cpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_FindSplit_TreeReg.cpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #include "AzRgf_FindSplit_TreeReg.hpp" 20 | 21 | /*--------------------------------------------------------*/ 22 | void AzRgf_FindSplit_TreeReg::findSplit(int nx, 23 | /*--- output ---*/ 24 | AzTrTsplit *best_split) 25 | { 26 | if (tree->usingInternalNodes()) { 27 | throw new AzException("AzRgf_FindSplit_TreeReg::findSplit", 28 | "can't coexist with UseInternalNodes"); 29 | } 30 | 31 | reg->reset_forNewLeaf(nx, tree, reg_depth); 32 | dR = ddR = 0; 33 | reg->penalty_deriv(&dR, &ddR); 34 | AzRgf_FindSplit_Dflt::findSplit(nx, best_split); 35 | } 36 | 37 | /*--------------------------------------------------------*/ 38 | double AzRgf_FindSplit_TreeReg::evalSplit( 39 | const Az_forFindSplit i[2], 40 | double bestP[2]) 41 | const 42 | { 43 | double d[2]; /* delta */ 44 | int ix; 45 | for (ix = 0; ix < 2; ++ix) { 46 | double wrsum = i[ix].wy_sum; 47 | d[ix] = (wrsum-nlam*dR)/(i[ix].w_sum+nlam*ddR); 48 | bestP[ix] = p_node->weight + d[ix]; 49 | } 50 | 51 | double penalty_diff = reg->penalty_diff(d); /* new - old */ 52 | 53 | double gain = 2*d[0]*i[0].wy_sum - d[0]*d[0]*i[0].w_sum 54 | + 2*d[1]*i[1].wy_sum - d[1]*d[1]*i[1].w_sum; 55 | 56 | gain -= 2 * nlam * penalty_diff; 57 | /* "2*" b/c penalty is sum v^2/2 */ 58 | 59 | return gain; 60 | } 61 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzRgf_FindSplit_TreeReg.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_FindSplit_TreeReg.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_RGF_FIND_SPLIT_TREE_REG_HPP_ 20 | #define _AZ_RGF_FIND_SPLIT_TREE_REG_HPP_ 21 | 22 | #include "AzRgf_FindSplit_Dflt.hpp" 23 | #include "AzReg_TreeReg.hpp" 24 | #include "AzReg_TreeRegArr.hpp" 25 | 26 | //! Node split search for RGF. L2 and tree structure regularization 27 | /*--------------------------------------------------------*/ 28 | class AzRgf_FindSplit_TreeReg : /* extends */ public virtual AzRgf_FindSplit_Dflt 29 | { 30 | protected: 31 | AzReg_TreeRegArr *reg_arr; 32 | AzReg_TreeReg *reg; 33 | double dR, ddR; 34 | 35 | public: 36 | AzRgf_FindSplit_TreeReg() : dR(0), ddR(0), reg(NULL), reg_arr(NULL) {} 37 | void reset(AzReg_TreeRegArr *inp_reg_arr) { 38 | reg_arr = inp_reg_arr; 39 | } 40 | 41 | //! override 42 | virtual void begin(const AzTrTree_ReadOnly *tree, 43 | const AzRgf_FindSplit_input &inp, 44 | int inp_min_size) 45 | { 46 | AzRgf_FindSplit_Dflt::begin(tree, inp, inp_min_size); 47 | reg = reg_arr->reg_forNewLeaf(inp.tx); 48 | reg->reset_forNewLeaf(tree, reg_depth); 49 | } 50 | 51 | //! override 52 | virtual void end() { 53 | AzRgf_FindSplit_Dflt::end(); 54 | reg = NULL; 55 | } 56 | 57 | //! override 58 | virtual void findSplit(int nx, AzTrTsplit *best_split); 59 | 60 | //! override AzFindSplit::evalSplit 61 | virtual double evalSplit(const Az_forFindSplit i[2], 62 | double bestP[2]) const; 63 | }; 64 | #endif 65 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzRgf_Optimizer.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_Optimizer.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_RGF_OPTIMIZER_HPP_ 20 | #define _AZ_RGF_OPTIMIZER_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzDmat.hpp" 24 | #include "AzBmat.hpp" 25 | #include "AzDataForTrTree.hpp" 26 | #include "AzLoss.hpp" 27 | #include "AzTrTreeEnsemble.hpp" 28 | #include "AzRgfTreeEnsemble.hpp" 29 | #include "AzRegDepth.hpp" 30 | #include "AzParam.hpp" 31 | #include "AzHelp.hpp" 32 | 33 | //! Abstract class: Weight optimizer interface. 34 | /*-------------------------------------------------------*/ 35 | class AzRgf_Optimizer 36 | { 37 | public: 38 | /*! Initialization */ 39 | virtual void cold_start(AzLossType loss_type, 40 | const AzDataForTrTree *training_data, /*!< training data */ 41 | const AzRegDepth *reg_depth, /*!< regularizer using tree attributes */ 42 | AzParam ¶m, /*!< confignuration */ 43 | const AzDvect *v_y, /*!< training targets */ 44 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */ 45 | const AzOut out, /*!< where to wrige log */ 46 | /*! output: prediction on training data. typically zeroes. */ 47 | AzDvect *v_p) 48 | = 0; 49 | 50 | virtual void warm_start(AzLossType loss_type, 51 | const AzDataForTrTree *training_data, /*!< training data */ 52 | const AzRegDepth *reg_depth, /*!< regularizer using tree attributes */ 53 | AzParam ¶m, /*!< confignuration */ 54 | const AzDvect *v_y, /*!< training targets */ 55 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */ 56 | const AzOut out, /*!< where to wrige log */ 57 | /*--- for warm start ---*/ 58 | const AzTrTreeEnsemble_ReadOnly *inp_ens, 59 | const AzDvect *inp_v_p) 60 | = 0; 61 | 62 | /*! Optimize weights. */ 63 | virtual void 64 | update(const AzDataForTrTree *data, /*!< training data */ 65 | AzRgfTreeEnsemble *ens, /*!< inout: tree ensmeble. */ 66 | /*--- output ---*/ 67 | AzDvect *v_p) /*. 17 | * * * * */ 18 | 19 | #include "AzRgf_Optimizer_Dflt.hpp" 20 | #include "AzRgfTreeEnsImp.hpp" 21 | #include "AzHelp.hpp" 22 | 23 | /*------------------------------------------------------------------*/ 24 | void AzRgf_Optimizer_Dflt::cold_start(AzLossType loss_type, 25 | const AzDataForTrTree *data, 26 | const AzRegDepth *reg_depth, 27 | AzParam ¶m, 28 | const AzDvect *v_y, /* for training */ 29 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */ 30 | const AzOut out_req, 31 | AzDvect *out_v_pval) /* output */ 32 | { 33 | out = out_req; 34 | bool beVerbose = resetParam(param); 35 | trainer->reset(loss_type, v_y, v_fixed_dw, reg_depth, param, beVerbose, out); 36 | trainer->copyPred_to(out_v_pval); 37 | feat1.reset(data, param, out); 38 | 39 | if (!beVerbose) { 40 | out.deactivate(); 41 | } 42 | } 43 | 44 | /*------------------------------------------------------------------*/ 45 | void AzRgf_Optimizer_Dflt::warm_start(AzLossType loss_type, 46 | const AzDataForTrTree *data, 47 | const AzRegDepth *reg_depth, 48 | AzParam ¶m, 49 | const AzDvect *v_y, /* for training */ 50 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */ 51 | const AzOut out_req, 52 | /*--- for warm start ---*/ 53 | const AzTrTreeEnsemble_ReadOnly *inp_ens, 54 | const AzDvect *inp_v_p) 55 | { 56 | out = out_req; 57 | bool beVerbose = resetParam(param); 58 | feat1.reset(data, param, out); 59 | AzIntArr ia_removed_fx; 60 | feat1.update_with_ens(inp_ens, &ia_removed_fx); 61 | 62 | trainer->reset(loss_type, v_y, v_fixed_dw, reg_depth, param, beVerbose, out, 63 | inp_ens, &feat1, inp_v_p); /* for warm start */ 64 | if (!beVerbose) { 65 | out.deactivate(); 66 | } 67 | } 68 | 69 | /*------------------------------------------------------------------*/ 70 | void 71 | AzRgf_Optimizer_Dflt::update(const AzDataForTrTree *data, 72 | AzRgfTreeEnsemble *ens, /* inout */ 73 | /*--- inout ---*/ 74 | AzDvect *v_p) /* prediction */ 75 | { 76 | AzIntArr ia_removed_fx; 77 | int f_num_delta = feat1.update_with_ens(ens, &ia_removed_fx); 78 | 79 | if (f_num_delta > 0 || ens->size() == 0) { 80 | trainer->optimize(ens, &feat1); 81 | } 82 | else { 83 | AzTimeLog::print("No new feature", out); 84 | } 85 | 86 | if (v_p != NULL) { 87 | trainer->copyPred_to(v_p); 88 | } 89 | } 90 | 91 | /*------------------------------------------------------------------*/ 92 | /*------------------------------------------------------------------*/ 93 | /* static */ 94 | void AzRgf_Optimizer_Dflt::_info(const AzTrTreeEnsemble_ReadOnly *ens, 95 | const AzOptimizerT *my_trainer, 96 | const AzTrTreeFeat *my_feat, 97 | int *f_num, int *nz_f_num) 98 | { 99 | const AzDvect *v_w = my_trainer->weights(); 100 | *f_num = v_w->rowNum(); 101 | *nz_f_num = my_feat->countNonzeroNodup(v_w, ens); 102 | } 103 | 104 | /*------------------------------------------------------------------*/ 105 | void AzRgf_Optimizer_Dflt::apply(const AzDataForTrTree *test_data, 106 | AzBmat *b_test_tran, /* inout */ 107 | const AzTrTreeEnsemble_ReadOnly *ens, 108 | /*--- output ---*/ 109 | AzDvect *v_p, 110 | int *f_num, 111 | int *nz_f_num) const 112 | { 113 | if (test_data != NULL) { 114 | feat1.updateMatrix(test_data, ens, b_test_tran); 115 | trainer->resetPred(b_test_tran, v_p); 116 | } 117 | 118 | /*--- set info ---*/ 119 | _info(ens, trainer, &feat1, f_num, nz_f_num); 120 | } 121 | 122 | /*------------------------------------------------------------------*/ 123 | /*------------------------------------------------------------------*/ 124 | bool AzRgf_Optimizer_Dflt::resetParam(AzParam &p) 125 | { 126 | bool beVerbose = false; 127 | p.swOn(&beVerbose, kw_opt_beVerbose); 128 | return beVerbose; 129 | } 130 | 131 | /*------------------------------------------------------------------*/ 132 | void AzRgf_Optimizer_Dflt::printHelp(AzHelp &h) const 133 | { 134 | trainer->printHelp(h); 135 | feat1.printHelp(h); 136 | h.begin(Azopt_config, "AzRgf_Optimizer_Dflt", "Weight optimization/correction"); 137 | h.item_experimental(kw_opt_beVerbose, help_opt_beVerbose); 138 | h.end(); 139 | } 140 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzRgf_Optimizer_Dflt.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_Optimizer_Dflt.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_RGF_OPTIMIZER_DFLT_HPP_ 20 | #define _AZ_RGF_OPTIMIZER_DFLT_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzDmat.hpp" 24 | #include "AzDataForTrTree.hpp" 25 | #include "AzRgfTree.hpp" 26 | #include "AzTrTreeEnsemble.hpp" 27 | #include "AzLoss.hpp" 28 | #include "AzTrTreeFeat.hpp" 29 | #include "AzOptOnTree.hpp" 30 | #include "AzRgf_Optimizer.hpp" 31 | #include "AzOptimizerT.hpp" 32 | 33 | //! implement AzRgf_Optimizer. 34 | /*-------------------------------------------------------*/ 35 | class AzRgf_Optimizer_Dflt : /* implements */ public virtual AzRgf_Optimizer 36 | { 37 | protected: 38 | AzTrTreeFeat feat1; 39 | AzOut out; 40 | 41 | AzOptOnTree trainer_dflt; /* linear trainer */ 42 | AzOptimizerT *trainer; 43 | 44 | public: 45 | AzRgf_Optimizer_Dflt() : trainer(&trainer_dflt) {} 46 | ~AzRgf_Optimizer_Dflt() {} 47 | AzRgf_Optimizer_Dflt(const AzRgf_Optimizer_Dflt *inp) { 48 | reset(inp); 49 | } 50 | 51 | /*------------------------------------------------------*/ 52 | /* override this to replace trainer */ 53 | /*------------------------------------------------------*/ 54 | virtual void reset(const AzRgf_Optimizer_Dflt *inp) { 55 | if (inp == NULL) return; 56 | feat1.reset(&inp->feat1); 57 | out = inp->out; 58 | trainer_dflt.reset(&inp->trainer_dflt); 59 | trainer = &trainer_dflt; 60 | } 61 | /*------------------------------------------------------*/ 62 | 63 | /*------------------------------------------------------*/ 64 | /* derived classes must override this */ 65 | /*------------------------------------------------------*/ 66 | virtual void temp_update_apply(const AzDataForTrTree *tr_data, 67 | AzRgfTreeEnsemble *temp_ens, 68 | const AzDataForTrTree *test_data, 69 | AzBmat *temp_b, AzDvect *v_test_p, 70 | int *f_num, int *nz_f_num) const { 71 | AzRgf_Optimizer_Dflt temp_opt(this); 72 | temp_opt.update(tr_data, temp_ens); 73 | if (test_data != NULL) temp_opt.apply(test_data, temp_b, temp_ens, 74 | v_test_p, f_num, nz_f_num); 75 | } 76 | /*--------------------------------------------------------*/ 77 | 78 | 79 | virtual void cold_start(AzLossType loss_type, 80 | const AzDataForTrTree *data, 81 | const AzRegDepth *reg_depth, 82 | AzParam ¶m, 83 | const AzDvect *v_yval, 84 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */ 85 | const AzOut out_req, 86 | AzDvect *v_pval); /* output */ 87 | virtual void warm_start(AzLossType loss_type, 88 | const AzDataForTrTree *data, 89 | const AzRegDepth *reg_depth, 90 | AzParam ¶m, 91 | const AzDvect *v_y, /* for training */ 92 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */ 93 | const AzOut out_req, 94 | /*--- for warm start ---*/ 95 | const AzTrTreeEnsemble_ReadOnly *ens, 96 | const AzDvect *v_p); 97 | virtual void 98 | update(const AzDataForTrTree *data, 99 | AzRgfTreeEnsemble *ens, /* inout */ 100 | /*--- output ---*/ 101 | AzDvect *v_p=NULL); /* prediction */ 102 | 103 | virtual void apply(const AzDataForTrTree *data, 104 | AzBmat *b_test_tran, /* inout */ 105 | const AzTrTreeEnsemble_ReadOnly *ens, 106 | /*--- output ---*/ 107 | AzDvect *v_p, 108 | int *f_num, int *nz_f_num) const; 109 | 110 | virtual void printHelp(AzHelp &h) const; 111 | 112 | protected: 113 | virtual bool resetParam(AzParam ¶m); 114 | static void _info(const AzTrTreeEnsemble_ReadOnly *ens, 115 | const AzOptimizerT *my_trainer, 116 | const AzTrTreeFeat *my_feat, 117 | int *f_num, int *nz_f_num); 118 | }; 119 | #endif 120 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzRgf_Optimizer_TreeReg.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_Optimizer_TreeReg.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_RGF_OPTIMIZER_TREE_REG_HPP_ 20 | #define _AZ_RGF_OPTIMIZER_TREE_REG_HPP_ 21 | 22 | #include "AzRgf_Optimizer_Dflt.hpp" 23 | #include "AzOptOnTree_TreeReg.hpp" 24 | 25 | //! for regularization using tree structure. 26 | /*-------------------------------------------------------*/ 27 | class AzRgf_Optimizer_TreeReg : /* extends */ public virtual AzRgf_Optimizer_Dflt 28 | { 29 | protected: 30 | AzOptOnTree_TreeReg trainer_tr; 31 | 32 | public: 33 | AzRgf_Optimizer_TreeReg() { 34 | trainer = &trainer_tr; 35 | } 36 | void reset(AzReg_TreeRegArr *reg_arr) { 37 | trainer_tr.reset(reg_arr); 38 | } 39 | AzRgf_Optimizer_TreeReg(const AzRgf_Optimizer_TreeReg *inp) { 40 | reset(inp); 41 | } 42 | 43 | /*------------------------------------------------------*/ 44 | /* override this to replace trainer */ 45 | /*------------------------------------------------------*/ 46 | virtual void reset(const AzRgf_Optimizer_TreeReg *inp) { 47 | if (inp == NULL) return; 48 | AzRgf_Optimizer_Dflt::reset(inp); 49 | 50 | trainer_tr.reset(&inp->trainer_tr); 51 | trainer = &trainer_tr; 52 | } 53 | 54 | /*------------------------------------------------------*/ 55 | /* derived classes must override this */ 56 | /*------------------------------------------------------*/ 57 | virtual void temp_update_apply(const AzDataForTrTree *tr_data, 58 | AzRgfTreeEnsemble *temp_ens, 59 | const AzDataForTrTree *test_data, 60 | AzBmat *temp_b, AzDvect *v_test_p, 61 | int *f_num, int *nz_f_num) const { 62 | AzRgf_Optimizer_TreeReg temp_opt(this); 63 | temp_opt.update(tr_data, temp_ens); 64 | if (test_data != NULL) temp_opt.apply(test_data, temp_b, temp_ens, 65 | v_test_p, f_num, nz_f_num); 66 | } 67 | /*--------------------------------------------------------*/ 68 | }; 69 | #endif 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzRgforest_TreeReg.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgforest_TreeReg.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_RGFOREST_TREEREG_HPP_ 20 | #define _AZ_RGFOREST_TREEREG_HPP_ 21 | 22 | #include "AzRgforest.hpp" 23 | 24 | #include "AzRgf_Optimizer_TreeReg.hpp" 25 | #include "AzRgf_FindSplit_TreeReg.hpp" 26 | #include "AzReg_TreeRegArrImp.hpp" 27 | 28 | //! RGF with min-penalty regularization. 29 | 30 | template 31 | class AzRgforest_TreeReg : /* implements */ public virtual AzRgforest { 32 | protected: 33 | AzRgf_FindSplit_TreeReg tr_fs; 34 | AzRgf_Optimizer_TreeReg tr_opt; /* weight optimizer */ 35 | AzReg_TreeRegArrImp reg_arr; 36 | AzBytArr s_sign, s_desc; 37 | 38 | public: 39 | AzRgforest_TreeReg() 40 | { 41 | tr_fs.reset(®_arr); 42 | tr_opt.reset(®_arr); 43 | opt = &tr_opt; 44 | fs = &tr_fs; 45 | s_sign.reset(reg_arr.tmpl()->signature()); 46 | s_desc.reset(reg_arr.tmpl()->description()); 47 | reg_depth->set_default_for_min_penalty(); 48 | } 49 | virtual inline const char *signature() const { 50 | return s_sign.c_str(); 51 | } 52 | virtual inline const char *description() const { 53 | return s_desc.c_str(); 54 | } 55 | 56 | virtual void printHelp(AzHelp &h) const { 57 | AzRgforest::printHelp(h); 58 | reg_arr.tmpl()->printHelp(h); 59 | h.begin(Azforest_config, "AzRgforest_TreeReg", "For min-penalty regularization"); 60 | h.item(kw_doApproxTsr, help_doApproxTsr); 61 | h.end(); 62 | } 63 | 64 | protected: 65 | virtual int resetParam(AzParam ¶m) { /* returns max #tree */ 66 | int max_tree_num = AzRgforest::resetParam(param); 67 | 68 | bool doApproxTsr = false; 69 | param.swOn(&doApproxTsr, kw_doApproxTsr); 70 | if (doApproxTsr) { 71 | AzPrint o(out); 72 | o.ppBegin("AzRgforest_TreeReg", "Approximation", ", "); 73 | o.printSw(kw_doApproxTsr, doApproxTsr); 74 | o.ppEnd(); 75 | if (doForceToRefreshAll) { 76 | doForceToRefreshAll = false; 77 | AzPrint::writeln(out, "Turning off ", kw_doForceToRefreshAll); 78 | } 79 | } 80 | else { 81 | if (!doForceToRefreshAll) { 82 | doForceToRefreshAll = true; 83 | AzPrint::writeln(out, "Turning on ", kw_doForceToRefreshAll); 84 | } 85 | } 86 | 87 | reg_arr.tmpl_u()->resetParam(param); 88 | reg_arr.tmpl()->printParam(out); 89 | reg_arr.reset(max_tree_num); 90 | return max_tree_num; 91 | } 92 | 93 | virtual void end_of_initialization() { 94 | AzRgforest::end_of_initialization(); 95 | reg_arr.tmpl()->check_reg_depth(reg_depth); 96 | } 97 | }; 98 | #endif 99 | 100 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTET_Eval.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTET_Eval.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TET_EVAL_HPP_ 20 | #define _AZ_TET_EVAL_HPP_ 21 | 22 | #include "AzDataForTrTree.hpp" 23 | #include "AzLoss.hpp" 24 | #include "AzTE_ModelInfo.hpp" 25 | #include "AzPerfResult.hpp" 26 | 27 | //! Abstract class: interface for evaluation modules for Tree Ensemble Trainer. 28 | /*-------------------------------------------------------*/ 29 | class AzTET_Eval { 30 | public: 31 | virtual void reset(const AzDvect *inp_v_y, 32 | const char *perf_fn, 33 | bool inp_doAppend) = 0; 34 | virtual void begin(const char *config="", 35 | AzLossType loss_type=AzLoss_None) = 0; 36 | virtual void resetConfig(const char *config) = 0; 37 | virtual void end() = 0; 38 | virtual void evaluate(const AzDvect *v_p, const AzTE_ModelInfo *info, 39 | const char *user_str=NULL) = 0; 40 | virtual bool isActive() const = 0; 41 | }; 42 | #endif 43 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTET_Eval_Dflt.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTET_Eval_Dflt.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TET_EVAL_DFLT_HPP_ 20 | #define _AZ_TET_EVAL_DFLT_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzTaskTools.hpp" 24 | #include "AzPerfResult.hpp" 25 | #include "AzDataForTrTree.hpp" 26 | #include "AzLoss.hpp" 27 | #include "AzTET_Eval.hpp" 28 | 29 | //! Evaluationt module for Tree Ensemble Trainer. 30 | /*-------------------------------------------------------*/ 31 | class AzTET_Eval_Dflt : /* implements */ public virtual AzTET_Eval { 32 | protected: 33 | /*--- targets ---*/ 34 | const AzDvect *v_y; 35 | 36 | /*--- to output evaluation results ---*/ 37 | AzBytArr s_perf_fn; 38 | AzPerfType perf_type; 39 | 40 | AzLossType loss_type; 41 | AzBytArr s_config; 42 | 43 | AzOfs ofs; 44 | AzOut perf_out; 45 | bool doAppend; 46 | 47 | public: 48 | AzTET_Eval_Dflt() : v_y(NULL), 49 | loss_type(AzLoss_None), doAppend(false) {} 50 | ~AzTET_Eval_Dflt() { 51 | end(); 52 | } 53 | inline virtual bool isActive() const { 54 | if (v_y != NULL) return true; 55 | return false; 56 | } 57 | 58 | virtual void reset() { 59 | v_y= NULL; 60 | s_perf_fn.reset(); 61 | s_config.reset(); 62 | if (ofs.is_open()) { 63 | ofs.close(); 64 | } 65 | } 66 | void reset(const AzDvect *inp_v_y, 67 | const char *perf_fn, 68 | bool inp_doAppend) 69 | { 70 | v_y = inp_v_y; 71 | s_perf_fn.reset(perf_fn); 72 | doAppend = inp_doAppend; 73 | } 74 | virtual void resetConfig(const char *config) { 75 | s_config.reset(config); 76 | _clean(&s_config); 77 | } 78 | 79 | virtual void begin(const char *config, 80 | AzLossType inp_loss_type) { 81 | if (!isActive()) return; 82 | _begin(config, inp_loss_type); 83 | _clean(&s_config); 84 | } 85 | virtual void end() { 86 | if (ofs.is_open()) { 87 | ofs.close(); 88 | } 89 | } 90 | 91 | virtual void evaluate(const AzDvect *v_p, 92 | const AzTE_ModelInfo *info, 93 | const char *user_str=NULL) { 94 | if (!isActive()) return; 95 | AzPerfResult result=AzTaskTools::eval(v_p, v_y, loss_type); 96 | 97 | /*--- signature and configuration ---*/ 98 | AzBytArr s_sign_config(info->s_sign); 99 | s_sign_config.concat(":"); 100 | concat_config(info, &s_sign_config); 101 | 102 | /*--- print ---*/ 103 | AzPrint o(perf_out); 104 | o.printBegin("", ",", ","); 105 | o.print("#tree", info->tree_num); 106 | o.print("#leaf", info->leaf_num); 107 | o.print("acc", result.acc, 4); 108 | o.print("rmse", result.rmse, 4); 109 | o.print("sqerr", result.rmse*result.rmse, 6); 110 | o.print(loss_str[loss_type]); 111 | o.print("loss", result.loss, 6); 112 | o.print("#test", v_p->rowNum()); 113 | o.print("cfg"); /* for compatibility */ 114 | o.print(s_sign_config); 115 | if (user_str != NULL) { 116 | o.print(user_str); 117 | } 118 | o.printEnd(); 119 | } 120 | 121 | protected: 122 | virtual void _begin(const char *config, AzLossType inp_loss_type) { 123 | s_config.reset(config); 124 | loss_type = inp_loss_type; 125 | 126 | if (ofs.is_open()) { 127 | ofs.close(); 128 | } 129 | const char *perf_fn = s_perf_fn.c_str(); 130 | if (AzTools::isSpecified(perf_fn)) { 131 | ios_base::openmode mode = ios_base::out; 132 | if (doAppend) { 133 | mode = ios_base::app | ios_base::out; 134 | } 135 | ofs.open(perf_fn, mode); 136 | ofs.set_to(perf_out); 137 | } 138 | else { 139 | perf_out.setStdout(); 140 | } 141 | } 142 | 143 | virtual void _clean(AzBytArr *s) const { 144 | /*-- replace comma with : for convenience later --*/ 145 | s->replace(',', ';'); 146 | } 147 | 148 | virtual void concat_config(const AzTE_ModelInfo *info, AzBytArr *s) const { 149 | if (s_config.length() > 0) { 150 | s->concat(&s_config); 151 | } 152 | else { 153 | AzBytArr s_cfg(&info->s_config); 154 | _clean(&s_cfg); 155 | s->concat(&s_cfg); 156 | } 157 | } 158 | }; 159 | #endif 160 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTETmain_kw.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTETmain_kw.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TET_MAIN_KW_HPP_ 20 | #define _AZ_TET_MAIN_KW_HPP_ 21 | 22 | #define kw_train "train" 23 | #define kw_train_test "train_test" 24 | #define kw_predict "predict" 25 | #define kw_batch_predict "batch_predict" 26 | #define kw_train_predict "train_predict" 27 | #define kw_features "output_features" 28 | #define help_train "Train and save models to files." 29 | #define help_train_test "Train and test models. Optionally models can be saved to files." 30 | #define help_train_predict "Train models and save predictions on test data to files. Models can also be saved to files." 31 | #define help_predict "Apply a model saved by \"train\" to new data." 32 | #define help_batch_predict "Apply several models to new data." 33 | #define help_features "Output features generated by tree ensembles." 34 | 35 | #define kw_alg_name "algorithm=" 36 | #define kw_train_x_fn "train_x_fn=" 37 | #define kw_train_y_fn "train_y_fn=" 38 | #define kw_fdic_fn "x_name_fn=" 39 | #define kw_model_stem "model_fn_prefix=" 40 | #define kw_model_names_fn "model_names_fn=" 41 | #define kw_model_fn "model_fn=" 42 | #define kw_prev_model_fn "model_fn_for_warmstart=" 43 | #define kw_pred_fn "prediction_fn=" 44 | #define kw_pred_fn_suffix "pred_fn_suffix=" 45 | #define kw_not_doLog "DontLog" 46 | #define kw_doLog "Log" 47 | #define kw_doDump "Dump" 48 | #define kw_eval_fn "evaluation_fn=" 49 | #define kw_doAppend_eval "Append_evaluation" 50 | #define kw_test_x_fn "test_x_fn=" 51 | #define kw_test_y_fn "test_y_fn=" 52 | #define kw_dw_fn "train_w_fn=" 53 | #define kw_doSaveLastModelOnly "SaveLastModelOnly" 54 | 55 | #define kw_xv_doShuffle "ShuffleData" 56 | #define kw_xv_num "num_xv=" 57 | #define kw_xv_fn "xv_fn=" 58 | #define kw_input_x_fn "input_x_fn=" 59 | #define kw_output_x_fn "output_x_fn=" 60 | #define kw_features_digits "features_digits=" 61 | #define kw_doSparse_features "SparseFeatures" 62 | 63 | #define help_train_x_fn "Path to the feature file of training data." 64 | #define help_train_y_fn "Path to the target file of training data." 65 | #define help_fdic_fn "Path to the file of feature names." 66 | #define help_model_stem "To save models, path names are generated by attaching \"-01\", \"-02\",... to this value." 67 | #define help_model_stem_tp "Path names to model files are generated by attaching \"-01\", \"-02\",... to this value. Path names to prediction files and model info files are generated by attaching \".pred\" of \".info\" to the model path names." 68 | #define help_model_names_fn_out "Path to the file to write model path names to. If omitted, model path names are not saved." 69 | #define help_model_names_fn_inp "Path to the file to read model path names from." 70 | #define help_model_fn "Path to the model file to be tested" 71 | #define help_prev_model_fn "Path to the input model file from which training should do warm-start." 72 | #define help_prev_model_fn_others "Path to the input model file from which training should do warm-start. (WARNING) Some algorithms do not support warm-start and return error if this parameter is specified." 73 | #define help_pred_fn_out "Path to the file to write predictions to." 74 | #define help_pred_fn_inp "Path to the file to read predictions from." 75 | #define help_pred_fn_suffix "The path names of the predictions files are generated by attaching this to the path names of the corresponding models." 76 | #define help_not_doLog "Suppress logging-purpose output to stdout." 77 | #define help_doDump "Enable dump to stderr for the verbose components." 78 | #define help_eval_fn "Path to the file to write evaluation to." 79 | #define help_doAppend_eval "Open the evaluation result file with append mode." 80 | #define help_test_x_fn "Path to the feature file of test data" 81 | #define help_test_y_fn "Path to the target file of test data" 82 | #define help_dw_fn "Path to the file of user-defined weights assigned to training data points." 83 | #define help_doSaveLastModelOnly "Save the last/largest model only." 84 | #define help_doSaveLastModelOnly_traintest "Save the last/largest model only. Referred to only when model_fn_suffix is specified." 85 | 86 | #define help_input_x_fn "Path to the input feature file." 87 | #define help_output_x_fn "Path to the output feature file." 88 | #define help_features_digits "How many digits should be retained in the output." 89 | #define help_doSparse_features "Write features in the sparse data format." 90 | 91 | /* #define dflt_model_names_fn "model_list.txt" */ 92 | #define dflt_model_stem "" 93 | 94 | #define Az_config "config" 95 | 96 | #endif 97 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTETselector.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTETselector.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TET_SELECTOR_HPP_ 20 | #define _AZ_TET_SELECTOR_HPP_ 21 | 22 | #include "AzTETrainer.hpp" 23 | 24 | class AzTETselector { 25 | public: 26 | //! Return trainer 27 | virtual AzTETrainer *select(const char *alg_name, //! algorithm name 28 | //! if true, don't throw exception on error 29 | bool dontThrow=false 30 | ) const = 0; 31 | 32 | virtual const char *dflt_name() const = 0; 33 | virtual const char *another_name() const = 0; 34 | virtual const AzStrArray *names() const = 0; 35 | virtual bool isRGFfamily(const char *name) const { 36 | AzBytArr s(name); 37 | return s.beginsWith("RGF"); 38 | } 39 | virtual bool isGBfamily(const char *name) const { 40 | AzBytArr s(name); 41 | return s.beginsWith("GB"); 42 | } 43 | 44 | //! Return algorithm names. 45 | virtual void printOptions(const char *dlm, //! delimiter between algorithm names. 46 | AzBytArr *s) //!< output: algorithm names separated by dlm. 47 | const = 0; 48 | 49 | //! Help 50 | virtual void printHelp(AzHelp &h) const = 0; 51 | }; 52 | 53 | #endif 54 | 55 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTE_ModelInfo.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTE_ModelInfo.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TE_MODEL_INFO_HPP_ 20 | #define _AZ_TE_MODEL_INFO_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | 24 | /*-------------------------------------------------------*/ 25 | /*! Tree Ensemble model info */ 26 | class AzTE_ModelInfo { 27 | public: 28 | int tree_num; //!< number of trees 29 | int leaf_num; //!< number of features/leaf nodes 30 | int f_num; //!< number of features including removed ones. 31 | int nz_f_num; //!< number of non-zero-weight features after consolidation 32 | AzBytArr s_sign; 33 | AzBytArr s_config; //!< configuration 34 | const char *sign; //!< signature of trainer 35 | AzTE_ModelInfo() : f_num(-1),leaf_num(-1),nz_f_num(-1),tree_num(-1) {} 36 | 37 | void reset() { 38 | tree_num = leaf_num = f_num = nz_f_num = -1; 39 | s_sign.reset(); 40 | s_config.reset(); 41 | } 42 | }; 43 | #endif 44 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTrTreeEnsemble_ReadOnly.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTrTreeEnsemble_ReadOnly.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TR_TREE_ENSEMBLE_READONLY_HPP_ 20 | #define _AZ_TR_TREE_ENSEMBLE_READONLY_HPP_ 21 | 22 | #include "AzTrTree_ReadOnly.hpp" 23 | #include "AzTreeEnsemble.hpp" 24 | #include "AzSvFeatInfo.hpp" 25 | 26 | //! Abstract class: interface for read-only access to trainalbe tree ensemble. 27 | class AzTrTreeEnsemble_ReadOnly { 28 | public: 29 | virtual bool usingTempFile() const { return false; } 30 | virtual const AzTrTree_ReadOnly *tree(int tx) const = 0; 31 | virtual int leafNum() const = 0; 32 | virtual int leafNum(int tx0, int tx1) const = 0; 33 | virtual int size() const = 0; 34 | virtual int max_size() const = 0; 35 | virtual int lastIndex() const = 0; 36 | virtual void copy_to(AzTreeEnsemble *out_ens, 37 | const char *config, const char *sign) const = 0; 38 | virtual void show(const AzSvFeatInfo *feat, 39 | const AzOut &out, const char *header="") const = 0; 40 | virtual double constant() const = 0; 41 | virtual int orgdim() const = 0; 42 | virtual const char *param_c_str() const = 0; 43 | }; 44 | #endif 45 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTrTreeNode.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTrTreeNode.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TR_TREE_NODE_HPP_ 20 | #define _AZ_TR_TREE_NODE_HPP_ 21 | 22 | #include "AzTreeNodes.hpp" 23 | 24 | class AzTrTree; 25 | 26 | /*---------------------------------------------*/ 27 | /*! used only for training */ 28 | class AzTrTreeNode : /* extends */ public virtual AzTreeNode { 29 | protected: 30 | const int *dxs; /* data indexes belonging to this node */ 31 | 32 | public: 33 | int dxs_offset; /* position in the data indexes at the root */ 34 | int dxs_num; 35 | int depth; //!< node depth 36 | 37 | AzTrTreeNode() : depth(-1), dxs(NULL), dxs_offset(-1), dxs_num(-1) {} 38 | void reset() { 39 | AzTreeNode::reset(); 40 | depth = dxs_offset = dxs_num = -1; 41 | dxs = NULL; 42 | } 43 | void transfer_from(AzTrTreeNode *inp) { 44 | AzTreeNode::transfer_from(inp); 45 | dxs = inp->dxs; 46 | dxs_offset = inp->dxs_offset; 47 | dxs_num = inp->dxs_num; 48 | depth = inp->depth; 49 | } 50 | 51 | inline const int *data_indexes() const { 52 | if (dxs_num > 0 && dxs == NULL) { 53 | throw new AzException("AzTrTreeNode::data_indexes", 54 | "data indexes are unavailable"); 55 | } 56 | return dxs; 57 | } 58 | inline void reset_data_indexes(const int *ptr) { 59 | dxs = ptr; 60 | } 61 | 62 | friend class AzTrTree; 63 | }; 64 | 65 | #endif 66 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTrTree_ReadOnly.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTrTree_ReadOnly.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TR_TREE_READONLY_HPP_ 20 | #define _AZ_TR_TREE_READONLY_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzDataForTrTree.hpp" 24 | #include "AzTreeRule.hpp" 25 | #include "AzSvFeatInfo.hpp" 26 | #include "AzTreeNodes.hpp" 27 | #include "AzTrTreeNode.hpp" 28 | #include "AzSortedFeat.hpp" 29 | 30 | //! Abstract class: interface for read-only (information-seeking) access to trainable tree. 31 | /*------------------------------------------*/ 32 | /* Trainable tree; read only */ 33 | class AzTrTree_ReadOnly : /* implements */ public virtual AzTreeNodes 34 | { 35 | public: 36 | /*--- information seeking ... ---*/ 37 | virtual int nodeNum() const = 0; 38 | virtual int leafNum() const = 0; 39 | virtual int maxDepth() const = 0; 40 | virtual void show(const AzSvFeatInfo *feat, const AzOut &out) const = 0; 41 | virtual void concat_stat(AzBytArr *o) const = 0; 42 | virtual double getRule(int inp_nx, AzTreeRule *rule) const = 0; 43 | virtual void concatDesc(const AzSvFeatInfo *feat, int nx, 44 | AzBytArr *str_desc, /* output */ 45 | int max_len=-1) const = 0; 46 | virtual void isActiveNode(bool doAllowZeroWeightLeaf, 47 | AzIntArr *ia_isDecisionNode) const = 0; /* output */ 48 | virtual bool usingInternalNodes() const = 0; 49 | 50 | virtual const AzSortedFeatArr *sorted_array(int nx, 51 | const AzDataForTrTree *data) const = 0; 52 | /*--- (NOTE) this is const but changes sorted_arr[nx] ---*/ 53 | 54 | virtual const AzIntArr *root_dx() const = 0; 55 | 56 | /*--- apply ... ---*/ 57 | virtual double apply(const AzDataForTrTree *data, int dx, 58 | AzIntArr *ia_nx=NULL) const /* node path */ 59 | = 0; 60 | 61 | virtual const AzTrTreeNode *node(int nx) const = 0; 62 | }; 63 | #endif 64 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTrTsplit.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTrTsplit.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TRT_SPLIT_HPP_ 20 | #define _AZ_TRT_SPLIT_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzTools.hpp" 24 | 25 | //! Node split information. 26 | class AzTrTsplit { 27 | public: 28 | int fx; 29 | double border_val; 30 | double gain; 31 | double bestP[2]; /* le gt */ 32 | AzBytArr str_desc; 33 | 34 | int tx, nx; /* set only by Rgf; not used by Std */ 35 | 36 | AzTrTsplit() : fx(-1), border_val(0), gain(0), tx(-1), nx(-1) { 37 | bestP[0] = bestP[1] = 0; 38 | } 39 | 40 | virtual void print(const char *header) { 41 | #if 0 42 | printf("%s, fx=%d,border_val=%e,gain=%e,bestP[0]=%e,bestP[1]=%e,tx=%d,nx=%d\n", 43 | header, fx, border_val, gain, bestP[0], bestP[1], tx, nx); 44 | #endif 45 | } 46 | 47 | virtual 48 | void reset() { 49 | fx = -1; 50 | border_val = 0; 51 | bestP[0] = bestP[1] = 0; 52 | gain = 0; 53 | str_desc.reset(); 54 | tx = nx = -1; 55 | } 56 | AzTrTsplit(int fx, double border_val, 57 | double gain, 58 | double bestP_L, double bestP_G) { 59 | reset_values(fx, border_val, gain, bestP_L, bestP_G); 60 | } 61 | AzTrTsplit(const AzTrTsplit *inp) { /* copy */ 62 | copy(inp); 63 | } 64 | virtual 65 | inline bool isEmpty() const { 66 | if (fx < 0) return true; 67 | return false; 68 | } 69 | 70 | virtual 71 | inline void reset(const AzTrTsplit *inp) { 72 | copy(inp); 73 | } 74 | virtual 75 | inline void reset(const AzTrTsplit *inp, int inp_tx, int inp_nx) { 76 | reset(inp); 77 | tx = inp_tx; 78 | nx = inp_nx; 79 | } 80 | virtual 81 | void copy(const AzTrTsplit *inp) { 82 | if (inp == NULL) return; 83 | fx = inp->fx; 84 | border_val = inp->border_val; 85 | gain = inp->gain; 86 | str_desc.clear(); 87 | str_desc.concat(&inp->str_desc); 88 | bestP[0] = inp->bestP[0]; 89 | bestP[1] = inp->bestP[1]; 90 | tx = inp->tx; 91 | nx = inp->nx; 92 | } 93 | 94 | virtual 95 | void keep_if_good(int inp_fx, double inp_border_val, 96 | double inp_gain, 97 | double bestP_L, double bestP_G) { 98 | if (inp_gain > gain) { 99 | reset_values(inp_fx, inp_border_val, inp_gain, bestP_L, bestP_G); 100 | } 101 | } 102 | 103 | virtual 104 | void reset_values(int inp_fx, double inp_border_val, 105 | double inp_gain, 106 | double bestP_L, double bestP_G) 107 | { 108 | fx = inp_fx; 109 | border_val = inp_border_val; 110 | gain = inp_gain; 111 | bestP[0] = bestP_L; 112 | bestP[1] = bestP_G; 113 | 114 | tx = nx = -1; 115 | } 116 | virtual 117 | void release() { 118 | str_desc.reset(); 119 | } 120 | }; 121 | #endif 122 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTrTtarget.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTrTtarget.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TRT_TARGET_HPP_ 20 | #define _AZ_TRT_TARGET_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzDmat.hpp" 24 | 25 | //! Targets and data point weights for node split search. 26 | /*--------------------------------------------------------*/ 27 | class AzTrTtarget { 28 | protected: 29 | AzDvect v_tar_dw, v_dw; 30 | AzDvect v_y; 31 | AzDvect v_fixed_dw; /* data point weights assigned by users */ 32 | double fixed_dw_sum; 33 | 34 | public: 35 | AzTrTtarget() : fixed_dw_sum(-1) {} 36 | AzTrTtarget(const AzDvect *inp_v_y, 37 | const AzDvect *inp_v_fixed_dw=NULL) { 38 | reset(inp_v_y, inp_v_fixed_dw); 39 | } 40 | void reset(const AzDvect *inp_v_y, 41 | const AzDvect *inp_v_fixed_dw=NULL) { 42 | v_dw.reform(inp_v_y->rowNum()); 43 | v_dw.set(1); 44 | v_tar_dw.set(inp_v_y); 45 | v_y.set(inp_v_y); 46 | fixed_dw_sum = -1; 47 | 48 | v_fixed_dw.reset(); 49 | if (!AzDvect::isNull(inp_v_fixed_dw)) { 50 | v_fixed_dw.set(inp_v_fixed_dw); 51 | if (v_fixed_dw.rowNum() != v_y.rowNum()) { 52 | throw new AzException(AzInputError, "AzTrTtarget::reset", 53 | "conlict in dimensionality: y and data point weights"); 54 | } 55 | fixed_dw_sum = v_fixed_dw.sum(); 56 | } 57 | } 58 | inline bool isWeighted() const { 59 | return !AzDvect::isNull(&v_fixed_dw); 60 | } 61 | inline double sum_fixed_dw() const { 62 | return fixed_dw_sum; 63 | } 64 | inline void weight_tarDw() { 65 | v_tar_dw.scale(&v_fixed_dw); 66 | } 67 | inline void weight_dw() { 68 | v_dw.scale(&v_fixed_dw); 69 | } 70 | 71 | AzTrTtarget(const AzTrTtarget *inp) { 72 | reset(inp); 73 | } 74 | 75 | void reset(const AzTrTtarget *inp) { 76 | if (inp != NULL) { 77 | v_tar_dw.set(&inp->v_tar_dw); 78 | v_dw.set(&inp->v_dw); 79 | v_y.set(&inp->v_y); 80 | v_fixed_dw.set(&inp->v_fixed_dw); 81 | fixed_dw_sum = inp->fixed_dw_sum; 82 | } 83 | } 84 | 85 | void resetTargetDw(const AzDvect *v_tar, const AzDvect *inp_v_dw) { 86 | v_tar_dw.set(v_tar); 87 | v_dw.set(inp_v_dw); 88 | v_tar_dw.scale(&v_dw); /* component-wise multiplication */ 89 | } 90 | void resetTarDw_residual(const AzDvect *v_p) { /* only for LS */ 91 | v_tar_dw.set(&v_y); 92 | v_tar_dw.add(v_p, -1); 93 | } 94 | inline const double *dw_arr() const { 95 | return v_dw.point(); 96 | } 97 | inline const double *tarDw_arr() const { 98 | return v_tar_dw.point(); 99 | } 100 | inline const AzDvect *y() const { 101 | return &v_y; 102 | } 103 | inline int dataNum() const { 104 | return v_tar_dw.rowNum(); 105 | } 106 | inline AzDvect *tarDw_forUpdate() { 107 | return &v_tar_dw; 108 | } 109 | inline const AzDvect *tarDw() const { 110 | return &v_tar_dw; 111 | } 112 | inline AzDvect *dw_forUpdate() { 113 | return &v_dw; 114 | } 115 | inline const AzDvect *dw() { 116 | return &v_dw; 117 | } 118 | inline double getTarDwSum(const int *dxs, int dxs_num) const { 119 | return v_tar_dw.sum(dxs, dxs_num); 120 | } 121 | inline double getDwSum(const int *dxs, int dxs_num) const { 122 | return v_dw.sum(dxs, dxs_num); 123 | } 124 | inline double getTarDwSum(const AzIntArr *ia_dx=NULL) const { 125 | return v_tar_dw.sum(ia_dx); 126 | } 127 | inline double getDwSum(const AzIntArr *ia_dx=NULL) const { 128 | return v_dw.sum(ia_dx); 129 | } 130 | 131 | int dim() const { 132 | return v_tar_dw.rowNum(); 133 | } 134 | }; 135 | #endif 136 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTree.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTree.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TREE_HPP_ 20 | #define _AZ_TREE_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzSmat.hpp" 24 | #include "AzSvFeatInfo.hpp" 25 | #include "AzTreeNodes.hpp" 26 | 27 | //! Untrainalbe regression tree. 28 | /*------------------------------------------*/ 29 | class AzTree : /* implements */ public virtual AzTreeNodes { 30 | protected: 31 | int root_nx; 32 | int nodes_used; 33 | AzTreeNode *nodes; 34 | AzBaseArray a_nodes; 35 | 36 | inline void _checkNode(int nx, const char *eyec) const { 37 | if (nodes == NULL || nx < 0 || nx >= nodes_used) { 38 | throw new AzException(eyec, "nx is out of range"); 39 | } 40 | } 41 | 42 | public: 43 | AzTree() : root_nx(-1), nodes_used(0), nodes(NULL) {} 44 | AzTree(AzFile *file) 45 | : root_nx(-1), nodes_used(0), nodes(NULL) { 46 | _read(file); 47 | } 48 | AzTree(const AzTreeNodes *inp) : root_nx(-1), nodes_used(0), nodes(NULL) { 49 | if (inp != NULL) { 50 | copy_from(inp); 51 | } 52 | } 53 | ~AzTree() {} 54 | 55 | void copy_from(const AzTreeNodes *tree_nodes); 56 | 57 | inline void transfer_from(AzTree *) { 58 | throw new AzException("AzTree::transfer_from", "no support"); 59 | } 60 | inline AzTree & operator =(const AzTree &inp) { 61 | if (this == &inp) return *this; 62 | throw new AzException("AzTree:=", "Don't use ="); 63 | } 64 | void reset() { 65 | _release(); 66 | } 67 | 68 | int write(AzFile *file); 69 | void read(AzFile *file); 70 | 71 | inline int nodeNum() const { 72 | return nodes_used; 73 | } 74 | 75 | static double apply(const AzReadOnlyVector *v_data, 76 | const AzTreeNodes *nodes, 77 | AzIntArr *ia_node=NULL); 78 | 79 | double apply(const AzReadOnlyVector *v_data, 80 | AzIntArr *ia_node=NULL) const { 81 | checkNodes("apply"); 82 | return apply(v_data, this, ia_node); 83 | } 84 | 85 | void show(const AzSvFeatInfo *feat, const AzOut &out, 86 | const char *header="") const; 87 | int leafNum() const; 88 | void clean_up(); 89 | 90 | inline const AzTreeNode *node(int nx) const { 91 | checkNode(nx, "point"); 92 | return &nodes[nx]; 93 | } 94 | inline int root() const { 95 | return root_nx; 96 | } 97 | void finfo(AzIFarr *ifa_fx_count, 98 | AzIFarr *ifa_fx_sum) const; /* appended */ 99 | void finfo(AzIntArr *ia_fxs) const; /* appended */ 100 | 101 | void cooccurrences(AzIIFarr *iifa_fx1_fx2_count) const; 102 | 103 | virtual void genDesc(const AzSvFeatInfo *feat, 104 | int nx, 105 | AzBytArr *s) /* output */ 106 | const; 107 | 108 | protected: 109 | /*--- functions ---*/ 110 | void _read(AzFile *file); 111 | 112 | inline void checkNode(int nx, const char *eyec) const { 113 | if (nodes == NULL || nx < 0 || nx >= nodes_used) { 114 | throw new AzException(eyec, "AzTree, nx is out of range"); 115 | } 116 | } 117 | inline void checkNodes(const char *msg) const { 118 | if (nodes == NULL && nodes_used > 0) { 119 | throw new AzException("AzTree, no nodes", msg); 120 | } 121 | } 122 | void _show(const AzSvFeatInfo *feat, 123 | int nx, 124 | int depth, 125 | const AzOut &out) const; 126 | 127 | void _release(); 128 | virtual void _genDesc(const AzSvFeatInfo *feat, 129 | int nx, 130 | AzBytArr *s) /* output */ 131 | const; 132 | }; 133 | 134 | #endif 135 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTreeEnsemble.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTreeEnsemble.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TREE_ENSEMBLE_HPP_ 20 | #define _AZ_TREE_ENSEMBLE_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | #include "AzTree.hpp" 24 | #include "AzDmat.hpp" 25 | #include "AzTE_ModelInfo.hpp" 26 | 27 | //! Untrainable tree ensemble. Like applier. Generated as a result of training. 28 | class AzTreeEnsemble 29 | { 30 | protected: 31 | AzObjPtrArray a_tree; 32 | AzTree **t; 33 | int t_num; 34 | AzTree empty_tree; 35 | double const_val; 36 | 37 | AzBytArr s_config, s_sign; 38 | int org_dim; /* dimension of original features */ 39 | 40 | public: 41 | AzTreeEnsemble() : t(NULL), t_num(0), const_val(0), org_dim(-1) {} 42 | ~AzTreeEnsemble() {} 43 | 44 | AzTreeEnsemble(const char *fn) 45 | : t(NULL), t_num(0), const_val(0), org_dim(-1) { 46 | read(fn); 47 | } 48 | AzTreeEnsemble(AzFile *file) 49 | : t(NULL), t_num(0), const_val(0), org_dim(-1) { 50 | _read(file); 51 | } 52 | 53 | void transfer_from(AzTree *inp_tree[], /* destroys input */ 54 | int inp_tree_num, 55 | double const_val, 56 | int orgdim, 57 | const char *config, 58 | const char *sign); 59 | 60 | void read(const char *fn); 61 | void read(AzFile *file) { 62 | _release(); 63 | _read(file); 64 | } 65 | int write(const char *fn); 66 | int write(AzFile *file); 67 | 68 | inline void destroy() { 69 | _release(); 70 | } 71 | 72 | inline const AzTree *tree(int tx) const { 73 | checkIndex(tx, "tree"); 74 | if (t[tx] == NULL) { 75 | return &empty_tree; 76 | } 77 | return t[tx]; 78 | } 79 | 80 | int leafNum() const { 81 | return leafNum(0, t_num); 82 | } 83 | int leafNum(int tx0, int tx1) const; 84 | inline int size() const { return t_num; } 85 | 86 | void apply(const AzSmat *m_data, 87 | AzDvect *v_pred) /* output */ 88 | const; 89 | double apply(const AzSvect *v_data) const; 90 | 91 | inline double constant() const { return const_val; } 92 | inline int orgdim() const { return org_dim; } 93 | const char *signature() const { return s_sign.c_str(); } 94 | const char *configuration() const { return s_config.c_str(); } 95 | 96 | /*---*/ 97 | void info(AzTE_ModelInfo *out_info) const; 98 | 99 | void show(const AzSvFeatInfo *feat, //!< may be NULL 100 | const AzOut &out, const char *header="") const; 101 | void finfo(AzIFarr *ifa_fx_count, 102 | AzIFarr *ifa_fx_sum) const { 103 | finfo(0, t_num, ifa_fx_count, ifa_fx_sum); 104 | } 105 | void finfo(int tx0, int tx1, 106 | AzIFarr *ifa_fx_count, 107 | AzIFarr *ifa_fx_sum) const; 108 | void finfo(AzIntArr *ia_fx2tx) const; 109 | void cooccurrences(AzIIFarr *iifa_fx1_fx2_count) const; 110 | 111 | void show_weights(const AzOut &out, AzSvFeatInfo *fi) const; 112 | 113 | protected: 114 | void _read(AzFile *file); 115 | inline void _release() { 116 | a_tree.free(&t); t_num = 0; 117 | s_config.reset(); 118 | s_sign.reset(); 119 | const_val = 0; 120 | org_dim = -1; 121 | } 122 | inline void checkIndex(int tx, const char *msg) const { 123 | if (tx < 0 || tx >= t_num) { 124 | throw new AzException("AzTreeEnsemble::checkIndex", msg); 125 | } 126 | } 127 | void clean_up(); 128 | }; 129 | #endif 130 | 131 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTreeNodes.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTreeNodes.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TREE_NODES_HPP_ 20 | #define _AZ_TREE_NODES_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | 24 | /*! Tree node */ 25 | class AzTreeNode { 26 | public: 27 | int fx; //!< feature id 28 | double border_val; 29 | int le_nx; //!< x[fx] <= border_val 30 | int gt_nx; //!< x[fx] > border_val 31 | int parent_nx; //!< pointing parent node 32 | double weight; //!< weight 33 | 34 | /*--- ---*/ 35 | AzTreeNode() { 36 | reset(); 37 | } 38 | void reset() { 39 | border_val = weight = 0; 40 | fx = le_nx = gt_nx = parent_nx = -1; 41 | } 42 | AzTreeNode(AzFile *file) { 43 | read(file); 44 | } 45 | inline bool isLeaf() const { 46 | if (le_nx < 0) return true; 47 | return false; 48 | } 49 | int write(AzFile *file); 50 | void read(AzFile *file); 51 | 52 | void transfer_from(AzTreeNode *inp) { 53 | *this = *inp; 54 | } 55 | }; 56 | 57 | class AzTreeNodes { 58 | public: 59 | virtual const AzTreeNode *node(int nx) const = 0; 60 | virtual int nodeNum() const = 0; 61 | virtual int root() const = 0; 62 | }; 63 | #endif 64 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/AzTreeRule.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTreeRule.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #ifndef _AZ_TREE_RULE_HPP_ 20 | #define _AZ_TREE_RULE_HPP_ 21 | 22 | #include "AzUtil.hpp" 23 | 24 | class AzTreeRule { 25 | protected: 26 | AzBytArr ba; 27 | 28 | public: 29 | inline void reset() { 30 | ba.clear(); 31 | } 32 | inline const AzBytArr *bytarr() { 33 | return &ba; 34 | } 35 | inline void reset(const AzTreeRule *inp) { 36 | ba.reset(); 37 | if (inp == NULL) return; 38 | ba.reset(&inp->ba); 39 | } 40 | inline void append(int fx, 41 | bool isLE, 42 | double border_val) 43 | { 44 | /*--- feat#, isLE, border_val ---*/ 45 | ba.concat((AzByte *)(&fx), sizeof(fx)); 46 | ba.concat((AzByte *)(&isLE), sizeof(isLE)); 47 | ba.concat((AzByte *)(&border_val), sizeof(border_val)); 48 | } 49 | inline void append(const AzTreeRule *inp) { 50 | if (inp != NULL) { 51 | ba.concat(&inp->ba); 52 | } 53 | } 54 | inline void finalize() { 55 | if (ba.getLen() == 0) { 56 | ba.concat('_'); /* root node (CONST) */ 57 | } 58 | } 59 | inline const AzBytArr *byteArr() { 60 | return &ba; 61 | } 62 | }; 63 | #endif 64 | -------------------------------------------------------------------------------- /rgf1.2/src/tet/driv_rgf.cpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * driv_rgf.cpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | * * * * */ 18 | 19 | #define _AZ_MAIN_ 20 | #include "AzUtil.hpp" 21 | #include "AzTETmain.hpp" 22 | #include "AzRgfTrainerSel.hpp" 23 | #include "AzTET_Eval_Dflt.hpp" 24 | #include "AzHelp.hpp" 25 | 26 | /*-----------------------------------------------------------------*/ 27 | void help(int argc, const char *argv[]) 28 | { 29 | cout << "Arguments: action parameters" <getMessage() << endl; 100 | return -1; 101 | } 102 | 103 | return 0; 104 | } 105 | 106 | -------------------------------------------------------------------------------- /rgf1.2/test/call_exe.pl: -------------------------------------------------------------------------------- 1 | 2 | use strict 'vars'; 3 | 4 | my $dlm = ','; 5 | my $inp_ext = '.inp'; 6 | 7 | my $arg_num = $#ARGV + 1; 8 | if ($arg_num != 3) { 9 | print "Arguments: exe train|predict|train_test cfg_fn\n"; 10 | print " exe : Name of the executable. Typically, ../bin/rgf \n"; 11 | print " train|predict|train_test \n"; 12 | print " train ... Train and save models to files. \n"; 13 | print " predict ... Apply a model to new data. \n"; 14 | print " train_test ... Train and test models in one call. \n"; 15 | print " cfg_fn: Path to a configuration file without extension; e.g., sample/train \n"; 16 | print " The file extension should be \".inp\"\n"; 17 | print " For example, see sample/train.inp and sample/predict.inp.\n"; 18 | print "\n"; 19 | print " To get help on the parameters that can be used in configuration\n"; 20 | print " files, call the executable with train|predict|train_test, e.g.,\n"; 21 | print "\n"; 22 | print " ..\\bin\\rgf train \n"; 23 | print " ../bin/rgf predict \n"; 24 | exit 1; 25 | } 26 | 27 | my $argx = 0; 28 | my $exe = $ARGV[$argx++]; 29 | my $action = $ARGV[$argx++]; 30 | my $inp_fn = $ARGV[$argx++]; 31 | 32 | $inp_fn .= $inp_ext; 33 | 34 | my @list = &readList_inc($inp_fn); 35 | my $num = $#list + 1; 36 | 37 | my $config = ""; 38 | 39 | my $repeat_num = 0; 40 | my(@repeat); 41 | my $ix; 42 | for ($ix = 0; $ix < $num; ++$ix) { 43 | my $kw = ""; 44 | if ($list[$ix] =~ /^\@(.*)$/) { 45 | $repeat[$repeat_num++] = $1; 46 | } 47 | else { 48 | &concatKwVal($list[$ix]); 49 | } 50 | } 51 | 52 | 53 | if ($repeat_num == 0) { 54 | my $cmd = "$exe $action $config"; 55 | print "$cmd\n"; 56 | system($cmd); 57 | } 58 | else { 59 | my $rx; 60 | for ($rx = 0; $rx < $repeat_num; ++$rx) { 61 | my $my_config = $repeat[$rx] . ',' . $config; 62 | 63 | my $cmd = "$exe $action $my_config"; 64 | print "$cmd\n"; 65 | system($cmd); 66 | my $ret = $?; 67 | if ($ret != 0) { 68 | print "system() failed; exit call_exe.pl\n"; 69 | exit 1; 70 | } 71 | } 72 | } 73 | 74 | exit 0; 75 | 76 | ##----------------------------------- 77 | ## if there is a line "#include filename" ... 78 | sub readList_inc { 79 | my($fn) = @_; 80 | my(@out); 81 | 82 | my @list = &readList($fn); 83 | 84 | my $num = $#list + 1; 85 | my $out_ix = 0; 86 | my $ix; 87 | for ($ix = 0; $ix < $num; ++$ix) { 88 | if ($list[$ix] =~ /^\include\s+(\S+)$/) { 89 | my $inc_fn = $1; 90 | my(@list1); 91 | @list1 = &readList($inc_fn); 92 | my $num1 = $#list1 + 1; 93 | my $jx; 94 | for ($jx = 0; $jx < $num1; ++$jx) { 95 | if ($list1[$jx] =~ /^include/) { 96 | print "Can't nest include: $inc_fn\n"; 97 | exit 1; 98 | } 99 | $out[$out_ix++] = &checkInputLine($list1[$jx]); 100 | } 101 | } 102 | else { 103 | $out[$out_ix++] = &checkInputLine($list[$ix]); 104 | } 105 | } 106 | return @out; 107 | } 108 | 109 | ##------ 110 | sub concatKwVal { 111 | my($inp) = @_; 112 | if ($inp =~ /\S/) { 113 | if ($config ne "") { 114 | $config .= $dlm; 115 | } 116 | $config .= $inp; 117 | } 118 | } 119 | 120 | ##### 121 | sub readList { 122 | my($lst_fn) = @_; 123 | my(@item); 124 | 125 | if (!open(LST, "$lst_fn")) { 126 | print "Can't open $lst_fn\n"; 127 | exit 1; 128 | } 129 | 130 | my $num = 0; 131 | while() { 132 | my $line = $_; 133 | chomp $line; 134 | if ($line !~ /^\s*\#/) { 135 | $line = &strip($line); 136 | $item[$num++] = $line; 137 | } 138 | } 139 | 140 | close(LST); 141 | return @item; 142 | } 143 | 144 | ##### 145 | sub strip { 146 | my($inp) = @_; 147 | my $out = $inp; 148 | if ($inp =~ /^\s*(\S.*\S)\s*$/) { 149 | $out = $1; 150 | } 151 | elsif ($inp =~ /^\s*(\S)\s*$/) { 152 | $out = $1; 153 | } 154 | return $out; 155 | } 156 | 157 | ##### 158 | sub checkInputLine { 159 | my($inp) = @_; 160 | 161 | my $param = $inp; 162 | if ($param =~ /^(\S+)\s+(\S+)/) { 163 | $param = $1; 164 | my $comment = $2; 165 | if ($comment !~ /^\#/) { 166 | print "No space is allowed in the parameter. Comments should start with \#\n"; 167 | print "Error in: [$inp]\n"; 168 | exit; 169 | } 170 | } 171 | if ($param =~ /^([^\#]+)\#/) { 172 | $param = $1; 173 | } 174 | return $param; 175 | } 176 | -------------------------------------------------------------------------------- /rgf1.2/test/output/sample.model-01: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/test/output/sample.model-01 -------------------------------------------------------------------------------- /rgf1.2/test/output/sample.model-02: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/test/output/sample.model-02 -------------------------------------------------------------------------------- /rgf1.2/test/output/sample.model-03: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/test/output/sample.model-03 -------------------------------------------------------------------------------- /rgf1.2/test/output/sample.model-04: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/test/output/sample.model-04 -------------------------------------------------------------------------------- /rgf1.2/test/output/sample.model-05: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/test/output/sample.model-05 -------------------------------------------------------------------------------- /rgf1.2/test/sample/predict.inp: -------------------------------------------------------------------------------- 1 | #### sample input to "predict" 2 | 3 | #--- apply a model to sample test data and save prediction values 4 | test_x_fn=sample/test.data.x # Test data points 5 | model_fn=output/sample.model-03 # Model 6 | prediction_fn=output/sample.pred # Where to write predictions 7 | -------------------------------------------------------------------------------- /rgf1.2/test/sample/regress_train_test.inp: -------------------------------------------------------------------------------- 1 | # To use this example configuration file: 2 | # Set the current directory to rgf_v1/test. 3 | # In the command line, enter: 4 | # 5 | # perl call_exe.pl ../bin/rgf train_test sample/regress_train_test 6 | # 7 | 8 | #------------------ Perform 3 runs --------------------# 9 | @reg_L2=1,model_fn_prefix=output/regress.lam1.model # used by 1st run 10 | @reg_L2=0.1,model_fn_prefix=output/regress.lam0.1.model # used by 2nd run 11 | @reg_L2=0.01,model_fn_prefix=output/regress.lam0.01.model # used by 3rd run 12 | #-------------------------------------------------------------------------# 13 | 14 | #--- Other parameters are shared by 3 runs. 15 | 16 | train_x_fn=sample/regress.train.x # Training data points 17 | train_y_fn=sample/regress.train.y # Training targets 18 | 19 | test_x_fn=sample/regress.test.x # Test data points 20 | test_y_fn=sample/regress.test.y # Test targets 21 | 22 | algorithm=RGF # RGF with L2 regularization on leaf-only models 23 | loss=LS # Square loss 24 | test_interval=500 # Test (and save) models every time 500 leaves are added. 25 | max_leaf_forest=5000 # Stop training when #leaf reaches 5000. 26 | Verbose # Display info during training. 27 | NormalizeTarget # 28 | 29 | #train_w_fn=?? # User-specified weights of data points. 30 | #model_fn_for_warmstart=?? # Path to the model file to do warm-start with 31 | -------------------------------------------------------------------------------- /rgf1.2/test/sample/test.data.x: -------------------------------------------------------------------------------- 1 | 79 93 97 99 68 14 94 90 3 34 2 | 41 40 88 38 87 71 34 82 18 91 3 | 9 73 66 12 88 84 26 21 74 48 4 | 52 82 75 67 98 46 95 5 49 17 5 | 15 67 89 80 73 61 90 19 9 72 6 | 15 25 11 6 54 95 7 37 83 57 7 | 63 2 96 90 16 99 68 99 27 67 8 | 39 45 84 4 61 27 64 60 90 12 9 | 40 80 6 51 50 78 48 20 0 39 10 | 40 35 42 9 25 43 95 57 56 34 11 | 63 91 35 79 7 68 65 16 93 73 12 | 79 31 18 18 86 78 13 50 25 26 13 | 78 73 1 4 83 85 83 29 38 93 14 | 67 68 47 43 47 57 11 73 97 49 15 | 37 95 20 7 68 47 0 61 33 16 16 | 93 7 50 75 18 91 35 78 86 29 17 | 64 7 86 25 0 14 85 31 52 10 18 | 5 74 47 57 50 33 68 73 16 26 19 | 89 33 0 47 51 93 34 61 21 69 20 | 86 79 21 52 28 68 89 63 66 84 21 | 47 87 56 44 55 89 1 92 68 53 22 | 49 41 33 29 61 92 45 27 79 57 23 | 81 36 2 74 50 90 84 34 58 48 24 | 6 53 74 92 7 19 74 86 17 89 25 | 72 63 77 47 1 61 29 3 92 25 26 | 99 89 64 88 95 67 22 70 73 2 27 | 63 92 61 66 8 34 99 34 54 76 28 | 70 88 46 56 54 60 6 51 28 74 29 | 13 83 63 49 97 6 89 97 70 89 30 | 60 93 73 19 54 10 46 99 25 98 31 | 66 51 75 44 56 38 93 3 35 89 32 | 50 85 55 81 91 30 84 64 52 33 33 | 16 57 37 31 99 49 46 25 22 56 34 | 0 58 51 3 14 34 58 95 94 73 35 | 56 62 23 76 63 48 90 4 42 36 36 | 48 63 86 72 44 19 66 70 1 49 37 | 23 93 59 39 67 24 5 93 34 29 38 | 41 5 86 21 79 74 7 26 90 78 39 | 34 97 16 70 27 41 9 61 72 35 40 | 87 3 37 79 21 73 13 70 36 50 41 | 2 58 37 80 11 88 21 31 21 40 42 | 1 5 32 86 62 23 84 91 11 74 43 | 36 94 72 70 45 28 59 27 69 30 44 | 33 85 70 17 8 49 70 14 86 92 45 | 1 35 92 8 53 41 28 85 96 96 46 | 56 82 79 32 64 87 86 10 1 13 47 | 24 58 91 41 36 86 32 5 8 47 48 | 50 93 65 33 83 20 28 33 63 77 49 | 63 15 26 80 39 66 70 85 32 3 50 | 27 21 13 75 41 46 70 64 70 80 51 | 19 2 57 89 35 40 46 12 29 66 52 | 46 6 14 52 98 37 95 10 7 18 53 | 80 73 37 72 84 34 84 93 63 33 54 | 42 49 38 97 86 97 55 79 8 1 55 | 76 40 67 41 72 95 31 64 75 78 56 | 87 62 25 99 14 37 26 85 24 58 57 | 68 88 83 93 82 11 91 88 31 85 58 | 15 3 66 23 80 5 89 16 39 14 59 | 48 84 66 86 19 76 67 53 10 78 60 | 1 12 36 73 19 99 42 89 58 27 61 | 65 80 65 49 74 83 11 15 9 67 62 | 88 79 65 6 96 14 78 50 88 71 63 | 48 64 26 48 7 84 56 32 22 77 64 | 44 1 88 99 43 72 53 84 53 11 65 | 77 16 31 90 50 22 4 97 47 80 66 | 21 13 69 90 13 85 89 30 59 59 67 | 74 84 9 87 25 26 56 15 2 91 68 | 41 52 60 15 76 56 43 42 51 40 69 | 66 6 98 71 40 70 12 64 83 1 70 | 44 25 23 22 48 89 27 61 30 27 71 | 61 34 74 95 23 96 79 88 93 9 72 | 69 74 65 56 6 54 45 51 8 85 73 | 48 6 47 81 76 64 31 30 42 20 74 | 7 67 86 74 91 83 56 7 56 43 75 | 64 63 26 88 42 34 14 2 77 45 76 | 93 35 40 1 55 52 49 84 24 95 77 | 78 67 89 64 10 8 2 85 81 49 78 | 2 48 58 12 94 58 52 96 15 34 79 | 6 60 90 62 12 59 13 70 71 55 80 | 36 91 76 82 90 90 76 94 81 92 81 | 46 64 70 40 61 60 12 55 79 82 82 | 18 75 83 76 65 39 7 34 13 74 83 | 49 18 7 47 31 62 14 56 72 26 84 | 76 21 93 18 68 62 36 47 58 71 85 | 4 77 90 24 1 33 80 58 69 96 86 | 7 45 99 8 17 54 72 17 23 50 87 | 67 62 73 80 96 53 84 63 56 9 88 | 46 23 34 51 14 22 34 39 40 6 89 | 31 63 27 2 54 28 81 89 85 40 90 | 23 66 36 92 91 64 19 31 98 3 91 | 39 50 93 19 38 21 93 49 65 34 92 | 49 92 24 5 61 42 85 16 85 55 93 | 88 44 80 1 71 39 53 29 56 59 94 | 73 45 83 70 86 90 78 56 84 24 95 | 22 91 15 84 72 37 84 32 12 71 96 | 83 81 25 83 44 34 72 91 52 20 97 | 97 96 62 30 28 38 47 93 48 38 98 | 63 43 14 12 49 37 96 59 57 49 99 | 0 61 0 83 34 67 16 42 52 35 100 | 20 41 25 84 33 69 28 52 70 40 101 | -------------------------------------------------------------------------------- /rgf1.2/test/sample/test.data.y: -------------------------------------------------------------------------------- 1 | +1 2 | +1 3 | -1 4 | +1 5 | +1 6 | +1 7 | +1 8 | -1 9 | -1 10 | -1 11 | -1 12 | -1 13 | -1 14 | +1 15 | +1 16 | -1 17 | -1 18 | +1 19 | -1 20 | -1 21 | +1 22 | +1 23 | -1 24 | -1 25 | +1 26 | +1 27 | -1 28 | +1 29 | +1 30 | -1 31 | +1 32 | -1 33 | +1 34 | +1 35 | -1 36 | -1 37 | +1 38 | +1 39 | +1 40 | -1 41 | +1 42 | +1 43 | -1 44 | -1 45 | -1 46 | -1 47 | +1 48 | +1 49 | +1 50 | -1 51 | +1 52 | +1 53 | -1 54 | +1 55 | +1 56 | -1 57 | +1 58 | +1 59 | -1 60 | +1 61 | +1 62 | -1 63 | -1 64 | -1 65 | -1 66 | -1 67 | +1 68 | +1 69 | -1 70 | -1 71 | -1 72 | +1 73 | -1 74 | -1 75 | +1 76 | +1 77 | -1 78 | +1 79 | +1 80 | +1 81 | +1 82 | +1 83 | +1 84 | -1 85 | +1 86 | -1 87 | -1 88 | -1 89 | +1 90 | -1 91 | -1 92 | +1 93 | +1 94 | -1 95 | +1 96 | +1 97 | +1 98 | -1 99 | +1 100 | -1 101 | -------------------------------------------------------------------------------- /rgf1.2/test/sample/train.data.sparse.x: -------------------------------------------------------------------------------- 1 | sparse 10 2 | 0:58 1:19 2:5 3:92 4:3 5:62 6:26 7:56 8:43 9:35 3 | 0:19 1:61 2:54 3:81 4:8 5:50 6:55 7:83 8:69 9:81 4 | 0:79 1:93 2:49 3:17 4:67 5:70 6:33 7:38 8:81 9:53 5 | 0:66 1:30 2:78 3:67 4:18 5:98 6:72 7:25 8:35 9:53 6 | 0:86 1:34 2:68 3:72 4:40 5:8 6:89 7:7 8:34 7 | 0:9 1:99 2:46 3:86 4:51 5:13 6:74 7:46 8:82 9:62 8 | 0:82 1:99 2:20 3:47 4:46 5:42 6:40 7:66 8:64 9:55 9 | 0:99 1:73 2:59 3:94 4:16 5:16 6:80 7:28 8:11 9:30 10 | 0:96 1:67 2:72 3:16 4:10 5:27 6:79 7:48 8:86 9:54 11 | 0:23 1:3 2:63 3:91 5:97 6:72 7:3 8:57 9:46 12 | 0:25 1:72 2:85 3:34 4:76 5:65 6:68 7:64 8:51 9:61 13 | 0:98 1:34 2:49 3:83 4:44 5:16 6:3 7:3 8:70 9:63 14 | 0:63 1:39 2:88 3:20 4:73 6:52 7:9 8:30 9:57 15 | 0:99 1:91 2:47 3:60 4:71 5:74 6:79 7:13 8:20 9:92 16 | 0:75 1:64 2:9 3:60 4:29 5:66 6:38 7:26 8:39 9:94 17 | 0:36 1:94 2:62 3:55 4:65 5:10 6:42 7:62 8:53 9:35 18 | 0:19 1:94 2:55 3:68 4:65 5:36 6:95 7:21 8:72 9:52 19 | 0:16 1:21 2:90 3:23 4:14 5:8 6:69 7:79 8:18 9:78 20 | 0:44 1:84 2:85 3:32 4:66 5:12 6:26 7:12 8:9 9:81 21 | 0:86 1:19 2:68 3:21 4:33 5:42 6:95 7:46 8:7 9:64 22 | 0:85 1:48 2:92 3:11 4:76 5:56 6:43 7:97 8:60 9:70 23 | 0:68 1:81 2:89 3:89 4:65 5:26 6:87 7:29 8:30 9:7 24 | 0:95 1:93 2:5 3:27 4:68 5:26 6:25 7:28 8:61 9:8 25 | 0:20 1:78 2:44 3:80 4:31 5:93 6:55 7:81 8:69 9:39 26 | 0:91 1:67 2:92 3:15 4:55 5:53 6:27 7:47 8:34 9:9 27 | 0:91 1:15 2:95 3:30 4:32 5:18 6:35 7:13 8:46 9:76 28 | 0:68 1:64 2:89 4:73 5:75 6:6 7:44 8:97 9:84 29 | 0:64 1:8 2:9 3:50 4:45 5:35 6:68 7:24 8:92 9:28 30 | 0:74 1:74 2:7 3:31 4:53 5:28 6:92 7:38 8:17 9:34 31 | 0:94 1:78 2:31 3:96 4:33 5:54 6:50 7:47 8:77 9:44 32 | 0:18 1:11 2:7 3:93 4:26 5:79 6:5 7:70 8:99 9:46 33 | 0:16 1:52 2:3 3:27 4:63 5:56 6:85 7:83 8:81 9:63 34 | 0:55 1:25 2:91 3:67 4:78 5:15 6:47 7:55 8:89 9:87 35 | 0:24 1:16 2:8 3:12 4:44 5:17 6:46 7:30 8:66 9:91 36 | 0:65 1:31 2:95 3:68 4:13 5:16 6:66 7:65 8:60 9:95 37 | 0:31 1:55 2:84 3:70 4:74 5:83 6:45 7:28 8:25 9:12 38 | 0:61 1:3 2:43 3:20 4:5 5:2 6:25 7:50 8:2 9:95 39 | 0:12 1:7 2:92 3:19 4:89 5:76 6:57 7:77 8:74 9:43 40 | 0:85 1:93 2:10 3:31 4:37 5:2 6:71 7:5 8:85 9:13 41 | 0:97 1:20 2:67 3:55 4:53 5:61 6:24 7:93 8:3 9:60 42 | 0:57 1:83 2:81 3:70 4:87 5:23 6:33 7:67 8:82 9:90 43 | 0:65 1:45 2:48 3:39 4:21 5:73 6:68 7:35 8:25 9:58 44 | 0:57 1:45 2:32 3:68 4:66 5:15 6:60 7:51 8:58 9:29 45 | 0:2 1:36 3:48 4:92 5:3 6:98 7:45 8:97 9:26 46 | 0:65 1:87 2:37 3:4 4:82 5:94 6:73 7:81 8:36 9:2 47 | 0:38 1:33 2:2 3:96 4:1 5:68 6:18 7:39 8:86 9:73 48 | 0:46 1:22 2:13 3:26 4:81 5:95 6:88 7:74 8:90 9:95 49 | 0:51 1:92 2:9 3:16 4:8 5:55 6:30 7:87 8:78 9:47 50 | 0:86 1:48 2:60 3:35 4:3 5:6 6:46 7:87 8:71 9:31 51 | 0:12 1:52 2:36 3:48 4:34 5:65 6:46 7:49 8:98 9:21 52 | 0:34 1:43 2:39 3:98 4:78 5:5 6:14 7:80 8:80 9:76 53 | 0:48 1:54 2:96 3:3 4:64 5:40 6:88 7:47 8:9 9:85 54 | 0:74 1:80 2:33 3:87 4:24 5:15 6:95 7:18 8:3 9:44 55 | 0:70 1:57 2:14 3:95 4:75 5:31 7:67 8:49 9:96 56 | 0:1 1:51 2:3 3:42 4:69 5:45 6:6 7:99 8:65 9:87 57 | 0:39 1:16 2:98 3:20 4:35 5:19 6:15 7:99 8:65 9:68 58 | 0:86 1:98 2:72 3:37 4:95 5:53 6:6 7:36 8:64 9:92 59 | 0:87 1:65 2:36 3:39 4:94 5:8 6:5 7:32 8:33 60 | 0:87 1:98 2:52 3:49 4:15 5:36 6:20 7:28 8:60 9:4 61 | 0:84 1:92 2:93 3:82 4:91 5:54 6:84 7:19 8:83 9:84 62 | 0:85 1:26 2:62 3:87 4:41 5:60 6:79 7:8 8:32 9:42 63 | 0:75 1:12 2:49 3:81 4:68 5:87 6:86 7:97 8:13 9:33 64 | 0:88 1:38 2:60 3:89 4:15 5:43 6:37 7:47 8:33 9:96 65 | 0:95 1:1 2:7 3:27 4:20 5:69 6:53 7:1 8:84 9:53 66 | 0:7 1:98 2:50 3:39 4:28 5:70 6:10 7:47 8:6 9:49 67 | 0:90 1:30 2:70 3:6 4:36 5:84 6:73 7:15 8:36 9:22 68 | 0:38 1:48 2:2 3:4 4:39 5:1 6:66 7:60 8:47 9:88 69 | 0:38 1:12 2:46 3:76 4:95 5:60 6:74 7:62 8:72 9:48 70 | 0:91 1:94 2:47 3:24 4:94 5:42 6:29 7:25 8:14 9:13 71 | 0:69 1:63 2:44 3:43 4:28 5:35 6:44 7:70 8:23 9:52 72 | 0:51 1:36 2:8 3:12 4:59 5:96 6:34 7:96 8:53 9:94 73 | 0:16 1:88 2:94 3:11 4:68 5:87 6:88 7:51 8:56 9:83 74 | 0:46 1:1 2:36 3:27 4:3 5:77 6:41 7:94 8:54 9:76 75 | 0:1 1:4 2:97 3:65 4:32 5:38 6:88 7:97 8:11 9:75 76 | 0:9 1:24 2:7 3:63 4:86 5:95 6:88 7:38 8:61 9:66 77 | 0:71 1:55 2:42 3:91 4:56 5:14 6:77 7:82 8:40 9:68 78 | 0:26 1:38 2:62 3:20 4:2 5:27 6:12 7:29 8:68 9:33 79 | 0:21 1:63 2:92 3:54 4:49 5:58 6:75 7:60 8:96 9:2 80 | 0:64 1:22 2:30 3:60 4:64 5:11 6:71 7:59 8:80 9:54 81 | 0:94 1:38 2:43 3:23 4:23 5:81 6:89 7:70 8:25 9:18 82 | 0:28 1:76 2:49 3:71 4:4 5:10 6:13 7:49 8:26 9:33 83 | 0:36 1:43 2:29 3:50 4:95 5:75 6:11 7:5 8:7 9:43 84 | 0:31 1:53 2:47 3:76 4:85 5:21 6:84 7:71 8:63 9:10 85 | 0:51 1:29 2:58 3:8 4:32 5:13 6:89 7:86 8:9 9:99 86 | 0:65 1:58 2:68 3:64 4:32 5:16 6:28 7:35 9:1 87 | 0:9 1:86 2:30 3:39 4:7 5:41 6:28 7:63 8:69 9:13 88 | 0:3 1:96 2:70 3:5 4:13 5:30 6:24 7:82 8:65 9:77 89 | 0:79 1:69 2:83 3:59 4:76 5:81 6:26 7:51 8:98 9:97 90 | 0:17 1:39 2:45 3:14 4:54 5:8 6:54 7:85 8:9 9:1 91 | 0:86 1:44 2:69 3:6 4:2 5:19 6:6 7:58 8:1 9:5 92 | 0:2 1:7 2:8 3:2 4:61 5:76 6:27 7:11 8:11 9:50 93 | 0:73 1:12 2:73 3:43 4:66 5:62 6:26 7:17 8:71 9:85 94 | 0:23 1:99 2:62 3:68 4:33 5:53 6:66 7:30 8:11 9:92 95 | 0:61 1:23 2:43 3:68 4:12 5:71 6:45 7:69 8:43 9:18 96 | 0:93 1:43 2:93 3:85 4:72 5:22 6:78 7:4 8:79 9:88 97 | 0:11 1:41 2:99 3:73 4:46 5:67 6:44 7:50 8:87 9:59 98 | 0:67 1:81 2:50 3:79 4:13 5:84 6:47 7:72 8:42 9:42 99 | 0:98 1:33 2:98 3:89 4:34 5:48 6:19 7:62 8:39 9:89 100 | 0:23 1:47 2:30 3:76 4:88 5:19 6:6 7:96 8:61 9:71 101 | 0:61 1:45 2:70 3:51 4:5 5:26 6:75 7:43 8:68 9:88 102 | -------------------------------------------------------------------------------- /rgf1.2/test/sample/train.data.x: -------------------------------------------------------------------------------- 1 | 58 19 5 92 3 62 26 56 43 35 2 | 19 61 54 81 8 50 55 83 69 81 3 | 79 93 49 17 67 70 33 38 81 53 4 | 66 30 78 67 18 98 72 25 35 53 5 | 86 34 68 72 40 8 89 7 34 0 6 | 9 99 46 86 51 13 74 46 82 62 7 | 82 99 20 47 46 42 40 66 64 55 8 | 99 73 59 94 16 16 80 28 11 30 9 | 96 67 72 16 10 27 79 48 86 54 10 | 23 3 63 91 0 97 72 3 57 46 11 | 25 72 85 34 76 65 68 64 51 61 12 | 98 34 49 83 44 16 3 3 70 63 13 | 63 39 88 20 73 0 52 9 30 57 14 | 99 91 47 60 71 74 79 13 20 92 15 | 75 64 9 60 29 66 38 26 39 94 16 | 36 94 62 55 65 10 42 62 53 35 17 | 19 94 55 68 65 36 95 21 72 52 18 | 16 21 90 23 14 8 69 79 18 78 19 | 44 84 85 32 66 12 26 12 9 81 20 | 86 19 68 21 33 42 95 46 7 64 21 | 85 48 92 11 76 56 43 97 60 70 22 | 68 81 89 89 65 26 87 29 30 7 23 | 95 93 5 27 68 26 25 28 61 8 24 | 20 78 44 80 31 93 55 81 69 39 25 | 91 67 92 15 55 53 27 47 34 9 26 | 91 15 95 30 32 18 35 13 46 76 27 | 68 64 89 0 73 75 6 44 97 84 28 | 64 8 9 50 45 35 68 24 92 28 29 | 74 74 7 31 53 28 92 38 17 34 30 | 94 78 31 96 33 54 50 47 77 44 31 | 18 11 7 93 26 79 5 70 99 46 32 | 16 52 3 27 63 56 85 83 81 63 33 | 55 25 91 67 78 15 47 55 89 87 34 | 24 16 8 12 44 17 46 30 66 91 35 | 65 31 95 68 13 16 66 65 60 95 36 | 31 55 84 70 74 83 45 28 25 12 37 | 61 3 43 20 5 2 25 50 2 95 38 | 12 7 92 19 89 76 57 77 74 43 39 | 85 93 10 31 37 2 71 5 85 13 40 | 97 20 67 55 53 61 24 93 3 60 41 | 57 83 81 70 87 23 33 67 82 90 42 | 65 45 48 39 21 73 68 35 25 58 43 | 57 45 32 68 66 15 60 51 58 29 44 | 2 36 0 48 92 3 98 45 97 26 45 | 65 87 37 4 82 94 73 81 36 2 46 | 38 33 2 96 1 68 18 39 86 73 47 | 46 22 13 26 81 95 88 74 90 95 48 | 51 92 9 16 8 55 30 87 78 47 49 | 86 48 60 35 3 6 46 87 71 31 50 | 12 52 36 48 34 65 46 49 98 21 51 | 34 43 39 98 78 5 14 80 80 76 52 | 48 54 96 3 64 40 88 47 9 85 53 | 74 80 33 87 24 15 95 18 3 44 54 | 70 57 14 95 75 31 0 67 49 96 55 | 1 51 3 42 69 45 6 99 65 87 56 | 39 16 98 20 35 19 15 99 65 68 57 | 86 98 72 37 95 53 6 36 64 92 58 | 87 65 36 39 94 8 5 32 33 0 59 | 87 98 52 49 15 36 20 28 60 4 60 | 84 92 93 82 91 54 84 19 83 84 61 | 85 26 62 87 41 60 79 8 32 42 62 | 75 12 49 81 68 87 86 97 13 33 63 | 88 38 60 89 15 43 37 47 33 96 64 | 95 1 7 27 20 69 53 1 84 53 65 | 7 98 50 39 28 70 10 47 6 49 66 | 90 30 70 6 36 84 73 15 36 22 67 | 38 48 2 4 39 1 66 60 47 88 68 | 38 12 46 76 95 60 74 62 72 48 69 | 91 94 47 24 94 42 29 25 14 13 70 | 69 63 44 43 28 35 44 70 23 52 71 | 51 36 8 12 59 96 34 96 53 94 72 | 16 88 94 11 68 87 88 51 56 83 73 | 46 1 36 27 3 77 41 94 54 76 74 | 1 4 97 65 32 38 88 97 11 75 75 | 9 24 7 63 86 95 88 38 61 66 76 | 71 55 42 91 56 14 77 82 40 68 77 | 26 38 62 20 2 27 12 29 68 33 78 | 21 63 92 54 49 58 75 60 96 2 79 | 64 22 30 60 64 11 71 59 80 54 80 | 94 38 43 23 23 81 89 70 25 18 81 | 28 76 49 71 4 10 13 49 26 33 82 | 36 43 29 50 95 75 11 5 7 43 83 | 31 53 47 76 85 21 84 71 63 10 84 | 51 29 58 8 32 13 89 86 9 99 85 | 65 58 68 64 32 16 28 35 0 1 86 | 9 86 30 39 7 41 28 63 69 13 87 | 3 96 70 5 13 30 24 82 65 77 88 | 79 69 83 59 76 81 26 51 98 97 89 | 17 39 45 14 54 8 54 85 9 1 90 | 86 44 69 6 2 19 6 58 1 5 91 | 2 7 8 2 61 76 27 11 11 50 92 | 73 12 73 43 66 62 26 17 71 85 93 | 23 99 62 68 33 53 66 30 11 92 94 | 61 23 43 68 12 71 45 69 43 18 95 | 93 43 93 85 72 22 78 4 79 88 96 | 11 41 99 73 46 67 44 50 87 59 97 | 67 81 50 79 13 84 47 72 42 42 98 | 98 33 98 89 34 48 19 62 39 89 99 | 23 47 30 76 88 19 6 96 61 71 100 | 61 45 70 51 5 26 75 43 68 88 101 | -------------------------------------------------------------------------------- /rgf1.2/test/sample/train.data.y: -------------------------------------------------------------------------------- 1 | -1 2 | -1 3 | -1 4 | -1 5 | -1 6 | -1 7 | +1 8 | -1 9 | -1 10 | -1 11 | -1 12 | -1 13 | +1 14 | +1 15 | -1 16 | +1 17 | -1 18 | -1 19 | -1 20 | -1 21 | +1 22 | -1 23 | -1 24 | -1 25 | +1 26 | -1 27 | -1 28 | -1 29 | -1 30 | -1 31 | -1 32 | +1 33 | -1 34 | -1 35 | -1 36 | +1 37 | -1 38 | +1 39 | -1 40 | +1 41 | -1 42 | +1 43 | -1 44 | -1 45 | +1 46 | -1 47 | +1 48 | +1 49 | -1 50 | +1 51 | +1 52 | -1 53 | +1 54 | +1 55 | +1 56 | -1 57 | +1 58 | +1 59 | -1 60 | -1 61 | -1 62 | -1 63 | -1 64 | -1 65 | +1 66 | +1 67 | -1 68 | +1 69 | +1 70 | -1 71 | -1 72 | -1 73 | +1 74 | +1 75 | -1 76 | -1 77 | -1 78 | -1 79 | +1 80 | +1 81 | +1 82 | +1 83 | +1 84 | -1 85 | -1 86 | +1 87 | +1 88 | -1 89 | -1 90 | -1 91 | -1 92 | -1 93 | -1 94 | -1 95 | -1 96 | -1 97 | -1 98 | -1 99 | +1 100 | -1 101 | -------------------------------------------------------------------------------- /rgf1.2/test/sample/train.inp: -------------------------------------------------------------------------------- 1 | #### sample input to "train" #### 2 | 3 | train_x_fn=sample/train.data.x # Training data points 4 | train_y_fn=sample/train.data.y # Training targets 5 | 6 | #--- Save the models with filenames output/sample.model-01, 7 | #--- output/sample.model-02,... 8 | model_fn_prefix=output/sample.model 9 | 10 | #--- training parameters. 11 | reg_L2=1 # Regularization parameter. 12 | algorithm=RGF # RGF with L2 regularization with leaf-only models 13 | loss=LS # Square loss 14 | test_interval=100 # Save models every time 100 leaves are added. 15 | max_leaf_forest=500 # Stop training when #leaf reaches 500. 16 | Verbose # Display info during training. 17 | 18 | #--- other parameters (commented out) 19 | #NormalizeTarget # Normalize targets so that the average becomes zero. 20 | #train_w_fn=?? # User-specified weights of data points. 21 | #model_fn_for_warmstart=?? # Path to the model file to do warm-start with 22 | 23 | -------------------------------------------------------------------------------- /rgf1.2/test/sample/train_predict.inp: -------------------------------------------------------------------------------- 1 | #### sample input to "train_predict" #### 2 | 3 | train_x_fn=sample/train.data.x # Training data points 4 | train_y_fn=sample/train.data.y # Training targets 5 | 6 | test_x_fn=sample/test.data.x # Test data points 7 | 8 | #--- 9 | model_fn_prefix=output/m 10 | #--- Models are saved with filenames output/m-01, 11 | #--- output/m-02,... 12 | #--- Predictions are saved with filenames output/m-01.pred, 13 | #--- output/m-02.pred,... 14 | #--- Model info such as #leaf are saved with filenames output/m-01.info, 15 | #--- output/m-02.info,... 16 | 17 | SaveLastModelOnly # Only the last (largest) model will be saved to a file. 18 | # Comment this out if all the models should be saved. 19 | 20 | #--- training parameters 21 | algorithm=RGF # RGF with L2 regularization on leaf-only models 22 | reg_L2=1 # Regularization parameter 23 | loss=LS # Square loss 24 | test_interval=100 # Test (and save) models every time 100 leaves are added. 25 | max_leaf_forest=500 # Stop training when #leaf reaches 500. 26 | Verbose # Display info during training. 27 | 28 | #--- other parameters (commented out) 29 | #NormalizeTarget # Normalize targets so that the average becomes zero. 30 | #train_w_fn=?? # User-specified weights of data points. 31 | #model_fn_for_warmstart=?? # Path to the model file to do warm-start with 32 | -------------------------------------------------------------------------------- /rgf1.2/test/sample/train_test.inp: -------------------------------------------------------------------------------- 1 | #### sample input to "train_test" #### 2 | 3 | train_x_fn=sample/train.data.x # Training data points 4 | train_y_fn=sample/train.data.y # Training targets 5 | 6 | test_x_fn=sample/test.data.x # Test data points 7 | test_y_fn=sample/test.data.y # Test targets 8 | 9 | evaluation_fn=output/sample.evaluation # Where to write performances 10 | 11 | model_fn_prefix=output/m # Comment this out if models should not be saved. 12 | 13 | 14 | #--- training parameters 15 | algorithm=RGF # RGF with L2 regularization on leaf-only models 16 | reg_L2=1 # Regularization parameter 17 | loss=LS # Square loss 18 | test_interval=100 # Test (and save) models every time 100 leaves are added. 19 | max_leaf_forest=500 # Stop training when #leaf reaches 500. 20 | Verbose # Display info during training. 21 | 22 | #--- other parameters (commented out) 23 | #NormalizeTarget # Normalize targets so that the average becomes zero. 24 | #train_w_fn=?? # User-specified weights of data points. 25 | #model_fn_for_warmstart=?? # Path to the model file to do warm-start with 26 | -------------------------------------------------------------------------------- /slides_Cambridge_Nov_14.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/slides_Cambridge_Nov_14.pdf --------------------------------------------------------------------------------