├── LICENSE
├── README.md
├── data
└── README.txt
├── matlab_blend
├── do_blend.m
└── wexpreg.m
├── model_output
└── README.txt
├── python_train_predict
├── higgsml_functions.py
├── make_predictions.py
├── start_rgf_models_7f_s.py
├── start_rgf_models_7f_s_exp.py
├── start_rgf_models_7f_w.py
├── start_rgf_models_7f_w_exp.py
├── start_rgf_models_7f_wl.py
└── start_rgf_models_7f_wl_exp.py
├── rgf1.2
├── BUILDLOG
├── COPYING
├── README
├── bin
│ ├── rgf
│ └── rgf.exe.file
├── makefile
├── proj_vc2010
│ └── rgf
│ │ ├── rgf.sln
│ │ └── rgf.vcxproj
├── rgf1.2-guide.pdf
├── src
│ ├── com
│ │ ├── AzBmat.hpp
│ │ ├── AzDmat.cpp
│ │ ├── AzDmat.hpp
│ │ ├── AzException.hpp
│ │ ├── AzHelp.hpp
│ │ ├── AzIntPool.cpp
│ │ ├── AzIntPool.hpp
│ │ ├── AzLoss.cpp
│ │ ├── AzLoss.hpp
│ │ ├── AzMemTempl.hpp
│ │ ├── AzOut.hpp
│ │ ├── AzParam.cpp
│ │ ├── AzParam.hpp
│ │ ├── AzPerfResult.hpp
│ │ ├── AzPrint.hpp
│ │ ├── AzReadOnlyMatrix.hpp
│ │ ├── AzSmat.cpp
│ │ ├── AzSmat.hpp
│ │ ├── AzStrArray.hpp
│ │ ├── AzStrPool.cpp
│ │ ├── AzStrPool.hpp
│ │ ├── AzSvDataS.cpp
│ │ ├── AzSvDataS.hpp
│ │ ├── AzSvFeatInfo.hpp
│ │ ├── AzSvFeatInfoClone.hpp
│ │ ├── AzTaskTools.cpp
│ │ ├── AzTaskTools.hpp
│ │ ├── AzTimer.hpp
│ │ ├── AzTools.cpp
│ │ ├── AzTools.hpp
│ │ ├── AzUtil.cpp
│ │ └── AzUtil.hpp
│ └── tet
│ │ ├── AzDataForTrTree.hpp
│ │ ├── AzFindSplit.cpp
│ │ ├── AzFindSplit.hpp
│ │ ├── AzFsinfo.hpp
│ │ ├── AzOptOnTree.cpp
│ │ ├── AzOptOnTree.hpp
│ │ ├── AzOptOnTree_TreeReg.cpp
│ │ ├── AzOptOnTree_TreeReg.hpp
│ │ ├── AzOptimizerT.hpp
│ │ ├── AzRegDepth.hpp
│ │ ├── AzReg_TreeReg.hpp
│ │ ├── AzReg_TreeRegArr.hpp
│ │ ├── AzReg_TreeRegArrImp.hpp
│ │ ├── AzReg_TsrOpt.cpp
│ │ ├── AzReg_TsrOpt.hpp
│ │ ├── AzReg_TsrSib.cpp
│ │ ├── AzReg_TsrSib.hpp
│ │ ├── AzReg_Tsrbase.cpp
│ │ ├── AzReg_Tsrbase.hpp
│ │ ├── AzRgfTrainerSel.hpp
│ │ ├── AzRgfTree.cpp
│ │ ├── AzRgfTree.hpp
│ │ ├── AzRgfTreeEnsImp.hpp
│ │ ├── AzRgfTreeEnsemble.hpp
│ │ ├── AzRgf_FindSplit.hpp
│ │ ├── AzRgf_FindSplit_Dflt.cpp
│ │ ├── AzRgf_FindSplit_Dflt.hpp
│ │ ├── AzRgf_FindSplit_TreeReg.cpp
│ │ ├── AzRgf_FindSplit_TreeReg.hpp
│ │ ├── AzRgf_Optimizer.hpp
│ │ ├── AzRgf_Optimizer_Dflt.cpp
│ │ ├── AzRgf_Optimizer_Dflt.hpp
│ │ ├── AzRgf_Optimizer_TreeReg.hpp
│ │ ├── AzRgf_kw.hpp
│ │ ├── AzRgforest.cpp
│ │ ├── AzRgforest.hpp
│ │ ├── AzRgforest_TreeReg.hpp
│ │ ├── AzSortedFeat.cpp
│ │ ├── AzSortedFeat.hpp
│ │ ├── AzTET_Eval.hpp
│ │ ├── AzTET_Eval_Dflt.hpp
│ │ ├── AzTETmain.cpp
│ │ ├── AzTETmain.hpp
│ │ ├── AzTETmain_kw.hpp
│ │ ├── AzTETproc.cpp
│ │ ├── AzTETproc.hpp
│ │ ├── AzTETrainer.hpp
│ │ ├── AzTETselector.hpp
│ │ ├── AzTE_ModelInfo.hpp
│ │ ├── AzTrTree.cpp
│ │ ├── AzTrTree.hpp
│ │ ├── AzTrTreeEnsemble.hpp
│ │ ├── AzTrTreeEnsemble_ReadOnly.hpp
│ │ ├── AzTrTreeFeat.cpp
│ │ ├── AzTrTreeFeat.hpp
│ │ ├── AzTrTreeNode.hpp
│ │ ├── AzTrTree_ReadOnly.hpp
│ │ ├── AzTrTsplit.hpp
│ │ ├── AzTrTtarget.hpp
│ │ ├── AzTree.cpp
│ │ ├── AzTree.hpp
│ │ ├── AzTreeEnsemble.cpp
│ │ ├── AzTreeEnsemble.hpp
│ │ ├── AzTreeNodes.hpp
│ │ ├── AzTreeRule.hpp
│ │ └── driv_rgf.cpp
└── test
│ ├── call_exe.pl
│ ├── output
│ ├── sample.model-01
│ ├── sample.model-02
│ ├── sample.model-03
│ ├── sample.model-04
│ └── sample.model-05
│ └── sample
│ ├── predict.inp
│ ├── regress.test.x
│ ├── regress.test.y
│ ├── regress.train.x
│ ├── regress.train.y
│ ├── regress_train_test.inp
│ ├── test.data.x
│ ├── test.data.y
│ ├── train.data.sparse.x
│ ├── train.data.x
│ ├── train.data.y
│ ├── train.inp
│ ├── train_predict.inp
│ └── train_test.inp
└── slides_Cambridge_Nov_14.pdf
/data/README.txt:
--------------------------------------------------------------------------------
1 | The data files "training.csv" and "test.csv" go here.
2 | See http://www.kaggle.com/c/higgs-boson/data.
--------------------------------------------------------------------------------
/matlab_blend/do_blend.m:
--------------------------------------------------------------------------------
1 | % This performs non-negative linear blending of several RGF models for the
2 | % HiggsML Challenge
3 | % Run this in Matlab 2014 for support of data tables, or update the
4 | % 'readtable' statement belows when using an older version
5 | % Author: Tim Salimans
6 |
7 | % load all the necessary data
8 | model_dir = [pwd '/../model_output'];
9 | data_dir = [pwd '/../data'];
10 | traindat = readtable([data_dir '/training.csv']);
11 | trainsel = traindat.DER_mass_MMC>0;
12 | raw_weights = traindat.Weight(trainsel);
13 | target = strcmp(traindat.Label(trainsel),'s');
14 | train_weights = raw_weights;
15 | train_weights(target) = train_weights(target)/mean(train_weights(target));
16 | train_weights(~target) = train_weights(~target)/mean(train_weights(~target));
17 |
18 | % load cross validated predictions
19 | n = length(target);
20 | cvind = mod(0:(n-1),7);
21 | xw_cv = zeros(n,70);
22 | xs_cv = zeros(n,70);
23 | xwl_cv = zeros(n,70);
24 | xw_exp_cv = zeros(n,70);
25 | xs_exp_cv = zeros(n,70);
26 | xwl_exp_cv = zeros(n,70);
27 | for i=1:70
28 | if i<10
29 | nrs = ['0' int2str(i)];
30 | else
31 | nrs = int2str(i);
32 | end
33 | for j=0:6
34 | fn = [model_dir '/w_cv' int2str(j) '_output/m-' nrs '.pred'];
35 | if exist(fn,'file')
36 | preds = csvread(fn);
37 | np = length(preds)/2;
38 | xw_cv(cvind==j,i) = preds(1:np)-preds((np+1):end);
39 | end
40 | fn = [model_dir '/s_cv' int2str(j) '_output/m-' nrs '.pred'];
41 | if exist(fn,'file')
42 | xs_cv(cvind==j,i) = csvread(fn);
43 | end
44 | fn = [model_dir '/wl_cv' int2str(j) '_output/m-' nrs '.pred'];
45 | if exist(fn,'file')
46 | xwl_cv(cvind==j,i) = csvread(fn);
47 | end
48 |
49 | fn = [model_dir '/w_exp_cv' int2str(j) '_output/m-' nrs '.pred'];
50 | if exist(fn,'file')
51 | preds = csvread(fn);
52 | np = length(preds)/2;
53 | xw_exp_cv(cvind==j,i) = preds(1:np)-preds((np+1):end);
54 | end
55 | fn = [model_dir '/s_exp_cv' int2str(j) '_output/m-' nrs '.pred'];
56 | if exist(fn,'file')
57 | xs_exp_cv(cvind==j,i) = csvread(fn);
58 | end
59 | fn = [model_dir '/wl_exp_cv' int2str(j) '_output/m-' nrs '.pred'];
60 | if exist(fn,'file')
61 | xwl_exp_cv(cvind==j,i) = csvread(fn);
62 | end
63 | end
64 | end
65 |
66 | % load predictions on the test set
67 | fn = [model_dir '/wl_exp_full_output/m-01.pred'];
68 | preds = csvread(fn);
69 | ntest = length(preds);
70 | xw_test = zeros(ntest,70);
71 | xs_test = zeros(ntest,70);
72 | xwl_test = zeros(ntest,70);
73 | xw_exp_test = zeros(ntest,70);
74 | xs_exp_test = zeros(ntest,70);
75 | xwl_exp_test = zeros(ntest,70);
76 | for i=1:70
77 | if i<10
78 | nrs = ['0' int2str(i)];
79 | else
80 | nrs = int2str(i);
81 | end
82 |
83 | fn = [model_dir '/w_full_output/m-' nrs '.pred'];
84 | if exist(fn,'file')
85 | preds = csvread(fn);
86 | np = length(preds)/2;
87 | xw_test(:,i) = preds(1:np)-preds((np+1):end);
88 | end
89 | fn = [model_dir '/s_full_output/m-' nrs '.pred'];
90 | if exist(fn,'file')
91 | xs_test(:,i) = csvread(fn);
92 | end
93 | fn = [model_dir '/wl_full_output/m-' nrs '.pred'];
94 | if exist(fn,'file')
95 | xwl_test(:,i) = csvread(fn);
96 | end
97 |
98 | fn = [model_dir '/w_exp_full_output/m-' nrs '.pred'];
99 | if exist(fn,'file')
100 | preds = csvread(fn);
101 | np = length(preds)/2;
102 | xw_exp_test(:,i) = preds(1:np)-preds((np+1):end);
103 | end
104 | fn = [model_dir '/s_exp_full_output/m-' nrs '.pred'];
105 | if exist(fn,'file')
106 | xs_exp_test(:,i) = csvread(fn);
107 | end
108 | fn = [model_dir '/wl_exp_full_output/m-' nrs '.pred'];
109 | if exist(fn,'file')
110 | xwl_exp_test(:,i) = csvread(fn);
111 | end
112 | end
113 |
114 | % combine all the predictions, and only keep the complete ones
115 | x_all = [ones(n,1) xw_cv xs_cv xwl_cv xw_exp_cv xs_exp_cv xwl_exp_cv];
116 | x_full = [ones(ntest,1) xw_test xs_test xwl_test xw_exp_test xs_exp_test xwl_exp_test];
117 | sel = ((sum(x_all==0)+sum(x_full==0))==0);
118 | x_all = x_all(:,sel);
119 | x_full = x_full(:,sel);
120 | k = size(x_all,2);
121 |
122 | % perform non-negative blending with exponential loss
123 | obj = @(b)wexpreg(b,double(target),x_all,train_weights);
124 | opt = optimset('GradObj','on','Hessian','on','Algorithm','trust-region-reflective','Display','iter','TolFun',1e-15);
125 | b = fmincon(obj,zeros(k,1),[],[],[],[],[-Inf; zeros(k-1,1)],[],[],opt);
126 |
127 | % define percentage cutoff
128 | p = 0.175;
129 |
130 | % write predictions
131 | testdat = readtable([data_dir '/test.csv']);
132 | testsel = testdat.DER_mass_MMC>0;
133 | ftest = x_full*b;
134 | rank_order = zeros(length(testsel),1);
135 | rank_order(~testsel) = 1:(length(testsel)-length(ftest));
136 | [~,ind] = sort(ftest);
137 | rank_order(testsel) = ind + (length(testsel)-length(ftest));
138 | pred_s = false(length(testsel),1);
139 | pred_s(testsel) = ftest>=quantile(ftest,1-p);
140 | pred_strings = repmat('b',length(pred_s),1);
141 | pred_strings(pred_s) = 's';
142 | to_predict = table(testdat.EventId,rank_order,cellstr(pred_strings),'VariableNames',{'EventId' 'RankOrder' 'Class'});
143 | writetable(to_predict,['preds_' strrep(num2str(p),'.','_') '.csv']);
144 |
--------------------------------------------------------------------------------
/matlab_blend/wexpreg.m:
--------------------------------------------------------------------------------
1 | % linear classification with exponential loss
2 | function [nllh,grad,hess] = wexpreg(b,y,x,w)
3 |
4 | f = x*b;
5 | pos = (w.*y).*exp(-f);
6 | neg = (w.*(1-y)).*exp(f);
7 | nllh = sum(pos)+sum(neg);
8 |
9 | if nargout>=2
10 | df = neg-pos;
11 | grad = x'*df;
12 | end
13 |
14 | if nargout==3
15 | dh = pos+neg;
16 | hess = x'*bsxfun(@times,x,dh);
17 | end
--------------------------------------------------------------------------------
/model_output/README.txt:
--------------------------------------------------------------------------------
1 | This folder holds the model output. The model predictions take over 3GB to store and many CPU days to create.
2 | The complete contents of this folder can be downloaded from: https://drive.google.com/file/d/0B4Zly9eEgwFsbUx5cm15UHpJZTg/edit?usp=sharing
--------------------------------------------------------------------------------
/python_train_predict/make_predictions.py:
--------------------------------------------------------------------------------
1 | from higgsml_functions import *
2 |
3 | # define the directory to store everything temporary
4 | temp_dir = 'temp'
5 | myMakeDir(temp_dir)
6 |
7 | # define the directory to find the models in
8 | save_dir = '../model_output'
9 |
10 | # load test data & process
11 | test_dat = loadAndProcessData("../data/test.csv")
12 |
13 | # make predictions for each model in 'save_dir'
14 | makePredictions(test_dat,temp_dir,save_dir)
15 |
--------------------------------------------------------------------------------
/python_train_predict/start_rgf_models_7f_s.py:
--------------------------------------------------------------------------------
1 | from higgsml_functions import *
2 |
3 | # define the directory to store everything temporary
4 | temp_dir = 'temp'
5 | myMakeDir(temp_dir)
6 |
7 | # define the directory to store the models
8 | save_dir = '../model_output'
9 | myMakeDir(save_dir)
10 |
11 | # load training data & process
12 | train_dat,train_target,train_weights = loadAndProcessData("../data/training.csv")
13 |
14 | # parameters for model
15 | param = {}
16 | param['algorithm'] = 'RGF'
17 | param['reg_L2'] = 0.1
18 | param['reg_sL2'] = 0.001
19 | param['loss'] = 'Log'
20 | param['test_interval'] = 1000
21 | param['max_leaf_forest'] = 50000
22 |
23 | # start training / cross validation
24 | startTraining(train_dat,train_target,None,'s',temp_dir,save_dir,param)
25 |
--------------------------------------------------------------------------------
/python_train_predict/start_rgf_models_7f_s_exp.py:
--------------------------------------------------------------------------------
1 | from higgsml_functions import *
2 |
3 | # define the directory to store everything temporary
4 | temp_dir = 'temp'
5 | myMakeDir(temp_dir)
6 |
7 | # define the directory to store the models
8 | save_dir = '../model_output'
9 | myMakeDir(save_dir)
10 |
11 | # load training data & process
12 | train_dat,train_target,train_weights = loadAndProcessData("../data/training.csv")
13 |
14 | # parameters for model
15 | param = {}
16 | param['algorithm'] = 'RGF'
17 | param['reg_L2'] = 0.1
18 | param['reg_sL2'] = 0.001
19 | param['loss'] = 'Expo'
20 | param['test_interval'] = 1000
21 | param['max_leaf_forest'] = 70000
22 |
23 | # start training / cross validation
24 | startTraining(train_dat,train_target,None,'s_exp',temp_dir,save_dir,param)
25 |
26 |
--------------------------------------------------------------------------------
/python_train_predict/start_rgf_models_7f_w.py:
--------------------------------------------------------------------------------
1 | from higgsml_functions import *
2 |
3 | # define the directory to store everything temporary
4 | temp_dir = 'temp'
5 | myMakeDir(temp_dir)
6 |
7 | # define the directory to store the models
8 | save_dir = '../model_output'
9 | myMakeDir(save_dir)
10 |
11 | # load training data & process
12 | train_dat,train_target,train_weights = loadAndProcessData("../data/training.csv")
13 |
14 | # parameters for model
15 | param = {}
16 | param['algorithm'] = 'RGF'
17 | param['reg_L2'] = 0.1
18 | param['reg_sL2'] = 0.001
19 | param['loss'] = 'Log'
20 | param['test_interval'] = 1000
21 | param['max_leaf_forest'] = 50000
22 |
23 | # start training / cross validation
24 | startTraining(train_dat,train_target,train_weights,'w',temp_dir,save_dir,param,predict_weights=True)
25 |
--------------------------------------------------------------------------------
/python_train_predict/start_rgf_models_7f_w_exp.py:
--------------------------------------------------------------------------------
1 | from higgsml_functions import *
2 |
3 | # define the directory to store everything temporary
4 | temp_dir = 'temp'
5 | myMakeDir(temp_dir)
6 |
7 | # define the directory to store the models
8 | save_dir = '../model_output'
9 | myMakeDir(save_dir)
10 |
11 | # load training data & process
12 | train_dat,train_target,train_weights = loadAndProcessData("../data/training.csv")
13 |
14 | # parameters for model
15 | param = {}
16 | param['algorithm'] = 'RGF'
17 | param['reg_L2'] = 0.1
18 | param['reg_sL2'] = 0.001
19 | param['loss'] = 'Expo'
20 | param['test_interval'] = 1000
21 | param['max_leaf_forest'] = 70000
22 |
23 | # start training / cross validation
24 | startTraining(train_dat,train_target,train_weights,'w_exp',temp_dir,save_dir,param,predict_weights=True)
25 |
26 |
--------------------------------------------------------------------------------
/python_train_predict/start_rgf_models_7f_wl.py:
--------------------------------------------------------------------------------
1 | from higgsml_functions import *
2 |
3 | # define the directory to store everything temporary
4 | temp_dir = 'temp'
5 | myMakeDir(temp_dir)
6 |
7 | # define the directory to store the models
8 | save_dir = '../model_output'
9 | myMakeDir(save_dir)
10 |
11 | # load training data & process
12 | train_dat,train_target,train_weights = loadAndProcessData("../data/training.csv")
13 |
14 | # parameters for model
15 | param = {}
16 | param['algorithm'] = 'RGF'
17 | param['reg_L2'] = 0.1
18 | param['reg_sL2'] = 0.001
19 | param['loss'] = 'Log'
20 | param['test_interval'] = 1000
21 | param['max_leaf_forest'] = 50000
22 |
23 | # start training / cross validation
24 | startTraining(train_dat,train_target,train_weights,'wl',temp_dir,save_dir,param)
25 |
26 |
--------------------------------------------------------------------------------
/python_train_predict/start_rgf_models_7f_wl_exp.py:
--------------------------------------------------------------------------------
1 | from higgsml_functions import *
2 |
3 | # define the directory to store everything temporary
4 | temp_dir = 'temp'
5 | myMakeDir(temp_dir)
6 |
7 | # define the directory to store the models
8 | save_dir = '../model_output'
9 | myMakeDir(save_dir)
10 |
11 | # load training data & process
12 | train_dat,train_target,train_weights = loadAndProcessData("../data/training.csv")
13 |
14 | # parameters for model
15 | param = {}
16 | param['algorithm'] = 'RGF'
17 | param['reg_L2'] = 0.1
18 | param['reg_sL2'] = 0.001
19 | param['loss'] = 'Expo'
20 | param['test_interval'] = 1000
21 | param['max_leaf_forest'] = 34000
22 |
23 | # start training / cross validation
24 | startTraining(train_dat,train_target,train_weights,'wl_exp',temp_dir,save_dir,param)
25 |
26 |
--------------------------------------------------------------------------------
/rgf1.2/BUILDLOG:
--------------------------------------------------------------------------------
1 | 09/24/2012: Version 1.2 first release.
2 | 10/15/2012: To remove a compile error with g++ on OS X.
3 |
--------------------------------------------------------------------------------
/rgf1.2/README:
--------------------------------------------------------------------------------
1 | ************************************************************************
2 |
3 | README
4 |
5 | RGF Version 1.2
6 | September 2012
7 |
8 | C++ programs for regularized greedy forest
9 |
10 | ************************************************************************
11 |
12 | Contents:
13 | ---------
14 |
15 | 1. Introduction
16 | 1.1. System Requirements
17 | 2. Download and Installation
18 | 3. Creating the Executable
19 | 3.1. Windows
20 | 3.2. Unix-like Systems
21 | 3.3. [Optional] Endianness Consideration
22 | 4. Documentation
23 | 5. Contact
24 | 6. Copyright
25 | 7. References
26 |
27 | ---------------
28 | 1. Introduction
29 | This software package provides implementation of regularized greedy forest
30 | (RGF) described in [1].
31 |
32 | 1.1 System Requirement
33 | Executables are provided only for some versions of Windows (see Section 3 for
34 | detail). If provided executables do not work for your environment, you need
35 | to compile C++ code.
36 |
37 | To use the provided tools and to go through the examples in the user guide,
38 | Perl is required.
39 |
40 | -----------
41 | 2. Download and Installation
42 | The package is provided as a zip file:
43 |
44 | rgf1.2.zip
45 |
46 | Download and extract the content.
47 |
48 | The top directory of the extracted content is "rgf1.2". Below all the
49 | path expressions are relative to "rgf1.2".
50 |
51 | --------------------------
52 | 3 Creating the Executable
53 |
54 | To go through the examples in the user guide, your executable needs to be
55 | at the "bin" directory. Otherwise, your executable can be anywhere you like.
56 |
57 | -----------
58 | 3.1 Windows
59 | A 64-bit executable is provided with the filename "rgf.exe.file" at the
60 | "bin" directory. It was tested on Windows 7 with the latest service pack as
61 | of 9/17/2012. You can either rename it to "rgf.exe" and use it, or you can
62 | rebuild it by yourself using the provided solution file for MS Visual C++
63 | 2010 Express: "proj_vc2010\rgf\rgf.sln".
64 |
65 | ---------------------
66 | 3.2 Unix-like Systems
67 | You need to build your executable from the source code. A make file
68 | "makefile" is provided at the top directory. It is configured to use
69 | "g++" and always compile everything. You may need to customize "makefile"
70 | for your environment.
71 |
72 | To build the executable, change the current directory to the top
73 | directory "rgf1.2" and enter in the command line "make". Check the
74 | "bin" directory to make sure that your new executable "rgf" is there.
75 |
76 | ----------------------------------------
77 | 3.3 [Optional] Endianness Consideration
78 | The models obtained by RGF training can be saved to files.
79 | The model files are essentially snap shots of memory that include
80 | numerical values. Therefore, the model files are sensitive to
81 | "endianness" of the environments. For this reason, if you wish to
82 | share model files among environments of different endianness, you need
83 | to follow the instructions below. Otherwise, you can skip this section.
84 |
85 | To share model files among the environments of different endianness,
86 | build your executable for the environment with big-endian with the
87 | compile option:
88 |
89 | /D_AZ_BIG_ENDIAN_
90 |
91 | By doing so, the executable in your big-endian environment swaps the
92 | byte order of numerical values before writing and after reading the
93 | model files.
94 |
95 | ----------------
96 | 4. Documentation
97 | rgf1.2-guide.pdf "Regularized Greedy Forest Version 1.2: User Guide" is included.
98 |
99 | ----------
100 | 5. Contact
101 | riejohnson@gmail.com
102 |
103 | ------------
104 | 6. Copyright
105 | RGF Version 1.2 is distributed under the GNU public license. Please read
106 | the file COPYING.
107 |
108 | -------------
109 | 7. References
110 |
111 | [1] Rie Johnson and Tong Zhang. Learning nonlinear functions using
112 | regularized greedy forest. Technical report: arXiv:1109.0887v5, 2011.
113 |
--------------------------------------------------------------------------------
/rgf1.2/bin/rgf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/bin/rgf
--------------------------------------------------------------------------------
/rgf1.2/bin/rgf.exe.file:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/bin/rgf.exe.file
--------------------------------------------------------------------------------
/rgf1.2/makefile:
--------------------------------------------------------------------------------
1 |
2 | BIN_NAME = rgf
3 | BIN_DIR = bin
4 | TARGET = $(BIN_DIR)/$(BIN_NAME)
5 | CFLAGS = -Isrc/com -Isrc/tet_tools -O2
6 |
7 | CPP_FILES= \
8 | src/tet/driv_rgf.cpp \
9 | src/com/AzDmat.cpp \
10 | src/tet/AzFindSplit.cpp \
11 | src/com/AzIntPool.cpp \
12 | src/com/AzLoss.cpp \
13 | src/tet/AzOptOnTree_TreeReg.cpp \
14 | src/tet/AzOptOnTree.cpp \
15 | src/com/AzParam.cpp \
16 | src/tet/AzReg_Tsrbase.cpp \
17 | src/tet/AzReg_TsrOpt.cpp \
18 | src/tet/AzReg_TsrSib.cpp \
19 | src/tet/AzRgf_FindSplit_Dflt.cpp \
20 | src/tet/AzRgf_FindSplit_TreeReg.cpp \
21 | src/tet/AzRgf_Optimizer_Dflt.cpp \
22 | src/tet/AzRgforest.cpp \
23 | src/tet/AzRgfTree.cpp \
24 | src/com/AzSmat.cpp \
25 | src/tet/AzSortedFeat.cpp \
26 | src/com/AzStrPool.cpp \
27 | src/com/AzSvDataS.cpp \
28 | src/com/AzTaskTools.cpp \
29 | src/tet/AzTETmain.cpp \
30 | src/tet/AzTETproc.cpp \
31 | src/com/AzTools.cpp \
32 | src/tet/AzTree.cpp \
33 | src/tet/AzTreeEnsemble.cpp \
34 | src/tet/AzTrTree.cpp \
35 | src/tet/AzTrTreeFeat.cpp \
36 | src/com/AzUtil.cpp
37 |
38 | #$(TARGET): $(CPP_FILES)
39 | all:
40 | /bin/rm -f $(TARGET)
41 | g++ $(CPP_FILES) $(CFLAGS) -o $(TARGET)
42 |
43 | clean:
44 | /bin/rm -f $(TARGET)
45 |
--------------------------------------------------------------------------------
/rgf1.2/proj_vc2010/rgf/rgf.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 11.00
3 | # Visual C++ Express 2010
4 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "rgf", "rgf.vcxproj", "{27341E76-36FB-4CA4-848C-1782869FE39B}"
5 | EndProject
6 | Global
7 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
8 | Debug|Win32 = Debug|Win32
9 | Debug|x64 = Debug|x64
10 | Release|Win32 = Release|Win32
11 | Release|x64 = Release|x64
12 | EndGlobalSection
13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
14 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Debug|Win32.ActiveCfg = Release|Win32
15 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Debug|Win32.Build.0 = Release|Win32
16 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Debug|x64.ActiveCfg = Release|x64
17 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Debug|x64.Build.0 = Release|x64
18 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Release|Win32.ActiveCfg = Release|Win32
19 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Release|Win32.Build.0 = Release|Win32
20 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Release|x64.ActiveCfg = Release|x64
21 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Release|x64.Build.0 = Release|x64
22 | EndGlobalSection
23 | GlobalSection(SolutionProperties) = preSolution
24 | HideSolutionNode = FALSE
25 | EndGlobalSection
26 | EndGlobal
27 |
--------------------------------------------------------------------------------
/rgf1.2/rgf1.2-guide.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/rgf1.2-guide.pdf
--------------------------------------------------------------------------------
/rgf1.2/src/com/AzBmat.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzBmat.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_BMAT_HPP_
20 | #define _AZ_BMAT_HPP_
21 | #include "AzUtil.hpp"
22 |
23 | //! binary matrix
24 | class AzBmat {
25 | protected:
26 | int row_num;
27 | AzDataArray a;
28 |
29 | public:
30 | AzBmat() : row_num(0) {}
31 | AzBmat(int inp_row_num, int inp_col_num) : row_num(0) {
32 | resize(inp_col_num);
33 | }
34 | AzBmat(const AzBmat *inp) : row_num(0) {
35 | set(inp);
36 | }
37 | inline void set(const AzBmat *inp) {
38 | row_num = inp->row_num;
39 | a.reset(&inp->a);
40 | }
41 | AzBmat(const AzBmat &inp) : row_num(0) {
42 | set(&inp);
43 | }
44 | AzBmat & operator =(const AzBmat &inp) {
45 | if (this == &inp) return *this;
46 | set(&inp);
47 | return *this;
48 | }
49 | inline void reform(int inp_row_num, int inp_col_num) {
50 | reset();
51 | row_num = inp_row_num;
52 | resize(inp_col_num);
53 | }
54 | inline void resize(int new_col_num) {
55 | a.resz(new_col_num);
56 | }
57 |
58 | inline void reset() {
59 | row_num = 0;
60 | a.reset();
61 | }
62 | inline int rowNum() const {
63 | return row_num;
64 | }
65 | inline int colNum() const {
66 | return a.cursor();
67 | }
68 |
69 | inline const AzIntArr *on_rows(int col) const {
70 | return a.point(col);
71 | }
72 | inline void clear(int fx) {
73 | a.point_u(fx)->reset();
74 | }
75 |
76 | inline void load(int col, const AzIntArr *ia_on_rows) {
77 | if (ia_on_rows == NULL || ia_on_rows->size() <= 0) return;
78 |
79 | if (ia_on_rows->min() < 0 ||
80 | ia_on_rows->max() >= row_num) {
81 | throw new AzException("AzBmat::load", "wrong row#");
82 | }
83 | a.point_u(col)->reset(ia_on_rows);
84 | }
85 | };
86 | #endif
87 |
88 |
89 |
--------------------------------------------------------------------------------
/rgf1.2/src/com/AzException.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzException.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_EXCEPTION_HPP_
20 | #define _AZ_EXCEPTION_HPP_
21 |
22 | #include
23 | #include
24 | #include
25 | #include
26 | #include
27 | using namespace std;
28 |
29 | enum AzRetCode {
30 | AzNormal=0,
31 | AzAllocError=10,
32 | AzFileIOError=20,
33 | AzInputError=30,
34 | AzInputMissing=31,
35 | AzInputNotValid=32,
36 | AzConflict=100, /* all others */
37 | };
38 |
39 | /*-----------------------------------------------------*/
40 | class AzException {
41 | public:
42 | AzException(const char *string1,
43 | const char *string2,
44 | const char *string3=NULL)
45 | {
46 | reset(AzConflict, string1, string2, string3);
47 | }
48 |
49 | AzException(AzRetCode retcode,
50 | const char *string1,
51 | const char *string2,
52 | const char *string3=NULL)
53 | {
54 | reset(retcode, string1, string2, string3);
55 | }
56 |
57 | template
58 | AzException(AzRetCode retcode,
59 | const char *string1,
60 | const char *string2,
61 | const char *string3,
62 | T anything)
63 | {
64 | reset(retcode, string1, string2, string3);
65 | s3 << "; " << anything;
66 | }
67 |
68 | void reset(AzRetCode retcode,
69 | const char *str1,
70 | const char *str2,
71 | const char *str3)
72 | {
73 | this->retcode = retcode;
74 | if (str1 != NULL) s1 << str1;
75 | if (str2 != NULL) s2 << str2;
76 | if (str3 != NULL) s3 << str3;
77 | }
78 |
79 | AzRetCode getReturnCode() {
80 | return retcode;
81 | }
82 |
83 | string getMessage()
84 | {
85 | if (retcode == AzNormal) {
86 |
87 | }
88 | else if (retcode == AzAllocError) {
89 | message << "!Memory alloc error!";
90 | }
91 | else if (retcode == AzFileIOError) {
92 | message << "!File I/O error!";
93 | }
94 | else if (retcode == AzInputError) {
95 | message << "!Input error!";
96 | }
97 | else if (retcode == AzInputMissing) {
98 | message << "!Missing input!";
99 | }
100 | else if (retcode == AzInputNotValid) {
101 | message << "!Input value is not valid!";
102 | }
103 | else if (retcode == AzConflict) {
104 | message << "Conflict";
105 | }
106 | else {
107 | message << "Unknown error";
108 | }
109 |
110 | message << ": ";
111 | if (s1.str().find("Az") == 0) {
112 | message << "(Detected in " << s1.str() << ") " << endl;
113 | }
114 | else {
115 | message << s1.str() << " ";
116 | }
117 | message << s2.str();
118 | if (s3.str().length() > 0) {
119 | message << " " << s3.str();
120 | }
121 | message << endl;
122 | return message.str();
123 | }
124 |
125 | protected:
126 | AzRetCode retcode;
127 |
128 | stringstream s1, s2, s3;
129 | stringstream message;
130 | };
131 |
132 | #endif
133 |
--------------------------------------------------------------------------------
/rgf1.2/src/com/AzIntPool.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzIntPool.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_INT_POOL_HPP_
20 | #define _AZ_INT_POOL_HPP_
21 |
22 | #include "AzUtil.hpp"
23 |
24 | class AzIpEnt {
25 | public:
26 | int offs;
27 | const int *ints;
28 | int num;
29 | int count;
30 | int value;
31 | AzIpEnt() {
32 | offs = 0;
33 | ints = NULL;
34 | num = count = 0;
35 | value = -1;
36 | }
37 | };
38 |
39 | //! Store integer arrays. Searchable after committed.
40 | class AzIntPool {
41 | protected:
42 | AzIpEnt *ent;
43 | AzBaseArray a_ent;
44 | int ent_num;
45 |
46 | AzBaseArray a_data;
47 | int *data;
48 | int data_num;
49 |
50 | bool isCommitted;
51 |
52 | public:
53 | AzIntPool() : ent(NULL), ent_num(0), data(NULL), data_num(0), isCommitted(true) {}
54 | AzIntPool(AzFile *file)
55 | : ent(NULL), ent_num(0), data(NULL), data_num(0), isCommitted(true) {
56 | _read(file);
57 | }
58 | AzIntPool(const AzIntPool *inp) /* copy */
59 | : ent(NULL), ent_num(0), data(NULL), data_num(0), isCommitted(true) {
60 | reset(inp);
61 | }
62 | ~AzIntPool() {}
63 |
64 | void reset() {
65 | a_ent.free(&ent); ent_num = 0;
66 | a_data.free(&data); data_num = 0;
67 | isCommitted = true;
68 | }
69 | void reset(const AzIntPool *ip);
70 |
71 | void read(AzFile *file) {
72 | reset();
73 | _read(file);
74 | }
75 |
76 | int write(AzFile *file);
77 |
78 | inline void update(int ex,
79 | const AzIntArr *iq,
80 | int count=1,
81 | int val=-1) {
82 | update(ex, iq->point(), iq->size(), count, val);
83 | }
84 | void update(int ex,
85 | const int *ints,
86 | int ints_num,
87 | int count=1,
88 | int val=-1);
89 |
90 | int put(const int *ints, int ints_num,
91 | int count=1,
92 | int value=-1);
93 | int put(const AzIntArr *intq, int
94 | count=1) {
95 | return this->put(intq->point(), intq->size(), count);
96 | }
97 |
98 | /*--- put with a value ---*/
99 | int putv(const AzIntArr *intq,
100 | int value) {
101 | int count = 1;
102 | return this->put(intq->point(), intq->size(), count, value);
103 | }
104 |
105 | inline int getValue(int idx) const {
106 | checkRange(idx, "AzIntPool::getValue");
107 | return ent[idx].value;
108 | }
109 |
110 | void commit();
111 | int size() const {return ent_num;}
112 | const int *point(int ent_no, int *out_ints_num=NULL) const;
113 | int getCount(int ent_no) const;
114 | void setCount(int ent_no, int new_count);
115 | int find(const int *ints, int ints_num) const;
116 | int find(const AzIntArr *intq) const {
117 | return find(intq->point(), intq->size());
118 | }
119 |
120 | void get(int ent_no, AzIntArr *intq) const
121 | {
122 | int num;
123 | const int *ints = point(ent_no, &num);
124 | intq->concat(ints, num);
125 | }
126 |
127 | inline void erase(int ent_no) {
128 | shorten(ent_no, 0);
129 | }
130 | void shorten(int ent_no, int new_len);
131 |
132 | void dump(const AzOut &, const char *header) const;
133 |
134 | bool isThisCommitted() const {
135 | return isCommitted;
136 | }
137 |
138 | inline void clear() {
139 | reset();
140 | }
141 |
142 | void concat(const AzIntPool *inp) {
143 | int num = inp->size();
144 | int ix;
145 | for (ix = 0; ix < num; ++ix) {
146 | int len;
147 | const int *ints = inp->point(ix, &len);
148 | int count = inp->getCount(ix);
149 | int value = inp->getValue(ix);
150 | put(ints, len, count, value);
151 | }
152 | }
153 |
154 | /*--- prohibit assign operator ---*/
155 | AzIntPool(const AzIntPool &) {
156 | throw new AzException("AzIntPool =", "no support");
157 | }
158 | AzIntPool & operator =(const AzIntPool &inp) {
159 | if (this == &inp) return *this;
160 | throw new AzException("AzIntPool =", "no support");
161 | }
162 |
163 |
164 | protected:
165 | void _swap();
166 | void _read(AzFile *file);
167 |
168 | int inc_ent();
169 | int inc_data(int min_inc);
170 |
171 | inline void checkRange(int ent_no, const char *eyec) const
172 | {
173 | if (ent_no < 0 || ent_no >= ent_num) {
174 | throw new AzException(eyec, "out of range");
175 | }
176 | }
177 | };
178 |
179 | #endif
180 |
--------------------------------------------------------------------------------
/rgf1.2/src/com/AzOut.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzOut.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_OUT_HPP_
20 | #define _AZ_OUT_HPP_
21 |
22 | class AzOut {
23 | protected:
24 | bool isActive;
25 | int level;
26 | public:
27 | ostream *o;
28 |
29 | inline AzOut() : o(NULL), isActive(true), level(0) {}
30 | inline AzOut(ostream *o_ptr) : isActive(true), level(0) {
31 | o = o_ptr;
32 | }
33 | inline void reset(ostream *o_ptr) {
34 | o = o_ptr;
35 | activate();
36 | }
37 |
38 | inline void deactivate() {
39 | isActive = false;
40 | }
41 | inline void activate() {
42 | isActive = true;
43 | }
44 | inline void setStdout() {
45 | o = &cout;
46 | activate();
47 | }
48 | inline void setStderr() {
49 | o = &cerr;
50 | activate();
51 | }
52 | inline bool isNull() const {
53 | if (!isActive) return true;
54 | if (o == NULL) return true;
55 | return false;
56 | }
57 | inline void flush() const {
58 | if (o != NULL) o->flush();
59 | }
60 | inline void setLevel(int inp_level) {
61 | level = inp_level;
62 | }
63 | inline int getLevel() const {
64 | return level;
65 | }
66 | };
67 | #endif
68 |
--------------------------------------------------------------------------------
/rgf1.2/src/com/AzParam.cpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzParam.cpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #include "AzParam.hpp"
20 |
21 | /*-------------------------------------------------------------*/
22 | void AzParam::check(const AzOut &out,
23 | AzBytArr *s_unused)
24 | {
25 | if (param == NULL) return;
26 |
27 | AzStrPool sp_unused, sp_kw;
28 | analyze(&sp_unused, &sp_kw);
29 |
30 | bool error = false;
31 |
32 | int ix;
33 | for (ix = 0; ix < sp_unused.size(); ++ix) {
34 | if (sp_unused.getLen(ix) <= 0) continue;
35 |
36 | /*--- this parameter wasn't used by anyone. ---*/
37 | if (s_unused != NULL) {
38 | /*--- may be used by someone else ---*/
39 | if (s_unused->length() > 0) s_unused->concat(dlm);
40 | s_unused->concat(sp_unused.c_str(ix));
41 | }
42 | else {
43 | /*--- it could be a typo ---*/
44 | AzBytArr s;
45 | s.concat("!Warning! Unknown parameter: \"");
46 | s.concat(sp_unused.c_str(ix)); s.concat("\"");
47 | AzPrint::writeln(out, s);
48 | }
49 | }
50 |
51 | AzBytArr s_err;
52 | int kw_num = sp_kw.size();
53 | sp_kw.commit();
54 | if (sp_kw.size() != kw_num) {
55 | for (ix = 0; ix < sp_kw.size(); ++ix) {
56 | int count = sp_kw.getCount(ix);
57 | if (count > 1) {
58 | /*--- this keyword appears more than once. ---*/
59 | if (s_err.length() <= 0) s_err.newline();
60 | s_err.concat(" Duplicated keyword: ");
61 | s_err.concat(sp_kw.c_str(ix));
62 | s_err.newline();
63 | error = true;
64 | }
65 | }
66 | }
67 |
68 | if (error) {
69 | throw new AzException(AzInputError, "AzParam::check", s_err.c_str());
70 | }
71 | }
72 |
73 | /*-------------------------------------------------------------*/
74 | void AzParam::analyze(AzStrPool *sp_unused,
75 | AzStrPool *sp_kw)
76 | {
77 | if (param == NULL) return;
78 | sp_used_kw.commit();
79 |
80 | AzStrPool sp_kwval;
81 | AzTools::getStrings(param, dlm, &sp_kwval);
82 | int ix;
83 | for (ix = 0; ix < sp_kwval.size(); ++ix) {
84 | int len;
85 | const AzByte *ptr = sp_kwval.point(ix, &len);
86 | if (len <= 0) continue;
87 |
88 | AzBytArr s_kw;
89 | AzTools::getString(&ptr, ptr+len, kwval_dlm, &s_kw);
90 | if (s_kw.length() < len) s_kw.concat(kwval_dlm);
91 | if (sp_used_kw.find(&s_kw) < 0) {
92 | /*--- this parameter wasn't used by anyone. ---*/
93 | sp_unused->put(sp_kwval.c_str(ix));
94 | }
95 | if (sp_kw != NULL) sp_kw->put(&s_kw);
96 | }
97 | }
98 |
--------------------------------------------------------------------------------
/rgf1.2/src/com/AzParam.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzParam.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_PARAM_HPP_
20 | #define _AZ_PARAM_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzLoss.hpp"
24 | #include "AzStrPool.hpp"
25 | #include "AzPrint.hpp"
26 | #include "AzTools.hpp"
27 |
28 | //! Parse parameters
29 | class AzParam {
30 | protected:
31 | const char *param;
32 | char dlm, kwval_dlm;
33 | AzStrPool sp_used_kw;
34 | bool doCheck; /* check unknown/duplicated keywords */
35 | public:
36 | AzParam(const char *inp_param,
37 | bool inp_doCheck=true,
38 | char inp_dlm=',',
39 | char inp_kwval_dlm='=') : sp_used_kw(100, 30)
40 | {
41 | param = inp_param;
42 | doCheck = inp_doCheck;
43 | dlm = inp_dlm;
44 | kwval_dlm = inp_kwval_dlm;
45 | }
46 |
47 | inline const char *c_str() const { return param; }
48 | inline void swOn(bool *swch, const char *kw,
49 | bool doCheckKw=true) {
50 | if (param == NULL) return;
51 | if (doCheckKw) {
52 | if (strstr(kw, "Dont") == kw ||
53 | (strstr(kw, "No") == kw && strstr(kw, "Normalize") == NULL)) {
54 | throw new AzException("AzParam::swOn",
55 | "On-kw shouldn't begin with \"Dont\" or \"No\"", kw);
56 | }
57 | }
58 | const char *ptr = pointAfterKw(param, kw);
59 | if (ptr != NULL &&
60 | (*ptr == '\0' || *ptr == dlm)) {
61 | *swch = true;
62 | }
63 | if (doCheck) sp_used_kw.put(kw);
64 | }
65 | inline void swOff(bool *swch, const char *kw,
66 | bool doCheckKw=true) {
67 | if (param == NULL) return;
68 | if (doCheckKw) {
69 | if (strstr(kw, "Dont") != kw &&
70 | (strstr(kw, "No") != kw && strstr(kw, "Normalize") == NULL)) {
71 | throw new AzException("AzParam::swOff",
72 | "Off-kw should start with \"dont\" or \"No\"", kw);
73 | }
74 | }
75 | const char *ptr = pointAfterKw(param, kw);
76 | if (ptr != NULL &&
77 | (*ptr == '\0' || *ptr == dlm) ) {
78 | *swch = false;
79 | }
80 | if (doCheck) sp_used_kw.put(kw);
81 | }
82 | inline void vStr(const char *kw, AzBytArr *s) {
83 | if (param == NULL) return;
84 | const char *bp = pointAfterKw(param, kw);
85 | if (bp == NULL) return;
86 | const char *ep = pointAt(bp, dlm);
87 | s->reset();
88 | s->concat(bp, Az64::ptr_diff(ep-bp, "AzParam::vStr"));
89 | if (doCheck) sp_used_kw.put(kw);
90 | }
91 |
92 | inline void vFloat(const char *kw, double *out_value) {
93 | if (param == NULL) return;
94 | const char *ptr = pointAfterKw(param, kw);
95 | if (ptr == NULL) return;
96 | *out_value = atof(ptr);
97 | if (doCheck) sp_used_kw.put(kw);
98 | }
99 | inline void vInt(const char *kw, int *out_value) {
100 | if (param == NULL) return;
101 | const char *ptr = pointAfterKw(param, kw);
102 | if (ptr == NULL) return;
103 | *out_value = atol(ptr);
104 | if (doCheck) sp_used_kw.put(kw);
105 | }
106 |
107 | inline void vLoss(const char *kw, AzLossType *out_loss) {
108 | if (param == NULL) return;
109 | AzBytArr s;
110 | vStr(kw, &s);
111 | if (s.length() > 0) {
112 | *out_loss = AzLoss::lossType(s.c_str());
113 | }
114 | if (doCheck) sp_used_kw.put(kw);
115 | }
116 |
117 | void check(const AzOut &out, AzBytArr *s_unused_param=NULL);
118 |
119 | protected:
120 | inline const char *pointAfterKw(const char *inp_inp, const char *kw) const {
121 | const char *ptr = NULL;
122 | const char *inp = inp_inp;
123 | for ( ; ; ) {
124 | ptr = strstr(inp, kw);
125 | if (ptr == NULL) return NULL;
126 | if (ptr == inp || *(ptr-1) == dlm) {
127 | break;
128 | }
129 | inp = ptr + strlen(kw);
130 | }
131 | ptr += strlen(kw);
132 | return ptr;
133 | }
134 | inline static const char *pointAt(const char *inp, const char *kw) {
135 | const char *ptr = strstr(inp, kw);
136 | if (ptr == NULL) return inp + strlen(inp);
137 | return ptr;
138 | }
139 | inline static const char *pointAt(const char *inp, char ch) {
140 | const char *ptr = strchr(inp, ch);
141 | if (ptr == NULL) return inp + strlen(inp);
142 | return ptr;
143 | }
144 |
145 | void analyze(AzStrPool *sp_unused,
146 | AzStrPool *sp_kw);
147 | };
148 | #endif
149 |
150 |
--------------------------------------------------------------------------------
/rgf1.2/src/com/AzPerfResult.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzPerfResult.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_PERF_RESULT_HPP_
20 | #define _AZ_PERF_RESULT_HPP_
21 |
22 | #include "AzUtil.hpp"
23 |
24 | #define AzOtherCat "_X_"
25 |
26 | /*--------------------------------------------------*/
27 | enum AzPerfType {
28 | AzPerfType_Acc = 0,
29 | AzPerfType_RMSE = 1,
30 | };
31 | #define AzPerfType_Num 2
32 | static const char *perf_str[AzPerfType_Num] = {
33 | "acc", "rmse",
34 | };
35 |
36 | /*--------------------------------------------------*/
37 | class AzPerfResult {
38 | public:
39 | AzPerfResult() {
40 | p=r=f=acc=breakEven_f=breakEven_acc=rmse=loss=-1;
41 | }
42 | double p, r, f, acc, breakEven_f, breakEven_acc, rmse, loss;
43 | inline void put(double inp_p, double inp_r, double inp_f, double inp_acc,
44 | double inp_be_f, double inp_be_acc,
45 | double inp_rmse, double inp_loss) {
46 | p = inp_p;
47 | r = inp_r;
48 | f = inp_f;
49 | acc = inp_acc;
50 | breakEven_f = inp_be_f;
51 | breakEven_acc = inp_be_acc;
52 | rmse = inp_rmse;
53 | loss=inp_loss;
54 | }
55 | double getPerf(AzPerfType p_type) {
56 | if (p_type == AzPerfType_Acc) return acc;
57 | if (p_type == AzPerfType_RMSE) return rmse;
58 | return -1;
59 | }
60 | static const char *getPerfStr(AzPerfType p_type) {
61 | if (p_type < 0 ||
62 | p_type >= AzPerfType_Num) return "???";
63 | return perf_str[p_type];
64 | }
65 |
66 | static double isBetter(AzPerfType p_type,
67 | double p, double comp_p) {
68 | /*--- negative means unset ---*/
69 | if (p < 0) return false;
70 | if (comp_p < 0) return true;
71 |
72 | if (p_type == AzPerfType_RMSE) {
73 | if (p < comp_p) return true;
74 | }
75 | else {
76 | if (p > comp_p) return true;
77 | }
78 | return false;
79 | }
80 | void zeroOut() {
81 | p=r=f=acc=breakEven_f=breakEven_acc=rmse=loss=0;
82 | }
83 | void add(const AzPerfResult *inp) {
84 | p+=inp->p;
85 | r+=inp->r;
86 | f+=inp->f;
87 | acc+=inp->acc;
88 | breakEven_f=inp->breakEven_f;
89 | breakEven_acc=inp->breakEven_acc;
90 | rmse+=inp->rmse;
91 | loss+=inp->loss;
92 | }
93 | void multiply(double val) {
94 | p*=val;
95 | r*=val;
96 | f*=val;
97 | acc*=val;
98 | breakEven_f*=val;
99 | breakEven_acc*=val;
100 | rmse*=val;
101 | loss*=val;
102 | }
103 | };
104 | #endif
105 |
--------------------------------------------------------------------------------
/rgf1.2/src/com/AzReadOnlyMatrix.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzReadOnlyMatrix.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_READONLY_MATRIX_HPP_
20 | #define _AZ_READONLY_MATRIX_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzStrArray.hpp"
24 |
25 | //! Abstract class: interface for read-only vectors.
26 | class AzReadOnlyVector {
27 | public:
28 | virtual int rowNum() const = 0;
29 | virtual double get(int row_no) const = 0;
30 | virtual int next(AzCursor &cursor, double &out_val) const = 0;
31 | virtual bool isZero() const = 0;
32 | virtual void dump(const AzOut &out, const char *header,
33 | const AzStrArray *sp_row = NULL,
34 | int cut_num = -1) const = 0;
35 | virtual double selfInnerProduct() const = 0;
36 | virtual int nonZeroRowNum() const = 0;
37 | virtual double sum() const = 0;
38 |
39 | /*--------------------------------------*/
40 | virtual void writeText(const char *fn, int digits) const {
41 | AzIntArr ia;
42 | ia.range(0, rowNum());
43 | writeText(fn, &ia, digits);
44 | }
45 |
46 | /*--------------------------------------*/
47 | virtual void writeText(const char *fn, const AzIntArr *ia, int digits) const {
48 | AzFile file(fn);
49 | file.open("wb");
50 | AzBytArr s;
51 | int ix;
52 | for (ix = 0; ix < ia->size(); ++ix) {
53 | int row = ia->get(ix);
54 | double val = get(row);
55 | s.cn(val, digits);
56 | s.nl();
57 | }
58 | s.writeText(&file);
59 | file.close(true);
60 | }
61 |
62 | /*--------------------------------------*/
63 | virtual void to_sparse(AzBytArr *s, int digits) const {
64 | AzCursor cur;
65 | for ( ; ; ) {
66 | double val;
67 | int row = next(cur, val);
68 | if (row < 0) break;
69 | if (s->length() > 0) {
70 | s->c(' ');
71 | }
72 | s->cn(row);
73 | if (val != 1) {
74 | s->c(':'); s->cn(val, digits);
75 | }
76 | }
77 | }
78 |
79 | /*--------------------------------------*/
80 | virtual void to_dense(AzBytArr *s, int digits) const {
81 | int row;
82 | for (row = 0; row < rowNum(); ++row) {
83 | double val = get(row);
84 | if (row > 0) {
85 | s->c(" ");
86 | }
87 | s->cn(val, digits);
88 | }
89 | }
90 | };
91 |
92 | //! Abstract class: interface for read-only matrices.
93 | class AzReadOnlyMatrix {
94 | public:
95 | virtual int rowNum() const = 0;
96 | virtual int colNum() const = 0;
97 |
98 | virtual void destroy() = 0;
99 | #if 0
100 | virtual void destroy(int col) = 0;
101 | #endif
102 |
103 | virtual double get(int row_no, int col_no) const = 0;
104 |
105 | virtual const AzReadOnlyVector *col(int col_no) const = 0;
106 |
107 | virtual bool isZero() const = 0;
108 | virtual bool isZero(int col) const = 0;
109 |
110 | virtual void dump(const AzOut &out, const char *header,
111 | const AzStrArray *sp_row = NULL, const AzStrArray *sp_col = NULL,
112 | int cut_num = -1) const = 0;
113 |
114 | /*--------------------------------------*/
115 | virtual void writeText(const char *fn, int digits,
116 | bool doSparse=false) const {
117 | AzIntArr ia;
118 | ia.range(0, colNum());
119 | writeText(fn, &ia, digits, doSparse);
120 | }
121 |
122 | /*--------------------------------------*/
123 | virtual void writeText(const char *fn, const AzIntArr *ia,
124 | int digits,
125 | bool doSparse=false) const {
126 | if (doSparse) {
127 | writeText_sparse(fn, ia, digits);
128 | return;
129 | }
130 | AzFile file(fn);
131 | file.open("wb");
132 | int num;
133 | const int *cxs = ia->point(&num);
134 | int ix;
135 | for (ix = 0; ix < num; ++ix) {
136 | int cx = cxs[ix];
137 | AzBytArr s;
138 | col(cx)->to_dense(&s, digits);
139 | s.nl();
140 | s.writeText(&file);
141 | }
142 | file.close(true);
143 | }
144 |
145 | /*--------------------------------------*/
146 | virtual void writeText_sparse(const char *fn, const AzIntArr *ia, int digits) const {
147 | AzFile file(fn);
148 | file.open("wb");
149 | AzBytArr s_header("sparse "); s_header.cn(rowNum()); s_header.nl();
150 | s_header.writeText(&file);
151 |
152 | int num;
153 | const int *cxs = ia->point(&num);
154 | int ix;
155 | for (ix = 0; ix < num; ++ix) {
156 | int cx = cxs[ix];
157 | AzBytArr s;
158 | col(cx)->to_sparse(&s, digits);
159 | s.nl();
160 | s.writeText(&file);
161 | }
162 | file.close(true);
163 | }
164 | };
165 | #endif
166 |
--------------------------------------------------------------------------------
/rgf1.2/src/com/AzStrArray.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzStrArray.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_STR_ARRAY_HPP_
20 | #define _AZ_STR_ARRAY_HPP_
21 |
22 | #include "AzUtil.hpp"
23 |
24 | class AzStrArray {
25 | public:
26 | virtual int size() const = 0;
27 | virtual const char *c_str(int no) const = 0;
28 | void get(int no, AzBytArr *byteq) const {
29 | byteq->reset();
30 | byteq->concat(c_str(no));
31 | }
32 |
33 | virtual bool isSame(const AzStrArray *inp) const {
34 | if (size() != inp->size()) {
35 | return false;
36 | }
37 | int ix;
38 | for (ix = 0; ix < size(); ++ix) {
39 | AzBytArr s0;
40 | get(ix, &s0);
41 | if (s0.compare(inp->c_str(ix)) != 0) {
42 | return false;
43 | }
44 | }
45 | return true;
46 | }
47 |
48 | /*---*/
49 | virtual void writeText(const char *fn) const {
50 | AzFile file(fn);
51 | file.open("wb");
52 | int ix;
53 | for (ix = 0; ix < size(); ++ix) {
54 | AzBytArr s;
55 | get(ix, &s);
56 | s.nl();
57 | s.writeText(&file);
58 | }
59 | file.close(true);
60 | }
61 | };
62 |
63 | #endif
64 |
65 |
--------------------------------------------------------------------------------
/rgf1.2/src/com/AzSvFeatInfo.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzSvFeatInfo.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_SV_FEAT_INFO_HPP_
20 | #define _AZ_SV_FEAT_INFO_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzStrPool.hpp"
24 | #include "AzPrint.hpp"
25 |
26 | //! Abstract class: interfact to access feature descriptions.
27 | class AzSvFeatInfo {
28 | public:
29 | // Concatenate feature description to str_desc.
30 | virtual void concatDesc(int ex, //!< feature id
31 | AzBytArr *str_desc) const = 0;
32 |
33 | //! Return number of features.
34 | virtual int featNum() const = 0; /* # of features */
35 |
36 | void desc(int ex, AzBytArr *str_desc) const {
37 | str_desc->reset();
38 | concatDesc(ex, str_desc);
39 | }
40 |
41 | int desc2fno(const char *fnm) const {
42 | int fx;
43 | for (fx = 0; fx < featNum(); ++fx) {
44 | AzBytArr s;
45 | desc(fx, &s);
46 | if (s.compare(fnm) == 0) {
47 | return fx;
48 | }
49 | }
50 | return -1;
51 | }
52 |
53 | void show(const AzOut &out, const AzIntArr *ia_fxs) const {
54 | int ix;
55 | for (ix = 0; ix < ia_fxs->size(); ++ix) {
56 | int fx = ia_fxs->get(ix);
57 | AzBytArr s("???");
58 | if (fx>=0 && fxreset();
80 | if (sp_kw->size()==0) return;
81 | int fx;
82 | for (fx = 0; fx < featNum(); ++fx) {
83 | AzBytArr s;
84 | desc(fx, &s);
85 | int ix;
86 | for (ix = 0; ix < sp_kw->size(); ++ix) {
87 | if (s.beginsWith(sp_kw->c_str(ix))) {
88 | ia_fxs->put(fx);
89 | break;
90 | }
91 | }
92 | }
93 | }
94 |
95 | void contains(const AzStrArray *sp_kw,
96 | AzIntArr *ia_fxs) const {
97 | ia_fxs->reset();
98 | if (sp_kw->size()==0) return;
99 | int fx;
100 | for (fx = 0; fx < featNum(); ++fx) {
101 | AzBytArr s;
102 | desc(fx, &s);
103 | int ix;
104 | for (ix = 0; ix < sp_kw->size(); ++ix) {
105 | if (s.contains(sp_kw->c_str(ix))) {
106 | ia_fxs->put(fx);
107 | break;
108 | }
109 | }
110 | }
111 | }
112 |
113 | int equals(const char *kw) const {
114 | int fx;
115 | for (fx = 0; fx < featNum(); ++fx) {
116 | AzBytArr s;
117 | desc(fx, &s);
118 | if (s.compare(kw) == 0) {
119 | return fx;
120 | }
121 | }
122 | return -1;
123 | }
124 | };
125 |
126 | #endif
127 |
128 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/rgf1.2/src/com/AzSvFeatInfoClone.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzSvFeatInfoClone.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_SV_FEAT_INFO_CLONE_HPP_
20 | #define _AZ_SV_FEAT_INFO_CLONE_HPP_
21 |
22 | #include "AzSvFeatInfo.hpp"
23 | #include "AzStrArray.hpp"
24 |
25 | class AzSvFeatInfoClone : /* implements */ public virtual AzSvFeatInfo,
26 | /* implements */ public virtual AzStrArray
27 | {
28 | protected:
29 | AzDataPool arr_desc;
30 |
31 | public:
32 | AzSvFeatInfoClone() {}
33 | AzSvFeatInfoClone(const AzSvFeatInfo *inp) {
34 | reset(inp);
35 | }
36 | AzSvFeatInfoClone(const AzStrArray *inp) {
37 | reset(inp);
38 | }
39 | inline int featNum() const {
40 | return arr_desc.size();
41 | }
42 | inline void concatDesc(int fx, AzBytArr *desc) const {
43 | if (fx < 0 || fx >= featNum()) {
44 | desc->c("?"); desc->cn(fx); desc->c("?");
45 | return;
46 | }
47 | desc->concat(arr_desc.point(fx));
48 | }
49 | void reset(const AzSvFeatInfo *inp) {
50 | int f_num = inp->featNum();
51 | arr_desc.reset();
52 | int fx;
53 | for (fx = 0; fx < f_num; ++fx) {
54 | AzBytArr *ptr = arr_desc.new_slot();
55 | inp->desc(fx, ptr);
56 | }
57 | }
58 | void reset(const AzStrArray *inp) {
59 | int f_num = inp->size();
60 | arr_desc.reset();
61 | int fx;
62 | for (fx = 0; fx < f_num; ++fx) {
63 | AzBytArr *ptr = arr_desc.new_slot();
64 | ptr->reset(inp->c_str(fx));
65 | }
66 | }
67 | void reset(int inp_f_num) {
68 | arr_desc.reset();
69 | int fx;
70 | for (fx = 0; fx < inp_f_num; ++fx) {
71 | AzBytArr s("F");
72 | s.cn(fx, 3, true); /* width=3, fillWithZero */
73 | arr_desc.new_slot()->reset(&s);
74 | }
75 | }
76 |
77 | void append(const AzSvFeatInfo *inp) {
78 | int fx;
79 | for (fx = 0; fx < inp->featNum(); ++fx) {
80 | inp->desc(fx, arr_desc.new_slot());
81 | }
82 | }
83 |
84 | /*--- to implement AzStrArray ---*/
85 | int size() const { return featNum(); }
86 | const char *c_str(int fx) const {
87 | if (fx < 0 || fx >= featNum()) {
88 | return "???";
89 | }
90 | return arr_desc.point(fx)->c_str();
91 | }
92 | };
93 | #endif
94 |
--------------------------------------------------------------------------------
/rgf1.2/src/com/AzTaskTools.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTaskTools.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TASK_TOOLS_HPP_
20 | #define _AZ_TASK_TOOLS_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzSmat.hpp"
24 | #include "AzSvFeatInfo.hpp"
25 | #include "AzLoss.hpp"
26 | #include "AzPerfResult.hpp"
27 | #include "AzPrint.hpp"
28 |
29 | /* Some functions are overlapping with the legacy code AzCatProc */
30 | /*--------------------------------------------------*/
31 | //! Tools related to classification or regression tasks.
32 | class AzTaskTools
33 | {
34 | public:
35 | static double analyzeLoss(AzLossType loss_type,
36 | const AzDvect *v_p,
37 | const AzDvect *v_y,
38 | const AzIntArr *inp_ia_dx,
39 | double p_coeff);
40 |
41 | static void showDist(const AzStrArray *sp_cat,
42 | const AzIntArr *ia_cat,
43 | const char *header,
44 | const AzOut &out);
45 |
46 | static double eval_breakEven(
47 | const AzIntArr *ia_gold,
48 | const AzDvect *v_pval,
49 | const AzStrArray *sp_cat,
50 | const char *eyecatcher,
51 | double *out_best_f=NULL,
52 | double *out_best_acc=NULL);
53 |
54 | static AzPerfResult eval(const AzDvect *v_p,
55 | const AzDvect *v_y, /* assume y in {+1,-1} */
56 | AzLossType loss_type) {
57 | AzOut null_out;
58 | AzPerfResult res;
59 | eval("", loss_type, NULL, NULL, v_p, v_y, "", null_out, &res);
60 | return res;
61 | }
62 | static void eval(const AzDvect *v_p,
63 | const AzDvect *v_y, /* assume y in {+1,-1} */
64 | AzPerfResult *result) {
65 | AzOut null_out;
66 | eval("", AzLoss_None, NULL, NULL, v_p, v_y, "", null_out, result);
67 | }
68 | static void eval(const AzDvect *v_p,
69 | const AzDvect *v_y, /* assume y in {+1,-1} */
70 | const AzOut &test_out,
71 | AzPerfResult *result=NULL) {
72 | AzOut null_out;
73 | eval("", AzLoss_None, NULL, NULL, v_p, v_y, "", test_out, result);
74 | }
75 | static void eval(const char *ite_str,
76 | AzLossType loss_type,
77 | const AzIntArr *ia_dx,
78 | const double p_coeff[2],
79 | const AzDvect *v_test_pval,
80 | const AzDvect *v_test_yval, /* assume y in {+1,-1} */
81 | const char *tt_eyec,
82 | const AzOut &test_out,
83 | AzPerfResult *result=NULL);
84 |
85 | static int genY(const AzIntArr *ia_cat,
86 | int focus_cat,
87 | double y_posi_val,
88 | double y_nega_val,
89 | AzDvect *v_yval); /* output */
90 |
91 | /*--- for displaying the weights, feature names, etc. ---*/
92 | static void dumpWeights(const AzOut &out,
93 | const AzDvect *v_w,
94 | const char *name,
95 | const AzSvFeatInfo *feat,
96 | int print_max,
97 | bool changeLine);
98 | static void printPR(AzPrint &o,
99 | int ok,
100 | int t,
101 | int g);
102 | protected:
103 | static void formatWeight(const AzSvFeatInfo *feat,
104 | int ex,
105 | double val,
106 | AzBytArr *str_out);
107 | };
108 | #endif
109 |
--------------------------------------------------------------------------------
/rgf1.2/src/com/AzTimer.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTimer.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TIMER_HPP_
20 | #define _AZ_TIMER_HPP_
21 |
22 | #include "AzUtil.hpp"
23 |
24 | class AzTimer {
25 | public:
26 | int chk; /* next check point */
27 | int inc; /* increment: negative means no checking */
28 |
29 | AzTimer() : chk(-1), inc(-1) {}
30 | ~AzTimer() {}
31 |
32 | inline void reset(int inp_inc) {
33 | chk = -1;
34 | inc = inp_inc;
35 | if (inc > 0) {
36 | chk = inc;
37 | }
38 | }
39 |
40 | inline bool ringing(bool isRinging, int inp) { /* timer is ringing */
41 | if (isRinging) return true;
42 |
43 | if (chk > 0 && inp >= chk) {
44 | while(chk <= inp) {
45 | chk += inc; /* next check point */
46 | }
47 | return true;
48 | }
49 | return false;
50 | }
51 |
52 | inline bool reachedMax(int inp,
53 | const char *msg,
54 | const AzOut &out) const {
55 | bool yes_no = reachedMax(inp);
56 | if (yes_no) {
57 | AzTimeLog::print(msg, " reached max", out);
58 | }
59 | return yes_no;
60 | }
61 | inline bool reachedMax(int inp) const {
62 | if (chk > 0 && inp >= chk) return true;
63 | else return false;
64 | }
65 | };
66 |
67 | #endif
68 |
69 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzFindSplit.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzFindSplit.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_FIND_SPLIT_HPP_
20 | #define _AZ_FIND_SPLIT_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzDataForTrTree.hpp"
24 | #include "AzTrTtarget.hpp"
25 | #include "AzTrTsplit.hpp"
26 | #include "AzTrTree.hpp"
27 |
28 | class Az_forFindSplit {
29 | public:
30 | double wy_sum, w_sum;
31 | Az_forFindSplit() : wy_sum(0), w_sum(0) {}
32 | void reset() {
33 | wy_sum = w_sum = 0;
34 | }
35 | };
36 |
37 | //! Abstract class: provides building blocks for node split search.
38 | /*------------------------------------------*/
39 | class AzFindSplit
40 | {
41 | protected:
42 | const AzTrTtarget *target;
43 | const AzDataForTrTree *data;
44 | const AzTrTree_ReadOnly *tree;
45 | int min_size;
46 |
47 | AzIntArr ia_feats;
48 | const AzIntArr *ia_fx;
49 |
50 | public:
51 | AzFindSplit() : target(NULL), data(NULL), tree(NULL), ia_fx(NULL),
52 | min_size(-1) {}
53 | ~AzFindSplit() {}
54 | void reset() {
55 | target = NULL;
56 | data = NULL;
57 | tree = NULL;
58 | min_size = -1;
59 | }
60 |
61 | void _begin(const AzTrTree_ReadOnly *inp_tree,
62 | const AzDataForTrTree *inp_data,
63 | const AzTrTtarget *inp_target,
64 | int inp_min_size);
65 | void _end() {
66 | reset();
67 | }
68 |
69 | //----------------------------------------------------------------
70 | // void findBestSplit(const AzTrTtarget *tar,
71 | // const AzIntArr *ia_dx,
72 | // ... parameters ...
73 | // AzTrTsplit *best_split); /* output */
74 | //----------------------------------------------------------------
75 |
76 | virtual void _pickFeats(int pick_num, int f_num);
77 |
78 | protected:
79 | /*----------------------------------------------------------------*/
80 | virtual double getBestGain(double w_sum,
81 | double wy_sum,
82 | double *out_best_p) /* must not be null */
83 | const = 0;
84 | virtual double evalSplit(const Az_forFindSplit i[2],
85 | double bestP[2]) /* output */
86 | const;
87 | /*----------------------------------------------------------------*/
88 |
89 | void _findBestSplit(int nx,
90 | /*--- output ---*/
91 | AzTrTsplit *best_split);
92 | void loop(AzTrTsplit *best_split,
93 | int fx, /* feature# */
94 | const AzSortedFeat *sorted,
95 | int dxs_num,
96 | const Az_forFindSplit *total);
97 | };
98 |
99 | #endif
100 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzOptOnTree_TreeReg.cpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzOptOnTree_TreeReg.cpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #include "AzOptOnTree_TreeReg.hpp"
20 | #include "AzPrint.hpp"
21 |
22 | /*--------------------------------------------------------*/
23 | void
24 | AzOptOnTree_TreeReg::optimize(AzRgfTreeEnsemble *inp_rgf_ens,
25 | const AzTrTreeFeat *inp_tree_feat,
26 | int ite_num,
27 | double lam,
28 | double sig)
29 | {
30 | ens = inp_rgf_ens;
31 | tree_feat = inp_tree_feat;
32 | rgf_ens = inp_rgf_ens;
33 |
34 | synchronize();
35 | updateTreeWeights(rgf_ens);
36 |
37 | int tree_num = ens->size();
38 | if (reg_arr->size() < tree_num) {
39 | throw new AzException("AzOptOnTree_TreeReg::optimize",
40 | "max #tree has changed??");
41 | }
42 | int tx;
43 | for (tx = 0; tx < tree_num; ++tx) {
44 | AzReg_TreeReg *reg = reg_arr->reg(tx);
45 | reg->reset(ens->tree(tx), reg_depth);
46 | }
47 | iterate(ite_num, lam, sig);
48 |
49 | ens = NULL;
50 | tree_feat = NULL;
51 | rgf_ens = NULL;
52 | }
53 |
54 | /*--------------------------------------------------------*/
55 | void AzOptOnTree_TreeReg::update_with_features(
56 | double nlam,
57 | double nsig,
58 | double py_avg,
59 | AzRgf_forDelta *for_delta) /* updated */
60 | {
61 | int tree_num = ens->size();
62 | int tx;
63 | for (tx = 0; tx < tree_num; ++tx) {
64 | ens->tree_u(tx)->restoreDataIndexes();
65 | AzReg_TreeReg *reg = reg_arr->reg(tx);
66 | reg->clearFocusNode();
67 |
68 | AzIIarr iia_nx_fx;
69 | tree_feat->featIds(tx, &iia_nx_fx);
70 | int num = iia_nx_fx.size();
71 | AzIIFarr iifa_nx_fx_delta;
72 | int ix;
73 | for (ix = 0; ix < num; ++ix) {
74 | int nx, fx;
75 | iia_nx_fx.get(ix, &nx, &fx);
76 |
77 | double delta = bestDelta(nx, fx, reg, nlam, nsig, py_avg, for_delta);
78 | update_weight(nx, fx, delta, reg);
79 | }
80 | ens->tree_u(tx)->releaseDataIndexes();
81 | }
82 | }
83 | /*--------------------------------------------------------*/
84 | void AzOptOnTree_TreeReg::update_weight(int nx,
85 | int fx,
86 | double delta,
87 | AzReg_TreeReg *reg)
88 | {
89 | double new_w = v_w.get(fx) + delta;
90 | v_w.set(fx, new_w);
91 |
92 | int dxs_num;
93 | const int *dxs = data_points(fx, &dxs_num);
94 | updatePred(dxs, dxs_num, delta, &v_p);
95 |
96 | /*--- update the weight in the ensemble ---*/
97 | const AzTrTreeFeatInfo *fp = tree_feat->featInfo(fx);
98 | rgf_ens->tree_u(fp->tx)->setWeight(fp->nx, new_w);
99 | reg->changeWeight(nx, delta);
100 | }
101 |
102 | /*--------------------------------------------------------*/
103 | double AzOptOnTree_TreeReg::bestDelta(
104 | int nx,
105 | int fx,
106 | AzReg_TreeReg *reg,
107 | double nlam,
108 | double nsig,
109 | double py_avg,
110 | AzRgf_forDelta *for_delta) /* updated */
111 | const
112 | {
113 | const char *eyec = "AzOptOnTree_TI::bestDelta";
114 |
115 | double w = v_w.get(fx);
116 | int dxs_num;
117 | const int *dxs = data_points(fx, &dxs_num);
118 | if (dxs_num <= 0) {
119 | throw new AzException(eyec, "no data indexes");
120 | }
121 |
122 | const double *fixed_dw = NULL;
123 | if (!AzDvect::isNull(&v_fixed_dw)) fixed_dw = v_fixed_dw.point();
124 | const double *p = v_p.point();
125 | const double *y = v_y.point();
126 | double nega_dL = 0, ddL= 0;
127 | if (fixed_dw == NULL) {
128 | AzLoss::sum_deriv(loss_type, dxs, dxs_num, p, y, py_avg,
129 | nega_dL, ddL);
130 | }
131 | else {
132 | AzLoss::sum_deriv_weighted(loss_type, dxs, dxs_num, p, y, fixed_dw, py_avg,
133 | nega_dL, ddL);
134 | }
135 |
136 | double dR, ddR;
137 | reg->penalty_deriv(nx, &dR, &ddR);
138 |
139 | double dd = ddL + nlam*ddR;
140 | if (dd == 0) dd = 1;
141 | double delta = (nega_dL-nlam*dR)*eta/dd;
142 | for_delta->check_delta(&delta, max_delta);
143 |
144 | return delta;
145 | }
146 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzOptOnTree_TreeReg.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzOptOnTree_TreeReg.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_OPT_ON_TREE_TREE_REG_HPP_
20 | #define _AZ_OPT_ON_TREE_TREE_REG_HPP_
21 |
22 | #include "AzOptOnTree.hpp"
23 | #include "AzReg_TreeRegArr.hpp"
24 |
25 | //! coordinate descent with regulatization using tree structure
26 | /*--------------------------------------------------------*/
27 | class AzOptOnTree_TreeReg : /* extends */ public virtual AzOptOnTree
28 | {
29 | protected:
30 | AzRgfTreeEnsemble *rgf_ens;
31 | AzReg_TreeRegArr *reg_arr;
32 |
33 | public:
34 | AzOptOnTree_TreeReg() : rgf_ens(NULL), reg_arr(NULL) {}
35 | void reset(AzReg_TreeRegArr *inp_reg_arr) {
36 | reg_arr = inp_reg_arr;
37 | }
38 |
39 | virtual void optimize(AzRgfTreeEnsemble *ens, /* weights are updated */
40 | const AzTrTreeFeat *tree_feat,
41 | int inp_ite_num=-1,
42 | double lam=-1,
43 | double sig=-1);
44 |
45 | /*--- ---*/
46 | virtual void reset(const AzOptOnTree_TreeReg *inp) {
47 | AzOptOnTree::reset(inp);
48 | rgf_ens = inp->rgf_ens;
49 | reg_arr = inp->reg_arr;
50 | }
51 |
52 | protected:
53 | //! override
54 | virtual void update_with_features(double nlam, double nsig, double py_avg,
55 | AzRgf_forDelta *for_delta);
56 |
57 | virtual void update_weight(int nx,
58 | int fx,
59 | double delta,
60 | AzReg_TreeReg *reg);
61 | virtual double bestDelta(
62 | int nx,
63 | int fx,
64 | AzReg_TreeReg *reg,
65 | double nlam,
66 | double nsig,
67 | double py_avg,
68 | AzRgf_forDelta *for_delta) /* updated */
69 | const;
70 | };
71 |
72 | #endif
73 |
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzOptimizerT.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzOptimizerT.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_OPTIMIZER_T_HPP_
20 | #define _AZ_OPTIMIZER_T_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzSmat.hpp"
24 | #include "AzDmat.hpp"
25 | #include "AzBmat.hpp"
26 | #include "AzLoss.hpp"
27 | #include "AzTrTreeFeat.hpp"
28 | #include "AzTrTreeEnsemble_ReadOnly.hpp"
29 | #include "AzRgfTreeEnsemble.hpp"
30 | #include "AzRegDepth.hpp"
31 | #include "AzParam.hpp"
32 |
33 | //! coordinate descent for weight optimization.
34 | /*--------------------------------------------------------*/
35 | class AzOptimizerT
36 | {
37 | public:
38 | virtual void reset(AzLossType loss_type,
39 | const AzDvect *v_y,
40 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */
41 | const AzRegDepth *reg_depth,
42 | AzParam ¶m,
43 | bool beVerbose,
44 | const AzOut out_req,
45 | /*--- for warm start ---*/
46 | const AzTrTreeEnsemble_ReadOnly *ens=NULL,
47 | const AzTrTreeFeat *tree_feat=NULL,
48 | const AzDvect *inp_v_p=NULL) = 0;
49 |
50 | virtual void copyPred_to(AzDvect *out_v_p) const = 0;
51 |
52 | virtual void resetPred(const AzBmat *m_tran,
53 | AzDvect *v_p) /* output */
54 | const = 0;
55 | virtual void optimize(AzRgfTreeEnsemble *ens,
56 | const AzTrTreeFeat *tree_feat,
57 | int inp_ite_num=-1,
58 | double lam=-1,
59 | double sig=-1) = 0;
60 | virtual void optimize(AzRgfTreeEnsemble *ens,
61 | const AzTrTreeFeat *tree_feat,
62 | bool doRefreshP,
63 | int inp_ite_num=-1,
64 | double lam=-1,
65 | double sig=-1) {
66 | throw new AzException("AzOptimizerT::optimize(...,doRefreshP,...)", "No support");
67 | }
68 | virtual const AzDvect *weights() const = 0;
69 | virtual double constant() const = 0;
70 | virtual void printHelp(AzHelp &h) const = 0;
71 | };
72 |
73 | #endif
74 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzRegDepth.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzRegDepth.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_REG_DEPTH_HPP_
20 | #define _AZ_REG_DEPTH_HPP_
21 |
22 | #include "AzRgf_kw.hpp"
23 | #include "AzUtil.hpp"
24 | #include "AzParam.hpp"
25 | #include "AzHelp.hpp"
26 | #include "AzRegDepth.hpp"
27 |
28 | #define depth_base_dflt 1
29 | /* #define depth_base_min_penalty_dflt 2 */
30 |
31 | //! Regularizer using node depth.
32 | class AzRegDepth {
33 | protected:
34 | double depth_base;
35 | AzDvect v_dep2pow; /* to avoid repetitive calls of pow() */
36 | const double *dep2pow;
37 |
38 | public:
39 | AzRegDepth() : depth_base(depth_base_dflt), dep2pow(NULL) {}
40 |
41 | virtual void set_default_for_min_penalty() {
42 | /* depth_base = depth_base_min_penalty_dflt; */
43 | }
44 |
45 | virtual inline
46 | void check_if_nonincreasing(const char *who) const {
47 | if (depth_base < 1) {
48 | AzBytArr s(kw_depth_base); s.c(" must be no smaller than 1 for ");
49 | s.c(who); s.c(".");
50 | throw new AzException(AzInputNotValid, "AzRegDepth::check_if_nonincreasing",
51 | s.c_str());
52 | }
53 | }
54 |
55 | virtual inline
56 | double apply(double val, int dep) const {
57 | if (depth_base == 1) return val;
58 | if (dep >= 0 && dep < v_dep2pow.rowNum()) {
59 | return val * dep2pow[dep];
60 | }
61 | else {
62 | return val * pow(depth_base, (double)dep);
63 | }
64 | }
65 |
66 | virtual void reset(AzParam ¶m,
67 | const AzOut &out) {
68 | resetParam(param);
69 | if (depth_base <= 0) {
70 | throw new AzException(AzInputNotValid, "AzRegDepth::reset",
71 | kw_depth_base, "must be no smaller than 1.");
72 | }
73 | if (depth_base < 1) {
74 | AzBytArr s("!Warning! "); s.c(kw_depth_base); s.c(" should be no smaller than 1.");
75 | AzPrint::writeln(out, s);
76 | }
77 |
78 | printParam(out);
79 | }
80 | virtual void printHelp(AzHelp &h) const {
81 | h.begin(Azforest_config, "AzRegDepth", "Regularization on node depth");
82 | h.item(kw_depth_base, help_depth_base, depth_base);
83 | h.end();
84 | }
85 | virtual void printParam(const AzOut &out) const {
86 | if (out.isNull()) return;
87 | if (depth_base != 1) {
88 | AzPrint o(out);
89 | o.ppBegin("AzRegDepth", "Reg. on depth", ", ");
90 | o.printV(kw_depth_base, depth_base);
91 | o.ppEnd();
92 | }
93 | }
94 |
95 | protected:
96 | virtual void resetParam(AzParam &p) {
97 | bool doCheck = false;
98 | p.vFloat(kw_depth_base, &depth_base);
99 |
100 | /*--- ---*/
101 | v_dep2pow.reform(50);
102 | int dep;
103 | for (dep = 0; dep < v_dep2pow.rowNum(); ++dep) {
104 | v_dep2pow.set(dep, pow(depth_base, (double)dep));
105 | }
106 | dep2pow = v_dep2pow.point();
107 | }
108 | };
109 | #endif
110 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzReg_TreeReg.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzReg_TreeReg.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_REG_TREE_REG_HPP_
20 | #define _AZ_REG_TREE_REG_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzDmat.hpp"
24 | #include "AzTrTree_ReadOnly.hpp"
25 | #include "AzRegDepth.hpp"
26 | #include "AzParam.hpp"
27 |
28 | class AzReg_TreeRegShared {
29 | public:
30 | virtual AzDmat *share() = 0;
31 | virtual bool create(const AzTrTree_ReadOnly *tree, const AzDmat *info) = 0;
32 | virtual AzDmat *share(const AzTrTree_ReadOnly *tree) = 0;
33 | };
34 |
35 | //! Default implementation of AzReg_TreeRegShared
36 | class AzReg_TreeRegShared_Dflt : /* implements */ public virtual AzReg_TreeRegShared
37 | {
38 | protected:
39 | AzDmat m_by_alltree; /* info not specific to individual tree */
40 |
41 | public:
42 | /*--- override these to store tree-specific info ---*/
43 | virtual AzDmat *share(const AzTrTree_ReadOnly *tree) { return NULL; }
44 | virtual bool create(const AzTrTree_ReadOnly *tree, const AzDmat *) { return false; }
45 | /*----------------------------------------------------*/
46 | virtual AzDmat *share() {
47 | return &m_by_alltree;
48 | }
49 | };
50 |
51 | //! Abstract class: interface to tree-structured regularizer
52 | class AzReg_TreeReg {
53 | public:
54 | virtual void set_shared(AzReg_TreeRegShared *shared) {}
55 | virtual void check_reg_depth(const AzRegDepth *) const {}
56 |
57 | virtual void reset(const AzTrTree_ReadOnly *inp_tree,
58 | const AzRegDepth *inp_reg_depth) = 0;
59 |
60 | virtual void penalty_deriv(int nx, double *dr,
61 | double *ddr) = 0;
62 |
63 | virtual void changeWeight(int nx, double w_diff) = 0;
64 |
65 | virtual void clearFocusNode() = 0;
66 |
67 | /*--- for node split ---*/
68 | //! called by AzRgf_FindSplit_TR::begin
69 | virtual void reset_forNewLeaf(const AzTrTree_ReadOnly *t,
70 | const AzRegDepth *rdep) = 0;
71 |
72 | //! called by AzRgf_FindSplit_TR::findSplit
73 | virtual void reset_forNewLeaf(int f_nx,
74 | const AzTrTree_ReadOnly *t,
75 | const AzRegDepth *rdep) = 0;
76 |
77 | virtual double penalty_diff(const double leaf_w_delta[2]) const = 0;
78 | virtual void penalty_deriv(double *dr,
79 | double *ddr) const = 0;
80 |
81 | /*--- for maintenance ---*/
82 | virtual void show(const AzOut &out,
83 | const char *header) const = 0;
84 | virtual double penalty() const {
85 | return -1;
86 | }
87 |
88 | /*---------------------------------------------------------*/
89 | virtual void resetParam(AzParam ¶m) = 0;
90 | virtual void printParam(const AzOut &out) const = 0;
91 | virtual void printHelp(AzHelp &h) const = 0;
92 |
93 | virtual const char *signature() const = 0;
94 | virtual const char *description() const = 0;
95 | };
96 | #endif
97 |
98 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzReg_TreeRegArr.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzReg_TreeRegArr.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_REG_TREE_REG_ARR_HPP_
20 | #define _AZ_REG_TREE_REG_ARR_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzReg_TreeReg.hpp"
24 |
25 | class AzReg_TreeRegArr
26 | {
27 | public:
28 | virtual void reset(int tree_num) = 0;
29 | virtual AzReg_TreeReg *reg(int tx) = 0;
30 | virtual AzReg_TreeReg *reg_forNewLeaf(int tx) = 0;
31 | virtual int size() const = 0;
32 | };
33 | #endif
34 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzReg_TreeRegArrImp.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzReg_TreeRegArrImp.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_REG_TREE_REG_ARR_IMP_HPP_
20 | #define _AZ_REG_TREE_REG_ARR_IMP_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzRgfTreeEnsemble.hpp"
24 | #include "AzReg_TreeRegArr.hpp"
25 |
26 | template
27 | class AzReg_TreeRegArrImp : /* implements */ public virtual AzReg_TreeRegArr
28 | {
29 | protected:
30 | AzPtrPool areg;
31 | T template_reg;
32 | T temporary_reg;
33 | AzReg_TreeRegShared_Dflt shared;
34 |
35 | public:
36 | T *tmpl_u() { return &template_reg; }
37 | const T *tmpl() const { return &template_reg; }
38 |
39 | inline int size() const { return areg.size(); }
40 | void reset(int tree_num) {
41 | areg.reset();
42 | int tx;
43 | for (tx = 0; tx < tree_num; ++tx) {
44 | T *reg = areg.new_slot();
45 | reg->copyParam_from(&template_reg);
46 | reg->set_shared(&shared);
47 | }
48 | temporary_reg.set_shared(&shared);
49 | }
50 | inline AzReg_TreeReg *reg(int tx) {
51 | return areg.point_u(tx);
52 | }
53 | inline AzReg_TreeReg *reg_forNewLeaf(int tx) {
54 | int t_num = areg.size();
55 | if (tx < t_num) return areg.point_u(tx);
56 |
57 | temporary_reg.copyParam_from(&template_reg);
58 | return &temporary_reg; /* should be root-only tree */
59 | }
60 | };
61 | #endif
62 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzReg_TsrOpt.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzReg_TsrOpt.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_REG_TSROPT_HPP_
20 | #define _AZ_REG_TSROPT_HPP_
21 |
22 | #include "AzRgf_kw.hpp"
23 | #include "AzReg_Tsrbase.hpp"
24 |
25 | //!
26 | class AzReg_TsrOpt : /* extends */ public virtual AzReg_Tsrbase
27 | {
28 | protected:
29 | int reg_ite_num;
30 | double curr_penalty;
31 |
32 | AzDmat *m_coeff; /* shared with the regularizers for other trees */
33 | /* owned by AzReg_TreeRegShared */
34 |
35 | //! tree structure
36 | AzIIarr iia_le_gt;
37 |
38 | public:
39 | AzReg_TsrOpt() : reg_ite_num(reg_ite_num_dflt), curr_penalty(0), m_coeff(NULL) {}
40 |
41 | virtual void copyParam_from(const AzReg_TsrOpt *inp) {
42 | AzReg_Tsrbase::copyParam_from(inp);
43 | reg_ite_num = inp->reg_ite_num;
44 | }
45 |
46 | void set_shared(AzReg_TreeRegShared *ptr) {
47 | m_coeff = ptr->share();
48 | }
49 | virtual void check_reg_depth(const AzRegDepth *rd) const {
50 | if (rd == NULL) return;
51 | rd->check_if_nonincreasing("min-penalty regularizers");
52 | }
53 |
54 | /*---------------------------------------------------------*/
55 | virtual void _reset(const AzTrTree_ReadOnly *inp_tree,
56 | const AzRegDepth *inp_reg_depth);
57 | /*---------------------------------------------------------*/
58 |
59 | /*--- for node split ---*/
60 | /*---------------------------------------------------------*/
61 | /*--- called by AzRgf_FindSplit_TR::begin for each tree ---*/
62 | //! set current penalty
63 | virtual void _reset_forNewLeaf(const AzTrTree_ReadOnly *t,
64 | const AzRegDepth *rdep);
65 | /*--- called by AzRgf_FindSplit_TR::findSplit for each node ---*/
66 | virtual double _reset_forNewLeaf(int f_nx,
67 | const AzTrTree_ReadOnly *t,
68 | const AzRegDepth *rdep);
69 | /*---------------------------------------------------------*/
70 |
71 | /*--- for maintenance ---*/
72 | virtual void show(const AzOut &out,
73 | const char *header) const {
74 | show(out, header, reg_depth, focus_nx, &av_dbar, tree, &v_bar);
75 | }
76 |
77 | /*---------------------------------------------------------*/
78 | virtual void resetParam(AzParam ¶m);
79 | virtual void printParam(const AzOut &out) const;
80 | virtual void printHelp(AzHelp &h) const;
81 |
82 | virtual inline const char *signature() const {
83 | return "-___-_RGF_TsrOpt_";
84 | }
85 | virtual inline const char *description() const {
86 | return "RGF w/min-penalty regularization";
87 | }
88 | /*---------------------------------------------------------*/
89 |
90 | /*--- static tools ---*/
91 | static void show(const AzOut &out,
92 | const char *header,
93 | const AzRegDepth *reg_depth,
94 | int focus_nx,
95 | const AzDataArray *av_dbar,
96 | const AzTrTree_ReadOnly *tree,
97 | const AzDvect *v_bar);
98 | static void _propagate(int ite_num,
99 | const AzTrTree_ReadOnly *tree,
100 | int split_nx,
101 | const double new_leaf_w[2],
102 | const AzIntArr *ia_nonleaf,
103 | const AzRegDepth *reg_depth,
104 | AzDmat *m_coeff, /* inout */
105 | AzDvect *v); /* output */
106 |
107 | static void _show(int focus_nx,
108 | const AzDataArray *av_dbar,
109 | const AzTrTree_ReadOnly *tree,
110 | const AzDvect *v_bar,
111 | int nx,
112 | const AzDmat *m_coeff,
113 | const AzOut &out);
114 |
115 | static void setCoeff(const AzRegDepth *reg_depth, int depth, double coeff[4]);
116 | static void setCoeff(const AzRegDepth *reg_depth,
117 | const AzTrTree_ReadOnly *tree,
118 | AzDmat *m_coeff);
119 |
120 | protected:
121 | void resetTreeStructure();
122 | bool isSameTreeStructure() const;
123 | void storeTreeStructure();
124 |
125 | void update_v();
126 | void update_dv();
127 | void reset_v_dv() {
128 | av_dv.reset();
129 | v_v.reset();
130 | v_dv2_sum.reset();
131 | }
132 |
133 | /*--- compute "bar" (auxiliary variables) iteratively ---*/
134 | virtual void reset_bar(int split_nx,
135 | const AzIntArr *ia_leaf,
136 | const AzIntArr *ia_nonleaf);
137 | /*--- compute "bar"'s derivatives iteratively ---*/
138 | virtual void deriv(int base_nx, /* derivative w.r.t. this node's weight */
139 | const AzIntArr *ia_nonleaf,
140 | AzDvect *v_dbar);
141 |
142 | };
143 | #endif
144 |
145 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzReg_TsrSib.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzReg_TsrSib.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_REG_TSRSIB_HPP_
20 | #define _AZ_REG_TSRSIB_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzDmat.hpp"
24 | #include "AzTrTree_ReadOnly.hpp"
25 | #include "AzRegDepth.hpp"
26 | #include "AzReg_TreeReg.hpp"
27 |
28 | //!
29 | class AzReg_TsrSib : /* implements */ public virtual AzReg_TreeReg {
30 | protected:
31 | const AzTrTree_ReadOnly *tree;
32 |
33 | int focus_nx;
34 | const AzRegDepth *reg_depth;
35 | bool forNewLeaf;
36 |
37 | AzDataArray av_dv;
38 | AzDvect v_v;
39 |
40 | double vdv_sum, dv2_sum, dr, ddr;
41 | double newleaf_dep_factor;
42 |
43 | virtual void reset_values() {
44 | vdv_sum = dv2_sum = dr = ddr = 0;
45 | newleaf_dep_factor = 1;
46 | }
47 |
48 | public:
49 | AzReg_TsrSib()
50 | : tree(NULL), forNewLeaf(false), focus_nx(-1),
51 | reg_depth(NULL), newleaf_dep_factor(1),
52 | vdv_sum(0), dv2_sum(0), dr(0), ddr(0) {}
53 |
54 | void copyParam_from(const AzReg_TsrSib *inp) {}
55 |
56 | /*---------------------------------------------------------*/
57 | virtual void reset(const AzTrTree_ReadOnly *inp_tree,
58 | const AzRegDepth *inp_reg_depth);
59 | /*---------------------------------------------------------*/
60 |
61 | virtual void penalty_deriv(int nx, double *dr,
62 | double *ddr);
63 |
64 | virtual void changeWeight(int nx, double w_diff);
65 |
66 | inline void clearFocusNode() {
67 | focus_nx = -1;
68 | }
69 |
70 | /*--- for node split ---*/
71 | /*---------------------------------------------------------*/
72 | /*--- called by AzRgf_FindSplit_TR::begin ---*/
73 | //! set current penalty
74 | virtual void reset_forNewLeaf(const AzTrTree_ReadOnly *t,
75 | const AzRegDepth *rdep);
76 | /*--- called by AzRgf_FindSplit_TR::findSplit ---*/
77 | virtual void reset_forNewLeaf(int f_nx,
78 | const AzTrTree_ReadOnly *t,
79 | const AzRegDepth *rdep);
80 | /*---------------------------------------------------------*/
81 |
82 | virtual double penalty_diff(const double leaf_w_delta[2]) const;
83 | virtual void penalty_deriv(double *dr,
84 | double *ddr) const;
85 |
86 | /*--- for maintenance ---*/
87 | virtual void show(const AzOut &out,
88 | const char *header) const {
89 |
90 | }
91 |
92 | /*---------------------------------------------------------*/
93 | virtual void resetParam(AzParam ¶m) {}
94 | virtual void printParam(const AzOut &out) const {}
95 | virtual void printHelp(AzHelp &h) const {}
96 |
97 | virtual inline const char *signature() const {
98 | return "-___-_RGF_TsrSib_";
99 | }
100 | virtual inline const char *description() const {
101 | return "RGF w/min-penalty regularization w/sum-to-zero sibling constraints";
102 | }
103 | /*---------------------------------------------------------*/
104 |
105 | protected:
106 | void checkLeaf(const char *msg) const {
107 | if (!forNewLeaf || focus_nx < 0) {
108 | throw new AzException("AzReg_TsrSib::checkLeaf", msg);
109 | }
110 | }
111 | virtual void deriv_v(const AzTrTree_ReadOnly *tree,
112 | int leaf_nx,
113 | bool forNewLeaf,
114 | /* output */
115 | AzSvect *v_dv,
116 | /* inout */
117 | AzDvect *v_v) const;
118 |
119 | inline int get_newleaf_depth() const {
120 | return tree->node(focus_nx)->depth + 1;
121 | }
122 | virtual void update();
123 | };
124 | #endif
125 |
126 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzRgfTrainerSel.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzRgfTrainerSel.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_RGF_TRAINER_SEL_HPP_
20 | #define _AZ_RGF_TRAINER_SEL_HPP_
21 |
22 | #include "AzRgforest.hpp"
23 | #include "AzRgforest_TreeReg.hpp"
24 | #include "AzReg_TsrOpt.hpp"
25 | #include "AzReg_TsrSib.hpp"
26 |
27 | #include "AzTETselector.hpp"
28 | #include "AzPrint.hpp"
29 |
30 | //! Training algorithm selector.
31 | class AzRgfTrainerSel : /* implements */ public virtual AzTETselector {
32 | protected:
33 | AzRgforest rgf;
34 | AzRgforest_TreeReg rgf_sib;
35 | AzRgforest_TreeReg rgf_opt;
36 |
37 | #define kw_rgf "RGF"
38 | #define kw_rgf_sib "RGF_Sib"
39 | #define kw_rgf_opt "RGF_Opt"
40 |
41 | AzStrPool sp_name;
42 | AzDataArray alg;
43 |
44 | virtual void reset() {
45 | int id = 0;
46 | sp_name.putv(kw_rgf, id++); *alg.new_slot() = &rgf;
47 | sp_name.putv(kw_rgf_sib, id++); *alg.new_slot() = &rgf_sib;
48 | sp_name.putv(kw_rgf_opt, id++); *alg.new_slot() = &rgf_opt;
49 | sp_name.commit();
50 | }
51 |
52 | public:
53 | AzRgfTrainerSel() {
54 | reset();
55 | }
56 |
57 | virtual const char *dflt_name() const {
58 | return kw_rgf;
59 | }
60 | virtual const char *another_name() const {
61 | return kw_rgf_sib;
62 | }
63 | const AzStrArray *names() const {
64 | return &sp_name;
65 | }
66 |
67 | virtual void printOptions(const char *dlm, AzBytArr *s) const {
68 | int ix;
69 | for (ix = 0; ix < sp_name.size(); ++ix) {
70 | if (ix > 0) s->c(dlm);
71 | s->c(sp_name.c_str(ix));
72 | }
73 | }
74 |
75 | virtual void printHelp(AzHelp &h) const {
76 | h.begin("", "", "");
77 | int ix;
78 | for (ix = 0; ix < sp_name.size(); ++ix) {
79 | int id = sp_name.getValue(ix);
80 | const AzTETrainer *trainer = *alg.point(id);
81 | h.item(sp_name.c_str(ix), trainer->description());
82 | }
83 | h.end();
84 | }
85 |
86 | virtual AzTETrainer *select(const char *alg_name, //!< name of algorithm.
87 | //! if true, don't throw exception at error.
88 | bool dontThrow=false) const
89 | {
90 | AzTETrainer *trainer = NULL;
91 |
92 | int ex = sp_name.find(alg_name);
93 | if (ex < 0 && !dontThrow) {
94 | throw new AzException(AzInputNotValid, "algorithm name", alg_name);
95 | }
96 | if (ex >= 0) {
97 | int id = sp_name.getValue(ex);
98 | trainer = *alg.point(id);
99 | }
100 | return trainer;
101 | }
102 | };
103 |
104 | #endif
105 |
106 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzRgfTreeEnsImp.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzRgfTreeEnsImp.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_RGF_TREE_ENS_IMP_HPP_
20 | #define _AZ_RGF_TREE_ENS_IMP_HPP_
21 |
22 | #include "AzRgfTreeEnsemble.hpp"
23 | #include "AzTrTreeEnsemble.hpp"
24 |
25 | //! implement AzRgfTreeEnsemble. T must be AzRgfTree or its extension.
26 | template
27 | class AzRgfTreeEnsImp : /*implements */public virtual AzRgfTreeEnsemble
28 | {
29 | protected:
30 | AzTrTreeEnsemble ens;
31 |
32 | public:
33 | AzRgfTreeEnsImp() {}
34 | ~AzRgfTreeEnsImp() {}
35 | inline bool usingTempFile() const {
36 | return ens.usingTempFile();
37 | }
38 | inline void reset() {
39 | ens.reset();
40 | }
41 | inline const char *param_c_str() const {
42 | return ens.param_c_str();
43 | }
44 |
45 | inline double constant() const {
46 | return ens.constant();
47 | }
48 | inline int orgdim() const {
49 | return ens.orgdim();
50 | }
51 | inline void set_constant(double inp) {
52 | ens.set_constant(inp);
53 | }
54 | inline AzRgfTree *new_tree(int *out_tx=NULL) {
55 | return ens.new_tree(out_tx);
56 | }
57 |
58 | inline const AzRgfTree *tree(int tx) const {
59 | return ens.tree(tx);
60 | }
61 | inline AzRgfTree *tree_u(int tx) const {
62 | return ens.tree_u(tx);
63 | }
64 |
65 | inline T *rawtree_u(int tx) const {
66 | return ens.tree_u(tx);
67 | }
68 |
69 | inline int leafNum() const {
70 | return ens.leafNum();
71 | }
72 | inline int leafNum(int tx0, int tx1) const {
73 | return ens.leafNum(tx0, tx1);
74 | }
75 |
76 | inline int lastIndex() const {
77 | return ens.lastIndex();
78 | }
79 | inline int nextIndex() const { /* next slot */
80 | return ens.nextIndex();
81 | }
82 |
83 | inline int size() const {
84 | return ens.size();
85 | }
86 | inline int max_size() const {
87 | return ens.max_size();
88 | }
89 | inline bool isFull() const {
90 | return ens.isFull();
91 | }
92 | inline void printHelp(AzHelp &h) const {
93 | ens.printHelp(h);
94 | }
95 | inline void copy_to(AzTreeEnsemble *out_ens,
96 | const char *config, const char *sign) const {
97 | ens.copy_to(out_ens, config, sign);
98 | }
99 | inline void copy_nodes_from(const AzTrTreeEnsemble_ReadOnly *inp) {
100 | ens.copy_nodes_from(inp);
101 | }
102 | inline void show(const AzSvFeatInfo *feat,
103 | const AzOut &out, const char *header="") const {
104 | ens.show(feat, out, header);
105 | }
106 |
107 | inline virtual void cold_start(AzParam ¶m,
108 | const AzBytArr *s_temp_prefix,
109 | int data_num,
110 | const AzOut &out,
111 | int tree_num_max,
112 | int inp_org_dim) {
113 | ens.cold_start(param, s_temp_prefix, data_num,
114 | out, tree_num_max, inp_org_dim);
115 | }
116 | inline virtual void warm_start(const AzTreeEnsemble *inp_ens,
117 | const AzDataForTrTree *data,
118 | AzParam ¶m,
119 | const AzBytArr *s_temp_prefix,
120 | const AzOut &out,
121 | int max_t_num,
122 | int search_t_num, /* to release work areas for the fixed trees */
123 | AzDvect *v_p, /* inout */
124 | const AzIntArr *inp_ia_tr_dx=NULL) {
125 | ens.warm_start(inp_ens, data, param, s_temp_prefix, out, max_t_num, search_t_num,
126 | v_p, inp_ia_tr_dx);
127 | }
128 | };
129 | #endif
130 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzRgfTreeEnsemble.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzRgfTreeEnsemble.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_RGF_TREE_ENSEMBLE_HPP_
20 | #define _AZ_RGF_TREE_ENSEMBLE_HPP_
21 |
22 | #include "AzTrTreeEnsemble_ReadOnly.hpp"
23 | #include "AzRgfTree.hpp"
24 | #include "AzParam.hpp"
25 | #include "AzHelp.hpp"
26 |
27 | //! Abstract class: interface for ensemble of RGF-trees.
28 | /**
29 | * implemented by AzRgfTreeEnsImp
30 | **/
31 | class AzRgfTreeEnsemble : /* extends */ public virtual AzTrTreeEnsemble_ReadOnly
32 | {
33 | public:
34 | virtual void set_constant(double inp) = 0;
35 | virtual AzRgfTree *new_tree(int *out_tx=NULL) = 0;
36 | virtual AzRgfTree *tree_u(int tx) const = 0;
37 |
38 | virtual int nextIndex() const = 0;
39 | virtual bool isFull() const = 0;
40 |
41 | virtual void copy_nodes_from(const AzTrTreeEnsemble_ReadOnly *inp) = 0;
42 | virtual void printHelp(AzHelp &h) const = 0;
43 |
44 | virtual void cold_start(AzParam ¶m,
45 | const AzBytArr *s_temp_prefix, /* may be NULL */
46 | int data_num,
47 | const AzOut &out,
48 | int tree_num_max,
49 | int inp_org_dim) = 0;
50 | virtual void warm_start(const AzTreeEnsemble *inp_ens,
51 | const AzDataForTrTree *data,
52 | AzParam ¶m,
53 | const AzBytArr *s_temp_prefix, /* may be NULL */
54 | const AzOut &out,
55 | int max_t_num,
56 | int search_t_num,
57 | AzDvect *v_p, /* inout */
58 | const AzIntArr *inp_ia_tr_dx=NULL) = 0;
59 | };
60 | #endif
61 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzRgf_FindSplit.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzRgf_FindSplit.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_RGF_FIND_SPLIT_HPP_
20 | #define _AZ_RGF_FIND_SPLIT_HPP_
21 |
22 | #include "AzTrTree_ReadOnly.hpp"
23 | #include "AzTrTsplit.hpp"
24 | #include "AzTrTtarget.hpp"
25 | #include "AzRgf_kw.hpp"
26 | #include "AzRegDepth.hpp"
27 | #include "AzParam.hpp"
28 | #include "AzFsinfo.hpp"
29 | #include "AzHelp.hpp"
30 |
31 | class AzRgf_FindSplit_input {
32 | public:
33 | int tx;
34 | const AzDataForTrTree *data;
35 | const AzTrTtarget *target;
36 | double lam_scale; /*!< for numerical stability of exp loss */
37 | double nn; /* sum of data point weights if weighted */
38 |
39 | AzRgf_FindSplit_input(int inp_tx,
40 | const AzDataForTrTree *inp_data,
41 | const AzTrTtarget *inp_target,
42 | double inp_lam_scale,
43 | double inp_nn) {
44 | tx = inp_tx;
45 | data = inp_data;
46 | target = inp_target;
47 | lam_scale = inp_lam_scale;
48 | nn = (double)inp_nn;
49 | }
50 | };
51 |
52 | /*--------------------------------------------------------*/
53 | //! Abstract class: interface for node split search for RGF.
54 | /**
55 | * Implemented by AzRgf_FindSplit_Dflt, AzRgf_FindSplit_TreeReg
56 | **/
57 | class AzRgf_FindSplit {
58 | public:
59 | virtual void reset(AzParam ¶m,
60 | const AzRegDepth *reg_depth,
61 | const AzOut &out) = 0;
62 |
63 | virtual void begin(const AzTrTree_ReadOnly *tree,
64 | const AzRgf_FindSplit_input &inp,
65 | int min_size)
66 | = 0;
67 | virtual void begin(const AzTrTree_ReadOnly *tree,
68 | const AzRgf_FindSplit_input &inp,
69 | int min_size,
70 | AzFsinfoOnTree *fot) { /* added for TreeRegFast */
71 | throw new AzException("AzRgf_FindSplit::begin(...fot)",
72 | "no appropriate override");
73 | }
74 |
75 | virtual void pickFeats(int f_num, int data_num) = 0;
76 |
77 | virtual void end() = 0;
78 | virtual
79 | void findSplit(int nx, //!< node id
80 | /*--- output ---*/
81 | AzTrTsplit *best_split) = 0;
82 |
83 | virtual void printParam(const AzOut &out) const = 0;
84 | virtual void printHelp(AzHelp &h) const = 0;
85 | };
86 | #endif
87 |
88 |
89 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzRgf_FindSplit_Dflt.cpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzRgf_FindSplit_Dflt.cpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #include "AzRgf_FindSplit_Dflt.hpp"
20 | #include "AzHelp.hpp"
21 | #include "AzRgf_kw.hpp"
22 |
23 | /*--------------------------------------------------------*/
24 | void AzRgf_FindSplit_Dflt::begin(
25 | const AzTrTree_ReadOnly *inp_tree,
26 | const AzRgf_FindSplit_input &inp, /* tx is not used */
27 | int inp_min_size)
28 | {
29 | AzFindSplit::_begin(inp_tree, inp.data, inp.target, inp_min_size);
30 |
31 | nlam = inp.nn*lambda;
32 | nsig = inp.nn*sigma;
33 | if (inp.lam_scale != 1) { /* for numerical stability for expo loss */
34 | nlam *= inp.lam_scale;
35 | nsig *= inp.lam_scale;
36 | }
37 |
38 | doUseInternalNodes = tree->usingInternalNodes();
39 | }
40 |
41 | /*--------------------------------------------------------*/
42 | double AzRgf_FindSplit_Dflt::getBestGain(double wsum, /* some of data weights */
43 | double wrsum, /* weighted sum of residual */
44 | double *best_q) const
45 | {
46 | double p = p_node->weight; /* parent's weight */
47 | double gain = 0;
48 | double q = 0;
49 |
50 | if (doUseInternalNodes) {
51 | q = wrsum/(wsum+c_nlam);
52 | gain = q*q*(wsum+c_nlam); /* n*gain */
53 | }
54 | else if (nsig <= 0) { /* L2 only */
55 | q = (wrsum-c_nlam*p)/(wsum+c_nlam);
56 | gain = q*q*(wsum+c_nlam)+(p_nlam-2*c_nlam)*p*p/2; /* "/2" for two child nodes */
57 | /* n*gain */
58 | q += p;
59 | }
60 | else { /* L1 and L2; not tested after code change */
61 | double _wysum = wrsum + wsum*p;
62 | if (_wysum > c_nsig) q = (_wysum-c_nsig)/(wsum+c_nlam);
63 | else if (_wysum < -c_nsig) q = (_wysum+c_nsig)/(wsum+c_nlam);
64 | else q = 0;
65 | double org_losshat = -2*p*_wysum+p*p*(wsum+p_nlam)+2*p_nsig*fabs(p);
66 | double new_losshat = -q*q*(wsum+c_nlam);
67 | gain = org_losshat - new_losshat;
68 | }
69 |
70 | *best_q = q;
71 | return gain;
72 | }
73 |
74 | /*--------------------------------------------------------*/
75 | /*--------------------------------------------------------*/
76 | void AzRgf_FindSplit_Dflt::resetParam(AzParam &p)
77 | {
78 | /*--- reg param shared with optimizer ---*/
79 | p.vFloat(kw_lambda, &lambda);
80 | p.vFloat(kw_sigma, &sigma);
81 |
82 | /*--- override ... ---*/
83 | p.vFloat(kw_s_lambda, &lambda);
84 | p.vFloat(kw_s_sigma, &sigma);
85 |
86 | if (lambda < 0) {
87 | throw new AzException(AzInputMissing, "AzRgf_FindSplit_Dflt",
88 | kw_lambda, "must be non-negative");
89 | }
90 | if (sigma < 0) {
91 | throw new AzException(AzInputNotValid, "AzRgf_FindSplit_Dflt",
92 | kw_sigma, "must be non-negative");
93 | }
94 | }
95 |
96 | /*--------------------------------------------------------*/
97 | void AzRgf_FindSplit_Dflt::printParam(const AzOut &out) const
98 | {
99 | if (out.isNull()) return;
100 |
101 | AzPrint o(out);
102 | o.reset_options();
103 | o.set_precision(5);
104 | o.ppBegin("AzRgf_FindSplit_Dflt", "Node split", ", ");
105 | o.printV(kw_lambda, lambda);
106 | o.printV_posiOnly(kw_sigma, sigma);
107 | o.ppEnd();
108 | }
109 |
110 | /*--------------------------------------------------------*/
111 | void AzRgf_FindSplit_Dflt::printHelp(AzHelp &h) const
112 | {
113 | h.begin(Azsplit_config, "AzRgf_FindSplit_Dflt", "Regularization at node split");
114 | h.item_required(kw_lambda, help_lambda);
115 | h.item_experimental(kw_sigma, help_sigma, sigma_dflt);
116 | h.item(kw_s_lambda, help_s_lambda);
117 | h.item_experimental(kw_s_sigma, help_s_sigma);
118 | h.end();
119 | }
120 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzRgf_FindSplit_Dflt.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzRgf_FindSplit_Dflt.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_RGF_FIND_SPLIT_DFLT_HPP_
20 | #define _AZ_RGF_FIND_SPLIT_DFLT_HPP_
21 |
22 | #include "AzFindSplit.hpp"
23 | #include "AzTrTree_ReadOnly.hpp"
24 | #include "AzRgf_FindSplit.hpp"
25 | #include "AzRegDepth.hpp"
26 | #include "AzParam.hpp"
27 |
28 | //! Node split search for RGF. L2 regularization.
29 | /*--------------------------------------------------------*/
30 | class AzRgf_FindSplit_Dflt : /* extends */ public virtual AzFindSplit,
31 | /* implements */ public virtual AzRgf_FindSplit
32 | {
33 | protected:
34 | double lambda, sigma;
35 | const AzRegDepth *reg_depth;
36 |
37 | double nlam, nsig;
38 | double p_nlam; //!< L2 reg param for parent (node to be split)
39 | double c_nlam; //!< L2 reg param for child (new node after split)
40 | double p_nsig, c_nsig;
41 | bool doUseInternalNodes;
42 | const AzTrTreeNode *p_node; //!< parent node (node to be split)
43 |
44 | public:
45 | AzRgf_FindSplit_Dflt() : reg_depth(NULL),
46 | lambda(-1), sigma(sigma_dflt),
47 | doUseInternalNodes(false), nlam(0), nsig(0),
48 | p_nlam(0), c_nlam(0), p_nsig(0), c_nsig(0),
49 | p_node(NULL) {}
50 | virtual void begin(const AzTrTree_ReadOnly *tree,
51 | const AzRgf_FindSplit_input &inp,
52 | int inp_min_size);
53 | virtual void end() {
54 | _end();
55 | }
56 | virtual inline
57 | void findSplit(int nx,
58 | /*--- output ---*/
59 | AzTrTsplit *best_split) {
60 | p_node = tree->node(nx);
61 | p_nlam = reg_depth->apply(nlam, p_node->depth);
62 | c_nlam = reg_depth->apply(nlam, p_node->depth+1);
63 | p_nsig = reg_depth->apply(nsig, p_node->depth);
64 | c_nsig = reg_depth->apply(nsig, p_node->depth+1);
65 | AzFindSplit::_findBestSplit(nx, best_split);
66 | }
67 | virtual void reset(AzParam ¶m,
68 | const AzRegDepth *inp_reg_depth,
69 | const AzOut &out)
70 | {
71 | reg_depth = inp_reg_depth;
72 | if (reg_depth == NULL) throw new AzException("AzRgf_FindSplit_Dflt",
73 | "null reg_depth");
74 | resetParam(param);
75 | printParam(out);
76 | }
77 |
78 | virtual void pickFeats(int pick_num, int f_num) {
79 | AzFindSplit::_pickFeats(pick_num, f_num);
80 | }
81 |
82 | virtual void printParam(const AzOut &out) const;
83 | virtual void printHelp(AzHelp &h) const;
84 |
85 | protected:
86 | virtual void resetParam(AzParam ¶m);
87 | virtual double getBestGain(double wsum,
88 | double wysum,
89 | double *best_q) const;
90 | };
91 | #endif
92 |
93 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzRgf_FindSplit_TreeReg.cpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzRgf_FindSplit_TreeReg.cpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #include "AzRgf_FindSplit_TreeReg.hpp"
20 |
21 | /*--------------------------------------------------------*/
22 | void AzRgf_FindSplit_TreeReg::findSplit(int nx,
23 | /*--- output ---*/
24 | AzTrTsplit *best_split)
25 | {
26 | if (tree->usingInternalNodes()) {
27 | throw new AzException("AzRgf_FindSplit_TreeReg::findSplit",
28 | "can't coexist with UseInternalNodes");
29 | }
30 |
31 | reg->reset_forNewLeaf(nx, tree, reg_depth);
32 | dR = ddR = 0;
33 | reg->penalty_deriv(&dR, &ddR);
34 | AzRgf_FindSplit_Dflt::findSplit(nx, best_split);
35 | }
36 |
37 | /*--------------------------------------------------------*/
38 | double AzRgf_FindSplit_TreeReg::evalSplit(
39 | const Az_forFindSplit i[2],
40 | double bestP[2])
41 | const
42 | {
43 | double d[2]; /* delta */
44 | int ix;
45 | for (ix = 0; ix < 2; ++ix) {
46 | double wrsum = i[ix].wy_sum;
47 | d[ix] = (wrsum-nlam*dR)/(i[ix].w_sum+nlam*ddR);
48 | bestP[ix] = p_node->weight + d[ix];
49 | }
50 |
51 | double penalty_diff = reg->penalty_diff(d); /* new - old */
52 |
53 | double gain = 2*d[0]*i[0].wy_sum - d[0]*d[0]*i[0].w_sum
54 | + 2*d[1]*i[1].wy_sum - d[1]*d[1]*i[1].w_sum;
55 |
56 | gain -= 2 * nlam * penalty_diff;
57 | /* "2*" b/c penalty is sum v^2/2 */
58 |
59 | return gain;
60 | }
61 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzRgf_FindSplit_TreeReg.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzRgf_FindSplit_TreeReg.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_RGF_FIND_SPLIT_TREE_REG_HPP_
20 | #define _AZ_RGF_FIND_SPLIT_TREE_REG_HPP_
21 |
22 | #include "AzRgf_FindSplit_Dflt.hpp"
23 | #include "AzReg_TreeReg.hpp"
24 | #include "AzReg_TreeRegArr.hpp"
25 |
26 | //! Node split search for RGF. L2 and tree structure regularization
27 | /*--------------------------------------------------------*/
28 | class AzRgf_FindSplit_TreeReg : /* extends */ public virtual AzRgf_FindSplit_Dflt
29 | {
30 | protected:
31 | AzReg_TreeRegArr *reg_arr;
32 | AzReg_TreeReg *reg;
33 | double dR, ddR;
34 |
35 | public:
36 | AzRgf_FindSplit_TreeReg() : dR(0), ddR(0), reg(NULL), reg_arr(NULL) {}
37 | void reset(AzReg_TreeRegArr *inp_reg_arr) {
38 | reg_arr = inp_reg_arr;
39 | }
40 |
41 | //! override
42 | virtual void begin(const AzTrTree_ReadOnly *tree,
43 | const AzRgf_FindSplit_input &inp,
44 | int inp_min_size)
45 | {
46 | AzRgf_FindSplit_Dflt::begin(tree, inp, inp_min_size);
47 | reg = reg_arr->reg_forNewLeaf(inp.tx);
48 | reg->reset_forNewLeaf(tree, reg_depth);
49 | }
50 |
51 | //! override
52 | virtual void end() {
53 | AzRgf_FindSplit_Dflt::end();
54 | reg = NULL;
55 | }
56 |
57 | //! override
58 | virtual void findSplit(int nx, AzTrTsplit *best_split);
59 |
60 | //! override AzFindSplit::evalSplit
61 | virtual double evalSplit(const Az_forFindSplit i[2],
62 | double bestP[2]) const;
63 | };
64 | #endif
65 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzRgf_Optimizer.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzRgf_Optimizer.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_RGF_OPTIMIZER_HPP_
20 | #define _AZ_RGF_OPTIMIZER_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzDmat.hpp"
24 | #include "AzBmat.hpp"
25 | #include "AzDataForTrTree.hpp"
26 | #include "AzLoss.hpp"
27 | #include "AzTrTreeEnsemble.hpp"
28 | #include "AzRgfTreeEnsemble.hpp"
29 | #include "AzRegDepth.hpp"
30 | #include "AzParam.hpp"
31 | #include "AzHelp.hpp"
32 |
33 | //! Abstract class: Weight optimizer interface.
34 | /*-------------------------------------------------------*/
35 | class AzRgf_Optimizer
36 | {
37 | public:
38 | /*! Initialization */
39 | virtual void cold_start(AzLossType loss_type,
40 | const AzDataForTrTree *training_data, /*!< training data */
41 | const AzRegDepth *reg_depth, /*!< regularizer using tree attributes */
42 | AzParam ¶m, /*!< confignuration */
43 | const AzDvect *v_y, /*!< training targets */
44 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */
45 | const AzOut out, /*!< where to wrige log */
46 | /*! output: prediction on training data. typically zeroes. */
47 | AzDvect *v_p)
48 | = 0;
49 |
50 | virtual void warm_start(AzLossType loss_type,
51 | const AzDataForTrTree *training_data, /*!< training data */
52 | const AzRegDepth *reg_depth, /*!< regularizer using tree attributes */
53 | AzParam ¶m, /*!< confignuration */
54 | const AzDvect *v_y, /*!< training targets */
55 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */
56 | const AzOut out, /*!< where to wrige log */
57 | /*--- for warm start ---*/
58 | const AzTrTreeEnsemble_ReadOnly *inp_ens,
59 | const AzDvect *inp_v_p)
60 | = 0;
61 |
62 | /*! Optimize weights. */
63 | virtual void
64 | update(const AzDataForTrTree *data, /*!< training data */
65 | AzRgfTreeEnsemble *ens, /*!< inout: tree ensmeble. */
66 | /*--- output ---*/
67 | AzDvect *v_p) /*.
17 | * * * * */
18 |
19 | #include "AzRgf_Optimizer_Dflt.hpp"
20 | #include "AzRgfTreeEnsImp.hpp"
21 | #include "AzHelp.hpp"
22 |
23 | /*------------------------------------------------------------------*/
24 | void AzRgf_Optimizer_Dflt::cold_start(AzLossType loss_type,
25 | const AzDataForTrTree *data,
26 | const AzRegDepth *reg_depth,
27 | AzParam ¶m,
28 | const AzDvect *v_y, /* for training */
29 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */
30 | const AzOut out_req,
31 | AzDvect *out_v_pval) /* output */
32 | {
33 | out = out_req;
34 | bool beVerbose = resetParam(param);
35 | trainer->reset(loss_type, v_y, v_fixed_dw, reg_depth, param, beVerbose, out);
36 | trainer->copyPred_to(out_v_pval);
37 | feat1.reset(data, param, out);
38 |
39 | if (!beVerbose) {
40 | out.deactivate();
41 | }
42 | }
43 |
44 | /*------------------------------------------------------------------*/
45 | void AzRgf_Optimizer_Dflt::warm_start(AzLossType loss_type,
46 | const AzDataForTrTree *data,
47 | const AzRegDepth *reg_depth,
48 | AzParam ¶m,
49 | const AzDvect *v_y, /* for training */
50 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */
51 | const AzOut out_req,
52 | /*--- for warm start ---*/
53 | const AzTrTreeEnsemble_ReadOnly *inp_ens,
54 | const AzDvect *inp_v_p)
55 | {
56 | out = out_req;
57 | bool beVerbose = resetParam(param);
58 | feat1.reset(data, param, out);
59 | AzIntArr ia_removed_fx;
60 | feat1.update_with_ens(inp_ens, &ia_removed_fx);
61 |
62 | trainer->reset(loss_type, v_y, v_fixed_dw, reg_depth, param, beVerbose, out,
63 | inp_ens, &feat1, inp_v_p); /* for warm start */
64 | if (!beVerbose) {
65 | out.deactivate();
66 | }
67 | }
68 |
69 | /*------------------------------------------------------------------*/
70 | void
71 | AzRgf_Optimizer_Dflt::update(const AzDataForTrTree *data,
72 | AzRgfTreeEnsemble *ens, /* inout */
73 | /*--- inout ---*/
74 | AzDvect *v_p) /* prediction */
75 | {
76 | AzIntArr ia_removed_fx;
77 | int f_num_delta = feat1.update_with_ens(ens, &ia_removed_fx);
78 |
79 | if (f_num_delta > 0 || ens->size() == 0) {
80 | trainer->optimize(ens, &feat1);
81 | }
82 | else {
83 | AzTimeLog::print("No new feature", out);
84 | }
85 |
86 | if (v_p != NULL) {
87 | trainer->copyPred_to(v_p);
88 | }
89 | }
90 |
91 | /*------------------------------------------------------------------*/
92 | /*------------------------------------------------------------------*/
93 | /* static */
94 | void AzRgf_Optimizer_Dflt::_info(const AzTrTreeEnsemble_ReadOnly *ens,
95 | const AzOptimizerT *my_trainer,
96 | const AzTrTreeFeat *my_feat,
97 | int *f_num, int *nz_f_num)
98 | {
99 | const AzDvect *v_w = my_trainer->weights();
100 | *f_num = v_w->rowNum();
101 | *nz_f_num = my_feat->countNonzeroNodup(v_w, ens);
102 | }
103 |
104 | /*------------------------------------------------------------------*/
105 | void AzRgf_Optimizer_Dflt::apply(const AzDataForTrTree *test_data,
106 | AzBmat *b_test_tran, /* inout */
107 | const AzTrTreeEnsemble_ReadOnly *ens,
108 | /*--- output ---*/
109 | AzDvect *v_p,
110 | int *f_num,
111 | int *nz_f_num) const
112 | {
113 | if (test_data != NULL) {
114 | feat1.updateMatrix(test_data, ens, b_test_tran);
115 | trainer->resetPred(b_test_tran, v_p);
116 | }
117 |
118 | /*--- set info ---*/
119 | _info(ens, trainer, &feat1, f_num, nz_f_num);
120 | }
121 |
122 | /*------------------------------------------------------------------*/
123 | /*------------------------------------------------------------------*/
124 | bool AzRgf_Optimizer_Dflt::resetParam(AzParam &p)
125 | {
126 | bool beVerbose = false;
127 | p.swOn(&beVerbose, kw_opt_beVerbose);
128 | return beVerbose;
129 | }
130 |
131 | /*------------------------------------------------------------------*/
132 | void AzRgf_Optimizer_Dflt::printHelp(AzHelp &h) const
133 | {
134 | trainer->printHelp(h);
135 | feat1.printHelp(h);
136 | h.begin(Azopt_config, "AzRgf_Optimizer_Dflt", "Weight optimization/correction");
137 | h.item_experimental(kw_opt_beVerbose, help_opt_beVerbose);
138 | h.end();
139 | }
140 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzRgf_Optimizer_Dflt.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzRgf_Optimizer_Dflt.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_RGF_OPTIMIZER_DFLT_HPP_
20 | #define _AZ_RGF_OPTIMIZER_DFLT_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzDmat.hpp"
24 | #include "AzDataForTrTree.hpp"
25 | #include "AzRgfTree.hpp"
26 | #include "AzTrTreeEnsemble.hpp"
27 | #include "AzLoss.hpp"
28 | #include "AzTrTreeFeat.hpp"
29 | #include "AzOptOnTree.hpp"
30 | #include "AzRgf_Optimizer.hpp"
31 | #include "AzOptimizerT.hpp"
32 |
33 | //! implement AzRgf_Optimizer.
34 | /*-------------------------------------------------------*/
35 | class AzRgf_Optimizer_Dflt : /* implements */ public virtual AzRgf_Optimizer
36 | {
37 | protected:
38 | AzTrTreeFeat feat1;
39 | AzOut out;
40 |
41 | AzOptOnTree trainer_dflt; /* linear trainer */
42 | AzOptimizerT *trainer;
43 |
44 | public:
45 | AzRgf_Optimizer_Dflt() : trainer(&trainer_dflt) {}
46 | ~AzRgf_Optimizer_Dflt() {}
47 | AzRgf_Optimizer_Dflt(const AzRgf_Optimizer_Dflt *inp) {
48 | reset(inp);
49 | }
50 |
51 | /*------------------------------------------------------*/
52 | /* override this to replace trainer */
53 | /*------------------------------------------------------*/
54 | virtual void reset(const AzRgf_Optimizer_Dflt *inp) {
55 | if (inp == NULL) return;
56 | feat1.reset(&inp->feat1);
57 | out = inp->out;
58 | trainer_dflt.reset(&inp->trainer_dflt);
59 | trainer = &trainer_dflt;
60 | }
61 | /*------------------------------------------------------*/
62 |
63 | /*------------------------------------------------------*/
64 | /* derived classes must override this */
65 | /*------------------------------------------------------*/
66 | virtual void temp_update_apply(const AzDataForTrTree *tr_data,
67 | AzRgfTreeEnsemble *temp_ens,
68 | const AzDataForTrTree *test_data,
69 | AzBmat *temp_b, AzDvect *v_test_p,
70 | int *f_num, int *nz_f_num) const {
71 | AzRgf_Optimizer_Dflt temp_opt(this);
72 | temp_opt.update(tr_data, temp_ens);
73 | if (test_data != NULL) temp_opt.apply(test_data, temp_b, temp_ens,
74 | v_test_p, f_num, nz_f_num);
75 | }
76 | /*--------------------------------------------------------*/
77 |
78 |
79 | virtual void cold_start(AzLossType loss_type,
80 | const AzDataForTrTree *data,
81 | const AzRegDepth *reg_depth,
82 | AzParam ¶m,
83 | const AzDvect *v_yval,
84 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */
85 | const AzOut out_req,
86 | AzDvect *v_pval); /* output */
87 | virtual void warm_start(AzLossType loss_type,
88 | const AzDataForTrTree *data,
89 | const AzRegDepth *reg_depth,
90 | AzParam ¶m,
91 | const AzDvect *v_y, /* for training */
92 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */
93 | const AzOut out_req,
94 | /*--- for warm start ---*/
95 | const AzTrTreeEnsemble_ReadOnly *ens,
96 | const AzDvect *v_p);
97 | virtual void
98 | update(const AzDataForTrTree *data,
99 | AzRgfTreeEnsemble *ens, /* inout */
100 | /*--- output ---*/
101 | AzDvect *v_p=NULL); /* prediction */
102 |
103 | virtual void apply(const AzDataForTrTree *data,
104 | AzBmat *b_test_tran, /* inout */
105 | const AzTrTreeEnsemble_ReadOnly *ens,
106 | /*--- output ---*/
107 | AzDvect *v_p,
108 | int *f_num, int *nz_f_num) const;
109 |
110 | virtual void printHelp(AzHelp &h) const;
111 |
112 | protected:
113 | virtual bool resetParam(AzParam ¶m);
114 | static void _info(const AzTrTreeEnsemble_ReadOnly *ens,
115 | const AzOptimizerT *my_trainer,
116 | const AzTrTreeFeat *my_feat,
117 | int *f_num, int *nz_f_num);
118 | };
119 | #endif
120 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzRgf_Optimizer_TreeReg.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzRgf_Optimizer_TreeReg.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_RGF_OPTIMIZER_TREE_REG_HPP_
20 | #define _AZ_RGF_OPTIMIZER_TREE_REG_HPP_
21 |
22 | #include "AzRgf_Optimizer_Dflt.hpp"
23 | #include "AzOptOnTree_TreeReg.hpp"
24 |
25 | //! for regularization using tree structure.
26 | /*-------------------------------------------------------*/
27 | class AzRgf_Optimizer_TreeReg : /* extends */ public virtual AzRgf_Optimizer_Dflt
28 | {
29 | protected:
30 | AzOptOnTree_TreeReg trainer_tr;
31 |
32 | public:
33 | AzRgf_Optimizer_TreeReg() {
34 | trainer = &trainer_tr;
35 | }
36 | void reset(AzReg_TreeRegArr *reg_arr) {
37 | trainer_tr.reset(reg_arr);
38 | }
39 | AzRgf_Optimizer_TreeReg(const AzRgf_Optimizer_TreeReg *inp) {
40 | reset(inp);
41 | }
42 |
43 | /*------------------------------------------------------*/
44 | /* override this to replace trainer */
45 | /*------------------------------------------------------*/
46 | virtual void reset(const AzRgf_Optimizer_TreeReg *inp) {
47 | if (inp == NULL) return;
48 | AzRgf_Optimizer_Dflt::reset(inp);
49 |
50 | trainer_tr.reset(&inp->trainer_tr);
51 | trainer = &trainer_tr;
52 | }
53 |
54 | /*------------------------------------------------------*/
55 | /* derived classes must override this */
56 | /*------------------------------------------------------*/
57 | virtual void temp_update_apply(const AzDataForTrTree *tr_data,
58 | AzRgfTreeEnsemble *temp_ens,
59 | const AzDataForTrTree *test_data,
60 | AzBmat *temp_b, AzDvect *v_test_p,
61 | int *f_num, int *nz_f_num) const {
62 | AzRgf_Optimizer_TreeReg temp_opt(this);
63 | temp_opt.update(tr_data, temp_ens);
64 | if (test_data != NULL) temp_opt.apply(test_data, temp_b, temp_ens,
65 | v_test_p, f_num, nz_f_num);
66 | }
67 | /*--------------------------------------------------------*/
68 | };
69 | #endif
70 |
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzRgforest_TreeReg.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzRgforest_TreeReg.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_RGFOREST_TREEREG_HPP_
20 | #define _AZ_RGFOREST_TREEREG_HPP_
21 |
22 | #include "AzRgforest.hpp"
23 |
24 | #include "AzRgf_Optimizer_TreeReg.hpp"
25 | #include "AzRgf_FindSplit_TreeReg.hpp"
26 | #include "AzReg_TreeRegArrImp.hpp"
27 |
28 | //! RGF with min-penalty regularization.
29 |
30 | template
31 | class AzRgforest_TreeReg : /* implements */ public virtual AzRgforest {
32 | protected:
33 | AzRgf_FindSplit_TreeReg tr_fs;
34 | AzRgf_Optimizer_TreeReg tr_opt; /* weight optimizer */
35 | AzReg_TreeRegArrImp reg_arr;
36 | AzBytArr s_sign, s_desc;
37 |
38 | public:
39 | AzRgforest_TreeReg()
40 | {
41 | tr_fs.reset(®_arr);
42 | tr_opt.reset(®_arr);
43 | opt = &tr_opt;
44 | fs = &tr_fs;
45 | s_sign.reset(reg_arr.tmpl()->signature());
46 | s_desc.reset(reg_arr.tmpl()->description());
47 | reg_depth->set_default_for_min_penalty();
48 | }
49 | virtual inline const char *signature() const {
50 | return s_sign.c_str();
51 | }
52 | virtual inline const char *description() const {
53 | return s_desc.c_str();
54 | }
55 |
56 | virtual void printHelp(AzHelp &h) const {
57 | AzRgforest::printHelp(h);
58 | reg_arr.tmpl()->printHelp(h);
59 | h.begin(Azforest_config, "AzRgforest_TreeReg", "For min-penalty regularization");
60 | h.item(kw_doApproxTsr, help_doApproxTsr);
61 | h.end();
62 | }
63 |
64 | protected:
65 | virtual int resetParam(AzParam ¶m) { /* returns max #tree */
66 | int max_tree_num = AzRgforest::resetParam(param);
67 |
68 | bool doApproxTsr = false;
69 | param.swOn(&doApproxTsr, kw_doApproxTsr);
70 | if (doApproxTsr) {
71 | AzPrint o(out);
72 | o.ppBegin("AzRgforest_TreeReg", "Approximation", ", ");
73 | o.printSw(kw_doApproxTsr, doApproxTsr);
74 | o.ppEnd();
75 | if (doForceToRefreshAll) {
76 | doForceToRefreshAll = false;
77 | AzPrint::writeln(out, "Turning off ", kw_doForceToRefreshAll);
78 | }
79 | }
80 | else {
81 | if (!doForceToRefreshAll) {
82 | doForceToRefreshAll = true;
83 | AzPrint::writeln(out, "Turning on ", kw_doForceToRefreshAll);
84 | }
85 | }
86 |
87 | reg_arr.tmpl_u()->resetParam(param);
88 | reg_arr.tmpl()->printParam(out);
89 | reg_arr.reset(max_tree_num);
90 | return max_tree_num;
91 | }
92 |
93 | virtual void end_of_initialization() {
94 | AzRgforest::end_of_initialization();
95 | reg_arr.tmpl()->check_reg_depth(reg_depth);
96 | }
97 | };
98 | #endif
99 |
100 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTET_Eval.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTET_Eval.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TET_EVAL_HPP_
20 | #define _AZ_TET_EVAL_HPP_
21 |
22 | #include "AzDataForTrTree.hpp"
23 | #include "AzLoss.hpp"
24 | #include "AzTE_ModelInfo.hpp"
25 | #include "AzPerfResult.hpp"
26 |
27 | //! Abstract class: interface for evaluation modules for Tree Ensemble Trainer.
28 | /*-------------------------------------------------------*/
29 | class AzTET_Eval {
30 | public:
31 | virtual void reset(const AzDvect *inp_v_y,
32 | const char *perf_fn,
33 | bool inp_doAppend) = 0;
34 | virtual void begin(const char *config="",
35 | AzLossType loss_type=AzLoss_None) = 0;
36 | virtual void resetConfig(const char *config) = 0;
37 | virtual void end() = 0;
38 | virtual void evaluate(const AzDvect *v_p, const AzTE_ModelInfo *info,
39 | const char *user_str=NULL) = 0;
40 | virtual bool isActive() const = 0;
41 | };
42 | #endif
43 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTET_Eval_Dflt.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTET_Eval_Dflt.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TET_EVAL_DFLT_HPP_
20 | #define _AZ_TET_EVAL_DFLT_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzTaskTools.hpp"
24 | #include "AzPerfResult.hpp"
25 | #include "AzDataForTrTree.hpp"
26 | #include "AzLoss.hpp"
27 | #include "AzTET_Eval.hpp"
28 |
29 | //! Evaluationt module for Tree Ensemble Trainer.
30 | /*-------------------------------------------------------*/
31 | class AzTET_Eval_Dflt : /* implements */ public virtual AzTET_Eval {
32 | protected:
33 | /*--- targets ---*/
34 | const AzDvect *v_y;
35 |
36 | /*--- to output evaluation results ---*/
37 | AzBytArr s_perf_fn;
38 | AzPerfType perf_type;
39 |
40 | AzLossType loss_type;
41 | AzBytArr s_config;
42 |
43 | AzOfs ofs;
44 | AzOut perf_out;
45 | bool doAppend;
46 |
47 | public:
48 | AzTET_Eval_Dflt() : v_y(NULL),
49 | loss_type(AzLoss_None), doAppend(false) {}
50 | ~AzTET_Eval_Dflt() {
51 | end();
52 | }
53 | inline virtual bool isActive() const {
54 | if (v_y != NULL) return true;
55 | return false;
56 | }
57 |
58 | virtual void reset() {
59 | v_y= NULL;
60 | s_perf_fn.reset();
61 | s_config.reset();
62 | if (ofs.is_open()) {
63 | ofs.close();
64 | }
65 | }
66 | void reset(const AzDvect *inp_v_y,
67 | const char *perf_fn,
68 | bool inp_doAppend)
69 | {
70 | v_y = inp_v_y;
71 | s_perf_fn.reset(perf_fn);
72 | doAppend = inp_doAppend;
73 | }
74 | virtual void resetConfig(const char *config) {
75 | s_config.reset(config);
76 | _clean(&s_config);
77 | }
78 |
79 | virtual void begin(const char *config,
80 | AzLossType inp_loss_type) {
81 | if (!isActive()) return;
82 | _begin(config, inp_loss_type);
83 | _clean(&s_config);
84 | }
85 | virtual void end() {
86 | if (ofs.is_open()) {
87 | ofs.close();
88 | }
89 | }
90 |
91 | virtual void evaluate(const AzDvect *v_p,
92 | const AzTE_ModelInfo *info,
93 | const char *user_str=NULL) {
94 | if (!isActive()) return;
95 | AzPerfResult result=AzTaskTools::eval(v_p, v_y, loss_type);
96 |
97 | /*--- signature and configuration ---*/
98 | AzBytArr s_sign_config(info->s_sign);
99 | s_sign_config.concat(":");
100 | concat_config(info, &s_sign_config);
101 |
102 | /*--- print ---*/
103 | AzPrint o(perf_out);
104 | o.printBegin("", ",", ",");
105 | o.print("#tree", info->tree_num);
106 | o.print("#leaf", info->leaf_num);
107 | o.print("acc", result.acc, 4);
108 | o.print("rmse", result.rmse, 4);
109 | o.print("sqerr", result.rmse*result.rmse, 6);
110 | o.print(loss_str[loss_type]);
111 | o.print("loss", result.loss, 6);
112 | o.print("#test", v_p->rowNum());
113 | o.print("cfg"); /* for compatibility */
114 | o.print(s_sign_config);
115 | if (user_str != NULL) {
116 | o.print(user_str);
117 | }
118 | o.printEnd();
119 | }
120 |
121 | protected:
122 | virtual void _begin(const char *config, AzLossType inp_loss_type) {
123 | s_config.reset(config);
124 | loss_type = inp_loss_type;
125 |
126 | if (ofs.is_open()) {
127 | ofs.close();
128 | }
129 | const char *perf_fn = s_perf_fn.c_str();
130 | if (AzTools::isSpecified(perf_fn)) {
131 | ios_base::openmode mode = ios_base::out;
132 | if (doAppend) {
133 | mode = ios_base::app | ios_base::out;
134 | }
135 | ofs.open(perf_fn, mode);
136 | ofs.set_to(perf_out);
137 | }
138 | else {
139 | perf_out.setStdout();
140 | }
141 | }
142 |
143 | virtual void _clean(AzBytArr *s) const {
144 | /*-- replace comma with : for convenience later --*/
145 | s->replace(',', ';');
146 | }
147 |
148 | virtual void concat_config(const AzTE_ModelInfo *info, AzBytArr *s) const {
149 | if (s_config.length() > 0) {
150 | s->concat(&s_config);
151 | }
152 | else {
153 | AzBytArr s_cfg(&info->s_config);
154 | _clean(&s_cfg);
155 | s->concat(&s_cfg);
156 | }
157 | }
158 | };
159 | #endif
160 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTETmain_kw.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTETmain_kw.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TET_MAIN_KW_HPP_
20 | #define _AZ_TET_MAIN_KW_HPP_
21 |
22 | #define kw_train "train"
23 | #define kw_train_test "train_test"
24 | #define kw_predict "predict"
25 | #define kw_batch_predict "batch_predict"
26 | #define kw_train_predict "train_predict"
27 | #define kw_features "output_features"
28 | #define help_train "Train and save models to files."
29 | #define help_train_test "Train and test models. Optionally models can be saved to files."
30 | #define help_train_predict "Train models and save predictions on test data to files. Models can also be saved to files."
31 | #define help_predict "Apply a model saved by \"train\" to new data."
32 | #define help_batch_predict "Apply several models to new data."
33 | #define help_features "Output features generated by tree ensembles."
34 |
35 | #define kw_alg_name "algorithm="
36 | #define kw_train_x_fn "train_x_fn="
37 | #define kw_train_y_fn "train_y_fn="
38 | #define kw_fdic_fn "x_name_fn="
39 | #define kw_model_stem "model_fn_prefix="
40 | #define kw_model_names_fn "model_names_fn="
41 | #define kw_model_fn "model_fn="
42 | #define kw_prev_model_fn "model_fn_for_warmstart="
43 | #define kw_pred_fn "prediction_fn="
44 | #define kw_pred_fn_suffix "pred_fn_suffix="
45 | #define kw_not_doLog "DontLog"
46 | #define kw_doLog "Log"
47 | #define kw_doDump "Dump"
48 | #define kw_eval_fn "evaluation_fn="
49 | #define kw_doAppend_eval "Append_evaluation"
50 | #define kw_test_x_fn "test_x_fn="
51 | #define kw_test_y_fn "test_y_fn="
52 | #define kw_dw_fn "train_w_fn="
53 | #define kw_doSaveLastModelOnly "SaveLastModelOnly"
54 |
55 | #define kw_xv_doShuffle "ShuffleData"
56 | #define kw_xv_num "num_xv="
57 | #define kw_xv_fn "xv_fn="
58 | #define kw_input_x_fn "input_x_fn="
59 | #define kw_output_x_fn "output_x_fn="
60 | #define kw_features_digits "features_digits="
61 | #define kw_doSparse_features "SparseFeatures"
62 |
63 | #define help_train_x_fn "Path to the feature file of training data."
64 | #define help_train_y_fn "Path to the target file of training data."
65 | #define help_fdic_fn "Path to the file of feature names."
66 | #define help_model_stem "To save models, path names are generated by attaching \"-01\", \"-02\",... to this value."
67 | #define help_model_stem_tp "Path names to model files are generated by attaching \"-01\", \"-02\",... to this value. Path names to prediction files and model info files are generated by attaching \".pred\" of \".info\" to the model path names."
68 | #define help_model_names_fn_out "Path to the file to write model path names to. If omitted, model path names are not saved."
69 | #define help_model_names_fn_inp "Path to the file to read model path names from."
70 | #define help_model_fn "Path to the model file to be tested"
71 | #define help_prev_model_fn "Path to the input model file from which training should do warm-start."
72 | #define help_prev_model_fn_others "Path to the input model file from which training should do warm-start. (WARNING) Some algorithms do not support warm-start and return error if this parameter is specified."
73 | #define help_pred_fn_out "Path to the file to write predictions to."
74 | #define help_pred_fn_inp "Path to the file to read predictions from."
75 | #define help_pred_fn_suffix "The path names of the predictions files are generated by attaching this to the path names of the corresponding models."
76 | #define help_not_doLog "Suppress logging-purpose output to stdout."
77 | #define help_doDump "Enable dump to stderr for the verbose components."
78 | #define help_eval_fn "Path to the file to write evaluation to."
79 | #define help_doAppend_eval "Open the evaluation result file with append mode."
80 | #define help_test_x_fn "Path to the feature file of test data"
81 | #define help_test_y_fn "Path to the target file of test data"
82 | #define help_dw_fn "Path to the file of user-defined weights assigned to training data points."
83 | #define help_doSaveLastModelOnly "Save the last/largest model only."
84 | #define help_doSaveLastModelOnly_traintest "Save the last/largest model only. Referred to only when model_fn_suffix is specified."
85 |
86 | #define help_input_x_fn "Path to the input feature file."
87 | #define help_output_x_fn "Path to the output feature file."
88 | #define help_features_digits "How many digits should be retained in the output."
89 | #define help_doSparse_features "Write features in the sparse data format."
90 |
91 | /* #define dflt_model_names_fn "model_list.txt" */
92 | #define dflt_model_stem ""
93 |
94 | #define Az_config "config"
95 |
96 | #endif
97 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTETselector.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTETselector.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TET_SELECTOR_HPP_
20 | #define _AZ_TET_SELECTOR_HPP_
21 |
22 | #include "AzTETrainer.hpp"
23 |
24 | class AzTETselector {
25 | public:
26 | //! Return trainer
27 | virtual AzTETrainer *select(const char *alg_name, //! algorithm name
28 | //! if true, don't throw exception on error
29 | bool dontThrow=false
30 | ) const = 0;
31 |
32 | virtual const char *dflt_name() const = 0;
33 | virtual const char *another_name() const = 0;
34 | virtual const AzStrArray *names() const = 0;
35 | virtual bool isRGFfamily(const char *name) const {
36 | AzBytArr s(name);
37 | return s.beginsWith("RGF");
38 | }
39 | virtual bool isGBfamily(const char *name) const {
40 | AzBytArr s(name);
41 | return s.beginsWith("GB");
42 | }
43 |
44 | //! Return algorithm names.
45 | virtual void printOptions(const char *dlm, //! delimiter between algorithm names.
46 | AzBytArr *s) //!< output: algorithm names separated by dlm.
47 | const = 0;
48 |
49 | //! Help
50 | virtual void printHelp(AzHelp &h) const = 0;
51 | };
52 |
53 | #endif
54 |
55 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTE_ModelInfo.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTE_ModelInfo.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TE_MODEL_INFO_HPP_
20 | #define _AZ_TE_MODEL_INFO_HPP_
21 |
22 | #include "AzUtil.hpp"
23 |
24 | /*-------------------------------------------------------*/
25 | /*! Tree Ensemble model info */
26 | class AzTE_ModelInfo {
27 | public:
28 | int tree_num; //!< number of trees
29 | int leaf_num; //!< number of features/leaf nodes
30 | int f_num; //!< number of features including removed ones.
31 | int nz_f_num; //!< number of non-zero-weight features after consolidation
32 | AzBytArr s_sign;
33 | AzBytArr s_config; //!< configuration
34 | const char *sign; //!< signature of trainer
35 | AzTE_ModelInfo() : f_num(-1),leaf_num(-1),nz_f_num(-1),tree_num(-1) {}
36 |
37 | void reset() {
38 | tree_num = leaf_num = f_num = nz_f_num = -1;
39 | s_sign.reset();
40 | s_config.reset();
41 | }
42 | };
43 | #endif
44 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTrTreeEnsemble_ReadOnly.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTrTreeEnsemble_ReadOnly.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TR_TREE_ENSEMBLE_READONLY_HPP_
20 | #define _AZ_TR_TREE_ENSEMBLE_READONLY_HPP_
21 |
22 | #include "AzTrTree_ReadOnly.hpp"
23 | #include "AzTreeEnsemble.hpp"
24 | #include "AzSvFeatInfo.hpp"
25 |
26 | //! Abstract class: interface for read-only access to trainalbe tree ensemble.
27 | class AzTrTreeEnsemble_ReadOnly {
28 | public:
29 | virtual bool usingTempFile() const { return false; }
30 | virtual const AzTrTree_ReadOnly *tree(int tx) const = 0;
31 | virtual int leafNum() const = 0;
32 | virtual int leafNum(int tx0, int tx1) const = 0;
33 | virtual int size() const = 0;
34 | virtual int max_size() const = 0;
35 | virtual int lastIndex() const = 0;
36 | virtual void copy_to(AzTreeEnsemble *out_ens,
37 | const char *config, const char *sign) const = 0;
38 | virtual void show(const AzSvFeatInfo *feat,
39 | const AzOut &out, const char *header="") const = 0;
40 | virtual double constant() const = 0;
41 | virtual int orgdim() const = 0;
42 | virtual const char *param_c_str() const = 0;
43 | };
44 | #endif
45 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTrTreeNode.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTrTreeNode.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TR_TREE_NODE_HPP_
20 | #define _AZ_TR_TREE_NODE_HPP_
21 |
22 | #include "AzTreeNodes.hpp"
23 |
24 | class AzTrTree;
25 |
26 | /*---------------------------------------------*/
27 | /*! used only for training */
28 | class AzTrTreeNode : /* extends */ public virtual AzTreeNode {
29 | protected:
30 | const int *dxs; /* data indexes belonging to this node */
31 |
32 | public:
33 | int dxs_offset; /* position in the data indexes at the root */
34 | int dxs_num;
35 | int depth; //!< node depth
36 |
37 | AzTrTreeNode() : depth(-1), dxs(NULL), dxs_offset(-1), dxs_num(-1) {}
38 | void reset() {
39 | AzTreeNode::reset();
40 | depth = dxs_offset = dxs_num = -1;
41 | dxs = NULL;
42 | }
43 | void transfer_from(AzTrTreeNode *inp) {
44 | AzTreeNode::transfer_from(inp);
45 | dxs = inp->dxs;
46 | dxs_offset = inp->dxs_offset;
47 | dxs_num = inp->dxs_num;
48 | depth = inp->depth;
49 | }
50 |
51 | inline const int *data_indexes() const {
52 | if (dxs_num > 0 && dxs == NULL) {
53 | throw new AzException("AzTrTreeNode::data_indexes",
54 | "data indexes are unavailable");
55 | }
56 | return dxs;
57 | }
58 | inline void reset_data_indexes(const int *ptr) {
59 | dxs = ptr;
60 | }
61 |
62 | friend class AzTrTree;
63 | };
64 |
65 | #endif
66 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTrTree_ReadOnly.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTrTree_ReadOnly.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TR_TREE_READONLY_HPP_
20 | #define _AZ_TR_TREE_READONLY_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzDataForTrTree.hpp"
24 | #include "AzTreeRule.hpp"
25 | #include "AzSvFeatInfo.hpp"
26 | #include "AzTreeNodes.hpp"
27 | #include "AzTrTreeNode.hpp"
28 | #include "AzSortedFeat.hpp"
29 |
30 | //! Abstract class: interface for read-only (information-seeking) access to trainable tree.
31 | /*------------------------------------------*/
32 | /* Trainable tree; read only */
33 | class AzTrTree_ReadOnly : /* implements */ public virtual AzTreeNodes
34 | {
35 | public:
36 | /*--- information seeking ... ---*/
37 | virtual int nodeNum() const = 0;
38 | virtual int leafNum() const = 0;
39 | virtual int maxDepth() const = 0;
40 | virtual void show(const AzSvFeatInfo *feat, const AzOut &out) const = 0;
41 | virtual void concat_stat(AzBytArr *o) const = 0;
42 | virtual double getRule(int inp_nx, AzTreeRule *rule) const = 0;
43 | virtual void concatDesc(const AzSvFeatInfo *feat, int nx,
44 | AzBytArr *str_desc, /* output */
45 | int max_len=-1) const = 0;
46 | virtual void isActiveNode(bool doAllowZeroWeightLeaf,
47 | AzIntArr *ia_isDecisionNode) const = 0; /* output */
48 | virtual bool usingInternalNodes() const = 0;
49 |
50 | virtual const AzSortedFeatArr *sorted_array(int nx,
51 | const AzDataForTrTree *data) const = 0;
52 | /*--- (NOTE) this is const but changes sorted_arr[nx] ---*/
53 |
54 | virtual const AzIntArr *root_dx() const = 0;
55 |
56 | /*--- apply ... ---*/
57 | virtual double apply(const AzDataForTrTree *data, int dx,
58 | AzIntArr *ia_nx=NULL) const /* node path */
59 | = 0;
60 |
61 | virtual const AzTrTreeNode *node(int nx) const = 0;
62 | };
63 | #endif
64 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTrTsplit.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTrTsplit.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TRT_SPLIT_HPP_
20 | #define _AZ_TRT_SPLIT_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzTools.hpp"
24 |
25 | //! Node split information.
26 | class AzTrTsplit {
27 | public:
28 | int fx;
29 | double border_val;
30 | double gain;
31 | double bestP[2]; /* le gt */
32 | AzBytArr str_desc;
33 |
34 | int tx, nx; /* set only by Rgf; not used by Std */
35 |
36 | AzTrTsplit() : fx(-1), border_val(0), gain(0), tx(-1), nx(-1) {
37 | bestP[0] = bestP[1] = 0;
38 | }
39 |
40 | virtual void print(const char *header) {
41 | #if 0
42 | printf("%s, fx=%d,border_val=%e,gain=%e,bestP[0]=%e,bestP[1]=%e,tx=%d,nx=%d\n",
43 | header, fx, border_val, gain, bestP[0], bestP[1], tx, nx);
44 | #endif
45 | }
46 |
47 | virtual
48 | void reset() {
49 | fx = -1;
50 | border_val = 0;
51 | bestP[0] = bestP[1] = 0;
52 | gain = 0;
53 | str_desc.reset();
54 | tx = nx = -1;
55 | }
56 | AzTrTsplit(int fx, double border_val,
57 | double gain,
58 | double bestP_L, double bestP_G) {
59 | reset_values(fx, border_val, gain, bestP_L, bestP_G);
60 | }
61 | AzTrTsplit(const AzTrTsplit *inp) { /* copy */
62 | copy(inp);
63 | }
64 | virtual
65 | inline bool isEmpty() const {
66 | if (fx < 0) return true;
67 | return false;
68 | }
69 |
70 | virtual
71 | inline void reset(const AzTrTsplit *inp) {
72 | copy(inp);
73 | }
74 | virtual
75 | inline void reset(const AzTrTsplit *inp, int inp_tx, int inp_nx) {
76 | reset(inp);
77 | tx = inp_tx;
78 | nx = inp_nx;
79 | }
80 | virtual
81 | void copy(const AzTrTsplit *inp) {
82 | if (inp == NULL) return;
83 | fx = inp->fx;
84 | border_val = inp->border_val;
85 | gain = inp->gain;
86 | str_desc.clear();
87 | str_desc.concat(&inp->str_desc);
88 | bestP[0] = inp->bestP[0];
89 | bestP[1] = inp->bestP[1];
90 | tx = inp->tx;
91 | nx = inp->nx;
92 | }
93 |
94 | virtual
95 | void keep_if_good(int inp_fx, double inp_border_val,
96 | double inp_gain,
97 | double bestP_L, double bestP_G) {
98 | if (inp_gain > gain) {
99 | reset_values(inp_fx, inp_border_val, inp_gain, bestP_L, bestP_G);
100 | }
101 | }
102 |
103 | virtual
104 | void reset_values(int inp_fx, double inp_border_val,
105 | double inp_gain,
106 | double bestP_L, double bestP_G)
107 | {
108 | fx = inp_fx;
109 | border_val = inp_border_val;
110 | gain = inp_gain;
111 | bestP[0] = bestP_L;
112 | bestP[1] = bestP_G;
113 |
114 | tx = nx = -1;
115 | }
116 | virtual
117 | void release() {
118 | str_desc.reset();
119 | }
120 | };
121 | #endif
122 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTrTtarget.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTrTtarget.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TRT_TARGET_HPP_
20 | #define _AZ_TRT_TARGET_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzDmat.hpp"
24 |
25 | //! Targets and data point weights for node split search.
26 | /*--------------------------------------------------------*/
27 | class AzTrTtarget {
28 | protected:
29 | AzDvect v_tar_dw, v_dw;
30 | AzDvect v_y;
31 | AzDvect v_fixed_dw; /* data point weights assigned by users */
32 | double fixed_dw_sum;
33 |
34 | public:
35 | AzTrTtarget() : fixed_dw_sum(-1) {}
36 | AzTrTtarget(const AzDvect *inp_v_y,
37 | const AzDvect *inp_v_fixed_dw=NULL) {
38 | reset(inp_v_y, inp_v_fixed_dw);
39 | }
40 | void reset(const AzDvect *inp_v_y,
41 | const AzDvect *inp_v_fixed_dw=NULL) {
42 | v_dw.reform(inp_v_y->rowNum());
43 | v_dw.set(1);
44 | v_tar_dw.set(inp_v_y);
45 | v_y.set(inp_v_y);
46 | fixed_dw_sum = -1;
47 |
48 | v_fixed_dw.reset();
49 | if (!AzDvect::isNull(inp_v_fixed_dw)) {
50 | v_fixed_dw.set(inp_v_fixed_dw);
51 | if (v_fixed_dw.rowNum() != v_y.rowNum()) {
52 | throw new AzException(AzInputError, "AzTrTtarget::reset",
53 | "conlict in dimensionality: y and data point weights");
54 | }
55 | fixed_dw_sum = v_fixed_dw.sum();
56 | }
57 | }
58 | inline bool isWeighted() const {
59 | return !AzDvect::isNull(&v_fixed_dw);
60 | }
61 | inline double sum_fixed_dw() const {
62 | return fixed_dw_sum;
63 | }
64 | inline void weight_tarDw() {
65 | v_tar_dw.scale(&v_fixed_dw);
66 | }
67 | inline void weight_dw() {
68 | v_dw.scale(&v_fixed_dw);
69 | }
70 |
71 | AzTrTtarget(const AzTrTtarget *inp) {
72 | reset(inp);
73 | }
74 |
75 | void reset(const AzTrTtarget *inp) {
76 | if (inp != NULL) {
77 | v_tar_dw.set(&inp->v_tar_dw);
78 | v_dw.set(&inp->v_dw);
79 | v_y.set(&inp->v_y);
80 | v_fixed_dw.set(&inp->v_fixed_dw);
81 | fixed_dw_sum = inp->fixed_dw_sum;
82 | }
83 | }
84 |
85 | void resetTargetDw(const AzDvect *v_tar, const AzDvect *inp_v_dw) {
86 | v_tar_dw.set(v_tar);
87 | v_dw.set(inp_v_dw);
88 | v_tar_dw.scale(&v_dw); /* component-wise multiplication */
89 | }
90 | void resetTarDw_residual(const AzDvect *v_p) { /* only for LS */
91 | v_tar_dw.set(&v_y);
92 | v_tar_dw.add(v_p, -1);
93 | }
94 | inline const double *dw_arr() const {
95 | return v_dw.point();
96 | }
97 | inline const double *tarDw_arr() const {
98 | return v_tar_dw.point();
99 | }
100 | inline const AzDvect *y() const {
101 | return &v_y;
102 | }
103 | inline int dataNum() const {
104 | return v_tar_dw.rowNum();
105 | }
106 | inline AzDvect *tarDw_forUpdate() {
107 | return &v_tar_dw;
108 | }
109 | inline const AzDvect *tarDw() const {
110 | return &v_tar_dw;
111 | }
112 | inline AzDvect *dw_forUpdate() {
113 | return &v_dw;
114 | }
115 | inline const AzDvect *dw() {
116 | return &v_dw;
117 | }
118 | inline double getTarDwSum(const int *dxs, int dxs_num) const {
119 | return v_tar_dw.sum(dxs, dxs_num);
120 | }
121 | inline double getDwSum(const int *dxs, int dxs_num) const {
122 | return v_dw.sum(dxs, dxs_num);
123 | }
124 | inline double getTarDwSum(const AzIntArr *ia_dx=NULL) const {
125 | return v_tar_dw.sum(ia_dx);
126 | }
127 | inline double getDwSum(const AzIntArr *ia_dx=NULL) const {
128 | return v_dw.sum(ia_dx);
129 | }
130 |
131 | int dim() const {
132 | return v_tar_dw.rowNum();
133 | }
134 | };
135 | #endif
136 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTree.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTree.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TREE_HPP_
20 | #define _AZ_TREE_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzSmat.hpp"
24 | #include "AzSvFeatInfo.hpp"
25 | #include "AzTreeNodes.hpp"
26 |
27 | //! Untrainalbe regression tree.
28 | /*------------------------------------------*/
29 | class AzTree : /* implements */ public virtual AzTreeNodes {
30 | protected:
31 | int root_nx;
32 | int nodes_used;
33 | AzTreeNode *nodes;
34 | AzBaseArray a_nodes;
35 |
36 | inline void _checkNode(int nx, const char *eyec) const {
37 | if (nodes == NULL || nx < 0 || nx >= nodes_used) {
38 | throw new AzException(eyec, "nx is out of range");
39 | }
40 | }
41 |
42 | public:
43 | AzTree() : root_nx(-1), nodes_used(0), nodes(NULL) {}
44 | AzTree(AzFile *file)
45 | : root_nx(-1), nodes_used(0), nodes(NULL) {
46 | _read(file);
47 | }
48 | AzTree(const AzTreeNodes *inp) : root_nx(-1), nodes_used(0), nodes(NULL) {
49 | if (inp != NULL) {
50 | copy_from(inp);
51 | }
52 | }
53 | ~AzTree() {}
54 |
55 | void copy_from(const AzTreeNodes *tree_nodes);
56 |
57 | inline void transfer_from(AzTree *) {
58 | throw new AzException("AzTree::transfer_from", "no support");
59 | }
60 | inline AzTree & operator =(const AzTree &inp) {
61 | if (this == &inp) return *this;
62 | throw new AzException("AzTree:=", "Don't use =");
63 | }
64 | void reset() {
65 | _release();
66 | }
67 |
68 | int write(AzFile *file);
69 | void read(AzFile *file);
70 |
71 | inline int nodeNum() const {
72 | return nodes_used;
73 | }
74 |
75 | static double apply(const AzReadOnlyVector *v_data,
76 | const AzTreeNodes *nodes,
77 | AzIntArr *ia_node=NULL);
78 |
79 | double apply(const AzReadOnlyVector *v_data,
80 | AzIntArr *ia_node=NULL) const {
81 | checkNodes("apply");
82 | return apply(v_data, this, ia_node);
83 | }
84 |
85 | void show(const AzSvFeatInfo *feat, const AzOut &out,
86 | const char *header="") const;
87 | int leafNum() const;
88 | void clean_up();
89 |
90 | inline const AzTreeNode *node(int nx) const {
91 | checkNode(nx, "point");
92 | return &nodes[nx];
93 | }
94 | inline int root() const {
95 | return root_nx;
96 | }
97 | void finfo(AzIFarr *ifa_fx_count,
98 | AzIFarr *ifa_fx_sum) const; /* appended */
99 | void finfo(AzIntArr *ia_fxs) const; /* appended */
100 |
101 | void cooccurrences(AzIIFarr *iifa_fx1_fx2_count) const;
102 |
103 | virtual void genDesc(const AzSvFeatInfo *feat,
104 | int nx,
105 | AzBytArr *s) /* output */
106 | const;
107 |
108 | protected:
109 | /*--- functions ---*/
110 | void _read(AzFile *file);
111 |
112 | inline void checkNode(int nx, const char *eyec) const {
113 | if (nodes == NULL || nx < 0 || nx >= nodes_used) {
114 | throw new AzException(eyec, "AzTree, nx is out of range");
115 | }
116 | }
117 | inline void checkNodes(const char *msg) const {
118 | if (nodes == NULL && nodes_used > 0) {
119 | throw new AzException("AzTree, no nodes", msg);
120 | }
121 | }
122 | void _show(const AzSvFeatInfo *feat,
123 | int nx,
124 | int depth,
125 | const AzOut &out) const;
126 |
127 | void _release();
128 | virtual void _genDesc(const AzSvFeatInfo *feat,
129 | int nx,
130 | AzBytArr *s) /* output */
131 | const;
132 | };
133 |
134 | #endif
135 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTreeEnsemble.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTreeEnsemble.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TREE_ENSEMBLE_HPP_
20 | #define _AZ_TREE_ENSEMBLE_HPP_
21 |
22 | #include "AzUtil.hpp"
23 | #include "AzTree.hpp"
24 | #include "AzDmat.hpp"
25 | #include "AzTE_ModelInfo.hpp"
26 |
27 | //! Untrainable tree ensemble. Like applier. Generated as a result of training.
28 | class AzTreeEnsemble
29 | {
30 | protected:
31 | AzObjPtrArray a_tree;
32 | AzTree **t;
33 | int t_num;
34 | AzTree empty_tree;
35 | double const_val;
36 |
37 | AzBytArr s_config, s_sign;
38 | int org_dim; /* dimension of original features */
39 |
40 | public:
41 | AzTreeEnsemble() : t(NULL), t_num(0), const_val(0), org_dim(-1) {}
42 | ~AzTreeEnsemble() {}
43 |
44 | AzTreeEnsemble(const char *fn)
45 | : t(NULL), t_num(0), const_val(0), org_dim(-1) {
46 | read(fn);
47 | }
48 | AzTreeEnsemble(AzFile *file)
49 | : t(NULL), t_num(0), const_val(0), org_dim(-1) {
50 | _read(file);
51 | }
52 |
53 | void transfer_from(AzTree *inp_tree[], /* destroys input */
54 | int inp_tree_num,
55 | double const_val,
56 | int orgdim,
57 | const char *config,
58 | const char *sign);
59 |
60 | void read(const char *fn);
61 | void read(AzFile *file) {
62 | _release();
63 | _read(file);
64 | }
65 | int write(const char *fn);
66 | int write(AzFile *file);
67 |
68 | inline void destroy() {
69 | _release();
70 | }
71 |
72 | inline const AzTree *tree(int tx) const {
73 | checkIndex(tx, "tree");
74 | if (t[tx] == NULL) {
75 | return &empty_tree;
76 | }
77 | return t[tx];
78 | }
79 |
80 | int leafNum() const {
81 | return leafNum(0, t_num);
82 | }
83 | int leafNum(int tx0, int tx1) const;
84 | inline int size() const { return t_num; }
85 |
86 | void apply(const AzSmat *m_data,
87 | AzDvect *v_pred) /* output */
88 | const;
89 | double apply(const AzSvect *v_data) const;
90 |
91 | inline double constant() const { return const_val; }
92 | inline int orgdim() const { return org_dim; }
93 | const char *signature() const { return s_sign.c_str(); }
94 | const char *configuration() const { return s_config.c_str(); }
95 |
96 | /*---*/
97 | void info(AzTE_ModelInfo *out_info) const;
98 |
99 | void show(const AzSvFeatInfo *feat, //!< may be NULL
100 | const AzOut &out, const char *header="") const;
101 | void finfo(AzIFarr *ifa_fx_count,
102 | AzIFarr *ifa_fx_sum) const {
103 | finfo(0, t_num, ifa_fx_count, ifa_fx_sum);
104 | }
105 | void finfo(int tx0, int tx1,
106 | AzIFarr *ifa_fx_count,
107 | AzIFarr *ifa_fx_sum) const;
108 | void finfo(AzIntArr *ia_fx2tx) const;
109 | void cooccurrences(AzIIFarr *iifa_fx1_fx2_count) const;
110 |
111 | void show_weights(const AzOut &out, AzSvFeatInfo *fi) const;
112 |
113 | protected:
114 | void _read(AzFile *file);
115 | inline void _release() {
116 | a_tree.free(&t); t_num = 0;
117 | s_config.reset();
118 | s_sign.reset();
119 | const_val = 0;
120 | org_dim = -1;
121 | }
122 | inline void checkIndex(int tx, const char *msg) const {
123 | if (tx < 0 || tx >= t_num) {
124 | throw new AzException("AzTreeEnsemble::checkIndex", msg);
125 | }
126 | }
127 | void clean_up();
128 | };
129 | #endif
130 |
131 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTreeNodes.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTreeNodes.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TREE_NODES_HPP_
20 | #define _AZ_TREE_NODES_HPP_
21 |
22 | #include "AzUtil.hpp"
23 |
24 | /*! Tree node */
25 | class AzTreeNode {
26 | public:
27 | int fx; //!< feature id
28 | double border_val;
29 | int le_nx; //!< x[fx] <= border_val
30 | int gt_nx; //!< x[fx] > border_val
31 | int parent_nx; //!< pointing parent node
32 | double weight; //!< weight
33 |
34 | /*--- ---*/
35 | AzTreeNode() {
36 | reset();
37 | }
38 | void reset() {
39 | border_val = weight = 0;
40 | fx = le_nx = gt_nx = parent_nx = -1;
41 | }
42 | AzTreeNode(AzFile *file) {
43 | read(file);
44 | }
45 | inline bool isLeaf() const {
46 | if (le_nx < 0) return true;
47 | return false;
48 | }
49 | int write(AzFile *file);
50 | void read(AzFile *file);
51 |
52 | void transfer_from(AzTreeNode *inp) {
53 | *this = *inp;
54 | }
55 | };
56 |
57 | class AzTreeNodes {
58 | public:
59 | virtual const AzTreeNode *node(int nx) const = 0;
60 | virtual int nodeNum() const = 0;
61 | virtual int root() const = 0;
62 | };
63 | #endif
64 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/AzTreeRule.hpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * AzTreeRule.hpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #ifndef _AZ_TREE_RULE_HPP_
20 | #define _AZ_TREE_RULE_HPP_
21 |
22 | #include "AzUtil.hpp"
23 |
24 | class AzTreeRule {
25 | protected:
26 | AzBytArr ba;
27 |
28 | public:
29 | inline void reset() {
30 | ba.clear();
31 | }
32 | inline const AzBytArr *bytarr() {
33 | return &ba;
34 | }
35 | inline void reset(const AzTreeRule *inp) {
36 | ba.reset();
37 | if (inp == NULL) return;
38 | ba.reset(&inp->ba);
39 | }
40 | inline void append(int fx,
41 | bool isLE,
42 | double border_val)
43 | {
44 | /*--- feat#, isLE, border_val ---*/
45 | ba.concat((AzByte *)(&fx), sizeof(fx));
46 | ba.concat((AzByte *)(&isLE), sizeof(isLE));
47 | ba.concat((AzByte *)(&border_val), sizeof(border_val));
48 | }
49 | inline void append(const AzTreeRule *inp) {
50 | if (inp != NULL) {
51 | ba.concat(&inp->ba);
52 | }
53 | }
54 | inline void finalize() {
55 | if (ba.getLen() == 0) {
56 | ba.concat('_'); /* root node (CONST) */
57 | }
58 | }
59 | inline const AzBytArr *byteArr() {
60 | return &ba;
61 | }
62 | };
63 | #endif
64 |
--------------------------------------------------------------------------------
/rgf1.2/src/tet/driv_rgf.cpp:
--------------------------------------------------------------------------------
1 | /* * * * *
2 | * driv_rgf.cpp
3 | * Copyright (C) 2011, 2012 Rie Johnson
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | * * * * */
18 |
19 | #define _AZ_MAIN_
20 | #include "AzUtil.hpp"
21 | #include "AzTETmain.hpp"
22 | #include "AzRgfTrainerSel.hpp"
23 | #include "AzTET_Eval_Dflt.hpp"
24 | #include "AzHelp.hpp"
25 |
26 | /*-----------------------------------------------------------------*/
27 | void help(int argc, const char *argv[])
28 | {
29 | cout << "Arguments: action parameters" <getMessage() << endl;
100 | return -1;
101 | }
102 |
103 | return 0;
104 | }
105 |
106 |
--------------------------------------------------------------------------------
/rgf1.2/test/call_exe.pl:
--------------------------------------------------------------------------------
1 |
2 | use strict 'vars';
3 |
4 | my $dlm = ',';
5 | my $inp_ext = '.inp';
6 |
7 | my $arg_num = $#ARGV + 1;
8 | if ($arg_num != 3) {
9 | print "Arguments: exe train|predict|train_test cfg_fn\n";
10 | print " exe : Name of the executable. Typically, ../bin/rgf \n";
11 | print " train|predict|train_test \n";
12 | print " train ... Train and save models to files. \n";
13 | print " predict ... Apply a model to new data. \n";
14 | print " train_test ... Train and test models in one call. \n";
15 | print " cfg_fn: Path to a configuration file without extension; e.g., sample/train \n";
16 | print " The file extension should be \".inp\"\n";
17 | print " For example, see sample/train.inp and sample/predict.inp.\n";
18 | print "\n";
19 | print " To get help on the parameters that can be used in configuration\n";
20 | print " files, call the executable with train|predict|train_test, e.g.,\n";
21 | print "\n";
22 | print " ..\\bin\\rgf train \n";
23 | print " ../bin/rgf predict \n";
24 | exit 1;
25 | }
26 |
27 | my $argx = 0;
28 | my $exe = $ARGV[$argx++];
29 | my $action = $ARGV[$argx++];
30 | my $inp_fn = $ARGV[$argx++];
31 |
32 | $inp_fn .= $inp_ext;
33 |
34 | my @list = &readList_inc($inp_fn);
35 | my $num = $#list + 1;
36 |
37 | my $config = "";
38 |
39 | my $repeat_num = 0;
40 | my(@repeat);
41 | my $ix;
42 | for ($ix = 0; $ix < $num; ++$ix) {
43 | my $kw = "";
44 | if ($list[$ix] =~ /^\@(.*)$/) {
45 | $repeat[$repeat_num++] = $1;
46 | }
47 | else {
48 | &concatKwVal($list[$ix]);
49 | }
50 | }
51 |
52 |
53 | if ($repeat_num == 0) {
54 | my $cmd = "$exe $action $config";
55 | print "$cmd\n";
56 | system($cmd);
57 | }
58 | else {
59 | my $rx;
60 | for ($rx = 0; $rx < $repeat_num; ++$rx) {
61 | my $my_config = $repeat[$rx] . ',' . $config;
62 |
63 | my $cmd = "$exe $action $my_config";
64 | print "$cmd\n";
65 | system($cmd);
66 | my $ret = $?;
67 | if ($ret != 0) {
68 | print "system() failed; exit call_exe.pl\n";
69 | exit 1;
70 | }
71 | }
72 | }
73 |
74 | exit 0;
75 |
76 | ##-----------------------------------
77 | ## if there is a line "#include filename" ...
78 | sub readList_inc {
79 | my($fn) = @_;
80 | my(@out);
81 |
82 | my @list = &readList($fn);
83 |
84 | my $num = $#list + 1;
85 | my $out_ix = 0;
86 | my $ix;
87 | for ($ix = 0; $ix < $num; ++$ix) {
88 | if ($list[$ix] =~ /^\include\s+(\S+)$/) {
89 | my $inc_fn = $1;
90 | my(@list1);
91 | @list1 = &readList($inc_fn);
92 | my $num1 = $#list1 + 1;
93 | my $jx;
94 | for ($jx = 0; $jx < $num1; ++$jx) {
95 | if ($list1[$jx] =~ /^include/) {
96 | print "Can't nest include: $inc_fn\n";
97 | exit 1;
98 | }
99 | $out[$out_ix++] = &checkInputLine($list1[$jx]);
100 | }
101 | }
102 | else {
103 | $out[$out_ix++] = &checkInputLine($list[$ix]);
104 | }
105 | }
106 | return @out;
107 | }
108 |
109 | ##------
110 | sub concatKwVal {
111 | my($inp) = @_;
112 | if ($inp =~ /\S/) {
113 | if ($config ne "") {
114 | $config .= $dlm;
115 | }
116 | $config .= $inp;
117 | }
118 | }
119 |
120 | #####
121 | sub readList {
122 | my($lst_fn) = @_;
123 | my(@item);
124 |
125 | if (!open(LST, "$lst_fn")) {
126 | print "Can't open $lst_fn\n";
127 | exit 1;
128 | }
129 |
130 | my $num = 0;
131 | while() {
132 | my $line = $_;
133 | chomp $line;
134 | if ($line !~ /^\s*\#/) {
135 | $line = &strip($line);
136 | $item[$num++] = $line;
137 | }
138 | }
139 |
140 | close(LST);
141 | return @item;
142 | }
143 |
144 | #####
145 | sub strip {
146 | my($inp) = @_;
147 | my $out = $inp;
148 | if ($inp =~ /^\s*(\S.*\S)\s*$/) {
149 | $out = $1;
150 | }
151 | elsif ($inp =~ /^\s*(\S)\s*$/) {
152 | $out = $1;
153 | }
154 | return $out;
155 | }
156 |
157 | #####
158 | sub checkInputLine {
159 | my($inp) = @_;
160 |
161 | my $param = $inp;
162 | if ($param =~ /^(\S+)\s+(\S+)/) {
163 | $param = $1;
164 | my $comment = $2;
165 | if ($comment !~ /^\#/) {
166 | print "No space is allowed in the parameter. Comments should start with \#\n";
167 | print "Error in: [$inp]\n";
168 | exit;
169 | }
170 | }
171 | if ($param =~ /^([^\#]+)\#/) {
172 | $param = $1;
173 | }
174 | return $param;
175 | }
176 |
--------------------------------------------------------------------------------
/rgf1.2/test/output/sample.model-01:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/test/output/sample.model-01
--------------------------------------------------------------------------------
/rgf1.2/test/output/sample.model-02:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/test/output/sample.model-02
--------------------------------------------------------------------------------
/rgf1.2/test/output/sample.model-03:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/test/output/sample.model-03
--------------------------------------------------------------------------------
/rgf1.2/test/output/sample.model-04:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/test/output/sample.model-04
--------------------------------------------------------------------------------
/rgf1.2/test/output/sample.model-05:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/rgf1.2/test/output/sample.model-05
--------------------------------------------------------------------------------
/rgf1.2/test/sample/predict.inp:
--------------------------------------------------------------------------------
1 | #### sample input to "predict"
2 |
3 | #--- apply a model to sample test data and save prediction values
4 | test_x_fn=sample/test.data.x # Test data points
5 | model_fn=output/sample.model-03 # Model
6 | prediction_fn=output/sample.pred # Where to write predictions
7 |
--------------------------------------------------------------------------------
/rgf1.2/test/sample/regress_train_test.inp:
--------------------------------------------------------------------------------
1 | # To use this example configuration file:
2 | # Set the current directory to rgf_v1/test.
3 | # In the command line, enter:
4 | #
5 | # perl call_exe.pl ../bin/rgf train_test sample/regress_train_test
6 | #
7 |
8 | #------------------ Perform 3 runs --------------------#
9 | @reg_L2=1,model_fn_prefix=output/regress.lam1.model # used by 1st run
10 | @reg_L2=0.1,model_fn_prefix=output/regress.lam0.1.model # used by 2nd run
11 | @reg_L2=0.01,model_fn_prefix=output/regress.lam0.01.model # used by 3rd run
12 | #-------------------------------------------------------------------------#
13 |
14 | #--- Other parameters are shared by 3 runs.
15 |
16 | train_x_fn=sample/regress.train.x # Training data points
17 | train_y_fn=sample/regress.train.y # Training targets
18 |
19 | test_x_fn=sample/regress.test.x # Test data points
20 | test_y_fn=sample/regress.test.y # Test targets
21 |
22 | algorithm=RGF # RGF with L2 regularization on leaf-only models
23 | loss=LS # Square loss
24 | test_interval=500 # Test (and save) models every time 500 leaves are added.
25 | max_leaf_forest=5000 # Stop training when #leaf reaches 5000.
26 | Verbose # Display info during training.
27 | NormalizeTarget #
28 |
29 | #train_w_fn=?? # User-specified weights of data points.
30 | #model_fn_for_warmstart=?? # Path to the model file to do warm-start with
31 |
--------------------------------------------------------------------------------
/rgf1.2/test/sample/test.data.x:
--------------------------------------------------------------------------------
1 | 79 93 97 99 68 14 94 90 3 34
2 | 41 40 88 38 87 71 34 82 18 91
3 | 9 73 66 12 88 84 26 21 74 48
4 | 52 82 75 67 98 46 95 5 49 17
5 | 15 67 89 80 73 61 90 19 9 72
6 | 15 25 11 6 54 95 7 37 83 57
7 | 63 2 96 90 16 99 68 99 27 67
8 | 39 45 84 4 61 27 64 60 90 12
9 | 40 80 6 51 50 78 48 20 0 39
10 | 40 35 42 9 25 43 95 57 56 34
11 | 63 91 35 79 7 68 65 16 93 73
12 | 79 31 18 18 86 78 13 50 25 26
13 | 78 73 1 4 83 85 83 29 38 93
14 | 67 68 47 43 47 57 11 73 97 49
15 | 37 95 20 7 68 47 0 61 33 16
16 | 93 7 50 75 18 91 35 78 86 29
17 | 64 7 86 25 0 14 85 31 52 10
18 | 5 74 47 57 50 33 68 73 16 26
19 | 89 33 0 47 51 93 34 61 21 69
20 | 86 79 21 52 28 68 89 63 66 84
21 | 47 87 56 44 55 89 1 92 68 53
22 | 49 41 33 29 61 92 45 27 79 57
23 | 81 36 2 74 50 90 84 34 58 48
24 | 6 53 74 92 7 19 74 86 17 89
25 | 72 63 77 47 1 61 29 3 92 25
26 | 99 89 64 88 95 67 22 70 73 2
27 | 63 92 61 66 8 34 99 34 54 76
28 | 70 88 46 56 54 60 6 51 28 74
29 | 13 83 63 49 97 6 89 97 70 89
30 | 60 93 73 19 54 10 46 99 25 98
31 | 66 51 75 44 56 38 93 3 35 89
32 | 50 85 55 81 91 30 84 64 52 33
33 | 16 57 37 31 99 49 46 25 22 56
34 | 0 58 51 3 14 34 58 95 94 73
35 | 56 62 23 76 63 48 90 4 42 36
36 | 48 63 86 72 44 19 66 70 1 49
37 | 23 93 59 39 67 24 5 93 34 29
38 | 41 5 86 21 79 74 7 26 90 78
39 | 34 97 16 70 27 41 9 61 72 35
40 | 87 3 37 79 21 73 13 70 36 50
41 | 2 58 37 80 11 88 21 31 21 40
42 | 1 5 32 86 62 23 84 91 11 74
43 | 36 94 72 70 45 28 59 27 69 30
44 | 33 85 70 17 8 49 70 14 86 92
45 | 1 35 92 8 53 41 28 85 96 96
46 | 56 82 79 32 64 87 86 10 1 13
47 | 24 58 91 41 36 86 32 5 8 47
48 | 50 93 65 33 83 20 28 33 63 77
49 | 63 15 26 80 39 66 70 85 32 3
50 | 27 21 13 75 41 46 70 64 70 80
51 | 19 2 57 89 35 40 46 12 29 66
52 | 46 6 14 52 98 37 95 10 7 18
53 | 80 73 37 72 84 34 84 93 63 33
54 | 42 49 38 97 86 97 55 79 8 1
55 | 76 40 67 41 72 95 31 64 75 78
56 | 87 62 25 99 14 37 26 85 24 58
57 | 68 88 83 93 82 11 91 88 31 85
58 | 15 3 66 23 80 5 89 16 39 14
59 | 48 84 66 86 19 76 67 53 10 78
60 | 1 12 36 73 19 99 42 89 58 27
61 | 65 80 65 49 74 83 11 15 9 67
62 | 88 79 65 6 96 14 78 50 88 71
63 | 48 64 26 48 7 84 56 32 22 77
64 | 44 1 88 99 43 72 53 84 53 11
65 | 77 16 31 90 50 22 4 97 47 80
66 | 21 13 69 90 13 85 89 30 59 59
67 | 74 84 9 87 25 26 56 15 2 91
68 | 41 52 60 15 76 56 43 42 51 40
69 | 66 6 98 71 40 70 12 64 83 1
70 | 44 25 23 22 48 89 27 61 30 27
71 | 61 34 74 95 23 96 79 88 93 9
72 | 69 74 65 56 6 54 45 51 8 85
73 | 48 6 47 81 76 64 31 30 42 20
74 | 7 67 86 74 91 83 56 7 56 43
75 | 64 63 26 88 42 34 14 2 77 45
76 | 93 35 40 1 55 52 49 84 24 95
77 | 78 67 89 64 10 8 2 85 81 49
78 | 2 48 58 12 94 58 52 96 15 34
79 | 6 60 90 62 12 59 13 70 71 55
80 | 36 91 76 82 90 90 76 94 81 92
81 | 46 64 70 40 61 60 12 55 79 82
82 | 18 75 83 76 65 39 7 34 13 74
83 | 49 18 7 47 31 62 14 56 72 26
84 | 76 21 93 18 68 62 36 47 58 71
85 | 4 77 90 24 1 33 80 58 69 96
86 | 7 45 99 8 17 54 72 17 23 50
87 | 67 62 73 80 96 53 84 63 56 9
88 | 46 23 34 51 14 22 34 39 40 6
89 | 31 63 27 2 54 28 81 89 85 40
90 | 23 66 36 92 91 64 19 31 98 3
91 | 39 50 93 19 38 21 93 49 65 34
92 | 49 92 24 5 61 42 85 16 85 55
93 | 88 44 80 1 71 39 53 29 56 59
94 | 73 45 83 70 86 90 78 56 84 24
95 | 22 91 15 84 72 37 84 32 12 71
96 | 83 81 25 83 44 34 72 91 52 20
97 | 97 96 62 30 28 38 47 93 48 38
98 | 63 43 14 12 49 37 96 59 57 49
99 | 0 61 0 83 34 67 16 42 52 35
100 | 20 41 25 84 33 69 28 52 70 40
101 |
--------------------------------------------------------------------------------
/rgf1.2/test/sample/test.data.y:
--------------------------------------------------------------------------------
1 | +1
2 | +1
3 | -1
4 | +1
5 | +1
6 | +1
7 | +1
8 | -1
9 | -1
10 | -1
11 | -1
12 | -1
13 | -1
14 | +1
15 | +1
16 | -1
17 | -1
18 | +1
19 | -1
20 | -1
21 | +1
22 | +1
23 | -1
24 | -1
25 | +1
26 | +1
27 | -1
28 | +1
29 | +1
30 | -1
31 | +1
32 | -1
33 | +1
34 | +1
35 | -1
36 | -1
37 | +1
38 | +1
39 | +1
40 | -1
41 | +1
42 | +1
43 | -1
44 | -1
45 | -1
46 | -1
47 | +1
48 | +1
49 | +1
50 | -1
51 | +1
52 | +1
53 | -1
54 | +1
55 | +1
56 | -1
57 | +1
58 | +1
59 | -1
60 | +1
61 | +1
62 | -1
63 | -1
64 | -1
65 | -1
66 | -1
67 | +1
68 | +1
69 | -1
70 | -1
71 | -1
72 | +1
73 | -1
74 | -1
75 | +1
76 | +1
77 | -1
78 | +1
79 | +1
80 | +1
81 | +1
82 | +1
83 | +1
84 | -1
85 | +1
86 | -1
87 | -1
88 | -1
89 | +1
90 | -1
91 | -1
92 | +1
93 | +1
94 | -1
95 | +1
96 | +1
97 | +1
98 | -1
99 | +1
100 | -1
101 |
--------------------------------------------------------------------------------
/rgf1.2/test/sample/train.data.sparse.x:
--------------------------------------------------------------------------------
1 | sparse 10
2 | 0:58 1:19 2:5 3:92 4:3 5:62 6:26 7:56 8:43 9:35
3 | 0:19 1:61 2:54 3:81 4:8 5:50 6:55 7:83 8:69 9:81
4 | 0:79 1:93 2:49 3:17 4:67 5:70 6:33 7:38 8:81 9:53
5 | 0:66 1:30 2:78 3:67 4:18 5:98 6:72 7:25 8:35 9:53
6 | 0:86 1:34 2:68 3:72 4:40 5:8 6:89 7:7 8:34
7 | 0:9 1:99 2:46 3:86 4:51 5:13 6:74 7:46 8:82 9:62
8 | 0:82 1:99 2:20 3:47 4:46 5:42 6:40 7:66 8:64 9:55
9 | 0:99 1:73 2:59 3:94 4:16 5:16 6:80 7:28 8:11 9:30
10 | 0:96 1:67 2:72 3:16 4:10 5:27 6:79 7:48 8:86 9:54
11 | 0:23 1:3 2:63 3:91 5:97 6:72 7:3 8:57 9:46
12 | 0:25 1:72 2:85 3:34 4:76 5:65 6:68 7:64 8:51 9:61
13 | 0:98 1:34 2:49 3:83 4:44 5:16 6:3 7:3 8:70 9:63
14 | 0:63 1:39 2:88 3:20 4:73 6:52 7:9 8:30 9:57
15 | 0:99 1:91 2:47 3:60 4:71 5:74 6:79 7:13 8:20 9:92
16 | 0:75 1:64 2:9 3:60 4:29 5:66 6:38 7:26 8:39 9:94
17 | 0:36 1:94 2:62 3:55 4:65 5:10 6:42 7:62 8:53 9:35
18 | 0:19 1:94 2:55 3:68 4:65 5:36 6:95 7:21 8:72 9:52
19 | 0:16 1:21 2:90 3:23 4:14 5:8 6:69 7:79 8:18 9:78
20 | 0:44 1:84 2:85 3:32 4:66 5:12 6:26 7:12 8:9 9:81
21 | 0:86 1:19 2:68 3:21 4:33 5:42 6:95 7:46 8:7 9:64
22 | 0:85 1:48 2:92 3:11 4:76 5:56 6:43 7:97 8:60 9:70
23 | 0:68 1:81 2:89 3:89 4:65 5:26 6:87 7:29 8:30 9:7
24 | 0:95 1:93 2:5 3:27 4:68 5:26 6:25 7:28 8:61 9:8
25 | 0:20 1:78 2:44 3:80 4:31 5:93 6:55 7:81 8:69 9:39
26 | 0:91 1:67 2:92 3:15 4:55 5:53 6:27 7:47 8:34 9:9
27 | 0:91 1:15 2:95 3:30 4:32 5:18 6:35 7:13 8:46 9:76
28 | 0:68 1:64 2:89 4:73 5:75 6:6 7:44 8:97 9:84
29 | 0:64 1:8 2:9 3:50 4:45 5:35 6:68 7:24 8:92 9:28
30 | 0:74 1:74 2:7 3:31 4:53 5:28 6:92 7:38 8:17 9:34
31 | 0:94 1:78 2:31 3:96 4:33 5:54 6:50 7:47 8:77 9:44
32 | 0:18 1:11 2:7 3:93 4:26 5:79 6:5 7:70 8:99 9:46
33 | 0:16 1:52 2:3 3:27 4:63 5:56 6:85 7:83 8:81 9:63
34 | 0:55 1:25 2:91 3:67 4:78 5:15 6:47 7:55 8:89 9:87
35 | 0:24 1:16 2:8 3:12 4:44 5:17 6:46 7:30 8:66 9:91
36 | 0:65 1:31 2:95 3:68 4:13 5:16 6:66 7:65 8:60 9:95
37 | 0:31 1:55 2:84 3:70 4:74 5:83 6:45 7:28 8:25 9:12
38 | 0:61 1:3 2:43 3:20 4:5 5:2 6:25 7:50 8:2 9:95
39 | 0:12 1:7 2:92 3:19 4:89 5:76 6:57 7:77 8:74 9:43
40 | 0:85 1:93 2:10 3:31 4:37 5:2 6:71 7:5 8:85 9:13
41 | 0:97 1:20 2:67 3:55 4:53 5:61 6:24 7:93 8:3 9:60
42 | 0:57 1:83 2:81 3:70 4:87 5:23 6:33 7:67 8:82 9:90
43 | 0:65 1:45 2:48 3:39 4:21 5:73 6:68 7:35 8:25 9:58
44 | 0:57 1:45 2:32 3:68 4:66 5:15 6:60 7:51 8:58 9:29
45 | 0:2 1:36 3:48 4:92 5:3 6:98 7:45 8:97 9:26
46 | 0:65 1:87 2:37 3:4 4:82 5:94 6:73 7:81 8:36 9:2
47 | 0:38 1:33 2:2 3:96 4:1 5:68 6:18 7:39 8:86 9:73
48 | 0:46 1:22 2:13 3:26 4:81 5:95 6:88 7:74 8:90 9:95
49 | 0:51 1:92 2:9 3:16 4:8 5:55 6:30 7:87 8:78 9:47
50 | 0:86 1:48 2:60 3:35 4:3 5:6 6:46 7:87 8:71 9:31
51 | 0:12 1:52 2:36 3:48 4:34 5:65 6:46 7:49 8:98 9:21
52 | 0:34 1:43 2:39 3:98 4:78 5:5 6:14 7:80 8:80 9:76
53 | 0:48 1:54 2:96 3:3 4:64 5:40 6:88 7:47 8:9 9:85
54 | 0:74 1:80 2:33 3:87 4:24 5:15 6:95 7:18 8:3 9:44
55 | 0:70 1:57 2:14 3:95 4:75 5:31 7:67 8:49 9:96
56 | 0:1 1:51 2:3 3:42 4:69 5:45 6:6 7:99 8:65 9:87
57 | 0:39 1:16 2:98 3:20 4:35 5:19 6:15 7:99 8:65 9:68
58 | 0:86 1:98 2:72 3:37 4:95 5:53 6:6 7:36 8:64 9:92
59 | 0:87 1:65 2:36 3:39 4:94 5:8 6:5 7:32 8:33
60 | 0:87 1:98 2:52 3:49 4:15 5:36 6:20 7:28 8:60 9:4
61 | 0:84 1:92 2:93 3:82 4:91 5:54 6:84 7:19 8:83 9:84
62 | 0:85 1:26 2:62 3:87 4:41 5:60 6:79 7:8 8:32 9:42
63 | 0:75 1:12 2:49 3:81 4:68 5:87 6:86 7:97 8:13 9:33
64 | 0:88 1:38 2:60 3:89 4:15 5:43 6:37 7:47 8:33 9:96
65 | 0:95 1:1 2:7 3:27 4:20 5:69 6:53 7:1 8:84 9:53
66 | 0:7 1:98 2:50 3:39 4:28 5:70 6:10 7:47 8:6 9:49
67 | 0:90 1:30 2:70 3:6 4:36 5:84 6:73 7:15 8:36 9:22
68 | 0:38 1:48 2:2 3:4 4:39 5:1 6:66 7:60 8:47 9:88
69 | 0:38 1:12 2:46 3:76 4:95 5:60 6:74 7:62 8:72 9:48
70 | 0:91 1:94 2:47 3:24 4:94 5:42 6:29 7:25 8:14 9:13
71 | 0:69 1:63 2:44 3:43 4:28 5:35 6:44 7:70 8:23 9:52
72 | 0:51 1:36 2:8 3:12 4:59 5:96 6:34 7:96 8:53 9:94
73 | 0:16 1:88 2:94 3:11 4:68 5:87 6:88 7:51 8:56 9:83
74 | 0:46 1:1 2:36 3:27 4:3 5:77 6:41 7:94 8:54 9:76
75 | 0:1 1:4 2:97 3:65 4:32 5:38 6:88 7:97 8:11 9:75
76 | 0:9 1:24 2:7 3:63 4:86 5:95 6:88 7:38 8:61 9:66
77 | 0:71 1:55 2:42 3:91 4:56 5:14 6:77 7:82 8:40 9:68
78 | 0:26 1:38 2:62 3:20 4:2 5:27 6:12 7:29 8:68 9:33
79 | 0:21 1:63 2:92 3:54 4:49 5:58 6:75 7:60 8:96 9:2
80 | 0:64 1:22 2:30 3:60 4:64 5:11 6:71 7:59 8:80 9:54
81 | 0:94 1:38 2:43 3:23 4:23 5:81 6:89 7:70 8:25 9:18
82 | 0:28 1:76 2:49 3:71 4:4 5:10 6:13 7:49 8:26 9:33
83 | 0:36 1:43 2:29 3:50 4:95 5:75 6:11 7:5 8:7 9:43
84 | 0:31 1:53 2:47 3:76 4:85 5:21 6:84 7:71 8:63 9:10
85 | 0:51 1:29 2:58 3:8 4:32 5:13 6:89 7:86 8:9 9:99
86 | 0:65 1:58 2:68 3:64 4:32 5:16 6:28 7:35 9:1
87 | 0:9 1:86 2:30 3:39 4:7 5:41 6:28 7:63 8:69 9:13
88 | 0:3 1:96 2:70 3:5 4:13 5:30 6:24 7:82 8:65 9:77
89 | 0:79 1:69 2:83 3:59 4:76 5:81 6:26 7:51 8:98 9:97
90 | 0:17 1:39 2:45 3:14 4:54 5:8 6:54 7:85 8:9 9:1
91 | 0:86 1:44 2:69 3:6 4:2 5:19 6:6 7:58 8:1 9:5
92 | 0:2 1:7 2:8 3:2 4:61 5:76 6:27 7:11 8:11 9:50
93 | 0:73 1:12 2:73 3:43 4:66 5:62 6:26 7:17 8:71 9:85
94 | 0:23 1:99 2:62 3:68 4:33 5:53 6:66 7:30 8:11 9:92
95 | 0:61 1:23 2:43 3:68 4:12 5:71 6:45 7:69 8:43 9:18
96 | 0:93 1:43 2:93 3:85 4:72 5:22 6:78 7:4 8:79 9:88
97 | 0:11 1:41 2:99 3:73 4:46 5:67 6:44 7:50 8:87 9:59
98 | 0:67 1:81 2:50 3:79 4:13 5:84 6:47 7:72 8:42 9:42
99 | 0:98 1:33 2:98 3:89 4:34 5:48 6:19 7:62 8:39 9:89
100 | 0:23 1:47 2:30 3:76 4:88 5:19 6:6 7:96 8:61 9:71
101 | 0:61 1:45 2:70 3:51 4:5 5:26 6:75 7:43 8:68 9:88
102 |
--------------------------------------------------------------------------------
/rgf1.2/test/sample/train.data.x:
--------------------------------------------------------------------------------
1 | 58 19 5 92 3 62 26 56 43 35
2 | 19 61 54 81 8 50 55 83 69 81
3 | 79 93 49 17 67 70 33 38 81 53
4 | 66 30 78 67 18 98 72 25 35 53
5 | 86 34 68 72 40 8 89 7 34 0
6 | 9 99 46 86 51 13 74 46 82 62
7 | 82 99 20 47 46 42 40 66 64 55
8 | 99 73 59 94 16 16 80 28 11 30
9 | 96 67 72 16 10 27 79 48 86 54
10 | 23 3 63 91 0 97 72 3 57 46
11 | 25 72 85 34 76 65 68 64 51 61
12 | 98 34 49 83 44 16 3 3 70 63
13 | 63 39 88 20 73 0 52 9 30 57
14 | 99 91 47 60 71 74 79 13 20 92
15 | 75 64 9 60 29 66 38 26 39 94
16 | 36 94 62 55 65 10 42 62 53 35
17 | 19 94 55 68 65 36 95 21 72 52
18 | 16 21 90 23 14 8 69 79 18 78
19 | 44 84 85 32 66 12 26 12 9 81
20 | 86 19 68 21 33 42 95 46 7 64
21 | 85 48 92 11 76 56 43 97 60 70
22 | 68 81 89 89 65 26 87 29 30 7
23 | 95 93 5 27 68 26 25 28 61 8
24 | 20 78 44 80 31 93 55 81 69 39
25 | 91 67 92 15 55 53 27 47 34 9
26 | 91 15 95 30 32 18 35 13 46 76
27 | 68 64 89 0 73 75 6 44 97 84
28 | 64 8 9 50 45 35 68 24 92 28
29 | 74 74 7 31 53 28 92 38 17 34
30 | 94 78 31 96 33 54 50 47 77 44
31 | 18 11 7 93 26 79 5 70 99 46
32 | 16 52 3 27 63 56 85 83 81 63
33 | 55 25 91 67 78 15 47 55 89 87
34 | 24 16 8 12 44 17 46 30 66 91
35 | 65 31 95 68 13 16 66 65 60 95
36 | 31 55 84 70 74 83 45 28 25 12
37 | 61 3 43 20 5 2 25 50 2 95
38 | 12 7 92 19 89 76 57 77 74 43
39 | 85 93 10 31 37 2 71 5 85 13
40 | 97 20 67 55 53 61 24 93 3 60
41 | 57 83 81 70 87 23 33 67 82 90
42 | 65 45 48 39 21 73 68 35 25 58
43 | 57 45 32 68 66 15 60 51 58 29
44 | 2 36 0 48 92 3 98 45 97 26
45 | 65 87 37 4 82 94 73 81 36 2
46 | 38 33 2 96 1 68 18 39 86 73
47 | 46 22 13 26 81 95 88 74 90 95
48 | 51 92 9 16 8 55 30 87 78 47
49 | 86 48 60 35 3 6 46 87 71 31
50 | 12 52 36 48 34 65 46 49 98 21
51 | 34 43 39 98 78 5 14 80 80 76
52 | 48 54 96 3 64 40 88 47 9 85
53 | 74 80 33 87 24 15 95 18 3 44
54 | 70 57 14 95 75 31 0 67 49 96
55 | 1 51 3 42 69 45 6 99 65 87
56 | 39 16 98 20 35 19 15 99 65 68
57 | 86 98 72 37 95 53 6 36 64 92
58 | 87 65 36 39 94 8 5 32 33 0
59 | 87 98 52 49 15 36 20 28 60 4
60 | 84 92 93 82 91 54 84 19 83 84
61 | 85 26 62 87 41 60 79 8 32 42
62 | 75 12 49 81 68 87 86 97 13 33
63 | 88 38 60 89 15 43 37 47 33 96
64 | 95 1 7 27 20 69 53 1 84 53
65 | 7 98 50 39 28 70 10 47 6 49
66 | 90 30 70 6 36 84 73 15 36 22
67 | 38 48 2 4 39 1 66 60 47 88
68 | 38 12 46 76 95 60 74 62 72 48
69 | 91 94 47 24 94 42 29 25 14 13
70 | 69 63 44 43 28 35 44 70 23 52
71 | 51 36 8 12 59 96 34 96 53 94
72 | 16 88 94 11 68 87 88 51 56 83
73 | 46 1 36 27 3 77 41 94 54 76
74 | 1 4 97 65 32 38 88 97 11 75
75 | 9 24 7 63 86 95 88 38 61 66
76 | 71 55 42 91 56 14 77 82 40 68
77 | 26 38 62 20 2 27 12 29 68 33
78 | 21 63 92 54 49 58 75 60 96 2
79 | 64 22 30 60 64 11 71 59 80 54
80 | 94 38 43 23 23 81 89 70 25 18
81 | 28 76 49 71 4 10 13 49 26 33
82 | 36 43 29 50 95 75 11 5 7 43
83 | 31 53 47 76 85 21 84 71 63 10
84 | 51 29 58 8 32 13 89 86 9 99
85 | 65 58 68 64 32 16 28 35 0 1
86 | 9 86 30 39 7 41 28 63 69 13
87 | 3 96 70 5 13 30 24 82 65 77
88 | 79 69 83 59 76 81 26 51 98 97
89 | 17 39 45 14 54 8 54 85 9 1
90 | 86 44 69 6 2 19 6 58 1 5
91 | 2 7 8 2 61 76 27 11 11 50
92 | 73 12 73 43 66 62 26 17 71 85
93 | 23 99 62 68 33 53 66 30 11 92
94 | 61 23 43 68 12 71 45 69 43 18
95 | 93 43 93 85 72 22 78 4 79 88
96 | 11 41 99 73 46 67 44 50 87 59
97 | 67 81 50 79 13 84 47 72 42 42
98 | 98 33 98 89 34 48 19 62 39 89
99 | 23 47 30 76 88 19 6 96 61 71
100 | 61 45 70 51 5 26 75 43 68 88
101 |
--------------------------------------------------------------------------------
/rgf1.2/test/sample/train.data.y:
--------------------------------------------------------------------------------
1 | -1
2 | -1
3 | -1
4 | -1
5 | -1
6 | -1
7 | +1
8 | -1
9 | -1
10 | -1
11 | -1
12 | -1
13 | +1
14 | +1
15 | -1
16 | +1
17 | -1
18 | -1
19 | -1
20 | -1
21 | +1
22 | -1
23 | -1
24 | -1
25 | +1
26 | -1
27 | -1
28 | -1
29 | -1
30 | -1
31 | -1
32 | +1
33 | -1
34 | -1
35 | -1
36 | +1
37 | -1
38 | +1
39 | -1
40 | +1
41 | -1
42 | +1
43 | -1
44 | -1
45 | +1
46 | -1
47 | +1
48 | +1
49 | -1
50 | +1
51 | +1
52 | -1
53 | +1
54 | +1
55 | +1
56 | -1
57 | +1
58 | +1
59 | -1
60 | -1
61 | -1
62 | -1
63 | -1
64 | -1
65 | +1
66 | +1
67 | -1
68 | +1
69 | +1
70 | -1
71 | -1
72 | -1
73 | +1
74 | +1
75 | -1
76 | -1
77 | -1
78 | -1
79 | +1
80 | +1
81 | +1
82 | +1
83 | +1
84 | -1
85 | -1
86 | +1
87 | +1
88 | -1
89 | -1
90 | -1
91 | -1
92 | -1
93 | -1
94 | -1
95 | -1
96 | -1
97 | -1
98 | -1
99 | +1
100 | -1
101 |
--------------------------------------------------------------------------------
/rgf1.2/test/sample/train.inp:
--------------------------------------------------------------------------------
1 | #### sample input to "train" ####
2 |
3 | train_x_fn=sample/train.data.x # Training data points
4 | train_y_fn=sample/train.data.y # Training targets
5 |
6 | #--- Save the models with filenames output/sample.model-01,
7 | #--- output/sample.model-02,...
8 | model_fn_prefix=output/sample.model
9 |
10 | #--- training parameters.
11 | reg_L2=1 # Regularization parameter.
12 | algorithm=RGF # RGF with L2 regularization with leaf-only models
13 | loss=LS # Square loss
14 | test_interval=100 # Save models every time 100 leaves are added.
15 | max_leaf_forest=500 # Stop training when #leaf reaches 500.
16 | Verbose # Display info during training.
17 |
18 | #--- other parameters (commented out)
19 | #NormalizeTarget # Normalize targets so that the average becomes zero.
20 | #train_w_fn=?? # User-specified weights of data points.
21 | #model_fn_for_warmstart=?? # Path to the model file to do warm-start with
22 |
23 |
--------------------------------------------------------------------------------
/rgf1.2/test/sample/train_predict.inp:
--------------------------------------------------------------------------------
1 | #### sample input to "train_predict" ####
2 |
3 | train_x_fn=sample/train.data.x # Training data points
4 | train_y_fn=sample/train.data.y # Training targets
5 |
6 | test_x_fn=sample/test.data.x # Test data points
7 |
8 | #---
9 | model_fn_prefix=output/m
10 | #--- Models are saved with filenames output/m-01,
11 | #--- output/m-02,...
12 | #--- Predictions are saved with filenames output/m-01.pred,
13 | #--- output/m-02.pred,...
14 | #--- Model info such as #leaf are saved with filenames output/m-01.info,
15 | #--- output/m-02.info,...
16 |
17 | SaveLastModelOnly # Only the last (largest) model will be saved to a file.
18 | # Comment this out if all the models should be saved.
19 |
20 | #--- training parameters
21 | algorithm=RGF # RGF with L2 regularization on leaf-only models
22 | reg_L2=1 # Regularization parameter
23 | loss=LS # Square loss
24 | test_interval=100 # Test (and save) models every time 100 leaves are added.
25 | max_leaf_forest=500 # Stop training when #leaf reaches 500.
26 | Verbose # Display info during training.
27 |
28 | #--- other parameters (commented out)
29 | #NormalizeTarget # Normalize targets so that the average becomes zero.
30 | #train_w_fn=?? # User-specified weights of data points.
31 | #model_fn_for_warmstart=?? # Path to the model file to do warm-start with
32 |
--------------------------------------------------------------------------------
/rgf1.2/test/sample/train_test.inp:
--------------------------------------------------------------------------------
1 | #### sample input to "train_test" ####
2 |
3 | train_x_fn=sample/train.data.x # Training data points
4 | train_y_fn=sample/train.data.y # Training targets
5 |
6 | test_x_fn=sample/test.data.x # Test data points
7 | test_y_fn=sample/test.data.y # Test targets
8 |
9 | evaluation_fn=output/sample.evaluation # Where to write performances
10 |
11 | model_fn_prefix=output/m # Comment this out if models should not be saved.
12 |
13 |
14 | #--- training parameters
15 | algorithm=RGF # RGF with L2 regularization on leaf-only models
16 | reg_L2=1 # Regularization parameter
17 | loss=LS # Square loss
18 | test_interval=100 # Test (and save) models every time 100 leaves are added.
19 | max_leaf_forest=500 # Stop training when #leaf reaches 500.
20 | Verbose # Display info during training.
21 |
22 | #--- other parameters (commented out)
23 | #NormalizeTarget # Normalize targets so that the average becomes zero.
24 | #train_w_fn=?? # User-specified weights of data points.
25 | #model_fn_for_warmstart=?? # Path to the model file to do warm-start with
26 |
--------------------------------------------------------------------------------
/slides_Cambridge_Nov_14.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TimSalimans/HiggsML/86fad61d392e54e6beca5e90066d2137ebda8a0b/slides_Cambridge_Nov_14.pdf
--------------------------------------------------------------------------------