├── .ci ├── python_tests.sh ├── python_tests_windows.ps1 ├── r_tests.sh └── r_tests_windows.ps1 ├── .github ├── ISSUE_TEMPLATE.MD └── workflows │ └── main.yml ├── .gitignore ├── AWESOME_RGF.md ├── FastRGF ├── CHANGES.md ├── CMakeLists.txt ├── LICENSE ├── README.md ├── examples │ ├── README.md │ ├── binary_classification │ │ ├── README.md │ │ ├── inputs │ │ │ ├── madelon.test │ │ │ └── madelon.train │ │ ├── outputs │ │ │ └── .gitignore │ │ └── run.sh │ └── regression │ │ ├── README.md │ │ ├── inputs │ │ ├── config │ │ ├── feature.names │ │ ├── housing.test │ │ └── housing.train │ │ ├── outputs │ │ └── .gitignore │ │ └── run.sh ├── include │ ├── classifier.h │ ├── data.h │ ├── discretization.h │ ├── dtree.h │ ├── forest.h │ ├── header.h │ └── utils.h └── src │ ├── base │ ├── CMakeLists.txt │ ├── classifier.cpp │ ├── data.cpp │ ├── discretization.cpp │ └── utils.cpp │ ├── exe │ ├── CMakeLists.txt │ ├── auc.cpp │ ├── discretized_gendata.cpp │ ├── discretized_trainer.cpp │ ├── forest_predict.cpp │ ├── forest_train.cpp │ ├── parser.h │ └── test_output.h │ └── forest │ ├── CMakeLists.txt │ ├── dtree.cpp │ ├── dtree_trainer.cpp │ ├── feature_mapper.h │ ├── forest.cpp │ ├── forest_trainer.h │ ├── node_trainer.h │ └── training_target.h ├── R-package ├── DESCRIPTION ├── LICENSE ├── LICENSE.note ├── NAMESPACE ├── NEWS.md ├── R │ ├── FastRGF_Classifier.R │ ├── FastRGF_Regressor.R │ ├── Internal_class.R │ ├── RGF_Classifier.R │ ├── RGF_Regressor.R │ ├── RGF_cleanup_temp_files.R │ ├── TO_scipy_sparse.R │ ├── mat_2scipy_sparse.R │ └── package.R ├── README.md ├── inst │ └── CITATION ├── man │ ├── FastRGF_Classifier.Rd │ ├── FastRGF_Regressor.Rd │ ├── Internal_class.Rd │ ├── RGF_Classifier.Rd │ ├── RGF_Regressor.Rd │ ├── RGF_cleanup_temp_files.Rd │ ├── TO_scipy_sparse.Rd │ └── mat_2scipy_sparse.Rd ├── tests │ ├── testthat.R │ └── testthat │ │ ├── helper-init.R │ │ ├── helper-skip.R │ │ ├── setup.R │ │ ├── teardown.R │ │ └── test-RGF_package.R └── vignettes │ └── the_RGF_package.Rmd ├── README.md ├── RGF ├── CHANGES.md ├── CMakeLists.txt ├── COPYING ├── README.md ├── Windows │ └── rgf │ │ ├── rgf.sln │ │ └── rgf.vcxproj ├── build │ └── makefile ├── examples │ ├── call_exe.pl │ └── sample │ │ ├── predict.inp │ │ ├── regress.test.x │ │ ├── regress.test.y │ │ ├── regress.train.x │ │ ├── regress.train.y │ │ ├── regress_train_test.inp │ │ ├── test.data.x │ │ ├── test.data.y │ │ ├── train.data.sparse.x │ │ ├── train.data.x │ │ ├── train.data.y │ │ ├── train.inp │ │ ├── train_predict.inp │ │ └── train_test.inp ├── rgf-guide.rst └── src │ ├── com │ ├── AzBmat.hpp │ ├── AzDmat.cpp │ ├── AzDmat.hpp │ ├── AzException.hpp │ ├── AzHelp.hpp │ ├── AzIntPool.cpp │ ├── AzIntPool.hpp │ ├── AzLoss.cpp │ ├── AzLoss.hpp │ ├── AzMemTempl.hpp │ ├── AzOut.hpp │ ├── AzParam.cpp │ ├── AzParam.hpp │ ├── AzPerfResult.hpp │ ├── AzPrint.hpp │ ├── AzReadOnlyMatrix.hpp │ ├── AzSmat.cpp │ ├── AzSmat.hpp │ ├── AzStrArray.hpp │ ├── AzStrPool.cpp │ ├── AzStrPool.hpp │ ├── AzSvDataS.cpp │ ├── AzSvDataS.hpp │ ├── AzSvFeatInfo.hpp │ ├── AzSvFeatInfoClone.hpp │ ├── AzTaskTools.cpp │ ├── AzTaskTools.hpp │ ├── AzTimer.hpp │ ├── AzTools.cpp │ ├── AzTools.hpp │ ├── AzUtil.cpp │ └── AzUtil.hpp │ └── tet │ ├── AzDataForTrTree.hpp │ ├── AzFindSplit.cpp │ ├── AzFindSplit.hpp │ ├── AzFsinfo.hpp │ ├── AzOptOnTree.cpp │ ├── AzOptOnTree.hpp │ ├── AzOptOnTree_TreeReg.cpp │ ├── AzOptOnTree_TreeReg.hpp │ ├── AzOptimizerT.hpp │ ├── AzRegDepth.hpp │ ├── AzReg_TreeReg.hpp │ ├── AzReg_TreeRegArr.hpp │ ├── AzReg_TreeRegArrImp.hpp │ ├── AzReg_TsrOpt.cpp │ ├── AzReg_TsrOpt.hpp │ ├── AzReg_TsrSib.cpp │ ├── AzReg_TsrSib.hpp │ ├── AzReg_Tsrbase.cpp │ ├── AzReg_Tsrbase.hpp │ ├── AzRgfTrainerSel.hpp │ ├── AzRgfTree.cpp │ ├── AzRgfTree.hpp │ ├── AzRgfTreeEnsImp.hpp │ ├── AzRgfTreeEnsemble.hpp │ ├── AzRgf_FindSplit.hpp │ ├── AzRgf_FindSplit_Dflt.cpp │ ├── AzRgf_FindSplit_Dflt.hpp │ ├── AzRgf_FindSplit_TreeReg.cpp │ ├── AzRgf_FindSplit_TreeReg.hpp │ ├── AzRgf_Optimizer.hpp │ ├── AzRgf_Optimizer_Dflt.cpp │ ├── AzRgf_Optimizer_Dflt.hpp │ ├── AzRgf_Optimizer_TreeReg.hpp │ ├── AzRgf_kw.hpp │ ├── AzRgforest.cpp │ ├── AzRgforest.hpp │ ├── AzRgforest_TreeReg.hpp │ ├── AzSortedFeat.cpp │ ├── AzSortedFeat.hpp │ ├── AzTET_Eval.hpp │ ├── AzTET_Eval_Dflt.hpp │ ├── AzTETmain.cpp │ ├── AzTETmain.hpp │ ├── AzTETmain_kw.hpp │ ├── AzTETproc.cpp │ ├── AzTETproc.hpp │ ├── AzTETrainer.hpp │ ├── AzTETselector.hpp │ ├── AzTE_ModelInfo.hpp │ ├── AzTrTree.cpp │ ├── AzTrTree.hpp │ ├── AzTrTreeEnsemble.hpp │ ├── AzTrTreeEnsemble_ReadOnly.hpp │ ├── AzTrTreeFeat.cpp │ ├── AzTrTreeFeat.hpp │ ├── AzTrTreeNode.hpp │ ├── AzTrTree_ReadOnly.hpp │ ├── AzTrTsplit.hpp │ ├── AzTrTtarget.hpp │ ├── AzTree.cpp │ ├── AzTree.hpp │ ├── AzTreeEnsemble.cpp │ ├── AzTreeEnsemble.hpp │ ├── AzTreeNodes.hpp │ ├── AzTreeRule.hpp │ └── driv_rgf.cpp ├── codecov.yml └── python-package ├── LICENSE ├── MANIFEST.in ├── Readme.rst ├── docker └── Dockerfile ├── examples ├── FastRGF │ ├── FastRGF_classifier_on_iris_dataset.py │ └── FastRGF_regressor_on_boston_dataset.py └── RGF │ ├── classification_on_iris_dataset.ipynb │ ├── comparison_RGF_and_GBM_classifiers_on_iris_dataset.py │ ├── comparison_RGF_and_RF_regressors_on_diabetes_dataset.py │ └── regression_on_diabetes_dataset.ipynb ├── rgf ├── VERSION ├── __init__.py ├── fastrgf_model.py ├── rgf_model.py ├── sklearn.py └── utils.py ├── setup.py └── tests ├── test_examples.py └── test_rgf_python.py /.ci/python_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $OS_NAME == "macos-latest" ]]; then 4 | brew install gcc 5 | curl -sL https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o miniconda.sh 6 | else 7 | apt-get update 8 | apt-get install --no-install-recommends -y \ 9 | apt-transport-https \ 10 | build-essential \ 11 | ca-certificates \ 12 | cmake \ 13 | curl \ 14 | gnupg-agent \ 15 | software-properties-common 16 | if [[ $TASK != "R_PACKAGE" ]]; then 17 | add-apt-repository -y ppa:ubuntu-toolchain-r/test 18 | apt-get update 19 | apt-get install --no-install-recommends -y "g++-$GCC_VER_LINUX" 20 | export CXX="g++-$GCC_VER_LINUX" && export CC="gcc-$GCC_VER_LINUX" 21 | fi 22 | curl -sL https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o miniconda.sh 23 | fi 24 | bash miniconda.sh -b -p $CONDA_PATH 25 | conda config --set always_yes yes --set changeps1 no 26 | conda update -q conda 27 | if [[ $TASK == "R_PACKAGE" ]]; then 28 | conda create -q -n $CONDA_ENV python=$PYTHON_VERSION pip openssl libffi zlib --no-deps 29 | source activate $CONDA_ENV 30 | pip install setuptools joblib numpy scikit-learn scipy pandas wheel 31 | else 32 | conda create -q -n $CONDA_ENV python=$PYTHON_VERSION joblib numpy scikit-learn scipy pandas pytest 33 | source activate $CONDA_ENV 34 | fi 35 | cd $GITHUB_WORKSPACE/python-package 36 | python setup.py sdist --formats gztar || exit -1 37 | pip install dist/rgf_python-$RGF_VER.tar.gz -v || exit -1 38 | if [[ $TASK != "R_PACKAGE" ]]; then 39 | pytest tests/ -v || exit -1 40 | fi 41 | -------------------------------------------------------------------------------- /.ci/python_tests_windows.ps1: -------------------------------------------------------------------------------- 1 | function Check-Output { 2 | param( [bool]$Success ) 3 | if (!$Success) { 4 | $host.SetShouldExit(-1) 5 | Exit -1 6 | } 7 | } 8 | 9 | 10 | $ProgressPreference = "SilentlyContinue" # progress bar bug extremely slows down download speed 11 | $InstallerName = "$env:GITHUB_WORKSPACE\Miniconda3-latest-Windows-x86_64.exe" 12 | Invoke-WebRequest -Uri "https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe" -OutFile $InstallerName -MaximumRetryCount 5 13 | Start-Process -FilePath $InstallerName -ArgumentList "/InstallationType=JustMe /RegisterPython=0 /S /D=$env:CONDA_PATH" -Wait 14 | conda init powershell 15 | Invoke-Expression -Command "$env:USERPROFILE\Documents\WindowsPowerShell\profile.ps1" 16 | conda config --set always_yes yes --set changeps1 no 17 | conda update -q conda 18 | conda create -q -n $env:CONDA_ENV python=$env:PYTHON_VERSION joblib numpy scikit-learn scipy pandas pytest 19 | conda activate $env:CONDA_ENV 20 | cd $env:GITHUB_WORKSPACE\python-package 21 | python setup.py sdist --formats gztar ; Check-Output $? 22 | pip install dist\rgf_python-$env:RGF_VER.tar.gz -v ; Check-Output $? 23 | if ($env:TASK -ne "R_PACKAGE") { 24 | pytest tests -v ; Check-Output $? 25 | } 26 | -------------------------------------------------------------------------------- /.ci/r_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate $CONDA_ENV 4 | 5 | cd $GITHUB_WORKSPACE/R-package 6 | if [[ $OS_NAME == "macos-latest" ]]; then 7 | brew install r qpdf pandoc 8 | brew install --cask basictex 9 | export PATH="/Library/TeX/texbin:$PATH" 10 | sudo tlmgr --verify-repo=none update --self 11 | sudo tlmgr --verify-repo=none install inconsolata helvetic 12 | 13 | echo 'options(pkgType = "mac.binary")' > .Rprofile 14 | echo 'options(install.packages.check.source = "no")' >> .Rprofile 15 | else 16 | tlmgr --verify-repo=none update --self 17 | tlmgr --verify-repo=none install ec hyperref iftex infwarerr kvoptions pdftexcmds 18 | 19 | echo "Sys.setenv(RETICULATE_PYTHON = '$CONDA_PREFIX/bin/python')" >> .Rprofile 20 | fi 21 | 22 | R_LIB_PATH=$HOME/R 23 | mkdir -p $R_LIB_PATH 24 | echo "R_LIBS=$R_LIB_PATH" > .Renviron 25 | 26 | # ignore R CMD CHECK NOTE checking how long it has 27 | # been since the last submission 28 | export _R_CHECK_CRAN_INCOMING_=0 29 | export _R_CHECK_CRAN_INCOMING_REMOTE_=0 30 | 31 | # increase the allowed time to run the examples 32 | export _R_CHECK_EXAMPLE_TIMING_THRESHOLD_=30 33 | 34 | # fix the 'unable to verify current time' NOTE 35 | # see: https://stackoverflow.com/a/63837547/8302386 36 | export _R_CHECK_SYSTEM_CLOCK_=0 37 | 38 | if [[ $OS_NAME == "macos-latest" ]]; then 39 | Rscript -e "install.packages('devtools', dependencies = TRUE, repos = 'https://cran.r-project.org')" 40 | fi 41 | Rscript -e 'devtools::install_deps(pkg = ".", dependencies = TRUE)' 42 | 43 | R CMD build . || exit -1 44 | 45 | PKG_FILE_NAME=$(ls -1t *.tar.gz | head -n 1) 46 | PKG_NAME="${PKG_FILE_NAME%%_*}" 47 | LOG_FILE_NAME="$PKG_NAME.Rcheck/00check.log" 48 | COVERAGE_FILE_NAME="$PKG_NAME.Rcheck/coverage.log" 49 | 50 | R CMD check "${PKG_FILE_NAME}" --as-cran || exit -1 51 | if grep -q -E "NOTE|WARNING|ERROR" "$LOG_FILE_NAME"; then 52 | echo "NOTEs, WARNINGs or ERRORs have been found by R CMD check" 53 | exit -1 54 | fi 55 | 56 | Rscript -e 'covr::codecov(quiet = FALSE)' 2>&1 | tee "$COVERAGE_FILE_NAME" 57 | if [[ "$(grep -R "RGF Coverage:" $COVERAGE_FILE_NAME | rev | cut -d" " -f1 | rev | cut -d"." -f1)" -le 50 ]]; then 58 | echo "Code coverage is extremely small!" 59 | exit -1 60 | fi 61 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.MD: -------------------------------------------------------------------------------- 1 | For bugs and unexpected issues, please provide the following information, so that we could reproduce them on our system. 2 | 3 | ## Environment Info 4 | 5 | Operating System: 6 | 7 | RGF/FastRGF/rgf_python version: 8 | 9 | Python version (for rgf_python errors): 10 | 11 | ## Error Message 12 | 13 | 14 | 15 | ## Reproducible Example 16 | 17 | 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # miscellaneous stuff 2 | .DS_Store 3 | */.DS_Store 4 | 5 | # build artifacts 6 | CMakeCache.txt 7 | CMakeFiles/ 8 | cmake_install.cmake 9 | 10 | # Python package stuff 11 | temp/ 12 | *.pyc 13 | rgf_python.egg-info/ 14 | dist/ 15 | build/* 16 | python-package/include/rgf/bin/* 17 | python-package/include/fast_rgf/bin/* 18 | *.pkl 19 | python-package/build/ 20 | python-package/compile/ 21 | python-package/temp_* 22 | python-package/.cache/ 23 | 24 | # R package stuff 25 | R-package/*.tar.gz 26 | *.tar.gz 27 | *.Rhistory 28 | -------------------------------------------------------------------------------- /AWESOME_RGF.md: -------------------------------------------------------------------------------- 1 | # Awesome RGF 2 | 3 | This page contains a curated list of blogs, competitions solutions about RGF. 4 | 5 | Please feel free to send a pull request if you find things that could be included in this document. 6 | 7 | ### Blog Posts 8 | 9 | * [Regularized Greedy Forest (RGF) - Nice alternative to tree-boosting](https://www.linkedin.com/pulse/regularized-greedy-forest-rgf-nice-alternative-marios-michailidis/) 10 | * [An Introductory Guide to Regularized Greedy Forests (RGF) with a case study in Python](https://www.analyticsvidhya.com/blog/2018/02/introductory-guide-regularized-greedy-forests-rgf-python/) 11 | * [Regularized Greedy Forest – The Scottish Play (Act I)](https://www.statworx.com/ch/blog/regularized-greedy-forest-the-scottish-play-act-i/) 12 | * [Regularized Greedy Forest – The Scottish Play (ACT II)](https://www.statworx.com/ch/blog/regularized-greedy-forest-the-scottish-play-act-ii/) 13 | * [The Gradient Boosters II: Regularized Greedy Forest](https://deep-and-shallow.com/2020/02/09/the-gradient-boosters-ii-regularized-greedy-forest/) 14 | 15 | ### Machine Learning Competitions Winning Solutions and Kernels 16 | 17 | * [1st Place Solution (TReNDS Neuroimaging)](https://www.kaggle.com/c/trends-assessment-prediction/discussion/163017) 18 | * [1st Place Solution (Home Credit Default Risk)](https://www.kaggle.com/c/home-credit-default-risk/discussion/64821) 19 | * [1st Place Solution (Allstate Claims Severity)](https://www.kaggle.com/c/allstate-claims-severity/discussion/26416) 20 | * [2nd Place Solution (Home Credit Default Risk)](https://www.kaggle.com/c/home-credit-default-risk/discussion/64722) 21 | * [3rd Place Solution (Santander Customer Satisfaction)](https://www.kaggle.com/c/santander-customer-satisfaction/discussion/20978) 22 | * [13th Place Solution (IEEE-CIS Fraud Detection)](https://www.kaggle.com/c/ieee-fraud-detection/discussion/111485) 23 | * [18th Place Solution (Porto Seguro’s Safe Driver Prediction)](https://www.kaggle.com/c/porto-seguro-safe-driver-prediction/discussion/44579) 24 | * [35th Place Solution (Porto Seguro’s Safe Driver Prediction)](https://www.kaggle.com/c/porto-seguro-safe-driver-prediction/discussion/44711) 25 | * [Kernel (Porto Seguro’s Safe Driver Prediction)](https://www.kaggle.com/scirpus/regularized-greedy-forest/notebook) 26 | -------------------------------------------------------------------------------- /FastRGF/CHANGES.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## 0.7 (Feb 2020) 4 | 5 | - Added support for compilation with Clang and AppleClang. 6 | 7 | ## 0.6 (Feb 2018) 8 | 9 | - Fixed bug which led to program crash in case of usage of small samples weights. 10 | 11 | ## 0.5 (Sept 2017) 12 | 13 | - Added OpenMP support and multithreading for discretization. 14 | - Added loop unrolling and compilation option for simd optimization. 15 | 16 | ## 0.4 (Aug 2017) 17 | 18 | - Fixed bug which truncated negative float values to `numeric_limits::min()`, causing degration in prediction performance for datasets with negative values; changed truncation to `numeric_limits::lowest()`. 19 | 20 | ## 0.3 (Dec 2016) 21 | 22 | - Fixed several bugs that affect prediction performance (especially for small datasets). 23 | 24 | ## 0.2 (Aug 2016) 25 | 26 | - This is the first release. 27 | 28 | It only supports binary classification and regression, with significant simplifications from the [original RGF algorithm](https://arxiv.org/abs/1109.0887) for speed consideration. 29 | 30 | Additional functionalities will be supported in future releases. 31 | -------------------------------------------------------------------------------- /FastRGF/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(APPLE) 2 | cmake_minimum_required(VERSION 3.0) 3 | else() 4 | cmake_minimum_required(VERSION 2.8) 5 | endif() 6 | 7 | project (FastRGF) 8 | 9 | # whether to use OpenMP (default is ON) 10 | option(OPENMP "Use OpenMP for multithreading" ON) 11 | 12 | set(CMAKE_CXX_FLAGS "-O3 -std=c++11") 13 | #set(CMAKE_CXX_FLAGS "-g -std=c++11 -Wall") 14 | 15 | 16 | if(OPENMP) 17 | message("Use OpenMP for multithreading") 18 | # use OpenMP 19 | add_definitions("-DUSE_OMP") 20 | 21 | if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") 22 | if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "5.0") 23 | message(FATAL_ERROR "Insufficient gcc version") 24 | endif() 25 | elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") 26 | if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "3.8") 27 | message(FATAL_ERROR "Insufficient Clang version") 28 | endif() 29 | elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") 30 | if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "8.1.0") 31 | message(FATAL_ERROR "Insufficient AppleClang version") 32 | endif() 33 | cmake_minimum_required(VERSION 3.16) 34 | endif() 35 | find_package(OpenMP REQUIRED) 36 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 37 | else() 38 | message("Use standard C++11 thread library") 39 | endif() 40 | 41 | if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR "${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") 42 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -ftree-vectorize -ffast-math") 43 | endif() 44 | 45 | if(WIN32 AND MINGW) 46 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -static-libstdc++") 47 | endif() 48 | 49 | message("C++ compiler: " ${CMAKE_CXX_COMPILER}) 50 | message("C++ options: " ${CMAKE_CXX_FLAGS}) 51 | get_directory_property(cDirDefs DIRECTORY ${CMAKE_SOURCE_DIR} COMPILE_DEFINITIONS) 52 | message("C++ definitions: " ${cDirDefs}) 53 | 54 | include_directories(include) 55 | 56 | add_subdirectory(src/base) 57 | add_subdirectory(src/forest) 58 | add_subdirectory(src/exe) 59 | -------------------------------------------------------------------------------- /FastRGF/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright (c) 2016 Baidu, Inc. All Rights Reserved. 2018 RGF-team 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a 5 | copy of this software and associated documentation files (the "Software"), 6 | to deal in the Software without restriction, including without limitation 7 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | and/or sell copies of the Software, and to permit persons to whom the 9 | Software is furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 15 | OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /FastRGF/examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | You can learn how to use FastRGF by these examples. 4 | 5 | Note that for these small examples, the running time with multithreading may be slower than with single-threading due to the overhead it introduces. 6 | However, for large datasets, one can observe an almost linear speedup. 7 | 8 | FastRGF can directly handle high-dimensional sparse features in the libsvm format as in [binary_classification example](./binary_classification). 9 | This is the recommended format to use when the dataset is relatively large (although some other formats are supported). 10 | -------------------------------------------------------------------------------- /FastRGF/examples/binary_classification/README.md: -------------------------------------------------------------------------------- 1 | # Binary Classification Example 2 | 3 | Here is an example for FastRGF to run binary classification task. 4 | Dataset for this example is taken from [here](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#madelon) and features are written in libsvm sparse format. 5 | 6 | You should make sure that executable files are placed into `../../bin` folder. 7 | 8 | Execute the shell script in this folder to run the example: 9 | 10 | for Windows: 11 | 12 | ``` 13 | run.sh 14 | ``` 15 | 16 | for Unix-like systems: 17 | 18 | ``` 19 | bash run.sh 20 | ``` 21 | -------------------------------------------------------------------------------- /FastRGF/examples/binary_classification/outputs/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /FastRGF/examples/binary_classification/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh -f 2 | 3 | exe_train=../../bin/forest_train 4 | exe_predict=../../bin/forest_predict 5 | 6 | trn=inputs/madelon.train 7 | 8 | tst=inputs/madelon.test 9 | 10 | model_rgf=outputs/model-rgf 11 | 12 | prediction=outputs/prediction 13 | 14 | orig_format="y.sparse" 15 | save_freq=200 16 | 17 | echo ------ training ------ 18 | time ${exe_train} trn.x-file=${trn} trn.x-file_format=${orig_format} trn.target=BINARY tst.x-file=${tst} tst.x-file_format=${orig_format} tst.target=BINARY model.save=${model_rgf} dtree.new_tree_gain_ratio=1.0 dtree.lamL2=5000 forest.ntrees=1000 dtree.loss=LOGISTIC forest.save_frequency=${save_freq} 19 | echo " " 20 | 21 | echo ------ testing intermediate model at ${save_freq} on ${tst} ------ 22 | time ${exe_predict} tst.x-file=${tst} tst.x-file_format=${orig_format} tst.target=BINARY model.load=${model_rgf}-${save_freq} 23 | echo " " 24 | 25 | echo ------ testing ------ 26 | echo === ${trn} === 27 | time ${exe_predict} tst.x-file=${trn} tst.x-file_format=${orig_format} tst.target=BINARY model.load=${model_rgf} tst.output-prediction=${prediction}-train 28 | echo " " 29 | echo === ${tst} === 30 | time ${exe_predict} tst.x-file=${tst} tst.x-file_format=${orig_format} tst.target=BINARY model.load=${model_rgf} tst.output-prediction=${prediction}-test 31 | -------------------------------------------------------------------------------- /FastRGF/examples/regression/README.md: -------------------------------------------------------------------------------- 1 | # Regression Example 2 | 3 | Here is an example for FastRGF to run regression task. 4 | Dataset for this example is taken from [here](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression.html#housing) and features are written in the dense format. 5 | 6 | You should make sure that executable files are placed into `../../bin` folder. 7 | 8 | Execute the shell script in this folder to run the example: 9 | 10 | for Windows: 11 | 12 | ``` 13 | run.sh 14 | ``` 15 | 16 | for Unix-like systems: 17 | 18 | ``` 19 | bash run.sh 20 | ``` 21 | -------------------------------------------------------------------------------- /FastRGF/examples/regression/inputs/config: -------------------------------------------------------------------------------- 1 | # discretization options 2 | discretize.dense.max_buckets=250 3 | discretize.dense.lamL2=10 4 | 5 | # training options 6 | dtree.new_tree_gain_ratio=1.0 7 | dtree.loss=LS 8 | dtree.lamL1=10 9 | dtree.lamL2=1000 10 | forest.ntrees=1000 11 | -------------------------------------------------------------------------------- /FastRGF/examples/regression/inputs/feature.names: -------------------------------------------------------------------------------- 1 | Crime 2 | Zoned 3 | Industry 4 | River 5 | Nox 6 | Rooms 7 | Old 8 | Wdist 9 | Hiway 10 | Tax 11 | Pup/Teach 12 | Blacks 13 | Poor 14 | -------------------------------------------------------------------------------- /FastRGF/examples/regression/outputs/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /FastRGF/examples/regression/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh -f 2 | 3 | exe_train=../../bin/forest_train 4 | exe_predict=../../bin/forest_predict 5 | 6 | trn=inputs/housing.train 7 | tst=inputs/housing.test 8 | feat_name=inputs/feature.names 9 | 10 | config=inputs/config 11 | 12 | model_rgf=outputs/model-rgf 13 | 14 | prediction=outputs/prediction 15 | 16 | orig_format="y.x" 17 | 18 | echo ------ training ------ 19 | time ${exe_train} -config=${config} trn.x-file=${trn} trn.x-file_format=${orig_format} trn.target=REAL tst.x-file=${tst} tst.x-file_format=${orig_format} tst.target=REAL model.save=${model_rgf} 20 | echo " " 21 | 22 | echo ------ printing forest ------ 23 | ${exe_predict} model.load=${model_rgf} tst.print-forest=${model_rgf}.print tst.feature-names=${feat_name} 24 | echo " " 25 | 26 | echo ------ testing ------ 27 | echo === ${trn} === 28 | time ${exe_predict} tst.x-file=${trn} tst.x-file_format=${orig_format} tst.target=REAL model.load=${model_rgf} tst.output-prediction=${prediction}-train 29 | echo " " 30 | echo === ${tst} === 31 | time ${exe_predict} tst.x-file=${tst} tst.x-file_format=${orig_format} tst.target=REAL model.load=${model_rgf} tst.output-prediction=${prediction}-test 32 | -------------------------------------------------------------------------------- /FastRGF/include/classifier.h: -------------------------------------------------------------------------------- 1 | /************************************************************************ 2 | * classifier.h (2016) by Tong Zhang 3 | * 4 | * For Copyright, see LICENSE. 5 | * 6 | ************************************************************************/ 7 | 8 | 9 | #ifndef _RGF_CLASSIFIER_H 10 | 11 | #define _RGF_CLASSIFIER_H 12 | 13 | #include "data.h" 14 | 15 | namespace rgf { 16 | 17 | 18 | namespace TrainLoss { 19 | enum { 20 | LS=0, 21 | MODLS=1, 22 | LOGISTIC=2, 23 | INVALID =3 24 | }; 25 | 26 | 27 | int str2loss(string loss_str); 28 | 29 | 30 | string loss2str(int loss); 31 | 32 | 33 | double binary_loss(int loss, double scr, double y); 34 | } 35 | 36 | 37 | template 38 | class BinaryClassifier { 39 | public: 40 | 41 | double threshold; 42 | 43 | 44 | BinaryClassifier() : 45 | threshold(0.0) { 46 | } 47 | 48 | 49 | virtual double apply(DataPoint & dp)=0; 50 | 51 | 52 | bool classify(double scr) { 53 | return scr > threshold; 54 | } 55 | 56 | virtual ~BinaryClassifier() {} 57 | }; 58 | 59 | 60 | class BinaryTestStat { 61 | 62 | class TestResult { 63 | public: 64 | 65 | double scr; 66 | 67 | double y; 68 | 69 | const bool operator<(const TestResult & b) const { 70 | return scr < b.scr; 71 | } 72 | 73 | 74 | TestResult(double _scr, double _y) : 75 | scr(_scr), y(_y) { 76 | } 77 | }; 78 | 79 | 80 | vector _results; 81 | 82 | 83 | Target _y_type; 84 | 85 | 86 | int _loss; 87 | public: 88 | 89 | size_t tp; 90 | 91 | size_t tn; 92 | 93 | size_t fp; 94 | 95 | size_t fn; 96 | 97 | 98 | size_t num; 99 | 100 | 101 | double total_loss; 102 | 103 | 104 | bool keep_results; 105 | 106 | 107 | BinaryTestStat(Target y_type, int loss) : 108 | _y_type(y_type), _loss(loss), tp(0), tn(0), fp(0), fn(0), 109 | num(0), total_loss(0), keep_results(true) { 110 | } 111 | 112 | 113 | void update(double y, double scr, bool pred_label); 114 | 115 | 116 | template 117 | void update(BinaryClassifier & appl, DataSet & ds) 118 | { 119 | for (size_t i=0; i dp=ds[i]; 121 | double scr=appl.apply(dp); 122 | update(ds.y[i], scr, appl.classify(scr)); 123 | } 124 | } 125 | 126 | 127 | 128 | double accuracy() { 129 | return (tp + tn) / (tp + tn + fp + fn + 1e-10); 130 | } 131 | 132 | double precision() { 133 | return tp / (tp + fp + 1e-10); 134 | } 135 | 136 | double recall() { 137 | return tp / (tp + fn + 1e-10); 138 | } 139 | 140 | double fb1() { 141 | return 2.0 / (1.0 / precision() + 1.0 / recall()); 142 | } 143 | 144 | 145 | void roc(size_t _tp, size_t _tn, double & tpr, double & fpr); 146 | 147 | 148 | double auc(); 149 | 150 | 151 | double mse(); 152 | 153 | 154 | void print(ostream & os); 155 | 156 | 157 | void clear(); 158 | }; 159 | 160 | 161 | } 162 | 163 | #endif 164 | 165 | -------------------------------------------------------------------------------- /FastRGF/include/forest.h: -------------------------------------------------------------------------------- 1 | /************************************************************************ 2 | * forest.h (2016) by Tong Zhang 3 | * 4 | * For Copyright, see LICENSE. 5 | * 6 | ************************************************************************/ 7 | 8 | 9 | #ifndef _RGF_FOREST_H 10 | 11 | #define _RGF_FOREST_H 12 | 13 | #include "dtree.h" 14 | 15 | namespace rgf { 16 | 17 | 18 | template 19 | class DecisionForest: public BinaryClassifier { 20 | 21 | 22 | vector > _dtree_vec; 23 | 24 | 25 | int _dim_dense; 26 | 27 | int _dim_sparse; 28 | 29 | 30 | unsigned int _ntrees; 31 | 32 | unsigned int _nthreads; 33 | 34 | 35 | int _train_loss; 36 | public: 37 | 38 | 39 | DecisionForest() : _dim_dense(0), _dim_sparse(0), 40 | _ntrees(0), _nthreads(1), _train_loss(TrainLoss::INVALID) {} 41 | 42 | 43 | 44 | int train_loss() { 45 | return _train_loss; 46 | } 47 | 48 | 49 | DecisionTree & operator [] (size_t i) {return _dtree_vec[i];} 50 | 51 | 52 | void set(int nthreads, int ntrees=0) { 53 | _ntrees=ntrees; 54 | _nthreads=nthreads; 55 | } 56 | 57 | virtual double apply(DataPoint & dp) { 58 | return apply(dp, _ntrees,_nthreads); 59 | } 60 | 61 | 62 | double apply(DataPoint & dp, unsigned int ntrees, int nthreads); 63 | 64 | 65 | size_t appendFeatures(DataPoint & dp, vector & feat_vec, size_t offset); 66 | 67 | 68 | void write(ostream & os); 69 | 70 | 71 | void read(istream & is); 72 | 73 | 74 | void clear() { 75 | _dtree_vec.clear(); 76 | } 77 | 78 | ~DecisionForest() { 79 | clear(); 80 | } 81 | 82 | 83 | class TrainParam: public ParameterParser { 84 | public: 85 | 86 | 87 | ParamValue loss; 88 | 89 | 90 | ParamValue step_size; 91 | 92 | 93 | 94 | ParamValue opt; 95 | 96 | 97 | ParamValue ntrees; 98 | 99 | 100 | ParamValue eval_frequency; 101 | 102 | 103 | ParamValue write_frequency; 104 | 105 | 106 | ParamValue verbose; 107 | 108 | 109 | TrainParam(string prefix = "rgf.") { 110 | step_size.insert(prefix + "stepsize", 0, "step size of epsilon-greedy boosting (inactive for rgf)", 111 | this,false); 112 | opt.insert(prefix + "opt", "rgf", "optimization method for training forest (rgf or epsilon-greedy)", 113 | this); 114 | ntrees.insert(prefix + "ntrees", 500, "number of trees", 115 | this); 116 | eval_frequency.insert(prefix+ "eval_frequency",50, "evaluate performance on test data at this frequency",this); 117 | write_frequency.insert(prefix+ "save_frequency",0, "save forest models to file 'model_file-iter' at this frequency",this); 118 | } 119 | }; 120 | 121 | 122 | void train( 123 | DataSet & ds, double* scr_arr, 124 | class DecisionTree::TrainParam ¶m_dt, 125 | TrainParam & param_forest, 126 | DataSet & tst, 127 | string model_file="", 128 | DataDiscretizationInt *disc_ptr=0); 129 | 130 | 131 | void revert_discretization(DataDiscretizationInt & disc); 132 | 133 | 134 | void print(ostream & os, vector & feature_names, 135 | bool depth_first=true); 136 | }; 137 | 138 | using DecisionForestFlt=DecisionForest; 139 | using DecisionForestInt=DecisionForest; 140 | using DecisionForestShort=DecisionForest; 141 | 142 | } 143 | 144 | #endif 145 | -------------------------------------------------------------------------------- /FastRGF/include/header.h: -------------------------------------------------------------------------------- 1 | /************************************************************************ 2 | * header.h (2016) by Tong Zhang 3 | * 4 | * For Copyright, see LICENSE. 5 | * 6 | ************************************************************************/ 7 | 8 | 9 | #ifndef _RGF_HEADER_H 10 | 11 | #define _RGF_HEADER_H 12 | 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | 31 | #include 32 | 33 | #ifdef USE_OMP 34 | #include "omp.h" 35 | #endif 36 | 37 | namespace rgf { 38 | 39 | 40 | #define VER "version 0.7" 41 | 42 | 43 | const int max_thrds=128; 44 | 45 | 46 | const int vect_width=8; 47 | 48 | 49 | using int_t=long; 50 | 51 | 52 | 53 | using src_index_t=int; 54 | 55 | 56 | using disc_dense_value_t=unsigned short; 57 | 58 | 59 | using disc_sparse_index_t=int; 60 | 61 | 62 | using disc_sparse_value_t=unsigned char; 63 | 64 | 65 | 66 | using train_size_t=unsigned int; 67 | } 68 | 69 | using namespace std; 70 | using namespace rgf; 71 | 72 | #endif 73 | -------------------------------------------------------------------------------- /FastRGF/src/base/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # create a library "base" and test program tester_base 2 | 3 | add_library (base utils.cpp data.cpp discretization.cpp classifier.cpp) 4 | #target_include_directories(base PUBLIC ${CMAKE_SOURCE_DIR}/include) 5 | 6 | #add_executable(tester_base tester.cpp) 7 | #target_link_libraries(tester_base base) 8 | 9 | #install(TARGETS tester_base DESTINATION ${CMAKE_SOURCE_DIR}/src/base) 10 | 11 | if(OPENMP) 12 | if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") 13 | target_link_libraries(base OpenMP::OpenMP_CXX) 14 | endif() 15 | endif(OPENMP) 16 | -------------------------------------------------------------------------------- /FastRGF/src/base/classifier.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************ 2 | * classifier.cpp (2016) by Tong Zhang 3 | * 4 | * For Copyright, see LICENSE. 5 | * 6 | ************************************************************************/ 7 | 8 | 9 | #include "classifier.h" 10 | 11 | int TrainLoss::str2loss(string loss_str) 12 | { 13 | int loss= TrainLoss::INVALID; 14 | if (loss_str.compare("MODLS") ==0) loss= TrainLoss::MODLS; 15 | if (loss_str.compare("LOGISTIC") ==0) loss= TrainLoss::LOGISTIC; 16 | if (loss_str.compare("LS") ==0) loss= TrainLoss::LS; 17 | if (loss==TrainLoss::INVALID) { 18 | cerr << "loss " << loss_str << " is invalid" <1)? 0: (tmp-1)*(tmp-1); 49 | case LOGISTIC: 50 | tmp=scr*y; 51 | return log(1+exp(-tmp)); 52 | default: 53 | cerr << "invalid loss" < tp_vec; 125 | 126 | vector tn_vec; 127 | 128 | { 129 | sort(_results.begin(),_results.end()); 130 | tp_vec.clear(); 131 | tn_vec.clear(); 132 | 133 | 134 | size_t my_tp=tp+fn; 135 | size_t my_tn=0; 136 | tp_vec.push_back(my_tp); 137 | tn_vec.push_back(my_tn); 138 | 139 | for (size_t i=0; i<_results.size(); i++) { 140 | bool truth= _y_type.binary_label(_results[i].y); 141 | if (truth) my_tp--; 142 | else my_tn++; 143 | if (i<_results.size()-1 && _results[i].scr==_results[i+1].scr) continue; 144 | tp_vec.push_back(my_tp); 145 | tn_vec.push_back(my_tn); 146 | } 147 | } 148 | 149 | double tpr1, tpr2; 150 | double fpr1, fpr2; 151 | double a=0; 152 | 153 | roc(tp_vec[0],tn_vec[0], tpr1,fpr1); 154 | for (size_t i=1; i< tp_vec.size(); i++) { 155 | roc(tp_vec[i],tn_vec[i], tpr2,fpr2); 156 | 157 | a += 0.5*(tpr1+tpr2)*(fpr1-fpr2); 158 | 159 | tpr1=tpr2; 160 | fpr1=fpr2; 161 | } 162 | return a; 163 | } 164 | -------------------------------------------------------------------------------- /FastRGF/src/base/utils.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************ 2 | * utils.cpp (2016) by Tong Zhang 3 | * 4 | * For Copyright, see LICENSE. 5 | * 6 | ************************************************************************/ 7 | 8 | 9 | #include "utils.h" 10 | 11 | bool ParameterParser::parse_and_assign(string token) 12 | { 13 | size_t pos=token.find_first_of('='); 14 | if (pos ==0 || pos==string::npos) { 15 | return false; 16 | } 17 | string key=token.substr(0,pos); 18 | 19 | string value=string(""); 20 | if (pos+1parsed_value=value; 25 | _kv_table[i].second->set_value(); 26 | return true; 27 | } 28 | } 29 | return false; 30 | } 31 | 32 | void ParameterParser::print_parameters(ostream & os, string indent) 33 | { 34 | 35 | for (auto it=_kv_table.begin(); it!=_kv_table.end(); it++) 36 | { 37 | string key=it->first; 38 | string value=it->second->parsed_value; 39 | if (it->second->is_valid) { 40 | os << indent << key <<"=" << value <first; 51 | os << indent << " " << key << "=value : " 52 | << it->second->description 53 | << " (default=" << it->second->default_value << ")" 54 | << endl; 55 | } 56 | } 57 | 58 | void ParameterParserGroup::command_line_parse(int_t argc, char *argv[]) 59 | { 60 | unparsed_tokens.clear(); 61 | for (int_t c=1; c=2) { 66 | cerr << " ambiguous command line option " << token <=2) { 95 | cerr << " ambigous option " << token <print_options(os,indent); 110 | } 111 | } 112 | 113 | int_t ParameterParserGroup::parse(string token) 114 | { 115 | int_t num_parsed=0; 116 | for (int_t i=0; iparse_and_assign(token)) num_parsed++; 118 | } 119 | return num_parsed; 120 | } 121 | -------------------------------------------------------------------------------- /FastRGF/src/exe/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # create test programs forest_train forest_predict 2 | 3 | add_executable(auc auc.cpp) 4 | target_link_libraries(auc base) 5 | 6 | #install(TARGETS auc DESTINATION ${CMAKE_SOURCE_DIR}/bin) 7 | 8 | add_executable(discretized_gendata discretized_gendata.cpp) 9 | target_link_libraries(discretized_gendata base) 10 | 11 | #install(TARGETS discretized_gendata DESTINATION ${CMAKE_SOURCE_DIR}/bin) 12 | 13 | 14 | add_executable(discretized_trainer discretized_trainer.cpp) 15 | target_link_libraries(discretized_trainer base forest) 16 | 17 | #install(TARGETS discretized_trainer DESTINATION ${CMAKE_SOURCE_DIR}/bin) 18 | 19 | 20 | ############### 21 | add_executable(forest_train forest_train.cpp) 22 | target_link_libraries(forest_train forest base) 23 | 24 | install(TARGETS forest_train DESTINATION ${CMAKE_SOURCE_DIR}/bin) 25 | 26 | add_executable(forest_predict forest_predict.cpp) 27 | target_link_libraries(forest_predict forest base) 28 | 29 | install(TARGETS forest_predict DESTINATION ${CMAKE_SOURCE_DIR}/bin) 30 | -------------------------------------------------------------------------------- /FastRGF/src/exe/auc.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************ 2 | * auc.cpp (2016) by Tong Zhang 3 | * 4 | * For Copyright, see LICENSE. 5 | * 6 | ************************************************************************/ 7 | 8 | 9 | #include "classifier.h" 10 | 11 | int main(int argc, char *argv[]) 12 | { 13 | if (argc>1 && argv[1][0]=='-' && argv[1][1]=='h') { 14 | cout << "usage: " << argv[0] << " < result-file " <> y >> score; 29 | label=y_type.binary_label(y); 30 | if (cin.eof()) break; 31 | test_result.update(label,0,score); 32 | nl++; 33 | } 34 | cout << "auc=" << test_result.auc() < ntrees; 19 | ParamValue output_prediction; 20 | ParamValue output_feature; 21 | ParamValue print_forest; 22 | ParamValue feature_names; 23 | 24 | TestParam(string prefix="tst.") : DataSetFlt::IOParam(prefix) { 25 | ntrees.insert(prefix+"ntrees",0,"if nonzero, use ntrees to compute prediction",this); 26 | output_prediction.insert(prefix+"output-prediction","","if nonempty, output predictions to this file",this); 27 | output_feature.insert(prefix+"output-feature","","if nonempty, output features to this file",this); 28 | print_forest.insert(prefix+"print-forest","","if nonempty, print forest to this file",this); 29 | feature_names.insert(prefix+"feature-names","","if nonempty, read feature names from the file in the format:\n feature-0-name\n feature-1-name\n ...\n feature-names are used to print forest when data are in dense or sparse (not mixed) format\n",this); 30 | } 31 | } param_tstfile; 32 | 33 | 34 | ModelParam param_modelfile("model."); 35 | 36 | void parser_init() 37 | { 38 | ppg.add_parser(¶m_config); 39 | ppg.add_parser(¶m_set); 40 | 41 | param_modelfile.set_description("model-file options:"); 42 | ppg.add_parser(¶m_modelfile); 43 | 44 | param_tstfile.set_description("test-data and output options:"); 45 | ppg.add_parser(¶m_tstfile); 46 | 47 | } 48 | 49 | #include "test_output.h" 50 | class MyTestOutput : public TestOutput { 51 | public: 52 | void print_forest(DecisionForestFlt & forest) 53 | { 54 | if (param_tstfile.print_forest.value.size()>0) { 55 | vector feature_names; 56 | if (param_tstfile.feature_names.value.size()>0) { 57 | ifstream is(param_tstfile.feature_names.value); 58 | cerr <<"read feature names from <" << param_tstfile.feature_names.value << ">" <" <" <" <=2) { 97 | cerr << " using up to " << nthreads << " threads" << endl; 98 | } 99 | 100 | Timer t; 101 | DecisionForestFlt forest; 102 | 103 | 104 | if (param_modelfile.load_filename.value.size()>0) { 105 | cerr << endl <" < save_filename; 15 | ParamValue load_filename; 16 | 17 | ModelParam(string prefix="model.") 18 | { 19 | save_filename.insert(prefix+"save","","if nonempty, save trained model to file",this); 20 | load_filename.insert(prefix+"load","","if nonempty, load previously trained model from file",this); 21 | } 22 | }; 23 | 24 | 25 | class ConfigParam : public ParameterParser { 26 | public: 27 | ParamValue filename; 28 | 29 | ConfigParam() 30 | { 31 | filename.insert("-config","","if nonempty, read options from config-file",this); 32 | } 33 | }; 34 | 35 | class SetParam : public ParameterParser { 36 | public: 37 | 38 | ParamValue nthreads; 39 | 40 | ParamValue verbose; 41 | SetParam(string prefix="set.") { 42 | nthreads.insert(prefix+"nthreads", 0, "number of threads for training and testing (0 means maximum number of hardware logical threads)",this); 43 | verbose.insert(prefix+ "verbose",2, "verbose level",this); 44 | this->set_description("global options:"); 45 | } 46 | }; 47 | 48 | ConfigParam param_config; 49 | SetParam param_set; 50 | 51 | ParameterParserGroup ppg; 52 | 53 | void usage(int argc, char *argv[]) 54 | { 55 | cerr << argv[0] << " " << VER <0) { 78 | cerr << "unknown option " << ppg.unparsed_tokens[0] <0) { 82 | cerr << endl; 83 | cerr << "reading options from configuration file <" << param_config.filename.value << ">" <0) { 87 | cerr << "unknown option " << ppg.unparsed_tokens[0] < 10 | class TestOutput { 11 | public: 12 | 13 | DataSet tst; 14 | 15 | void read_tstfile() 16 | { 17 | Timer t("loading time"); 18 | if (param_tstfile.fn_x.value.size()>0) { 19 | cerr << endl < & forest, int ntrees,int nthreads=0) 31 | { 32 | if (tst.size()==0) return; 33 | 34 | cerr <0) { 41 | cerr <<"output predictions to <" << param_tstfile.output_prediction.value << ">" <" <0) cerr << "using "<< ntrees << " trees" << endl; 51 | for (i=0; i dp=tst[i]; 53 | double scr=forest.apply(dp,ntrees,nthreads); 54 | bool pred=forest.classify(scr); 55 | if (os.good()) os << scr <0) { 62 | os.close(); 63 | } 64 | if (compute_stat) test_result.print(cout); 65 | } 66 | 67 | if (param_tstfile.output_feature.value.size()>0) { 68 | cerr <<"output features to <" << param_tstfile.output_feature.value << ">" <" < { 76 | public: 77 | DecisionForest * forest_ptr; 78 | virtual void write_datapoint(ostream & os, DataSet & ds, size_t i) { 79 | vector feat_vec; 80 | DataPoint dp=ds[i]; 81 | forest_ptr->appendFeatures(dp,feat_vec,0); 82 | for (int j=0; j 37 | void init(DataSet & ds, int ngrps, int verbose); 38 | 39 | 40 | template 41 | void build_single_tree(DataSet & ds, double *scr_arr, 42 | class DecisionTree::TrainParam & param_dt, 43 | double step_size, 44 | class DecisionTree & dtree); 45 | 46 | template 47 | void fully_corrective_update(DataSet & ds, double *scr_arr, 48 | class DecisionTree::TrainParam ¶m_dt, 49 | DecisionTree * dtree_vec, 50 | int ntrees); 51 | 52 | 53 | template 54 | void finish(DataSet &ds, int verbose); 55 | }; 56 | } 57 | -------------------------------------------------------------------------------- /R-package/DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: RGF 2 | Type: Package 3 | Title: Regularized Greedy Forest 4 | Version: 1.0.9 5 | Date: 2021-09-11 6 | Authors@R: c( person("Lampros", "Mouselimis", email = "mouselimislampros@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "https://orcid.org/0000-0002-8024-1546")), person("Ryosuke", "Fukatani", role = "cph", comment = "Author of the python wrapper of the 'Regularized Greedy Forest' machine learning algorithm"), person("Nikita", "Titov", role = "cph", comment = "Author of the python wrapper of the 'Regularized Greedy Forest' machine learning algorithm"), person("Tong", "Zhang", role = "cph", comment = "Author of the 'Regularized Greedy Forest' and of the Multi-core implementation of Regularized Greedy Forest machine learning algorithm"), person("Rie", "Johnson", role = "cph", comment = "Author of the 'Regularized Greedy Forest' machine learning algorithm") ) 7 | BugReports: https://github.com/RGF-team/rgf/issues 8 | URL: https://github.com/RGF-team/rgf/tree/master/R-package 9 | Description: Regularized Greedy Forest wrapper of the 'Regularized Greedy Forest' 'python' package, which also includes a Multi-core implementation (FastRGF) . 10 | License: MIT + file LICENSE 11 | SystemRequirements: Python (>= 3.7), rgf_python, scikit-learn (>= 0.18.0), scipy, numpy. Detailed installation instructions for each operating system can be found in the README file. 12 | Depends: 13 | R(>= 3.2.0) 14 | Imports: 15 | reticulate, 16 | R6, 17 | Matrix 18 | Suggests: 19 | testthat, 20 | covr, 21 | knitr, 22 | rmarkdown 23 | Encoding: UTF-8 24 | RoxygenNote: 7.1.1 25 | VignetteBuilder: knitr 26 | -------------------------------------------------------------------------------- /R-package/LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2018 2 | COPYRIGHT HOLDER: Mouselimis Lampros 3 | -------------------------------------------------------------------------------- /R-package/LICENSE.note: -------------------------------------------------------------------------------- 1 | #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 2 | # Regularized Greedy Forest (RGF) 3 | #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 4 | 5 | 6 | See file COPYING: https://github.com/RGF-team/rgf/blob/master/RGF/COPYING 7 | 8 | 9 | 10 | #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 11 | # Fast Regularized Greedy Forest (FastRGF) 12 | #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 13 | 14 | 15 | See file LICENSE: https://github.com/RGF-team/rgf/blob/master/FastRGF/LICENSE 16 | 17 | 18 | 19 | #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 20 | # Python wrapper of RGF and FastRGF 21 | #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 22 | 23 | 24 | See file LICENSE: https://github.com/RGF-team/rgf/blob/master/python-package/LICENSE 25 | -------------------------------------------------------------------------------- /R-package/NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | export(FastRGF_Classifier) 4 | export(FastRGF_Regressor) 5 | export(RGF_Classifier) 6 | export(RGF_Regressor) 7 | export(RGF_cleanup_temp_files) 8 | export(TO_scipy_sparse) 9 | export(mat_2scipy_sparse) 10 | import(reticulate) 11 | importFrom(Matrix,Matrix) 12 | importFrom(R6,R6Class) 13 | -------------------------------------------------------------------------------- /R-package/R/Internal_class.R: -------------------------------------------------------------------------------- 1 | #' Internal R6 class for all secondary functions used in RGF and FastRGF 2 | #' 3 | #' @importFrom R6 R6Class 4 | #' @keywords internal 5 | Internal_class <- R6::R6Class( 6 | "Internal_class", 7 | lock_objects = FALSE, 8 | public = list( 9 | 10 | # 'fit' function 11 | #---------------- 12 | fit = function(x, y, sample_weight = NULL) { 13 | private$rgf_init$fit(x, y, sample_weight) 14 | return(invisible(NULL)) 15 | }, 16 | 17 | # 'predict' function 18 | #-------------------- 19 | predict = function(x) { 20 | return(private$rgf_init$predict(x)) 21 | }, 22 | 23 | # 'predict' function [ probabilities ] 24 | #-------------------- 25 | predict_proba = function(x) { 26 | return(private$rgf_init$predict_proba(x)) 27 | }, 28 | 29 | # 'cleanup' function 30 | #------------------- 31 | cleanup = function() { 32 | private$rgf_init$cleanup() 33 | return(invisible(NULL)) 34 | }, 35 | 36 | # 'get_params' function 37 | #---------------------- 38 | get_params = function(deep = TRUE) { 39 | return(private$rgf_init$get_params(deep)) 40 | }, 41 | 42 | # score function 43 | #--------------- 44 | score = function(x, y, sample_weight = NULL) { 45 | return(private$rgf_init$score(x, y, sample_weight)) 46 | }, 47 | 48 | # feature importance 49 | #------------------- 50 | feature_importances = function() { 51 | return(private$rgf_init$feature_importances_) 52 | }, 53 | 54 | # dump-model 55 | #----------- 56 | dump_model = function() { 57 | return(private$rgf_init$dump_model) 58 | }, 59 | 60 | # save_model 61 | #----------- 62 | save_model = function(filename) { 63 | private$rgf_init$save_model(filename) 64 | return(invisible(NULL)) 65 | } 66 | ), 67 | 68 | private = list( 69 | rgf_init = NULL 70 | ) 71 | ) 72 | -------------------------------------------------------------------------------- /R-package/R/RGF_cleanup_temp_files.R: -------------------------------------------------------------------------------- 1 | #' Delete all temporary files of the created RGF estimators 2 | #' 3 | #' @details 4 | #' This function deletes all temporary files of the created RGF estimators. See the issue \emph{https://github.com/RGF-team/rgf/issues/75} for more details. 5 | #' @export 6 | #' @references \emph{https://github.com/RGF-team/rgf/tree/master/python-package} 7 | #' @examples 8 | #' 9 | #' \dontrun{ 10 | #' library(RGF) 11 | #' 12 | #' RGF_cleanup_temp_files() 13 | #' } 14 | 15 | RGF_cleanup_temp_files = function() { 16 | 17 | RGF_utils$cleanup() 18 | 19 | invisible() 20 | } 21 | -------------------------------------------------------------------------------- /R-package/R/TO_scipy_sparse.R: -------------------------------------------------------------------------------- 1 | #' conversion of an R sparse matrix to a scipy sparse matrix 2 | #' 3 | #' 4 | #' @param R_sparse_matrix an R sparse matrix. Acceptable input objects are either a \emph{dgCMatrix} or a \emph{dgRMatrix}. 5 | #' @details 6 | #' This function allows the user to convert either an R \emph{dgCMatrix} or a \emph{dgRMatrix} to a scipy sparse matrix (\emph{scipy.sparse.csc_matrix} or \emph{scipy.sparse.csr_matrix}). This is useful because the \emph{RGF} package accepts besides an R dense matrix also python sparse matrices as input. 7 | #' 8 | #' The \emph{dgCMatrix} class is a class of sparse numeric matrices in the compressed, sparse, \emph{column-oriented format}. The \emph{dgRMatrix} class is a class of sparse numeric matrices in the compressed, sparse, \emph{row-oriented format}. 9 | #' 10 | #' @export 11 | #' @import reticulate 12 | #' @importFrom Matrix Matrix 13 | #' @references https://stat.ethz.ch/R-manual/R-devel/library/Matrix/html/dgCMatrix-class.html, https://stat.ethz.ch/R-manual/R-devel/library/Matrix/html/dgRMatrix-class.html, https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html#scipy.sparse.csc_matrix 14 | #' @examples 15 | #' 16 | #' try({ 17 | #' if (reticulate::py_available(initialize = FALSE)) { 18 | #' if (reticulate::py_module_available("scipy")) { 19 | #' 20 | #' if (Sys.info()["sysname"] != 'Darwin') { 21 | #' 22 | #' library(RGF) 23 | #' 24 | #' 25 | #' # 'dgCMatrix' sparse matrix 26 | #' #-------------------------- 27 | #' 28 | #' data = c(1, 0, 2, 0, 0, 3, 4, 5, 6) 29 | #' 30 | #' dgcM = Matrix::Matrix( 31 | #' data = data 32 | #' , nrow = 3 33 | #' , ncol = 3 34 | #' , byrow = TRUE 35 | #' , sparse = TRUE 36 | #' ) 37 | #' 38 | #' print(dim(dgcM)) 39 | #' 40 | #' res = TO_scipy_sparse(dgcM) 41 | #' 42 | #' print(res$shape) 43 | #' 44 | #' 45 | #' # 'dgRMatrix' sparse matrix 46 | #' #-------------------------- 47 | #' 48 | #' dgrM = as(dgcM, "RsparseMatrix") 49 | #' 50 | #' print(dim(dgrM)) 51 | #' 52 | #' res_dgr = TO_scipy_sparse(dgrM) 53 | #' 54 | #' print(res_dgr$shape) 55 | #' } 56 | #' } 57 | #' } 58 | #' }, silent = TRUE) 59 | 60 | TO_scipy_sparse = function(R_sparse_matrix) { 61 | 62 | if (inherits(R_sparse_matrix, "dgCMatrix")) { 63 | py_obj <- SCP$sparse$csc_matrix( 64 | reticulate::tuple( 65 | R_sparse_matrix@x 66 | , R_sparse_matrix@i 67 | , R_sparse_matrix@p 68 | ) 69 | , shape = reticulate::tuple( 70 | R_sparse_matrix@Dim[1] 71 | , R_sparse_matrix@Dim[2] 72 | ) 73 | ) 74 | } 75 | 76 | else if (inherits(R_sparse_matrix, "dgRMatrix")) { 77 | 78 | py_obj <- SCP$sparse$csr_matrix( 79 | reticulate::tuple( 80 | R_sparse_matrix@x 81 | , R_sparse_matrix@j 82 | , R_sparse_matrix@p 83 | ) 84 | , shape = reticulate::tuple( 85 | R_sparse_matrix@Dim[1] 86 | , R_sparse_matrix@Dim[2] 87 | ) 88 | ) 89 | } 90 | 91 | else { 92 | stop("the 'R_sparse_matrix' parameter should be either a 'dgCMatrix' or a 'dgRMatrix' sparse matrix", call. = FALSE) 93 | } 94 | 95 | return(py_obj) 96 | } 97 | -------------------------------------------------------------------------------- /R-package/R/mat_2scipy_sparse.R: -------------------------------------------------------------------------------- 1 | #' conversion of an R matrix to a scipy sparse matrix 2 | #' 3 | #' 4 | #' @param x a data matrix 5 | #' @param format a character string. Either \emph{"sparse_row_matrix"} or \emph{"sparse_column_matrix"} 6 | #' @details 7 | #' This function allows the user to convert an R matrix to a scipy sparse matrix. This is useful because the Regularized Greedy Forest algorithm accepts only python sparse matrices as input. 8 | #' @export 9 | #' @references https://docs.scipy.org/doc/scipy/reference/sparse.html 10 | #' @examples 11 | #' 12 | #' try({ 13 | #' if (reticulate::py_available(initialize = FALSE)) { 14 | #' if (reticulate::py_module_available("scipy")) { 15 | #' 16 | #' library(RGF) 17 | #' 18 | #' set.seed(1) 19 | #' 20 | #' x = matrix(runif(1000), nrow = 100, ncol = 10) 21 | #' 22 | #' res = mat_2scipy_sparse(x) 23 | #' 24 | #' print(dim(x)) 25 | #' 26 | #' print(res$shape) 27 | #' } 28 | #' } 29 | #' }, silent = TRUE) 30 | 31 | mat_2scipy_sparse = function(x, format = 'sparse_row_matrix') { 32 | 33 | if (!inherits(x, "matrix")) { 34 | stop("the 'x' parameter should be of type 'matrix'", call. = FALSE) 35 | } 36 | 37 | if (format == 'sparse_column_matrix') { 38 | 39 | return(SCP$sparse$csc_matrix(x)) 40 | 41 | } else if (format == 'sparse_row_matrix') { 42 | 43 | return(SCP$sparse$csr_matrix(x)) 44 | 45 | } else { 46 | 47 | stop("the function can take either a 'sparse_row_matrix' or a 'sparse_column_matrix' for the 'format' parameter as input", call. = FALSE) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /R-package/R/package.R: -------------------------------------------------------------------------------- 1 | #------------------------- 2 | # Load the python-modules 3 | #------------------------- 4 | 5 | 6 | RGF_mod <- NULL; RGF_utils <- NULL; SCP <- NULL; 7 | 8 | 9 | .onLoad <- function(libname, pkgname) { 10 | 11 | try({ 12 | if (reticulate::py_available(initialize = TRUE)) { 13 | 14 | try({ 15 | RGF_mod <<- reticulate::import("rgf.sklearn", delay_load = TRUE) 16 | }, silent = TRUE) 17 | 18 | try({ 19 | RGF_utils <<- reticulate::import("rgf.utils", delay_load = TRUE) 20 | }, silent = TRUE) 21 | 22 | try({ 23 | SCP <<- reticulate::import("scipy", delay_load = TRUE, convert = FALSE) 24 | }, silent = TRUE) 25 | 26 | } 27 | }, silent=TRUE) 28 | } 29 | 30 | 31 | .onAttach <- function(libname, pkgname) { 32 | packageStartupMessage("If the 'RGF' package gives the following error: 'attempt to apply non-function' then make sure to open a new R session and run 'reticulate::py_config()' before loading the package!") 33 | } 34 | -------------------------------------------------------------------------------- /R-package/inst/CITATION: -------------------------------------------------------------------------------- 1 | citHeader("Please cite both the package and the original articles / software in your publications:") 2 | 3 | year <- sub("-.*", "", meta$Date) 4 | note <- sprintf("R package version %s", meta$Version) 5 | 6 | bibentry( 7 | bibtype = "Manual", 8 | title = "{RGF}: Regularized Greedy Forest", 9 | author = c(person("Lampros", "Mouselimis"), person("Ryosuke", "Fukatani"), person("Nikita", "Titov"), person("Tong", "Zhang"), person("Rie", "Johnson")), 10 | year = year, 11 | note = note, 12 | url = "https://CRAN.R-project.org/package=RGF" 13 | ) 14 | 15 | bibentry( 16 | bibtype = "Manual", 17 | title = "{rgf_python}: The wrapper of machine learning algorithm Regularized Greedy Forest (RGF) for Python", 18 | author = c(person("Ryosuke", "Fukatani"), person("Nikita", "Titov"), person("Tong", "Zhang"), person("Rie", "Johnson")), 19 | year = year, 20 | url = "https://pypi.org/project/rgf-python/" 21 | ) 22 | 23 | bibentry( 24 | bibtype = "Article", 25 | title = "Learning Nonlinear Functions Using Regularized Greedy Forest", 26 | author = c(as.person("Rie Johnson"), as.person("Tong Zhang")), 27 | journal = "IEEE Transactions on Pattern Analysis and Machine Intelligence", 28 | year = "2014", 29 | volume = "36", 30 | pages = "942--954", 31 | doi = "10.1109/TPAMI.2013.159" 32 | ) 33 | 34 | bibentry( 35 | bibtype = "Article", 36 | title = "Learning Nonlinear Functions Using Regularized Greedy Forest", 37 | author = c(as.person("Rie Johnson"), as.person("Tong Zhang")), 38 | journal = "arXiv.org, stat.ML", 39 | year = "2011", 40 | note = "arXiv:1109.0887" 41 | ) 42 | -------------------------------------------------------------------------------- /R-package/man/RGF_cleanup_temp_files.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/RGF_cleanup_temp_files.R 3 | \name{RGF_cleanup_temp_files} 4 | \alias{RGF_cleanup_temp_files} 5 | \title{Delete all temporary files of the created RGF estimators} 6 | \usage{ 7 | RGF_cleanup_temp_files() 8 | } 9 | \description{ 10 | Delete all temporary files of the created RGF estimators 11 | } 12 | \details{ 13 | This function deletes all temporary files of the created RGF estimators. See the issue \emph{https://github.com/RGF-team/rgf/issues/75} for more details. 14 | } 15 | \examples{ 16 | 17 | \dontrun{ 18 | library(RGF) 19 | 20 | RGF_cleanup_temp_files() 21 | } 22 | } 23 | \references{ 24 | \emph{https://github.com/RGF-team/rgf/tree/master/python-package} 25 | } 26 | -------------------------------------------------------------------------------- /R-package/man/TO_scipy_sparse.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/TO_scipy_sparse.R 3 | \name{TO_scipy_sparse} 4 | \alias{TO_scipy_sparse} 5 | \title{conversion of an R sparse matrix to a scipy sparse matrix} 6 | \usage{ 7 | TO_scipy_sparse(R_sparse_matrix) 8 | } 9 | \arguments{ 10 | \item{R_sparse_matrix}{an R sparse matrix. Acceptable input objects are either a \emph{dgCMatrix} or a \emph{dgRMatrix}.} 11 | } 12 | \description{ 13 | conversion of an R sparse matrix to a scipy sparse matrix 14 | } 15 | \details{ 16 | This function allows the user to convert either an R \emph{dgCMatrix} or a \emph{dgRMatrix} to a scipy sparse matrix (\emph{scipy.sparse.csc_matrix} or \emph{scipy.sparse.csr_matrix}). This is useful because the \emph{RGF} package accepts besides an R dense matrix also python sparse matrices as input. 17 | 18 | The \emph{dgCMatrix} class is a class of sparse numeric matrices in the compressed, sparse, \emph{column-oriented format}. The \emph{dgRMatrix} class is a class of sparse numeric matrices in the compressed, sparse, \emph{row-oriented format}. 19 | } 20 | \examples{ 21 | 22 | try({ 23 | if (reticulate::py_available(initialize = FALSE)) { 24 | if (reticulate::py_module_available("scipy")) { 25 | 26 | if (Sys.info()["sysname"] != 'Darwin') { 27 | 28 | library(RGF) 29 | 30 | 31 | # 'dgCMatrix' sparse matrix 32 | #-------------------------- 33 | 34 | data = c(1, 0, 2, 0, 0, 3, 4, 5, 6) 35 | 36 | dgcM = Matrix::Matrix( 37 | data = data 38 | , nrow = 3 39 | , ncol = 3 40 | , byrow = TRUE 41 | , sparse = TRUE 42 | ) 43 | 44 | print(dim(dgcM)) 45 | 46 | res = TO_scipy_sparse(dgcM) 47 | 48 | print(res$shape) 49 | 50 | 51 | # 'dgRMatrix' sparse matrix 52 | #-------------------------- 53 | 54 | dgrM = as(dgcM, "RsparseMatrix") 55 | 56 | print(dim(dgrM)) 57 | 58 | res_dgr = TO_scipy_sparse(dgrM) 59 | 60 | print(res_dgr$shape) 61 | } 62 | } 63 | } 64 | }, silent = TRUE) 65 | } 66 | \references{ 67 | https://stat.ethz.ch/R-manual/R-devel/library/Matrix/html/dgCMatrix-class.html, https://stat.ethz.ch/R-manual/R-devel/library/Matrix/html/dgRMatrix-class.html, https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html#scipy.sparse.csc_matrix 68 | } 69 | -------------------------------------------------------------------------------- /R-package/man/mat_2scipy_sparse.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/mat_2scipy_sparse.R 3 | \name{mat_2scipy_sparse} 4 | \alias{mat_2scipy_sparse} 5 | \title{conversion of an R matrix to a scipy sparse matrix} 6 | \usage{ 7 | mat_2scipy_sparse(x, format = "sparse_row_matrix") 8 | } 9 | \arguments{ 10 | \item{x}{a data matrix} 11 | 12 | \item{format}{a character string. Either \emph{"sparse_row_matrix"} or \emph{"sparse_column_matrix"}} 13 | } 14 | \description{ 15 | conversion of an R matrix to a scipy sparse matrix 16 | } 17 | \details{ 18 | This function allows the user to convert an R matrix to a scipy sparse matrix. This is useful because the Regularized Greedy Forest algorithm accepts only python sparse matrices as input. 19 | } 20 | \examples{ 21 | 22 | try({ 23 | if (reticulate::py_available(initialize = FALSE)) { 24 | if (reticulate::py_module_available("scipy")) { 25 | 26 | library(RGF) 27 | 28 | set.seed(1) 29 | 30 | x = matrix(runif(1000), nrow = 100, ncol = 10) 31 | 32 | res = mat_2scipy_sparse(x) 33 | 34 | print(dim(x)) 35 | 36 | print(res$shape) 37 | } 38 | } 39 | }, silent = TRUE) 40 | } 41 | \references{ 42 | https://docs.scipy.org/doc/scipy/reference/sparse.html 43 | } 44 | -------------------------------------------------------------------------------- /R-package/tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(RGF) 3 | 4 | test_check("RGF") 5 | -------------------------------------------------------------------------------- /R-package/tests/testthat/helper-init.R: -------------------------------------------------------------------------------- 1 | 2 | # prefer Python 3 if available [ see: https://github.com/rstudio/reticulate/blob/master/tests/testthat/helper-init.R ] 3 | if (!reticulate::py_available(initialize = FALSE) && 4 | is.na(Sys.getenv("RETICULATE_PYTHON", unset = NA))) 5 | { 6 | python <- Sys.which("python3") 7 | if (nzchar(python)) 8 | reticulate::use_python(python, required = TRUE) 9 | } 10 | -------------------------------------------------------------------------------- /R-package/tests/testthat/helper-skip.R: -------------------------------------------------------------------------------- 1 | 2 | #....................................... 3 | # skip a test if python is not available [ see: https://github.com/rstudio/reticulate/tree/master/tests/testthat ] 4 | #....................................... 5 | 6 | skip_test_if_no_python <- function() { 7 | if (!reticulate::py_available(initialize = FALSE)) 8 | testthat::skip("Python bindings not available for testing") 9 | } 10 | 11 | 12 | #......................................... 13 | # skip a test if a module is not available [ see: https://github.com/rstudio/reticulate ] 14 | #......................................... 15 | 16 | skip_test_if_no_module <- function(MODULE) { # MODULE is of type character string ( length(MODULE) >= 1 ) 17 | 18 | try({ 19 | if (length(MODULE) == 1) { 20 | module_exists <- reticulate::py_module_available(MODULE)} 21 | else { 22 | module_exists <- sum(as.vector(sapply(MODULE, function(x) reticulate::py_module_available(x)))) == length(MODULE) 23 | } 24 | }, silent = TRUE) 25 | 26 | if (!module_exists) { 27 | testthat::skip(paste0(MODULE, " is not available for testthat-testing")) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /R-package/tests/testthat/setup.R: -------------------------------------------------------------------------------- 1 | 2 | # Input data 3 | 4 | # data [ regression and (multiclass-) classification RGF_Regressor, RGF_Classifier ] 5 | #----------------------------------------------------------------------------------- 6 | 7 | set.seed(1) 8 | x_rgf = matrix(runif(1000), nrow = 100, ncol = 10) 9 | 10 | 11 | # data [ regression and (multiclass-) classification FastRGF_Regressor, FastRGF_Classifier ] 12 | #------------------------------------------------------------------------------------------- 13 | 14 | set.seed(2) 15 | x_FASTrgf = matrix(runif(100000), nrow = 100, ncol = 1000) # high dimensionality for 'FastRGF' (however more observations are needed so that it works properly) 16 | 17 | 18 | # response regression 19 | #-------------------- 20 | 21 | set.seed(3) 22 | y_reg = runif(100) 23 | 24 | 25 | # response "binary" classification 26 | #--------------------------------- 27 | 28 | set.seed(4) 29 | y_BINclass = sample(1:2, 100, replace = TRUE) 30 | 31 | 32 | # response "multiclass" classification 33 | #------------------------------------- 34 | 35 | set.seed(5) 36 | y_MULTIclass = sample(1:5, 100, replace = TRUE) 37 | 38 | 39 | # weights for the fit function 40 | #------------------------------ 41 | 42 | set.seed(6) 43 | W = runif(100) 44 | 45 | # Temporary I/O structures 46 | 47 | # default directory where the temporary 'rgf' files are saved 48 | #------------------------------------------------------------ 49 | 50 | default_dir = file.path(dirname(tempdir()), 'rgf') 51 | -------------------------------------------------------------------------------- /R-package/tests/testthat/teardown.R: -------------------------------------------------------------------------------- 1 | context("teardown") 2 | 3 | # remove temporary 'rgf' files 4 | if (dir.exists(default_dir)) unlink(default_dir, recursive = TRUE, force = TRUE) 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Python and R tests](https://github.com/RGF-team/rgf/workflows/Python%20and%20R%20tests/badge.svg?branch=master)](https://github.com/RGF-team/rgf/actions) 2 | [![DOI](https://zenodo.org/badge/DOI/10.1109/TPAMI.2013.159.svg)](https://doi.org/10.1109/TPAMI.2013.159) 3 | [![arXiv.org](https://img.shields.io/badge/arXiv-1109.0887-b31b1b.svg)](https://arxiv.org/abs/1109.0887) 4 | [![Python Versions](https://img.shields.io/pypi/pyversions/rgf_python.svg)](https://pypi.org/project/rgf_python) 5 | [![PyPI Version](https://img.shields.io/pypi/v/rgf_python.svg)](https://pypi.org/project/rgf_python) 6 | [![CRAN Version](https://r-pkg.org/badges/version/RGF)](https://cran.r-project.org/package=RGF) 7 | 8 | # Regularized Greedy Forest 9 | 10 | Regularized Greedy Forest (RGF) is a tree ensemble machine learning method described in [this paper](https://arxiv.org/abs/1109.0887). 11 | RGF can deliver better results than gradient boosted decision trees (GBDT) on a number of datasets and it has been used to win a few Kaggle competitions. 12 | Unlike the traditional boosted decision tree approach, RGF works directly with the underlying forest structure. 13 | RGF integrates two ideas: one is to include tree-structured regularization into the learning formulation; and the other is to employ the fully-corrective regularized greedy algorithm. 14 | 15 | This repository contains the following implementations of the RGF algorithm: 16 | 17 | - [RGF](https://github.com/RGF-team/rgf/tree/master/RGF): original implementation from the paper; 18 | - [FastRGF](https://github.com/RGF-team/rgf/tree/master/FastRGF): multi-core implementation with some simplifications; 19 | - [rgf_python](https://github.com/RGF-team/rgf/tree/master/python-package): wrapper of both RGF and FastRGF implementations for Python; 20 | - [R package](https://github.com/RGF-team/rgf/tree/master/R-package): wrapper of rgf_python for R. 21 | 22 | You may want to get interesting information about RGF from the posts collected in [Awesome RGF](https://github.com/RGF-team/rgf/blob/master/AWESOME_RGF.md ). 23 | -------------------------------------------------------------------------------- /RGF/CHANGES.md: -------------------------------------------------------------------------------- 1 | # March 2018 (version 1.3) 2 | 3 | 1. Calculation of feature importance has been added. 4 | 5 | 2. Dumping information about the forest model to the console has been added. 6 | 7 | 3. License has been changed from GPLv3 to MIT. 8 | 9 | **Breaking changes:** 10 | 11 | - Due to adding feature importance, old model files are not compatible with version `1.3`. 12 | 13 | # February 2018 14 | 15 | 1. Absolute error loss has been added. 16 | 17 | # August-December 2017 18 | 19 | 1. The executable file for 32-bit Windows has been added. 20 | 21 | 2. Compilation with MinGW on Windows has been fixed. 22 | 23 | 3. `CMakeLists.txt` file has been added. 24 | 25 | 4. Compilation on 32-bit Windows has been fixed. 26 | Also, compilation with newer MS Visual Studios has been added. 27 | 28 | # June 2014 29 | 30 | This version (`1.2` with modifications listed below) 31 | you can download [here](http://tongzhang-ml.org/software/rgf/index.html). 32 | 33 | 1. The restriction on the size of training data files has been removed. 34 | 35 | _Old_: Training data files had to be smaller than 2GB. 36 | 37 | _New_: No restriction on the training data file sizes. 38 | (However, the number of lines and the length of each line must 39 | be smaller than 2^31 (2,147,483,648).) 40 | 41 | 2. The solution file for MS Visual C++ 2010 Express has been changed 42 | from 32-bit to 64-bit; also, `__AZ_MSDN__` has been added to 43 | Preprocessor Definitions. 44 | 45 | 3. Some source code files have been changed. 46 | 47 | # September 2012 (version 1.2) 48 | 49 | 1. The initial release. 50 | -------------------------------------------------------------------------------- /RGF/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(rgf) 2 | set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_SOURCE_DIR}) 3 | set(EXECUTABLE_OUTPUT_PATH ${CMAKE_CURRENT_SOURCE_DIR}/bin) 4 | if(MINGW) 5 | set(CMAKE_EXE_LINKER_FLAGS ${CMAKE_EXE_LINKER_FLAGS} -static-libstdc++) 6 | endif() 7 | if(MSVC) 8 | cmake_minimum_required(VERSION 3.1) 9 | add_definitions(-D__AZ_MSDN__) 10 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${CMAKE_CURRENT_SOURCE_DIR}/bin) 11 | set(variables 12 | CMAKE_CXX_FLAGS_DEBUG 13 | CMAKE_CXX_FLAGS_MINSIZEREL 14 | CMAKE_CXX_FLAGS_RELEASE 15 | CMAKE_CXX_FLAGS_RELWITHDEBINFO) 16 | foreach(variable ${variables}) 17 | if(${variable} MATCHES "/MD") 18 | string(REGEX REPLACE "/MD" "/MT" ${variable} "${${variable}}") 19 | set(${variable} "${${variable}}" PARENT_SCOPE) 20 | endif() 21 | endforeach() 22 | endif() 23 | if(CMAKE_CONFIGURATION_TYPES) 24 | set(CMAKE_CONFIGURATION_TYPES Release) 25 | set(CMAKE_CONFIGURATION_TYPES "${CMAKE_CONFIGURATION_TYPES}" CACHE STRING 26 | "Reset the configurations to what we need" FORCE) 27 | endif() 28 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/com ${CMAKE_CURRENT_SOURCE_DIR}/src/tet) 29 | file(GLOB_RECURSE FILES "./src/*.cpp") 30 | add_definitions(-O2) 31 | add_executable(${PROJECT_NAME} ${FILES}) 32 | target_link_libraries(${PROJECT_NAME}) 33 | install(TARGETS ${PROJECT_NAME} DESTINATION bin) 34 | -------------------------------------------------------------------------------- /RGF/COPYING: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2014 Rie Johnson, 2018 RGF-team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /RGF/Windows/rgf/rgf.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 11.00 3 | # Visual C++ Express 2010 4 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "rgf", "rgf.vcxproj", "{27341E76-36FB-4CA4-848C-1782869FE39B}" 5 | EndProject 6 | Global 7 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 8 | Debug|Win32 = Debug|Win32 9 | Debug|x64 = Debug|x64 10 | Release|Win32 = Release|Win32 11 | Release|x64 = Release|x64 12 | EndGlobalSection 13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 14 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Release|x64.ActiveCfg = Release|x64 15 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Release|x64.Build.0 = Release|x64 16 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Release|Win32.ActiveCfg = Release|Win32 17 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Release|Win32.Build.0 = Release|Win32 18 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Debug|x64.ActiveCfg = Debug|x64 19 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Debug|x64.Build.0 = Debug|x64 20 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Debug|Win32.ActiveCfg = Debug|Win32 21 | {27341E76-36FB-4CA4-848C-1782869FE39B}.Debug|Win32.Build.0 = Debug|Win32 22 | EndGlobalSection 23 | GlobalSection(SolutionProperties) = preSolution 24 | HideSolutionNode = FALSE 25 | EndGlobalSection 26 | EndGlobal 27 | -------------------------------------------------------------------------------- /RGF/build/makefile: -------------------------------------------------------------------------------- 1 | BIN_NAME = rgf 2 | BIN_DIR = ../bin 3 | TARGET = $(BIN_DIR)/$(BIN_NAME) 4 | CFLAGS = -I../src/com -I../src/tet -O2 5 | 6 | ifeq ($(OS), Windows_NT) 7 | CFLAGS += -static # MinGW fix 8 | endif 9 | 10 | CPP_FILES = ../src/tet/driv_rgf.cpp \ 11 | ../src/com/AzDmat.cpp \ 12 | ../src/tet/AzFindSplit.cpp \ 13 | ../src/com/AzIntPool.cpp \ 14 | ../src/com/AzLoss.cpp \ 15 | ../src/tet/AzOptOnTree_TreeReg.cpp \ 16 | ../src/tet/AzOptOnTree.cpp \ 17 | ../src/com/AzParam.cpp \ 18 | ../src/tet/AzReg_Tsrbase.cpp \ 19 | ../src/tet/AzReg_TsrOpt.cpp \ 20 | ../src/tet/AzReg_TsrSib.cpp \ 21 | ../src/tet/AzRgf_FindSplit_Dflt.cpp \ 22 | ../src/tet/AzRgf_FindSplit_TreeReg.cpp \ 23 | ../src/tet/AzRgf_Optimizer_Dflt.cpp \ 24 | ../src/tet/AzRgforest.cpp \ 25 | ../src/tet/AzRgfTree.cpp \ 26 | ../src/com/AzSmat.cpp \ 27 | ../src/tet/AzSortedFeat.cpp \ 28 | ../src/com/AzStrPool.cpp \ 29 | ../src/com/AzSvDataS.cpp \ 30 | ../src/com/AzTaskTools.cpp \ 31 | ../src/tet/AzTETmain.cpp \ 32 | ../src/tet/AzTETproc.cpp \ 33 | ../src/com/AzTools.cpp \ 34 | ../src/tet/AzTree.cpp \ 35 | ../src/tet/AzTreeEnsemble.cpp \ 36 | ../src/tet/AzTrTree.cpp \ 37 | ../src/tet/AzTrTreeFeat.cpp \ 38 | ../src/com/AzUtil.cpp 39 | 40 | all: 41 | g++ $(CPP_FILES) $(CFLAGS) -o $(TARGET) 42 | -------------------------------------------------------------------------------- /RGF/examples/sample/predict.inp: -------------------------------------------------------------------------------- 1 | #### sample input to "predict" 2 | 3 | #--- apply a model to sample test data and save prediction values 4 | test_x_fn=sample/test.data.x # Test data points 5 | model_fn=output/sample.model-03 # Model file 6 | prediction_fn=output/sample.pred # Where to write prediction values 7 | -------------------------------------------------------------------------------- /RGF/examples/sample/regress_train_test.inp: -------------------------------------------------------------------------------- 1 | # To use this example configuration file: 2 | # Set the current directory to RGF/examples. 3 | # In the command line, enter: 4 | # 5 | # perl call_exe.pl ../bin/rgf train_test sample/regress_train_test 6 | # 7 | 8 | #------------------ Perform 3 runs --------------------# 9 | @reg_L2=1,model_fn_prefix=output/regress.lam1.model # used by 1st run 10 | @reg_L2=0.1,model_fn_prefix=output/regress.lam0.1.model # used by 2nd run 11 | @reg_L2=0.01,model_fn_prefix=output/regress.lam0.01.model # used by 3rd run 12 | #-------------------------------------------------------------------------# 13 | 14 | #--- Other parameters are shared by 3 runs 15 | 16 | train_x_fn=sample/regress.train.x # Training data points 17 | train_y_fn=sample/regress.train.y # Training targets 18 | 19 | test_x_fn=sample/regress.test.x # Test data points 20 | test_y_fn=sample/regress.test.y # Test targets 21 | 22 | algorithm=RGF # RGF with L2 regularization on leaf-only models 23 | loss=LS # Square loss 24 | test_interval=500 # Test (and save) models every time 500 leaves are added 25 | max_leaf_forest=5000 # Stop training when #leaf reaches 5000 26 | Verbose # Display info during training 27 | NormalizeTarget # Normalize targets so that the average becomes zero 28 | 29 | #train_w_fn=?? # User-specified weights of data points 30 | #model_fn_for_warmstart=?? # Path to the model file to do warm-start with 31 | -------------------------------------------------------------------------------- /RGF/examples/sample/test.data.x: -------------------------------------------------------------------------------- 1 | 79 93 97 99 68 14 94 90 3 34 2 | 41 40 88 38 87 71 34 82 18 91 3 | 9 73 66 12 88 84 26 21 74 48 4 | 52 82 75 67 98 46 95 5 49 17 5 | 15 67 89 80 73 61 90 19 9 72 6 | 15 25 11 6 54 95 7 37 83 57 7 | 63 2 96 90 16 99 68 99 27 67 8 | 39 45 84 4 61 27 64 60 90 12 9 | 40 80 6 51 50 78 48 20 0 39 10 | 40 35 42 9 25 43 95 57 56 34 11 | 63 91 35 79 7 68 65 16 93 73 12 | 79 31 18 18 86 78 13 50 25 26 13 | 78 73 1 4 83 85 83 29 38 93 14 | 67 68 47 43 47 57 11 73 97 49 15 | 37 95 20 7 68 47 0 61 33 16 16 | 93 7 50 75 18 91 35 78 86 29 17 | 64 7 86 25 0 14 85 31 52 10 18 | 5 74 47 57 50 33 68 73 16 26 19 | 89 33 0 47 51 93 34 61 21 69 20 | 86 79 21 52 28 68 89 63 66 84 21 | 47 87 56 44 55 89 1 92 68 53 22 | 49 41 33 29 61 92 45 27 79 57 23 | 81 36 2 74 50 90 84 34 58 48 24 | 6 53 74 92 7 19 74 86 17 89 25 | 72 63 77 47 1 61 29 3 92 25 26 | 99 89 64 88 95 67 22 70 73 2 27 | 63 92 61 66 8 34 99 34 54 76 28 | 70 88 46 56 54 60 6 51 28 74 29 | 13 83 63 49 97 6 89 97 70 89 30 | 60 93 73 19 54 10 46 99 25 98 31 | 66 51 75 44 56 38 93 3 35 89 32 | 50 85 55 81 91 30 84 64 52 33 33 | 16 57 37 31 99 49 46 25 22 56 34 | 0 58 51 3 14 34 58 95 94 73 35 | 56 62 23 76 63 48 90 4 42 36 36 | 48 63 86 72 44 19 66 70 1 49 37 | 23 93 59 39 67 24 5 93 34 29 38 | 41 5 86 21 79 74 7 26 90 78 39 | 34 97 16 70 27 41 9 61 72 35 40 | 87 3 37 79 21 73 13 70 36 50 41 | 2 58 37 80 11 88 21 31 21 40 42 | 1 5 32 86 62 23 84 91 11 74 43 | 36 94 72 70 45 28 59 27 69 30 44 | 33 85 70 17 8 49 70 14 86 92 45 | 1 35 92 8 53 41 28 85 96 96 46 | 56 82 79 32 64 87 86 10 1 13 47 | 24 58 91 41 36 86 32 5 8 47 48 | 50 93 65 33 83 20 28 33 63 77 49 | 63 15 26 80 39 66 70 85 32 3 50 | 27 21 13 75 41 46 70 64 70 80 51 | 19 2 57 89 35 40 46 12 29 66 52 | 46 6 14 52 98 37 95 10 7 18 53 | 80 73 37 72 84 34 84 93 63 33 54 | 42 49 38 97 86 97 55 79 8 1 55 | 76 40 67 41 72 95 31 64 75 78 56 | 87 62 25 99 14 37 26 85 24 58 57 | 68 88 83 93 82 11 91 88 31 85 58 | 15 3 66 23 80 5 89 16 39 14 59 | 48 84 66 86 19 76 67 53 10 78 60 | 1 12 36 73 19 99 42 89 58 27 61 | 65 80 65 49 74 83 11 15 9 67 62 | 88 79 65 6 96 14 78 50 88 71 63 | 48 64 26 48 7 84 56 32 22 77 64 | 44 1 88 99 43 72 53 84 53 11 65 | 77 16 31 90 50 22 4 97 47 80 66 | 21 13 69 90 13 85 89 30 59 59 67 | 74 84 9 87 25 26 56 15 2 91 68 | 41 52 60 15 76 56 43 42 51 40 69 | 66 6 98 71 40 70 12 64 83 1 70 | 44 25 23 22 48 89 27 61 30 27 71 | 61 34 74 95 23 96 79 88 93 9 72 | 69 74 65 56 6 54 45 51 8 85 73 | 48 6 47 81 76 64 31 30 42 20 74 | 7 67 86 74 91 83 56 7 56 43 75 | 64 63 26 88 42 34 14 2 77 45 76 | 93 35 40 1 55 52 49 84 24 95 77 | 78 67 89 64 10 8 2 85 81 49 78 | 2 48 58 12 94 58 52 96 15 34 79 | 6 60 90 62 12 59 13 70 71 55 80 | 36 91 76 82 90 90 76 94 81 92 81 | 46 64 70 40 61 60 12 55 79 82 82 | 18 75 83 76 65 39 7 34 13 74 83 | 49 18 7 47 31 62 14 56 72 26 84 | 76 21 93 18 68 62 36 47 58 71 85 | 4 77 90 24 1 33 80 58 69 96 86 | 7 45 99 8 17 54 72 17 23 50 87 | 67 62 73 80 96 53 84 63 56 9 88 | 46 23 34 51 14 22 34 39 40 6 89 | 31 63 27 2 54 28 81 89 85 40 90 | 23 66 36 92 91 64 19 31 98 3 91 | 39 50 93 19 38 21 93 49 65 34 92 | 49 92 24 5 61 42 85 16 85 55 93 | 88 44 80 1 71 39 53 29 56 59 94 | 73 45 83 70 86 90 78 56 84 24 95 | 22 91 15 84 72 37 84 32 12 71 96 | 83 81 25 83 44 34 72 91 52 20 97 | 97 96 62 30 28 38 47 93 48 38 98 | 63 43 14 12 49 37 96 59 57 49 99 | 0 61 0 83 34 67 16 42 52 35 100 | 20 41 25 84 33 69 28 52 70 40 101 | -------------------------------------------------------------------------------- /RGF/examples/sample/test.data.y: -------------------------------------------------------------------------------- 1 | +1 2 | +1 3 | -1 4 | +1 5 | +1 6 | +1 7 | +1 8 | -1 9 | -1 10 | -1 11 | -1 12 | -1 13 | -1 14 | +1 15 | +1 16 | -1 17 | -1 18 | +1 19 | -1 20 | -1 21 | +1 22 | +1 23 | -1 24 | -1 25 | +1 26 | +1 27 | -1 28 | +1 29 | +1 30 | -1 31 | +1 32 | -1 33 | +1 34 | +1 35 | -1 36 | -1 37 | +1 38 | +1 39 | +1 40 | -1 41 | +1 42 | +1 43 | -1 44 | -1 45 | -1 46 | -1 47 | +1 48 | +1 49 | +1 50 | -1 51 | +1 52 | +1 53 | -1 54 | +1 55 | +1 56 | -1 57 | +1 58 | +1 59 | -1 60 | +1 61 | +1 62 | -1 63 | -1 64 | -1 65 | -1 66 | -1 67 | +1 68 | +1 69 | -1 70 | -1 71 | -1 72 | +1 73 | -1 74 | -1 75 | +1 76 | +1 77 | -1 78 | +1 79 | +1 80 | +1 81 | +1 82 | +1 83 | +1 84 | -1 85 | +1 86 | -1 87 | -1 88 | -1 89 | +1 90 | -1 91 | -1 92 | +1 93 | +1 94 | -1 95 | +1 96 | +1 97 | +1 98 | -1 99 | +1 100 | -1 101 | -------------------------------------------------------------------------------- /RGF/examples/sample/train.data.x: -------------------------------------------------------------------------------- 1 | 58 19 5 92 3 62 26 56 43 35 2 | 19 61 54 81 8 50 55 83 69 81 3 | 79 93 49 17 67 70 33 38 81 53 4 | 66 30 78 67 18 98 72 25 35 53 5 | 86 34 68 72 40 8 89 7 34 0 6 | 9 99 46 86 51 13 74 46 82 62 7 | 82 99 20 47 46 42 40 66 64 55 8 | 99 73 59 94 16 16 80 28 11 30 9 | 96 67 72 16 10 27 79 48 86 54 10 | 23 3 63 91 0 97 72 3 57 46 11 | 25 72 85 34 76 65 68 64 51 61 12 | 98 34 49 83 44 16 3 3 70 63 13 | 63 39 88 20 73 0 52 9 30 57 14 | 99 91 47 60 71 74 79 13 20 92 15 | 75 64 9 60 29 66 38 26 39 94 16 | 36 94 62 55 65 10 42 62 53 35 17 | 19 94 55 68 65 36 95 21 72 52 18 | 16 21 90 23 14 8 69 79 18 78 19 | 44 84 85 32 66 12 26 12 9 81 20 | 86 19 68 21 33 42 95 46 7 64 21 | 85 48 92 11 76 56 43 97 60 70 22 | 68 81 89 89 65 26 87 29 30 7 23 | 95 93 5 27 68 26 25 28 61 8 24 | 20 78 44 80 31 93 55 81 69 39 25 | 91 67 92 15 55 53 27 47 34 9 26 | 91 15 95 30 32 18 35 13 46 76 27 | 68 64 89 0 73 75 6 44 97 84 28 | 64 8 9 50 45 35 68 24 92 28 29 | 74 74 7 31 53 28 92 38 17 34 30 | 94 78 31 96 33 54 50 47 77 44 31 | 18 11 7 93 26 79 5 70 99 46 32 | 16 52 3 27 63 56 85 83 81 63 33 | 55 25 91 67 78 15 47 55 89 87 34 | 24 16 8 12 44 17 46 30 66 91 35 | 65 31 95 68 13 16 66 65 60 95 36 | 31 55 84 70 74 83 45 28 25 12 37 | 61 3 43 20 5 2 25 50 2 95 38 | 12 7 92 19 89 76 57 77 74 43 39 | 85 93 10 31 37 2 71 5 85 13 40 | 97 20 67 55 53 61 24 93 3 60 41 | 57 83 81 70 87 23 33 67 82 90 42 | 65 45 48 39 21 73 68 35 25 58 43 | 57 45 32 68 66 15 60 51 58 29 44 | 2 36 0 48 92 3 98 45 97 26 45 | 65 87 37 4 82 94 73 81 36 2 46 | 38 33 2 96 1 68 18 39 86 73 47 | 46 22 13 26 81 95 88 74 90 95 48 | 51 92 9 16 8 55 30 87 78 47 49 | 86 48 60 35 3 6 46 87 71 31 50 | 12 52 36 48 34 65 46 49 98 21 51 | 34 43 39 98 78 5 14 80 80 76 52 | 48 54 96 3 64 40 88 47 9 85 53 | 74 80 33 87 24 15 95 18 3 44 54 | 70 57 14 95 75 31 0 67 49 96 55 | 1 51 3 42 69 45 6 99 65 87 56 | 39 16 98 20 35 19 15 99 65 68 57 | 86 98 72 37 95 53 6 36 64 92 58 | 87 65 36 39 94 8 5 32 33 0 59 | 87 98 52 49 15 36 20 28 60 4 60 | 84 92 93 82 91 54 84 19 83 84 61 | 85 26 62 87 41 60 79 8 32 42 62 | 75 12 49 81 68 87 86 97 13 33 63 | 88 38 60 89 15 43 37 47 33 96 64 | 95 1 7 27 20 69 53 1 84 53 65 | 7 98 50 39 28 70 10 47 6 49 66 | 90 30 70 6 36 84 73 15 36 22 67 | 38 48 2 4 39 1 66 60 47 88 68 | 38 12 46 76 95 60 74 62 72 48 69 | 91 94 47 24 94 42 29 25 14 13 70 | 69 63 44 43 28 35 44 70 23 52 71 | 51 36 8 12 59 96 34 96 53 94 72 | 16 88 94 11 68 87 88 51 56 83 73 | 46 1 36 27 3 77 41 94 54 76 74 | 1 4 97 65 32 38 88 97 11 75 75 | 9 24 7 63 86 95 88 38 61 66 76 | 71 55 42 91 56 14 77 82 40 68 77 | 26 38 62 20 2 27 12 29 68 33 78 | 21 63 92 54 49 58 75 60 96 2 79 | 64 22 30 60 64 11 71 59 80 54 80 | 94 38 43 23 23 81 89 70 25 18 81 | 28 76 49 71 4 10 13 49 26 33 82 | 36 43 29 50 95 75 11 5 7 43 83 | 31 53 47 76 85 21 84 71 63 10 84 | 51 29 58 8 32 13 89 86 9 99 85 | 65 58 68 64 32 16 28 35 0 1 86 | 9 86 30 39 7 41 28 63 69 13 87 | 3 96 70 5 13 30 24 82 65 77 88 | 79 69 83 59 76 81 26 51 98 97 89 | 17 39 45 14 54 8 54 85 9 1 90 | 86 44 69 6 2 19 6 58 1 5 91 | 2 7 8 2 61 76 27 11 11 50 92 | 73 12 73 43 66 62 26 17 71 85 93 | 23 99 62 68 33 53 66 30 11 92 94 | 61 23 43 68 12 71 45 69 43 18 95 | 93 43 93 85 72 22 78 4 79 88 96 | 11 41 99 73 46 67 44 50 87 59 97 | 67 81 50 79 13 84 47 72 42 42 98 | 98 33 98 89 34 48 19 62 39 89 99 | 23 47 30 76 88 19 6 96 61 71 100 | 61 45 70 51 5 26 75 43 68 88 101 | -------------------------------------------------------------------------------- /RGF/examples/sample/train.data.y: -------------------------------------------------------------------------------- 1 | -1 2 | -1 3 | -1 4 | -1 5 | -1 6 | -1 7 | +1 8 | -1 9 | -1 10 | -1 11 | -1 12 | -1 13 | +1 14 | +1 15 | -1 16 | +1 17 | -1 18 | -1 19 | -1 20 | -1 21 | +1 22 | -1 23 | -1 24 | -1 25 | +1 26 | -1 27 | -1 28 | -1 29 | -1 30 | -1 31 | -1 32 | +1 33 | -1 34 | -1 35 | -1 36 | +1 37 | -1 38 | +1 39 | -1 40 | +1 41 | -1 42 | +1 43 | -1 44 | -1 45 | +1 46 | -1 47 | +1 48 | +1 49 | -1 50 | +1 51 | +1 52 | -1 53 | +1 54 | +1 55 | +1 56 | -1 57 | +1 58 | +1 59 | -1 60 | -1 61 | -1 62 | -1 63 | -1 64 | -1 65 | +1 66 | +1 67 | -1 68 | +1 69 | +1 70 | -1 71 | -1 72 | -1 73 | +1 74 | +1 75 | -1 76 | -1 77 | -1 78 | -1 79 | +1 80 | +1 81 | +1 82 | +1 83 | +1 84 | -1 85 | -1 86 | +1 87 | +1 88 | -1 89 | -1 90 | -1 91 | -1 92 | -1 93 | -1 94 | -1 95 | -1 96 | -1 97 | -1 98 | -1 99 | +1 100 | -1 101 | -------------------------------------------------------------------------------- /RGF/examples/sample/train.inp: -------------------------------------------------------------------------------- 1 | #### sample input to "train" #### 2 | 3 | train_x_fn=sample/train.data.x # Training data points 4 | train_y_fn=sample/train.data.y # Training targets 5 | 6 | #--- Save the models with filenames output/sample.model-01, 7 | #--- output/sample.model-02,... 8 | model_fn_prefix=output/sample.model 9 | 10 | #--- training parameters 11 | reg_L2=1 # Regularization parameter 12 | algorithm=RGF # RGF with L2 regularization with leaf-only models 13 | loss=LS # Square loss 14 | test_interval=100 # Save models every time 100 leaves are added 15 | max_leaf_forest=500 # Stop training when #leaf reaches 500 16 | Verbose # Display info during training 17 | 18 | #--- other parameters (commented out) 19 | #NormalizeTarget # Normalize targets so that the average becomes zero 20 | #train_w_fn=?? # User-specified weights of data points 21 | #model_fn_for_warmstart=?? # Path to the model file to do warm-start with 22 | -------------------------------------------------------------------------------- /RGF/examples/sample/train_predict.inp: -------------------------------------------------------------------------------- 1 | #### sample input to "train_predict" #### 2 | 3 | train_x_fn=sample/train.data.x # Training data points 4 | train_y_fn=sample/train.data.y # Training targets 5 | 6 | test_x_fn=sample/test.data.x # Test data points 7 | 8 | #--- 9 | model_fn_prefix=output/m 10 | #--- Models are saved with filenames output/m-01, 11 | #--- output/m-02,... 12 | #--- Predictions are saved with filenames output/m-01.pred, 13 | #--- output/m-02.pred,... 14 | #--- Model info such as #leaf are saved with filenames output/m-01.info, 15 | #--- output/m-02.info,... 16 | 17 | SaveLastModelOnly # Only the last (largest) model will be saved to a file. 18 | # Comment this out if all the models should be saved 19 | 20 | #--- training parameters 21 | algorithm=RGF # RGF with L2 regularization on leaf-only models 22 | reg_L2=1 # Regularization parameter 23 | loss=LS # Square loss 24 | test_interval=100 # Test (and save) models every time 100 leaves are added 25 | max_leaf_forest=500 # Stop training when #leaf reaches 500 26 | Verbose # Display info during training 27 | 28 | #--- other parameters (commented out) 29 | #NormalizeTarget # Normalize targets so that the average becomes zero 30 | #train_w_fn=?? # User-specified weights of data points 31 | #model_fn_for_warmstart=?? # Path to the model file to do warm-start with 32 | -------------------------------------------------------------------------------- /RGF/examples/sample/train_test.inp: -------------------------------------------------------------------------------- 1 | #### sample input to "train_test" #### 2 | 3 | train_x_fn=sample/train.data.x # Training data points 4 | train_y_fn=sample/train.data.y # Training targets 5 | 6 | test_x_fn=sample/test.data.x # Test data points 7 | test_y_fn=sample/test.data.y # Test targets 8 | 9 | evaluation_fn=output/sample.evaluation # Where to write evaluation results 10 | 11 | model_fn_prefix=output/m # Comment this out if models should not be saved 12 | 13 | 14 | #--- training parameters 15 | algorithm=RGF # RGF with L2 regularization on leaf-only models 16 | reg_L2=1 # Regularization parameter 17 | loss=LS # Square loss 18 | test_interval=100 # Test (and save) models every time 100 leaves are added 19 | max_leaf_forest=500 # Stop training when #leaf reaches 500 20 | Verbose # Display info during training 21 | 22 | #--- other parameters (commented out) 23 | #NormalizeTarget # Normalize targets so that the average becomes zero 24 | #train_w_fn=?? # User-specified weights of data points 25 | #model_fn_for_warmstart=?? # Path to the model file to do warm-start with 26 | -------------------------------------------------------------------------------- /RGF/src/com/AzBmat.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzBmat.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_BMAT_HPP_ 10 | #define _AZ_BMAT_HPP_ 11 | #include "AzUtil.hpp" 12 | 13 | //! binary matrix 14 | class AzBmat { 15 | protected: 16 | int row_num; 17 | AzDataArray a; 18 | 19 | public: 20 | AzBmat() : row_num(0) {} 21 | AzBmat(int inp_row_num, int inp_col_num) : row_num(0) { 22 | resize(inp_col_num); 23 | } 24 | AzBmat(const AzBmat *inp) : row_num(0) { 25 | set(inp); 26 | } 27 | inline void set(const AzBmat *inp) { 28 | row_num = inp->row_num; 29 | a.reset(&inp->a); 30 | } 31 | AzBmat(const AzBmat &inp) : row_num(0) { 32 | set(&inp); 33 | } 34 | AzBmat & operator =(const AzBmat &inp) { 35 | if (this == &inp) return *this; 36 | set(&inp); 37 | return *this; 38 | } 39 | inline void reform(int inp_row_num, int inp_col_num) { 40 | reset(); 41 | row_num = inp_row_num; 42 | resize(inp_col_num); 43 | } 44 | inline void resize(int new_col_num) { 45 | a.resz(new_col_num); 46 | } 47 | 48 | inline void reset() { 49 | row_num = 0; 50 | a.reset(); 51 | } 52 | inline int rowNum() const { 53 | return row_num; 54 | } 55 | inline int colNum() const { 56 | return a.cursor(); 57 | } 58 | 59 | inline const AzIntArr *on_rows(int col) const { 60 | return a.point(col); 61 | } 62 | inline void clear(int fx) { 63 | a.point_u(fx)->reset(); 64 | } 65 | 66 | inline void load(int col, const AzIntArr *ia_on_rows) { 67 | if (ia_on_rows == NULL || ia_on_rows->size() <= 0) return; 68 | 69 | if (ia_on_rows->min() < 0 || 70 | ia_on_rows->max() >= row_num) { 71 | throw new AzException("AzBmat::load", "wrong row#"); 72 | } 73 | a.point_u(col)->reset(ia_on_rows); 74 | } 75 | }; 76 | #endif 77 | 78 | 79 | -------------------------------------------------------------------------------- /RGF/src/com/AzException.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzException.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_EXCEPTION_HPP_ 10 | #define _AZ_EXCEPTION_HPP_ 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | using namespace std; 18 | 19 | enum AzRetCode { 20 | AzNormal=0, 21 | AzAllocError=10, 22 | AzFileIOError=20, 23 | AzInputError=30, 24 | AzInputMissing=31, 25 | AzInputNotValid=32, 26 | AzConflict=100, /* all others */ 27 | }; 28 | 29 | /*-----------------------------------------------------*/ 30 | class AzException { 31 | public: 32 | AzException(const char *string1, 33 | const char *string2, 34 | const char *string3=NULL) 35 | { 36 | reset(AzConflict, string1, string2, string3); 37 | } 38 | 39 | AzException(AzRetCode retcode, 40 | const char *string1, 41 | const char *string2, 42 | const char *string3=NULL) 43 | { 44 | reset(retcode, string1, string2, string3); 45 | } 46 | 47 | template 48 | AzException(AzRetCode retcode, 49 | const char *string1, 50 | const char *string2, 51 | const char *string3, 52 | T anything) 53 | { 54 | reset(retcode, string1, string2, string3); 55 | s3 << "; " << anything; 56 | } 57 | 58 | void reset(AzRetCode retcode, 59 | const char *str1, 60 | const char *str2, 61 | const char *str3) 62 | { 63 | this->retcode = retcode; 64 | if (str1 != NULL) s1 << str1; 65 | if (str2 != NULL) s2 << str2; 66 | if (str3 != NULL) s3 << str3; 67 | } 68 | 69 | AzRetCode getReturnCode() { 70 | return retcode; 71 | } 72 | 73 | string getMessage() 74 | { 75 | if (retcode == AzNormal) { 76 | 77 | } 78 | else if (retcode == AzAllocError) { 79 | message << "!Memory alloc error!"; 80 | } 81 | else if (retcode == AzFileIOError) { 82 | message << "!File I/O error!"; 83 | } 84 | else if (retcode == AzInputError) { 85 | message << "!Input error!"; 86 | } 87 | else if (retcode == AzInputMissing) { 88 | message << "!Missing input!"; 89 | } 90 | else if (retcode == AzInputNotValid) { 91 | message << "!Input value is not valid!"; 92 | } 93 | else if (retcode == AzConflict) { 94 | message << "Conflict"; 95 | } 96 | else { 97 | message << "Unknown error"; 98 | } 99 | 100 | message << ": "; 101 | if (s1.str().find("Az") == 0) { 102 | message << "(Detected in " << s1.str() << ") " << endl; 103 | } 104 | else { 105 | message << s1.str() << " "; 106 | } 107 | message << s2.str(); 108 | if (s3.str().length() > 0) { 109 | message << " " << s3.str(); 110 | } 111 | message << endl; 112 | return message.str(); 113 | } 114 | 115 | protected: 116 | AzRetCode retcode; 117 | 118 | stringstream s1, s2, s3; 119 | stringstream message; 120 | }; 121 | 122 | #endif 123 | -------------------------------------------------------------------------------- /RGF/src/com/AzOut.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzOut.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_OUT_HPP_ 10 | #define _AZ_OUT_HPP_ 11 | 12 | class AzOut { 13 | protected: 14 | bool isActive; 15 | int level; 16 | public: 17 | ostream *o; 18 | 19 | inline AzOut() : o(NULL), isActive(true), level(0) {} 20 | inline AzOut(ostream *o_ptr) : isActive(true), level(0) { 21 | o = o_ptr; 22 | } 23 | inline void reset(ostream *o_ptr) { 24 | o = o_ptr; 25 | activate(); 26 | } 27 | 28 | inline void deactivate() { 29 | isActive = false; 30 | } 31 | inline void activate() { 32 | isActive = true; 33 | } 34 | inline void setStdout() { 35 | o = &cout; 36 | activate(); 37 | } 38 | inline void setStderr() { 39 | o = &cerr; 40 | activate(); 41 | } 42 | inline bool isNull() const { 43 | if (!isActive) return true; 44 | if (o == NULL) return true; 45 | return false; 46 | } 47 | inline void flush() const { 48 | if (o != NULL) o->flush(); 49 | } 50 | inline void setLevel(int inp_level) { 51 | level = inp_level; 52 | } 53 | inline int getLevel() const { 54 | return level; 55 | } 56 | }; 57 | #endif 58 | -------------------------------------------------------------------------------- /RGF/src/com/AzPerfResult.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzPerfResult.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_PERF_RESULT_HPP_ 10 | #define _AZ_PERF_RESULT_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | 14 | #define AzOtherCat "_X_" 15 | 16 | /*--------------------------------------------------*/ 17 | enum AzPerfType { 18 | AzPerfType_Acc = 0, 19 | AzPerfType_RMSE = 1, 20 | }; 21 | #define AzPerfType_Num 2 22 | static const char *perf_str[AzPerfType_Num] = { 23 | "acc", "rmse", 24 | }; 25 | 26 | /*--------------------------------------------------*/ 27 | class AzPerfResult { 28 | public: 29 | AzPerfResult() { 30 | p=r=f=acc=breakEven_f=breakEven_acc=rmse=loss=-1; 31 | } 32 | double p, r, f, acc, breakEven_f, breakEven_acc, rmse, loss; 33 | inline void put(double inp_p, double inp_r, double inp_f, double inp_acc, 34 | double inp_be_f, double inp_be_acc, 35 | double inp_rmse, double inp_loss) { 36 | p = inp_p; 37 | r = inp_r; 38 | f = inp_f; 39 | acc = inp_acc; 40 | breakEven_f = inp_be_f; 41 | breakEven_acc = inp_be_acc; 42 | rmse = inp_rmse; 43 | loss=inp_loss; 44 | } 45 | double getPerf(AzPerfType p_type) { 46 | if (p_type == AzPerfType_Acc) return acc; 47 | if (p_type == AzPerfType_RMSE) return rmse; 48 | return -1; 49 | } 50 | static const char *getPerfStr(AzPerfType p_type) { 51 | if (p_type < 0 || 52 | p_type >= AzPerfType_Num) return "???"; 53 | return perf_str[p_type]; 54 | } 55 | 56 | static double isBetter(AzPerfType p_type, 57 | double p, double comp_p) { 58 | /*--- negative means unset ---*/ 59 | if (p < 0) return false; 60 | if (comp_p < 0) return true; 61 | 62 | if (p_type == AzPerfType_RMSE) { 63 | if (p < comp_p) return true; 64 | } 65 | else { 66 | if (p > comp_p) return true; 67 | } 68 | return false; 69 | } 70 | void zeroOut() { 71 | p=r=f=acc=breakEven_f=breakEven_acc=rmse=loss=0; 72 | } 73 | void add(const AzPerfResult *inp) { 74 | p+=inp->p; 75 | r+=inp->r; 76 | f+=inp->f; 77 | acc+=inp->acc; 78 | breakEven_f=inp->breakEven_f; 79 | breakEven_acc=inp->breakEven_acc; 80 | rmse+=inp->rmse; 81 | loss+=inp->loss; 82 | } 83 | void multiply(double val) { 84 | p*=val; 85 | r*=val; 86 | f*=val; 87 | acc*=val; 88 | breakEven_f*=val; 89 | breakEven_acc*=val; 90 | rmse*=val; 91 | loss*=val; 92 | } 93 | }; 94 | #endif 95 | -------------------------------------------------------------------------------- /RGF/src/com/AzStrArray.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzStrArray.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_STR_ARRAY_HPP_ 10 | #define _AZ_STR_ARRAY_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | 14 | class AzStrArray { 15 | public: 16 | virtual int size() const = 0; 17 | virtual const char *c_str(int no) const = 0; 18 | void get(int no, AzBytArr *byteq) const { 19 | byteq->reset(); 20 | byteq->concat(c_str(no)); 21 | } 22 | 23 | virtual bool isSame(const AzStrArray *inp) const { 24 | if (size() != inp->size()) { 25 | return false; 26 | } 27 | int ix; 28 | for (ix = 0; ix < size(); ++ix) { 29 | AzBytArr s0; 30 | get(ix, &s0); 31 | if (s0.compare(inp->c_str(ix)) != 0) { 32 | return false; 33 | } 34 | } 35 | return true; 36 | } 37 | 38 | /*---*/ 39 | virtual void writeText(const char *fn) const { 40 | AzFile file(fn); 41 | file.open("wb"); 42 | int ix; 43 | for (ix = 0; ix < size(); ++ix) { 44 | AzBytArr s; 45 | get(ix, &s); 46 | s.nl(); 47 | s.writeText(&file); 48 | } 49 | file.close(true); 50 | } 51 | }; 52 | 53 | #endif 54 | 55 | -------------------------------------------------------------------------------- /RGF/src/com/AzSvFeatInfo.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzSvFeatInfo.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_SV_FEAT_INFO_HPP_ 10 | #define _AZ_SV_FEAT_INFO_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzStrPool.hpp" 14 | #include "AzPrint.hpp" 15 | 16 | //! Abstract class: interfact to access feature descriptions. 17 | class AzSvFeatInfo { 18 | public: 19 | // Concatenate feature description to str_desc. 20 | virtual void concatDesc(int ex, //!< feature id 21 | AzBytArr *str_desc) const = 0; 22 | 23 | //! Return number of features. 24 | virtual int featNum() const = 0; /* # of features */ 25 | 26 | void desc(int ex, AzBytArr *str_desc) const { 27 | str_desc->reset(); 28 | concatDesc(ex, str_desc); 29 | } 30 | 31 | int desc2fno(const char *fnm) const { 32 | int fx; 33 | for (fx = 0; fx < featNum(); ++fx) { 34 | AzBytArr s; 35 | desc(fx, &s); 36 | if (s.compare(fnm) == 0) { 37 | return fx; 38 | } 39 | } 40 | return -1; 41 | } 42 | 43 | void show(const AzOut &out, const AzIntArr *ia_fxs) const { 44 | int ix; 45 | for (ix = 0; ix < ia_fxs->size(); ++ix) { 46 | int fx = ia_fxs->get(ix); 47 | AzBytArr s("???"); 48 | if (fx>=0 && fxreset(); 70 | if (sp_kw->size()==0) return; 71 | int fx; 72 | for (fx = 0; fx < featNum(); ++fx) { 73 | AzBytArr s; 74 | desc(fx, &s); 75 | int ix; 76 | for (ix = 0; ix < sp_kw->size(); ++ix) { 77 | if (s.beginsWith(sp_kw->c_str(ix))) { 78 | ia_fxs->put(fx); 79 | break; 80 | } 81 | } 82 | } 83 | } 84 | 85 | void contains(const AzStrArray *sp_kw, 86 | AzIntArr *ia_fxs) const { 87 | ia_fxs->reset(); 88 | if (sp_kw->size()==0) return; 89 | int fx; 90 | for (fx = 0; fx < featNum(); ++fx) { 91 | AzBytArr s; 92 | desc(fx, &s); 93 | int ix; 94 | for (ix = 0; ix < sp_kw->size(); ++ix) { 95 | if (s.contains(sp_kw->c_str(ix))) { 96 | ia_fxs->put(fx); 97 | break; 98 | } 99 | } 100 | } 101 | } 102 | 103 | int equals(const char *kw) const { 104 | int fx; 105 | for (fx = 0; fx < featNum(); ++fx) { 106 | AzBytArr s; 107 | desc(fx, &s); 108 | if (s.compare(kw) == 0) { 109 | return fx; 110 | } 111 | } 112 | return -1; 113 | } 114 | }; 115 | 116 | #endif 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /RGF/src/com/AzSvFeatInfoClone.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzSvFeatInfoClone.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_SV_FEAT_INFO_CLONE_HPP_ 10 | #define _AZ_SV_FEAT_INFO_CLONE_HPP_ 11 | 12 | #include "AzSvFeatInfo.hpp" 13 | #include "AzStrArray.hpp" 14 | 15 | class AzSvFeatInfoClone : /* implements */ public virtual AzSvFeatInfo, 16 | /* implements */ public virtual AzStrArray 17 | { 18 | protected: 19 | AzDataPool arr_desc; 20 | 21 | public: 22 | AzSvFeatInfoClone() {} 23 | AzSvFeatInfoClone(const AzSvFeatInfo *inp) { 24 | reset(inp); 25 | } 26 | AzSvFeatInfoClone(const AzStrArray *inp) { 27 | reset(inp); 28 | } 29 | inline int featNum() const { 30 | return arr_desc.size(); 31 | } 32 | inline void concatDesc(int fx, AzBytArr *desc) const { 33 | if (fx < 0 || fx >= featNum()) { 34 | desc->c("?"); desc->cn(fx); desc->c("?"); 35 | return; 36 | } 37 | desc->concat(arr_desc.point(fx)); 38 | } 39 | void reset(const AzSvFeatInfo *inp) { 40 | int f_num = inp->featNum(); 41 | arr_desc.reset(); 42 | int fx; 43 | for (fx = 0; fx < f_num; ++fx) { 44 | AzBytArr *ptr = arr_desc.new_slot(); 45 | inp->desc(fx, ptr); 46 | } 47 | } 48 | void reset(const AzStrArray *inp) { 49 | int f_num = inp->size(); 50 | arr_desc.reset(); 51 | int fx; 52 | for (fx = 0; fx < f_num; ++fx) { 53 | AzBytArr *ptr = arr_desc.new_slot(); 54 | ptr->reset(inp->c_str(fx)); 55 | } 56 | } 57 | void reset(int inp_f_num) { 58 | arr_desc.reset(); 59 | int fx; 60 | for (fx = 0; fx < inp_f_num; ++fx) { 61 | AzBytArr s("F"); 62 | s.cn(fx, 3, true); /* width=3, fillWithZero */ 63 | arr_desc.new_slot()->reset(&s); 64 | } 65 | } 66 | 67 | void append(const AzSvFeatInfo *inp) { 68 | int fx; 69 | for (fx = 0; fx < inp->featNum(); ++fx) { 70 | inp->desc(fx, arr_desc.new_slot()); 71 | } 72 | } 73 | 74 | /*--- to implement AzStrArray ---*/ 75 | int size() const { return featNum(); } 76 | const char *c_str(int fx) const { 77 | if (fx < 0 || fx >= featNum()) { 78 | return "???"; 79 | } 80 | return arr_desc.point(fx)->c_str(); 81 | } 82 | }; 83 | #endif 84 | -------------------------------------------------------------------------------- /RGF/src/com/AzTaskTools.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTaskTools.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TASK_TOOLS_HPP_ 10 | #define _AZ_TASK_TOOLS_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzSmat.hpp" 14 | #include "AzSvFeatInfo.hpp" 15 | #include "AzLoss.hpp" 16 | #include "AzPerfResult.hpp" 17 | #include "AzPrint.hpp" 18 | 19 | /* Some functions are overlapping with the legacy code AzCatProc */ 20 | /*--------------------------------------------------*/ 21 | //! Tools related to classification or regression tasks. 22 | class AzTaskTools 23 | { 24 | public: 25 | static double analyzeLoss(AzLossType loss_type, 26 | const AzDvect *v_p, 27 | const AzDvect *v_y, 28 | const AzIntArr *inp_ia_dx, 29 | double p_coeff); 30 | 31 | static void showDist(const AzStrArray *sp_cat, 32 | const AzIntArr *ia_cat, 33 | const char *header, 34 | const AzOut &out); 35 | 36 | static double eval_breakEven( 37 | const AzIntArr *ia_gold, 38 | const AzDvect *v_pval, 39 | const AzStrArray *sp_cat, 40 | const char *eyecatcher, 41 | double *out_best_f=NULL, 42 | double *out_best_acc=NULL); 43 | 44 | static AzPerfResult eval(const AzDvect *v_p, 45 | const AzDvect *v_y, /* assume y in {+1,-1} */ 46 | AzLossType loss_type) { 47 | AzOut null_out; 48 | AzPerfResult res; 49 | eval("", loss_type, NULL, NULL, v_p, v_y, "", null_out, &res); 50 | return res; 51 | } 52 | static void eval(const AzDvect *v_p, 53 | const AzDvect *v_y, /* assume y in {+1,-1} */ 54 | AzPerfResult *result) { 55 | AzOut null_out; 56 | eval("", AzLoss_None, NULL, NULL, v_p, v_y, "", null_out, result); 57 | } 58 | static void eval(const AzDvect *v_p, 59 | const AzDvect *v_y, /* assume y in {+1,-1} */ 60 | const AzOut &test_out, 61 | AzPerfResult *result=NULL) { 62 | AzOut null_out; 63 | eval("", AzLoss_None, NULL, NULL, v_p, v_y, "", test_out, result); 64 | } 65 | static void eval(const char *ite_str, 66 | AzLossType loss_type, 67 | const AzIntArr *ia_dx, 68 | const double p_coeff[2], 69 | const AzDvect *v_test_pval, 70 | const AzDvect *v_test_yval, /* assume y in {+1,-1} */ 71 | const char *tt_eyec, 72 | const AzOut &test_out, 73 | AzPerfResult *result=NULL); 74 | 75 | static int genY(const AzIntArr *ia_cat, 76 | int focus_cat, 77 | double y_posi_val, 78 | double y_nega_val, 79 | AzDvect *v_yval); /* output */ 80 | 81 | /*--- for displaying the weights, feature names, etc. ---*/ 82 | static void dumpWeights(const AzOut &out, 83 | const AzDvect *v_w, 84 | const char *name, 85 | const AzSvFeatInfo *feat, 86 | int print_max, 87 | bool changeLine); 88 | static void printPR(AzPrint &o, 89 | int ok, 90 | int t, 91 | int g); 92 | protected: 93 | static void formatWeight(const AzSvFeatInfo *feat, 94 | int ex, 95 | double val, 96 | AzBytArr *str_out); 97 | }; 98 | #endif 99 | -------------------------------------------------------------------------------- /RGF/src/com/AzTimer.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTimer.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TIMER_HPP_ 10 | #define _AZ_TIMER_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | 14 | class AzTimer { 15 | public: 16 | int chk; /* next check point */ 17 | int inc; /* increment: negative means no checking */ 18 | 19 | AzTimer() : chk(-1), inc(-1) {} 20 | ~AzTimer() {} 21 | 22 | inline void reset(int inp_inc) { 23 | chk = -1; 24 | inc = inp_inc; 25 | if (inc > 0) { 26 | chk = inc; 27 | } 28 | } 29 | 30 | inline bool ringing(bool isRinging, int inp) { /* timer is ringing */ 31 | if (isRinging) return true; 32 | 33 | if (chk > 0 && inp >= chk) { 34 | while(chk <= inp) { 35 | chk += inc; /* next check point */ 36 | } 37 | return true; 38 | } 39 | return false; 40 | } 41 | 42 | inline bool reachedMax(int inp, 43 | const char *msg, 44 | const AzOut &out) const { 45 | bool yes_no = reachedMax(inp); 46 | if (yes_no) { 47 | AzTimeLog::print(msg, " reached max", out); 48 | } 49 | return yes_no; 50 | } 51 | inline bool reachedMax(int inp) const { 52 | if (chk > 0 && inp >= chk) return true; 53 | else return false; 54 | } 55 | }; 56 | 57 | #endif 58 | 59 | -------------------------------------------------------------------------------- /RGF/src/tet/AzFindSplit.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzFindSplit.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson, 2018 RGF-team 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_FIND_SPLIT_HPP_ 10 | #define _AZ_FIND_SPLIT_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzDataForTrTree.hpp" 14 | #include "AzTrTtarget.hpp" 15 | #include "AzTrTsplit.hpp" 16 | #include "AzTrTree.hpp" 17 | 18 | class Az_forFindSplit { 19 | public: 20 | double wy_sum, w_sum; 21 | Az_forFindSplit() : wy_sum(0), w_sum(0) {} 22 | void reset() { 23 | wy_sum = w_sum = 0; 24 | } 25 | }; 26 | 27 | //! Abstract class: provides building blocks for node split search. 28 | /*------------------------------------------*/ 29 | class AzFindSplit 30 | { 31 | protected: 32 | const AzTrTtarget *target; 33 | const AzDataForTrTree *data; 34 | const AzTrTree_ReadOnly *tree; 35 | int min_size; 36 | 37 | AzIntArr ia_feats; 38 | const AzIntArr *ia_fx; 39 | 40 | public: 41 | AzFindSplit() : target(NULL), data(NULL), tree(NULL), ia_fx(NULL), 42 | min_size(-1) {} 43 | ~AzFindSplit() {} 44 | void reset() { 45 | target = NULL; 46 | data = NULL; 47 | tree = NULL; 48 | min_size = -1; 49 | } 50 | 51 | void _begin(const AzTrTree_ReadOnly *inp_tree, 52 | const AzDataForTrTree *inp_data, 53 | const AzTrTtarget *inp_target, 54 | int inp_min_size); 55 | void _end() { 56 | reset(); 57 | } 58 | 59 | //---------------------------------------------------------------- 60 | // void findBestSplit(const AzTrTtarget *tar, 61 | // const AzIntArr *ia_dx, 62 | // ... parameters ... 63 | // AzTrTsplit *best_split); /* output */ 64 | //---------------------------------------------------------------- 65 | 66 | virtual void _pickFeats(int pick_num, int f_num); 67 | 68 | protected: 69 | /*----------------------------------------------------------------*/ 70 | virtual double getBestGain(double w_sum, 71 | double wy_sum, 72 | double *out_best_p) /* must not be null */ 73 | const = 0; 74 | virtual double evalSplit(const Az_forFindSplit i[2], 75 | double bestP[2]) /* output */ const; 76 | /*----------------------------------------------------------------*/ 77 | 78 | void _findBestSplit(int nx, 79 | /*--- output ---*/ 80 | AzTrTsplit *best_split); 81 | void loop(AzTrTsplit *best_split, 82 | int fx, /* feature# */ 83 | const AzSortedFeat *sorted, 84 | int dxs_num, 85 | const Az_forFindSplit *total); 86 | }; 87 | 88 | #endif 89 | -------------------------------------------------------------------------------- /RGF/src/tet/AzOptOnTree_TreeReg.cpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzOptOnTree_TreeReg.cpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #include "AzOptOnTree_TreeReg.hpp" 10 | #include "AzPrint.hpp" 11 | 12 | /*--------------------------------------------------------*/ 13 | void AzOptOnTree_TreeReg::optimize(AzRgfTreeEnsemble *inp_rgf_ens, 14 | const AzTrTreeFeat *inp_tree_feat, 15 | int ite_num, 16 | double lam, 17 | double sig) 18 | { 19 | ens = inp_rgf_ens; 20 | tree_feat = inp_tree_feat; 21 | rgf_ens = inp_rgf_ens; 22 | 23 | synchronize(); 24 | updateTreeWeights(rgf_ens); 25 | 26 | const int tree_num = ens->size(); 27 | if (reg_arr->size() < tree_num) { 28 | throw new AzException("AzOptOnTree_TreeReg::optimize", 29 | "max #tree has changed??"); 30 | } 31 | for (int tx = 0; tx < tree_num; ++tx) { 32 | AzReg_TreeReg *reg = reg_arr->reg(tx); 33 | reg->reset(ens->tree(tx), reg_depth); 34 | } 35 | iterate(ite_num, lam, sig); 36 | 37 | ens = NULL; 38 | tree_feat = NULL; 39 | rgf_ens = NULL; 40 | } 41 | 42 | /*--------------------------------------------------------*/ 43 | void AzOptOnTree_TreeReg::update_with_features( 44 | double nlam, 45 | double nsig, 46 | double py_avg, 47 | AzRgf_forDelta *for_delta) /* updated */ 48 | { 49 | const int tree_num = ens->size(); 50 | for (int tx = 0; tx < tree_num; ++tx) { 51 | ens->tree_u(tx)->restoreDataIndexes(); 52 | AzReg_TreeReg *reg = reg_arr->reg(tx); 53 | reg->clearFocusNode(); 54 | 55 | AzIIarr iia_nx_fx; 56 | tree_feat->featIds(tx, &iia_nx_fx); 57 | for (int ix = 0; ix < iia_nx_fx.size(); ++ix) { 58 | int nx, fx; 59 | iia_nx_fx.get(ix, &nx, &fx); 60 | 61 | const double delta = bestDelta(nx, fx, reg, nlam, nsig, py_avg, for_delta); 62 | update_weight(nx, fx, delta, reg); 63 | } 64 | ens->tree_u(tx)->releaseDataIndexes(); 65 | } 66 | } 67 | /*--------------------------------------------------------*/ 68 | void AzOptOnTree_TreeReg::update_weight(int nx, 69 | int fx, 70 | double delta, 71 | AzReg_TreeReg *reg) 72 | { 73 | const double new_w = v_w.get(fx) + delta; 74 | v_w.set(fx, new_w); 75 | 76 | int dxs_num; 77 | const int *dxs = data_points(fx, &dxs_num); 78 | updatePred(dxs, dxs_num, delta, &v_p); 79 | 80 | /*--- update the weight in the ensemble ---*/ 81 | const AzTrTreeFeatInfo *fp = tree_feat->featInfo(fx); 82 | rgf_ens->tree_u(fp->tx)->setWeight(fp->nx, new_w); 83 | reg->changeWeight(nx, delta); 84 | } 85 | 86 | /*--------------------------------------------------------*/ 87 | double AzOptOnTree_TreeReg::bestDelta( 88 | int nx, 89 | int fx, 90 | AzReg_TreeReg *reg, 91 | double nlam, 92 | double nsig, 93 | double py_avg, 94 | AzRgf_forDelta *for_delta) /* updated */ 95 | const 96 | { 97 | const char *eyec = "AzOptOnTree_TI::bestDelta"; 98 | 99 | int dxs_num; 100 | const int *dxs = data_points(fx, &dxs_num); 101 | if (dxs_num <= 0) { 102 | throw new AzException(eyec, "no data indexes"); 103 | } 104 | 105 | const double *fixed_dw = NULL; 106 | if (!AzDvect::isNull(&v_fixed_dw)) fixed_dw = v_fixed_dw.point(); 107 | const double *p = v_p.point(); 108 | const double *y = v_y.point(); 109 | double nega_dL = 0, ddL= 0; 110 | if (fixed_dw == NULL) { 111 | AzLoss::sum_deriv(loss_type, dxs, dxs_num, p, y, py_avg, 112 | nega_dL, ddL); 113 | } 114 | else { 115 | AzLoss::sum_deriv_weighted(loss_type, dxs, dxs_num, p, y, fixed_dw, py_avg, 116 | nega_dL, ddL); 117 | } 118 | 119 | double dR, ddR; 120 | reg->penalty_deriv(nx, &dR, &ddR); 121 | 122 | double dd = ddL + nlam*ddR; 123 | if (dd == 0) dd = 1; 124 | double delta = (nega_dL-nlam*dR)*eta/dd; 125 | for_delta->check_delta(&delta, max_delta); 126 | 127 | return delta; 128 | } 129 | -------------------------------------------------------------------------------- /RGF/src/tet/AzOptOnTree_TreeReg.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzOptOnTree_TreeReg.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_OPT_ON_TREE_TREE_REG_HPP_ 10 | #define _AZ_OPT_ON_TREE_TREE_REG_HPP_ 11 | 12 | #include "AzOptOnTree.hpp" 13 | #include "AzReg_TreeRegArr.hpp" 14 | 15 | //! coordinate descent with regulatization using tree structure 16 | /*--------------------------------------------------------*/ 17 | class AzOptOnTree_TreeReg : /* extends */ public virtual AzOptOnTree 18 | { 19 | protected: 20 | AzRgfTreeEnsemble *rgf_ens; 21 | AzReg_TreeRegArr *reg_arr; 22 | 23 | public: 24 | AzOptOnTree_TreeReg() : rgf_ens(NULL), reg_arr(NULL) {} 25 | void reset(AzReg_TreeRegArr *inp_reg_arr) { 26 | reg_arr = inp_reg_arr; 27 | } 28 | 29 | virtual void optimize(AzRgfTreeEnsemble *ens, /* weights are updated */ 30 | const AzTrTreeFeat *tree_feat, 31 | int inp_ite_num=-1, 32 | double lam=-1, 33 | double sig=-1); 34 | 35 | /*--- ---*/ 36 | virtual void reset(const AzOptOnTree_TreeReg *inp) { 37 | AzOptOnTree::reset(inp); 38 | rgf_ens = inp->rgf_ens; 39 | reg_arr = inp->reg_arr; 40 | } 41 | 42 | protected: 43 | //! override 44 | virtual void update_with_features(double nlam, double nsig, double py_avg, 45 | AzRgf_forDelta *for_delta); 46 | 47 | virtual void update_weight(int nx, 48 | int fx, 49 | double delta, 50 | AzReg_TreeReg *reg); 51 | virtual double bestDelta( 52 | int nx, 53 | int fx, 54 | AzReg_TreeReg *reg, 55 | double nlam, 56 | double nsig, 57 | double py_avg, 58 | AzRgf_forDelta *for_delta) /* updated */ 59 | const; 60 | }; 61 | 62 | #endif 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /RGF/src/tet/AzOptimizerT.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzOptimizerT.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_OPTIMIZER_T_HPP_ 10 | #define _AZ_OPTIMIZER_T_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzSmat.hpp" 14 | #include "AzDmat.hpp" 15 | #include "AzBmat.hpp" 16 | #include "AzLoss.hpp" 17 | #include "AzTrTreeFeat.hpp" 18 | #include "AzTrTreeEnsemble_ReadOnly.hpp" 19 | #include "AzRgfTreeEnsemble.hpp" 20 | #include "AzRegDepth.hpp" 21 | #include "AzParam.hpp" 22 | 23 | //! coordinate descent for weight optimization. 24 | /*--------------------------------------------------------*/ 25 | class AzOptimizerT 26 | { 27 | public: 28 | virtual void reset(AzLossType loss_type, 29 | const AzDvect *v_y, 30 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */ 31 | const AzRegDepth *reg_depth, 32 | AzParam ¶m, 33 | bool beVerbose, 34 | const AzOut out_req, 35 | /*--- for warm start ---*/ 36 | const AzTrTreeEnsemble_ReadOnly *ens=NULL, 37 | const AzTrTreeFeat *tree_feat=NULL, 38 | const AzDvect *inp_v_p=NULL) = 0; 39 | 40 | virtual void copyPred_to(AzDvect *out_v_p) const = 0; 41 | 42 | virtual void resetPred(const AzBmat *m_tran, 43 | AzDvect *v_p) /* output */ 44 | const = 0; 45 | virtual void optimize(AzRgfTreeEnsemble *ens, 46 | const AzTrTreeFeat *tree_feat, 47 | int inp_ite_num=-1, 48 | double lam=-1, 49 | double sig=-1) = 0; 50 | virtual void optimize(AzRgfTreeEnsemble *ens, 51 | const AzTrTreeFeat *tree_feat, 52 | bool doRefreshP, 53 | int inp_ite_num=-1, 54 | double lam=-1, 55 | double sig=-1) { 56 | throw new AzException("AzOptimizerT::optimize(...,doRefreshP,...)", "No support"); 57 | } 58 | virtual const AzDvect *weights() const = 0; 59 | virtual double constant() const = 0; 60 | virtual void printHelp(AzHelp &h) const = 0; 61 | }; 62 | 63 | #endif 64 | -------------------------------------------------------------------------------- /RGF/src/tet/AzRegDepth.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRegDepth.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_REG_DEPTH_HPP_ 10 | #define _AZ_REG_DEPTH_HPP_ 11 | 12 | #include "AzRgf_kw.hpp" 13 | #include "AzUtil.hpp" 14 | #include "AzParam.hpp" 15 | #include "AzHelp.hpp" 16 | #include "AzRegDepth.hpp" 17 | 18 | #define depth_base_dflt 1 19 | /* #define depth_base_min_penalty_dflt 2 */ 20 | 21 | //! Regularizer using node depth. 22 | class AzRegDepth { 23 | protected: 24 | double depth_base; 25 | AzDvect v_dep2pow; /* to avoid repetitive calls of pow() */ 26 | const double *dep2pow; 27 | 28 | public: 29 | AzRegDepth() : depth_base(depth_base_dflt), dep2pow(NULL) {} 30 | 31 | virtual void set_default_for_min_penalty() { 32 | /* depth_base = depth_base_min_penalty_dflt; */ 33 | } 34 | 35 | virtual inline 36 | void check_if_nonincreasing(const char *who) const { 37 | if (depth_base < 1) { 38 | AzBytArr s(kw_depth_base); s.c(" must be no smaller than 1 for "); 39 | s.c(who); s.c("."); 40 | throw new AzException(AzInputNotValid, "AzRegDepth::check_if_nonincreasing", 41 | s.c_str()); 42 | } 43 | } 44 | 45 | virtual inline 46 | double apply(double val, int dep) const { 47 | if (depth_base == 1) return val; 48 | if (dep >= 0 && dep < v_dep2pow.rowNum()) { 49 | return val * dep2pow[dep]; 50 | } 51 | else { 52 | return val * pow(depth_base, (double)dep); 53 | } 54 | } 55 | 56 | virtual void reset(AzParam ¶m, 57 | const AzOut &out) { 58 | resetParam(param); 59 | if (depth_base <= 0) { 60 | throw new AzException(AzInputNotValid, "AzRegDepth::reset", 61 | kw_depth_base, "must be no smaller than 1."); 62 | } 63 | if (depth_base < 1) { 64 | AzBytArr s("!Warning! "); s.c(kw_depth_base); s.c(" should be no smaller than 1."); 65 | AzPrint::writeln(out, s); 66 | } 67 | 68 | printParam(out); 69 | } 70 | virtual void printHelp(AzHelp &h) const { 71 | h.begin(Azforest_config, "AzRegDepth", "Regularization on node depth"); 72 | h.item(kw_depth_base, help_depth_base, depth_base); 73 | h.end(); 74 | } 75 | virtual void printParam(const AzOut &out) const { 76 | if (out.isNull()) return; 77 | if (depth_base != 1) { 78 | AzPrint o(out); 79 | o.ppBegin("AzRegDepth", "Reg. on depth", ", "); 80 | o.printV(kw_depth_base, depth_base); 81 | o.ppEnd(); 82 | } 83 | } 84 | 85 | protected: 86 | virtual void resetParam(AzParam &p) { 87 | bool doCheck = false; 88 | p.vFloat(kw_depth_base, &depth_base); 89 | 90 | /*--- ---*/ 91 | v_dep2pow.reform(50); 92 | int dep; 93 | for (dep = 0; dep < v_dep2pow.rowNum(); ++dep) { 94 | v_dep2pow.set(dep, pow(depth_base, (double)dep)); 95 | } 96 | dep2pow = v_dep2pow.point(); 97 | } 98 | }; 99 | #endif 100 | -------------------------------------------------------------------------------- /RGF/src/tet/AzReg_TreeReg.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzReg_TreeReg.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_REG_TREE_REG_HPP_ 10 | #define _AZ_REG_TREE_REG_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzDmat.hpp" 14 | #include "AzTrTree_ReadOnly.hpp" 15 | #include "AzRegDepth.hpp" 16 | #include "AzParam.hpp" 17 | 18 | class AzReg_TreeRegShared { 19 | public: 20 | virtual AzDmat *share() = 0; 21 | virtual bool create(const AzTrTree_ReadOnly *tree, const AzDmat *info) = 0; 22 | virtual AzDmat *share(const AzTrTree_ReadOnly *tree) = 0; 23 | }; 24 | 25 | //! Default implementation of AzReg_TreeRegShared 26 | class AzReg_TreeRegShared_Dflt : /* implements */ public virtual AzReg_TreeRegShared 27 | { 28 | protected: 29 | AzDmat m_by_alltree; /* info not specific to individual tree */ 30 | 31 | public: 32 | /*--- override these to store tree-specific info ---*/ 33 | virtual AzDmat *share(const AzTrTree_ReadOnly *tree) { return NULL; } 34 | virtual bool create(const AzTrTree_ReadOnly *tree, const AzDmat *) { return false; } 35 | /*----------------------------------------------------*/ 36 | virtual AzDmat *share() { 37 | return &m_by_alltree; 38 | } 39 | }; 40 | 41 | //! Abstract class: interface to tree-structured regularizer 42 | class AzReg_TreeReg { 43 | public: 44 | virtual void set_shared(AzReg_TreeRegShared *shared) {} 45 | virtual void check_reg_depth(const AzRegDepth *) const {} 46 | 47 | virtual void reset(const AzTrTree_ReadOnly *inp_tree, 48 | const AzRegDepth *inp_reg_depth) = 0; 49 | 50 | virtual void penalty_deriv(int nx, double *dr, 51 | double *ddr) = 0; 52 | 53 | virtual void changeWeight(int nx, double w_diff) = 0; 54 | 55 | virtual void clearFocusNode() = 0; 56 | 57 | /*--- for node split ---*/ 58 | //! called by AzRgf_FindSplit_TR::begin 59 | virtual void reset_forNewLeaf(const AzTrTree_ReadOnly *t, 60 | const AzRegDepth *rdep) = 0; 61 | 62 | //! called by AzRgf_FindSplit_TR::findSplit 63 | virtual void reset_forNewLeaf(int f_nx, 64 | const AzTrTree_ReadOnly *t, 65 | const AzRegDepth *rdep) = 0; 66 | 67 | virtual double penalty_diff(const double leaf_w_delta[2]) const = 0; 68 | virtual void penalty_deriv(double *dr, 69 | double *ddr) const = 0; 70 | 71 | /*--- for maintenance ---*/ 72 | virtual void show(const AzOut &out, 73 | const char *header) const = 0; 74 | virtual double penalty() const { 75 | return -1; 76 | } 77 | 78 | /*---------------------------------------------------------*/ 79 | virtual void resetParam(AzParam ¶m) = 0; 80 | virtual void printParam(const AzOut &out) const = 0; 81 | virtual void printHelp(AzHelp &h) const = 0; 82 | 83 | virtual const char *signature() const = 0; 84 | virtual const char *description() const = 0; 85 | }; 86 | #endif 87 | 88 | -------------------------------------------------------------------------------- /RGF/src/tet/AzReg_TreeRegArr.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzReg_TreeRegArr.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_REG_TREE_REG_ARR_HPP_ 10 | #define _AZ_REG_TREE_REG_ARR_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzReg_TreeReg.hpp" 14 | 15 | class AzReg_TreeRegArr 16 | { 17 | public: 18 | virtual void reset(int tree_num) = 0; 19 | virtual AzReg_TreeReg *reg(int tx) = 0; 20 | virtual AzReg_TreeReg *reg_forNewLeaf(int tx) = 0; 21 | virtual int size() const = 0; 22 | }; 23 | #endif 24 | -------------------------------------------------------------------------------- /RGF/src/tet/AzReg_TreeRegArrImp.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzReg_TreeRegArrImp.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_REG_TREE_REG_ARR_IMP_HPP_ 10 | #define _AZ_REG_TREE_REG_ARR_IMP_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzRgfTreeEnsemble.hpp" 14 | #include "AzReg_TreeRegArr.hpp" 15 | 16 | template 17 | class AzReg_TreeRegArrImp : /* implements */ public virtual AzReg_TreeRegArr 18 | { 19 | protected: 20 | AzPtrPool areg; 21 | T template_reg; 22 | T temporary_reg; 23 | AzReg_TreeRegShared_Dflt shared; 24 | 25 | public: 26 | T *tmpl_u() { return &template_reg; } 27 | const T *tmpl() const { return &template_reg; } 28 | 29 | inline int size() const { return areg.size(); } 30 | void reset(int tree_num) { 31 | areg.reset(); 32 | int tx; 33 | for (tx = 0; tx < tree_num; ++tx) { 34 | T *reg = areg.new_slot(); 35 | reg->copyParam_from(&template_reg); 36 | reg->set_shared(&shared); 37 | } 38 | temporary_reg.set_shared(&shared); 39 | } 40 | inline AzReg_TreeReg *reg(int tx) { 41 | return areg.point_u(tx); 42 | } 43 | inline AzReg_TreeReg *reg_forNewLeaf(int tx) { 44 | int t_num = areg.size(); 45 | if (tx < t_num) return areg.point_u(tx); 46 | 47 | temporary_reg.copyParam_from(&template_reg); 48 | return &temporary_reg; /* should be root-only tree */ 49 | } 50 | }; 51 | #endif 52 | -------------------------------------------------------------------------------- /RGF/src/tet/AzReg_TsrSib.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzReg_TsrSib.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_REG_TSRSIB_HPP_ 10 | #define _AZ_REG_TSRSIB_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzDmat.hpp" 14 | #include "AzTrTree_ReadOnly.hpp" 15 | #include "AzRegDepth.hpp" 16 | #include "AzReg_TreeReg.hpp" 17 | 18 | //! 19 | class AzReg_TsrSib : /* implements */ public virtual AzReg_TreeReg { 20 | protected: 21 | const AzTrTree_ReadOnly *tree; 22 | 23 | int focus_nx; 24 | const AzRegDepth *reg_depth; 25 | bool forNewLeaf; 26 | 27 | AzDataArray av_dv; 28 | AzDvect v_v; 29 | 30 | double vdv_sum, dv2_sum, dr, ddr; 31 | double newleaf_dep_factor; 32 | 33 | virtual void reset_values() { 34 | vdv_sum = dv2_sum = dr = ddr = 0; 35 | newleaf_dep_factor = 1; 36 | } 37 | 38 | public: 39 | AzReg_TsrSib() 40 | : tree(NULL), forNewLeaf(false), focus_nx(-1), 41 | reg_depth(NULL), newleaf_dep_factor(1), 42 | vdv_sum(0), dv2_sum(0), dr(0), ddr(0) {} 43 | 44 | void copyParam_from(const AzReg_TsrSib *inp) {} 45 | 46 | /*---------------------------------------------------------*/ 47 | virtual void reset(const AzTrTree_ReadOnly *inp_tree, 48 | const AzRegDepth *inp_reg_depth); 49 | /*---------------------------------------------------------*/ 50 | 51 | virtual void penalty_deriv(int nx, double *dr, 52 | double *ddr); 53 | 54 | virtual void changeWeight(int nx, double w_diff); 55 | 56 | inline void clearFocusNode() { 57 | focus_nx = -1; 58 | } 59 | 60 | /*--- for node split ---*/ 61 | /*---------------------------------------------------------*/ 62 | /*--- called by AzRgf_FindSplit_TR::begin ---*/ 63 | //! set current penalty 64 | virtual void reset_forNewLeaf(const AzTrTree_ReadOnly *t, 65 | const AzRegDepth *rdep); 66 | /*--- called by AzRgf_FindSplit_TR::findSplit ---*/ 67 | virtual void reset_forNewLeaf(int f_nx, 68 | const AzTrTree_ReadOnly *t, 69 | const AzRegDepth *rdep); 70 | /*---------------------------------------------------------*/ 71 | 72 | virtual double penalty_diff(const double leaf_w_delta[2]) const; 73 | virtual void penalty_deriv(double *dr, 74 | double *ddr) const; 75 | 76 | /*--- for maintenance ---*/ 77 | virtual void show(const AzOut &out, 78 | const char *header) const { 79 | 80 | } 81 | 82 | /*---------------------------------------------------------*/ 83 | virtual void resetParam(AzParam ¶m) {} 84 | virtual void printParam(const AzOut &out) const {} 85 | virtual void printHelp(AzHelp &h) const {} 86 | 87 | virtual inline const char *signature() const { 88 | return "-___-_RGF_TsrSib_"; 89 | } 90 | virtual inline const char *description() const { 91 | return "RGF w/min-penalty regularization w/sum-to-zero sibling constraints"; 92 | } 93 | /*---------------------------------------------------------*/ 94 | 95 | protected: 96 | void checkLeaf(const char *msg) const { 97 | if (!forNewLeaf || focus_nx < 0) { 98 | throw new AzException("AzReg_TsrSib::checkLeaf", msg); 99 | } 100 | } 101 | virtual void deriv_v(const AzTrTree_ReadOnly *tree, 102 | int leaf_nx, 103 | bool forNewLeaf, 104 | /* output */ 105 | AzSvect *v_dv, 106 | /* inout */ 107 | AzDvect *v_v) const; 108 | 109 | inline int get_newleaf_depth() const { 110 | return tree->node(focus_nx)->depth + 1; 111 | } 112 | virtual void update(); 113 | }; 114 | #endif 115 | 116 | -------------------------------------------------------------------------------- /RGF/src/tet/AzRgfTrainerSel.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgfTrainerSel.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_RGF_TRAINER_SEL_HPP_ 10 | #define _AZ_RGF_TRAINER_SEL_HPP_ 11 | 12 | #include "AzRgforest.hpp" 13 | #include "AzRgforest_TreeReg.hpp" 14 | #include "AzReg_TsrOpt.hpp" 15 | #include "AzReg_TsrSib.hpp" 16 | 17 | #include "AzTETselector.hpp" 18 | #include "AzPrint.hpp" 19 | 20 | //! Training algorithm selector. 21 | class AzRgfTrainerSel : /* implements */ public virtual AzTETselector { 22 | protected: 23 | AzRgforest rgf; 24 | AzRgforest_TreeReg rgf_sib; 25 | AzRgforest_TreeReg rgf_opt; 26 | 27 | #define kw_rgf "RGF" 28 | #define kw_rgf_sib "RGF_Sib" 29 | #define kw_rgf_opt "RGF_Opt" 30 | 31 | AzStrPool sp_name; 32 | AzDataArray alg; 33 | 34 | virtual void reset() { 35 | int id = 0; 36 | sp_name.putv(kw_rgf, id++); *alg.new_slot() = &rgf; 37 | sp_name.putv(kw_rgf_sib, id++); *alg.new_slot() = &rgf_sib; 38 | sp_name.putv(kw_rgf_opt, id++); *alg.new_slot() = &rgf_opt; 39 | sp_name.commit(); 40 | } 41 | 42 | public: 43 | AzRgfTrainerSel() { 44 | reset(); 45 | } 46 | 47 | virtual const char *dflt_name() const { 48 | return kw_rgf; 49 | } 50 | virtual const char *another_name() const { 51 | return kw_rgf_sib; 52 | } 53 | const AzStrArray *names() const { 54 | return &sp_name; 55 | } 56 | 57 | virtual void printOptions(const char *dlm, AzBytArr *s) const { 58 | int ix; 59 | for (ix = 0; ix < sp_name.size(); ++ix) { 60 | if (ix > 0) s->c(dlm); 61 | s->c(sp_name.c_str(ix)); 62 | } 63 | } 64 | 65 | virtual void printHelp(AzHelp &h) const { 66 | h.begin("", "", ""); 67 | int ix; 68 | for (ix = 0; ix < sp_name.size(); ++ix) { 69 | int id = sp_name.getValue(ix); 70 | const AzTETrainer *trainer = *alg.point(id); 71 | h.item(sp_name.c_str(ix), trainer->description()); 72 | } 73 | h.end(); 74 | } 75 | 76 | virtual AzTETrainer *select(const char *alg_name, //!< name of algorithm. 77 | //! if true, don't throw exception at error. 78 | bool dontThrow=false) const 79 | { 80 | AzTETrainer *trainer = NULL; 81 | 82 | int ex = sp_name.find(alg_name); 83 | if (ex < 0 && !dontThrow) { 84 | throw new AzException(AzInputNotValid, "algorithm name", alg_name); 85 | } 86 | if (ex >= 0) { 87 | int id = sp_name.getValue(ex); 88 | trainer = *alg.point(id); 89 | } 90 | return trainer; 91 | } 92 | }; 93 | 94 | #endif 95 | 96 | -------------------------------------------------------------------------------- /RGF/src/tet/AzRgfTreeEnsImp.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgfTreeEnsImp.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_RGF_TREE_ENS_IMP_HPP_ 10 | #define _AZ_RGF_TREE_ENS_IMP_HPP_ 11 | 12 | #include "AzRgfTreeEnsemble.hpp" 13 | #include "AzTrTreeEnsemble.hpp" 14 | 15 | //! implement AzRgfTreeEnsemble. T must be AzRgfTree or its extension. 16 | template 17 | class AzRgfTreeEnsImp : /*implements */public virtual AzRgfTreeEnsemble 18 | { 19 | protected: 20 | AzTrTreeEnsemble ens; 21 | 22 | public: 23 | AzRgfTreeEnsImp() {} 24 | ~AzRgfTreeEnsImp() {} 25 | inline bool usingTempFile() const { 26 | return ens.usingTempFile(); 27 | } 28 | inline void reset() { 29 | ens.reset(); 30 | } 31 | inline const char *param_c_str() const { 32 | return ens.param_c_str(); 33 | } 34 | 35 | inline double constant() const { 36 | return ens.constant(); 37 | } 38 | inline int orgdim() const { 39 | return ens.orgdim(); 40 | } 41 | inline void set_constant(double inp) { 42 | ens.set_constant(inp); 43 | } 44 | inline AzRgfTree *new_tree(int *out_tx=NULL) { 45 | return ens.new_tree(out_tx); 46 | } 47 | 48 | inline const AzRgfTree *tree(int tx) const { 49 | return ens.tree(tx); 50 | } 51 | inline AzRgfTree *tree_u(int tx) const { 52 | return ens.tree_u(tx); 53 | } 54 | 55 | inline T *rawtree_u(int tx) const { 56 | return ens.tree_u(tx); 57 | } 58 | 59 | inline int leafNum() const { 60 | return ens.leafNum(); 61 | } 62 | inline int leafNum(int tx0, int tx1) const { 63 | return ens.leafNum(tx0, tx1); 64 | } 65 | 66 | inline int lastIndex() const { 67 | return ens.lastIndex(); 68 | } 69 | inline int nextIndex() const { /* next slot */ 70 | return ens.nextIndex(); 71 | } 72 | 73 | inline int size() const { 74 | return ens.size(); 75 | } 76 | inline int max_size() const { 77 | return ens.max_size(); 78 | } 79 | inline bool isFull() const { 80 | return ens.isFull(); 81 | } 82 | inline void printHelp(AzHelp &h) const { 83 | ens.printHelp(h); 84 | } 85 | inline void copy_to(AzTreeEnsemble *out_ens, 86 | const char *config, const char *sign) const { 87 | ens.copy_to(out_ens, config, sign); 88 | } 89 | inline void copy_nodes_from(const AzTrTreeEnsemble_ReadOnly *inp) { 90 | ens.copy_nodes_from(inp); 91 | } 92 | inline void show(const AzSvFeatInfo *feat, 93 | const AzOut &out, const char *header="") const { 94 | ens.show(feat, out, header); 95 | } 96 | 97 | inline virtual void cold_start(AzParam ¶m, 98 | const AzBytArr *s_temp_prefix, 99 | int data_num, 100 | const AzOut &out, 101 | int tree_num_max, 102 | int inp_org_dim) { 103 | ens.cold_start(param, s_temp_prefix, data_num, 104 | out, tree_num_max, inp_org_dim); 105 | } 106 | inline virtual void warm_start(const AzTreeEnsemble *inp_ens, 107 | const AzDataForTrTree *data, 108 | AzParam ¶m, 109 | const AzBytArr *s_temp_prefix, 110 | const AzOut &out, 111 | int max_t_num, 112 | int search_t_num, /* to release work areas for the fixed trees */ 113 | AzDvect *v_p, /* inout */ 114 | const AzIntArr *inp_ia_tr_dx=NULL) { 115 | ens.warm_start(inp_ens, data, param, s_temp_prefix, out, max_t_num, search_t_num, 116 | v_p, inp_ia_tr_dx); 117 | } 118 | }; 119 | #endif 120 | -------------------------------------------------------------------------------- /RGF/src/tet/AzRgfTreeEnsemble.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgfTreeEnsemble.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_RGF_TREE_ENSEMBLE_HPP_ 10 | #define _AZ_RGF_TREE_ENSEMBLE_HPP_ 11 | 12 | #include "AzTrTreeEnsemble_ReadOnly.hpp" 13 | #include "AzRgfTree.hpp" 14 | #include "AzParam.hpp" 15 | #include "AzHelp.hpp" 16 | 17 | //! Abstract class: interface for ensemble of RGF-trees. 18 | /** 19 | * implemented by AzRgfTreeEnsImp 20 | **/ 21 | class AzRgfTreeEnsemble : /* extends */ public virtual AzTrTreeEnsemble_ReadOnly 22 | { 23 | public: 24 | virtual void set_constant(double inp) = 0; 25 | virtual AzRgfTree *new_tree(int *out_tx=NULL) = 0; 26 | virtual AzRgfTree *tree_u(int tx) const = 0; 27 | 28 | virtual int nextIndex() const = 0; 29 | virtual bool isFull() const = 0; 30 | 31 | virtual void copy_nodes_from(const AzTrTreeEnsemble_ReadOnly *inp) = 0; 32 | virtual void printHelp(AzHelp &h) const = 0; 33 | 34 | virtual void cold_start(AzParam ¶m, 35 | const AzBytArr *s_temp_prefix, /* may be NULL */ 36 | int data_num, 37 | const AzOut &out, 38 | int tree_num_max, 39 | int inp_org_dim) = 0; 40 | virtual void warm_start(const AzTreeEnsemble *inp_ens, 41 | const AzDataForTrTree *data, 42 | AzParam ¶m, 43 | const AzBytArr *s_temp_prefix, /* may be NULL */ 44 | const AzOut &out, 45 | int max_t_num, 46 | int search_t_num, 47 | AzDvect *v_p, /* inout */ 48 | const AzIntArr *inp_ia_tr_dx=NULL) = 0; 49 | }; 50 | #endif 51 | -------------------------------------------------------------------------------- /RGF/src/tet/AzRgf_FindSplit.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_FindSplit.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_RGF_FIND_SPLIT_HPP_ 10 | #define _AZ_RGF_FIND_SPLIT_HPP_ 11 | 12 | #include "AzTrTree_ReadOnly.hpp" 13 | #include "AzTrTsplit.hpp" 14 | #include "AzTrTtarget.hpp" 15 | #include "AzRgf_kw.hpp" 16 | #include "AzRegDepth.hpp" 17 | #include "AzParam.hpp" 18 | #include "AzFsinfo.hpp" 19 | #include "AzHelp.hpp" 20 | 21 | class AzRgf_FindSplit_input { 22 | public: 23 | int tx; 24 | const AzDataForTrTree *data; 25 | const AzTrTtarget *target; 26 | double lam_scale; /*!< for numerical stability of exp loss */ 27 | double nn; /* sum of data point weights if weighted */ 28 | 29 | AzRgf_FindSplit_input(int inp_tx, 30 | const AzDataForTrTree *inp_data, 31 | const AzTrTtarget *inp_target, 32 | double inp_lam_scale, 33 | double inp_nn) { 34 | tx = inp_tx; 35 | data = inp_data; 36 | target = inp_target; 37 | lam_scale = inp_lam_scale; 38 | nn = (double)inp_nn; 39 | } 40 | }; 41 | 42 | /*--------------------------------------------------------*/ 43 | //! Abstract class: interface for node split search for RGF. 44 | /** 45 | * Implemented by AzRgf_FindSplit_Dflt, AzRgf_FindSplit_TreeReg 46 | **/ 47 | class AzRgf_FindSplit { 48 | public: 49 | virtual void reset(AzParam ¶m, 50 | const AzRegDepth *reg_depth, 51 | const AzOut &out) = 0; 52 | 53 | virtual void begin(const AzTrTree_ReadOnly *tree, 54 | const AzRgf_FindSplit_input &inp, 55 | int min_size) 56 | = 0; 57 | virtual void begin(const AzTrTree_ReadOnly *tree, 58 | const AzRgf_FindSplit_input &inp, 59 | int min_size, 60 | AzFsinfoOnTree *fot) { /* added for TreeRegFast */ 61 | throw new AzException("AzRgf_FindSplit::begin(...fot)", 62 | "no appropriate override"); 63 | } 64 | 65 | virtual void pickFeats(int f_num, int data_num) = 0; 66 | 67 | virtual void end() = 0; 68 | virtual 69 | void findSplit(int nx, //!< node id 70 | /*--- output ---*/ 71 | AzTrTsplit *best_split) = 0; 72 | 73 | virtual void printParam(const AzOut &out) const = 0; 74 | virtual void printHelp(AzHelp &h) const = 0; 75 | }; 76 | #endif 77 | 78 | 79 | -------------------------------------------------------------------------------- /RGF/src/tet/AzRgf_FindSplit_Dflt.cpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_FindSplit_Dflt.cpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #include "AzRgf_FindSplit_Dflt.hpp" 10 | #include "AzHelp.hpp" 11 | #include "AzRgf_kw.hpp" 12 | 13 | /*--------------------------------------------------------*/ 14 | void AzRgf_FindSplit_Dflt::begin( 15 | const AzTrTree_ReadOnly *inp_tree, 16 | const AzRgf_FindSplit_input &inp, /* tx is not used */ 17 | int inp_min_size) 18 | { 19 | AzFindSplit::_begin(inp_tree, inp.data, inp.target, inp_min_size); 20 | 21 | nlam = inp.nn*lambda; 22 | nsig = inp.nn*sigma; 23 | if (inp.lam_scale != 1) { /* for numerical stability for expo loss */ 24 | nlam *= inp.lam_scale; 25 | nsig *= inp.lam_scale; 26 | } 27 | 28 | doUseInternalNodes = tree->usingInternalNodes(); 29 | } 30 | 31 | /*--------------------------------------------------------*/ 32 | double AzRgf_FindSplit_Dflt::getBestGain(double wsum, /* some of data weights */ 33 | double wrsum, /* weighted sum of residual */ 34 | double *best_q) const 35 | { 36 | double p = p_node->weight; /* parent's weight */ 37 | double gain = 0; 38 | double q = 0; 39 | 40 | if (doUseInternalNodes) { 41 | q = wrsum/(wsum+c_nlam); 42 | gain = q*q*(wsum+c_nlam); /* n*gain */ 43 | } 44 | else if (nsig <= 0) { /* L2 only */ 45 | q = (wrsum-c_nlam*p)/(wsum+c_nlam); 46 | gain = q*q*(wsum+c_nlam)+(p_nlam-2*c_nlam)*p*p/2; /* "/2" for two child nodes */ 47 | /* n*gain */ 48 | q += p; 49 | } 50 | else { /* L1 and L2; not tested after code change */ 51 | double _wysum = wrsum + wsum*p; 52 | if (_wysum > c_nsig) { 53 | q = (_wysum-c_nsig)/(wsum+c_nlam); 54 | } else if (_wysum < -c_nsig) { 55 | q = (_wysum+c_nsig)/(wsum+c_nlam); 56 | } else { 57 | q = 0; 58 | } 59 | double org_losshat = -2*p*_wysum+p*p*(wsum+p_nlam)+2*p_nsig*fabs(p); 60 | double new_losshat = -q*q*(wsum+c_nlam); 61 | gain = org_losshat - new_losshat; 62 | } 63 | *best_q = q; 64 | return gain; 65 | } 66 | 67 | /*--------------------------------------------------------*/ 68 | /*--------------------------------------------------------*/ 69 | void AzRgf_FindSplit_Dflt::resetParam(AzParam &p) 70 | { 71 | /*--- reg param shared with optimizer ---*/ 72 | p.vFloat(kw_lambda, &lambda); 73 | p.vFloat(kw_sigma, &sigma); 74 | 75 | /*--- override ... ---*/ 76 | p.vFloat(kw_s_lambda, &lambda); 77 | p.vFloat(kw_s_sigma, &sigma); 78 | 79 | if (lambda < 0) { 80 | throw new AzException(AzInputMissing, "AzRgf_FindSplit_Dflt", 81 | kw_lambda, "must be non-negative"); 82 | } 83 | if (sigma < 0) { 84 | throw new AzException(AzInputNotValid, "AzRgf_FindSplit_Dflt", 85 | kw_sigma, "must be non-negative"); 86 | } 87 | } 88 | 89 | /*--------------------------------------------------------*/ 90 | void AzRgf_FindSplit_Dflt::printParam(const AzOut &out) const 91 | { 92 | if (out.isNull()) return; 93 | 94 | AzPrint o(out); 95 | o.reset_options(); 96 | o.set_precision(5); 97 | o.ppBegin("AzRgf_FindSplit_Dflt", "Node split", ", "); 98 | o.printV(kw_lambda, lambda); 99 | o.printV_posiOnly(kw_sigma, sigma); 100 | o.ppEnd(); 101 | } 102 | 103 | /*--------------------------------------------------------*/ 104 | void AzRgf_FindSplit_Dflt::printHelp(AzHelp &h) const 105 | { 106 | h.begin(Azsplit_config, "AzRgf_FindSplit_Dflt", "Regularization at node split"); 107 | h.item_required(kw_lambda, help_lambda); 108 | h.item_experimental(kw_sigma, help_sigma, sigma_dflt); 109 | h.item(kw_s_lambda, help_s_lambda); 110 | h.item_experimental(kw_s_sigma, help_s_sigma); 111 | h.end(); 112 | } 113 | -------------------------------------------------------------------------------- /RGF/src/tet/AzRgf_FindSplit_Dflt.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_FindSplit_Dflt.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_RGF_FIND_SPLIT_DFLT_HPP_ 10 | #define _AZ_RGF_FIND_SPLIT_DFLT_HPP_ 11 | 12 | #include "AzFindSplit.hpp" 13 | #include "AzTrTree_ReadOnly.hpp" 14 | #include "AzRgf_FindSplit.hpp" 15 | #include "AzRegDepth.hpp" 16 | #include "AzParam.hpp" 17 | 18 | //! Node split search for RGF. L2 regularization. 19 | /*--------------------------------------------------------*/ 20 | class AzRgf_FindSplit_Dflt : /* extends */ public virtual AzFindSplit, 21 | /* implements */ public virtual AzRgf_FindSplit 22 | { 23 | protected: 24 | double lambda, sigma; 25 | const AzRegDepth *reg_depth; 26 | 27 | double nlam, nsig; 28 | double p_nlam; //!< L2 reg param for parent (node to be split) 29 | double c_nlam; //!< L2 reg param for child (new node after split) 30 | double p_nsig, c_nsig; 31 | bool doUseInternalNodes; 32 | const AzTrTreeNode *p_node; //!< parent node (node to be split) 33 | 34 | public: 35 | AzRgf_FindSplit_Dflt() : reg_depth(NULL), 36 | lambda(-1), sigma(sigma_dflt), 37 | doUseInternalNodes(false), nlam(0), nsig(0), 38 | p_nlam(0), c_nlam(0), p_nsig(0), c_nsig(0), 39 | p_node(NULL) {} 40 | virtual void begin(const AzTrTree_ReadOnly *tree, 41 | const AzRgf_FindSplit_input &inp, 42 | int inp_min_size); 43 | virtual void end() { 44 | _end(); 45 | } 46 | virtual inline 47 | void findSplit(int nx, 48 | /*--- output ---*/ 49 | AzTrTsplit *best_split) { 50 | p_node = tree->node(nx); 51 | p_nlam = reg_depth->apply(nlam, p_node->depth); 52 | c_nlam = reg_depth->apply(nlam, p_node->depth+1); 53 | p_nsig = reg_depth->apply(nsig, p_node->depth); 54 | c_nsig = reg_depth->apply(nsig, p_node->depth+1); 55 | AzFindSplit::_findBestSplit(nx, best_split); 56 | } 57 | virtual void reset(AzParam ¶m, 58 | const AzRegDepth *inp_reg_depth, 59 | const AzOut &out) 60 | { 61 | reg_depth = inp_reg_depth; 62 | if (reg_depth == NULL) throw new AzException("AzRgf_FindSplit_Dflt", 63 | "null reg_depth"); 64 | resetParam(param); 65 | printParam(out); 66 | } 67 | 68 | virtual void pickFeats(int pick_num, int f_num) { 69 | AzFindSplit::_pickFeats(pick_num, f_num); 70 | } 71 | 72 | virtual void printParam(const AzOut &out) const; 73 | virtual void printHelp(AzHelp &h) const; 74 | 75 | protected: 76 | virtual void resetParam(AzParam ¶m); 77 | virtual double getBestGain(double wsum, 78 | double wysum, 79 | double *best_q) const; 80 | }; 81 | #endif 82 | 83 | -------------------------------------------------------------------------------- /RGF/src/tet/AzRgf_FindSplit_TreeReg.cpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_FindSplit_TreeReg.cpp 3 | * Copyright (C) 2011, 2012 Rie Johnson, 2018 RGF-team 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #include "AzRgf_FindSplit_TreeReg.hpp" 10 | 11 | /*--------------------------------------------------------*/ 12 | void AzRgf_FindSplit_TreeReg::findSplit(int nx, 13 | /*--- output ---*/ 14 | AzTrTsplit *best_split) 15 | { 16 | if (tree->usingInternalNodes()) { 17 | throw new AzException("AzRgf_FindSplit_TreeReg::findSplit", 18 | "can't coexist with UseInternalNodes"); 19 | } 20 | 21 | reg->reset_forNewLeaf(nx, tree, reg_depth); 22 | dR = ddR = 0; 23 | reg->penalty_deriv(&dR, &ddR); 24 | AzRgf_FindSplit_Dflt::findSplit(nx, best_split); 25 | } 26 | 27 | /*--------------------------------------------------------*/ 28 | double AzRgf_FindSplit_TreeReg::evalSplit( 29 | const Az_forFindSplit i[2], 30 | double bestP[2]) const 31 | { 32 | double d[2]; /* delta */ 33 | for (int ix = 0; ix < 2; ++ix) { 34 | double wrsum = i[ix].wy_sum; 35 | d[ix] = (wrsum-nlam*dR)/(i[ix].w_sum+nlam*ddR); 36 | bestP[ix] = p_node->weight + d[ix]; 37 | } 38 | 39 | double penalty_diff = reg->penalty_diff(d); /* new - old */ 40 | 41 | double gain = 2*d[0]*i[0].wy_sum - d[0]*d[0]*i[0].w_sum - nlam * penalty_diff; 42 | gain += 2*d[1]*i[1].wy_sum - d[1]*d[1]*i[1].w_sum - nlam * penalty_diff; 43 | 44 | /* "2*" b/c penalty is sum v^2/2 */ 45 | 46 | return gain; 47 | } 48 | -------------------------------------------------------------------------------- /RGF/src/tet/AzRgf_FindSplit_TreeReg.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_FindSplit_TreeReg.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_RGF_FIND_SPLIT_TREE_REG_HPP_ 10 | #define _AZ_RGF_FIND_SPLIT_TREE_REG_HPP_ 11 | 12 | #include "AzRgf_FindSplit_Dflt.hpp" 13 | #include "AzReg_TreeReg.hpp" 14 | #include "AzReg_TreeRegArr.hpp" 15 | 16 | //! Node split search for RGF. L2 and tree structure regularization 17 | /*--------------------------------------------------------*/ 18 | class AzRgf_FindSplit_TreeReg : /* extends */ public virtual AzRgf_FindSplit_Dflt 19 | { 20 | protected: 21 | AzReg_TreeRegArr *reg_arr; 22 | AzReg_TreeReg *reg; 23 | double dR, ddR; 24 | 25 | public: 26 | AzRgf_FindSplit_TreeReg() : dR(0), ddR(0), reg(NULL), reg_arr(NULL) {} 27 | void reset(AzReg_TreeRegArr *inp_reg_arr) { 28 | reg_arr = inp_reg_arr; 29 | } 30 | 31 | //! override 32 | virtual void begin(const AzTrTree_ReadOnly *tree, 33 | const AzRgf_FindSplit_input &inp, 34 | int inp_min_size) 35 | { 36 | AzRgf_FindSplit_Dflt::begin(tree, inp, inp_min_size); 37 | reg = reg_arr->reg_forNewLeaf(inp.tx); 38 | reg->reset_forNewLeaf(tree, reg_depth); 39 | } 40 | 41 | //! override 42 | virtual void end() { 43 | AzRgf_FindSplit_Dflt::end(); 44 | reg = NULL; 45 | } 46 | 47 | //! override 48 | virtual void findSplit(int nx, AzTrTsplit *best_split); 49 | 50 | //! override AzFindSplit::evalSplit 51 | virtual double evalSplit(const Az_forFindSplit i[2], 52 | double bestP[2]) const; 53 | }; 54 | #endif 55 | -------------------------------------------------------------------------------- /RGF/src/tet/AzRgf_Optimizer.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_Optimizer.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_RGF_OPTIMIZER_HPP_ 10 | #define _AZ_RGF_OPTIMIZER_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzDmat.hpp" 14 | #include "AzBmat.hpp" 15 | #include "AzDataForTrTree.hpp" 16 | #include "AzLoss.hpp" 17 | #include "AzTrTreeEnsemble.hpp" 18 | #include "AzRgfTreeEnsemble.hpp" 19 | #include "AzRegDepth.hpp" 20 | #include "AzParam.hpp" 21 | #include "AzHelp.hpp" 22 | 23 | //! Abstract class: Weight optimizer interface. 24 | /*-------------------------------------------------------*/ 25 | class AzRgf_Optimizer 26 | { 27 | public: 28 | /*! Initialization */ 29 | virtual void cold_start(AzLossType loss_type, 30 | const AzDataForTrTree *training_data, /*!< training data */ 31 | const AzRegDepth *reg_depth, /*!< regularizer using tree attributes */ 32 | AzParam ¶m, /*!< confignuration */ 33 | const AzDvect *v_y, /*!< training targets */ 34 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */ 35 | const AzOut out, /*!< where to wrige log */ 36 | /*! output: prediction on training data. typically zeroes. */ 37 | AzDvect *v_p) 38 | = 0; 39 | 40 | virtual void warm_start(AzLossType loss_type, 41 | const AzDataForTrTree *training_data, /*!< training data */ 42 | const AzRegDepth *reg_depth, /*!< regularizer using tree attributes */ 43 | AzParam ¶m, /*!< confignuration */ 44 | const AzDvect *v_y, /*!< training targets */ 45 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */ 46 | const AzOut out, /*!< where to wrige log */ 47 | /*--- for warm start ---*/ 48 | const AzTrTreeEnsemble_ReadOnly *inp_ens, 49 | const AzDvect *inp_v_p) 50 | = 0; 51 | 52 | /*! Optimize weights. */ 53 | virtual void 54 | update(const AzDataForTrTree *data, /*!< training data */ 55 | AzRgfTreeEnsemble *ens, /*!< inout: tree ensmeble. */ 56 | /*--- output ---*/ 57 | AzDvect *v_p) /*feat1); 47 | out = inp->out; 48 | trainer_dflt.reset(&inp->trainer_dflt); 49 | trainer = &trainer_dflt; 50 | } 51 | /*------------------------------------------------------*/ 52 | 53 | /*------------------------------------------------------*/ 54 | /* derived classes must override this */ 55 | /*------------------------------------------------------*/ 56 | virtual void temp_update_apply(const AzDataForTrTree *tr_data, 57 | AzRgfTreeEnsemble *temp_ens, 58 | const AzDataForTrTree *test_data, 59 | AzBmat *temp_b, AzDvect *v_test_p, 60 | int *f_num, int *nz_f_num) const { 61 | AzRgf_Optimizer_Dflt temp_opt(this); 62 | temp_opt.update(tr_data, temp_ens); 63 | if (test_data != NULL) temp_opt.apply(test_data, temp_b, temp_ens, 64 | v_test_p, f_num, nz_f_num); 65 | } 66 | /*--------------------------------------------------------*/ 67 | 68 | 69 | virtual void cold_start(AzLossType loss_type, 70 | const AzDataForTrTree *data, 71 | const AzRegDepth *reg_depth, 72 | AzParam ¶m, 73 | const AzDvect *v_yval, 74 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */ 75 | const AzOut out_req, 76 | AzDvect *v_pval); /* output */ 77 | virtual void warm_start(AzLossType loss_type, 78 | const AzDataForTrTree *data, 79 | const AzRegDepth *reg_depth, 80 | AzParam ¶m, 81 | const AzDvect *v_y, /* for training */ 82 | const AzDvect *v_fixed_dw, /* user-assigned data point weights */ 83 | const AzOut out_req, 84 | /*--- for warm start ---*/ 85 | const AzTrTreeEnsemble_ReadOnly *ens, 86 | const AzDvect *v_p); 87 | virtual void 88 | update(const AzDataForTrTree *data, 89 | AzRgfTreeEnsemble *ens, /* inout */ 90 | /*--- output ---*/ 91 | AzDvect *v_p=NULL); /* prediction */ 92 | 93 | virtual void apply(const AzDataForTrTree *data, 94 | AzBmat *b_test_tran, /* inout */ 95 | const AzTrTreeEnsemble_ReadOnly *ens, 96 | /*--- output ---*/ 97 | AzDvect *v_p, 98 | int *f_num, int *nz_f_num) const; 99 | 100 | virtual void printHelp(AzHelp &h) const; 101 | 102 | protected: 103 | virtual bool resetParam(AzParam ¶m); 104 | static void _info(const AzTrTreeEnsemble_ReadOnly *ens, 105 | const AzOptimizerT *my_trainer, 106 | const AzTrTreeFeat *my_feat, 107 | int *f_num, int *nz_f_num); 108 | }; 109 | #endif 110 | -------------------------------------------------------------------------------- /RGF/src/tet/AzRgf_Optimizer_TreeReg.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgf_Optimizer_TreeReg.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_RGF_OPTIMIZER_TREE_REG_HPP_ 10 | #define _AZ_RGF_OPTIMIZER_TREE_REG_HPP_ 11 | 12 | #include "AzRgf_Optimizer_Dflt.hpp" 13 | #include "AzOptOnTree_TreeReg.hpp" 14 | 15 | //! for regularization using tree structure. 16 | /*-------------------------------------------------------*/ 17 | class AzRgf_Optimizer_TreeReg : /* extends */ public virtual AzRgf_Optimizer_Dflt 18 | { 19 | protected: 20 | AzOptOnTree_TreeReg trainer_tr; 21 | 22 | public: 23 | AzRgf_Optimizer_TreeReg() { 24 | trainer = &trainer_tr; 25 | } 26 | void reset(AzReg_TreeRegArr *reg_arr) { 27 | trainer_tr.reset(reg_arr); 28 | } 29 | AzRgf_Optimizer_TreeReg(const AzRgf_Optimizer_TreeReg *inp) { 30 | reset(inp); 31 | } 32 | 33 | /*------------------------------------------------------*/ 34 | /* override this to replace trainer */ 35 | /*------------------------------------------------------*/ 36 | virtual void reset(const AzRgf_Optimizer_TreeReg *inp) { 37 | if (inp == NULL) return; 38 | AzRgf_Optimizer_Dflt::reset(inp); 39 | 40 | trainer_tr.reset(&inp->trainer_tr); 41 | trainer = &trainer_tr; 42 | } 43 | 44 | /*------------------------------------------------------*/ 45 | /* derived classes must override this */ 46 | /*------------------------------------------------------*/ 47 | virtual void temp_update_apply(const AzDataForTrTree *tr_data, 48 | AzRgfTreeEnsemble *temp_ens, 49 | const AzDataForTrTree *test_data, 50 | AzBmat *temp_b, AzDvect *v_test_p, 51 | int *f_num, int *nz_f_num) const { 52 | AzRgf_Optimizer_TreeReg temp_opt(this); 53 | temp_opt.update(tr_data, temp_ens); 54 | if (test_data != NULL) temp_opt.apply(test_data, temp_b, temp_ens, 55 | v_test_p, f_num, nz_f_num); 56 | } 57 | /*--------------------------------------------------------*/ 58 | }; 59 | #endif 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /RGF/src/tet/AzRgforest_TreeReg.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzRgforest_TreeReg.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_RGFOREST_TREEREG_HPP_ 10 | #define _AZ_RGFOREST_TREEREG_HPP_ 11 | 12 | #include "AzRgforest.hpp" 13 | 14 | #include "AzRgf_Optimizer_TreeReg.hpp" 15 | #include "AzRgf_FindSplit_TreeReg.hpp" 16 | #include "AzReg_TreeRegArrImp.hpp" 17 | 18 | //! RGF with min-penalty regularization. 19 | 20 | template 21 | class AzRgforest_TreeReg : /* implements */ public virtual AzRgforest { 22 | protected: 23 | AzRgf_FindSplit_TreeReg tr_fs; 24 | AzRgf_Optimizer_TreeReg tr_opt; /* weight optimizer */ 25 | AzReg_TreeRegArrImp reg_arr; 26 | AzBytArr s_sign, s_desc; 27 | 28 | public: 29 | AzRgforest_TreeReg() 30 | { 31 | tr_fs.reset(®_arr); 32 | tr_opt.reset(®_arr); 33 | opt = &tr_opt; 34 | fs = &tr_fs; 35 | s_sign.reset(reg_arr.tmpl()->signature()); 36 | s_desc.reset(reg_arr.tmpl()->description()); 37 | reg_depth->set_default_for_min_penalty(); 38 | } 39 | virtual inline const char *signature() const { 40 | return s_sign.c_str(); 41 | } 42 | virtual inline const char *description() const { 43 | return s_desc.c_str(); 44 | } 45 | 46 | virtual void printHelp(AzHelp &h) const { 47 | AzRgforest::printHelp(h); 48 | reg_arr.tmpl()->printHelp(h); 49 | h.begin(Azforest_config, "AzRgforest_TreeReg", "For min-penalty regularization"); 50 | h.item(kw_doApproxTsr, help_doApproxTsr); 51 | h.end(); 52 | } 53 | 54 | protected: 55 | virtual int resetParam(AzParam ¶m) { /* returns max #tree */ 56 | int max_tree_num = AzRgforest::resetParam(param); 57 | 58 | bool doApproxTsr = false; 59 | param.swOn(&doApproxTsr, kw_doApproxTsr); 60 | if (doApproxTsr) { 61 | AzPrint o(out); 62 | o.ppBegin("AzRgforest_TreeReg", "Approximation", ", "); 63 | o.printSw(kw_doApproxTsr, doApproxTsr); 64 | o.ppEnd(); 65 | if (doForceToRefreshAll) { 66 | doForceToRefreshAll = false; 67 | AzPrint::writeln(out, "Turning off ", kw_doForceToRefreshAll); 68 | } 69 | } 70 | else { 71 | if (!doForceToRefreshAll) { 72 | doForceToRefreshAll = true; 73 | AzPrint::writeln(out, "Turning on ", kw_doForceToRefreshAll); 74 | } 75 | } 76 | 77 | reg_arr.tmpl_u()->resetParam(param); 78 | reg_arr.tmpl()->printParam(out); 79 | reg_arr.reset(max_tree_num); 80 | return max_tree_num; 81 | } 82 | 83 | virtual void end_of_initialization() { 84 | AzRgforest::end_of_initialization(); 85 | reg_arr.tmpl()->check_reg_depth(reg_depth); 86 | } 87 | }; 88 | #endif 89 | 90 | -------------------------------------------------------------------------------- /RGF/src/tet/AzTET_Eval.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTET_Eval.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TET_EVAL_HPP_ 10 | #define _AZ_TET_EVAL_HPP_ 11 | 12 | #include "AzDataForTrTree.hpp" 13 | #include "AzLoss.hpp" 14 | #include "AzTE_ModelInfo.hpp" 15 | #include "AzPerfResult.hpp" 16 | 17 | //! Abstract class: interface for evaluation modules for Tree Ensemble Trainer. 18 | /*-------------------------------------------------------*/ 19 | class AzTET_Eval { 20 | public: 21 | virtual void reset(const AzDvect *inp_v_y, 22 | const char *perf_fn, 23 | bool inp_doAppend) = 0; 24 | virtual void begin(const char *config="", 25 | AzLossType loss_type=AzLoss_None) = 0; 26 | virtual void resetConfig(const char *config) = 0; 27 | virtual void end() = 0; 28 | virtual void evaluate(const AzDvect *v_p, const AzTE_ModelInfo *info, 29 | const char *user_str=NULL) = 0; 30 | virtual bool isActive() const = 0; 31 | }; 32 | #endif 33 | -------------------------------------------------------------------------------- /RGF/src/tet/AzTET_Eval_Dflt.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTET_Eval_Dflt.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TET_EVAL_DFLT_HPP_ 10 | #define _AZ_TET_EVAL_DFLT_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzTaskTools.hpp" 14 | #include "AzPerfResult.hpp" 15 | #include "AzDataForTrTree.hpp" 16 | #include "AzLoss.hpp" 17 | #include "AzTET_Eval.hpp" 18 | 19 | //! Evaluationt module for Tree Ensemble Trainer. 20 | /*-------------------------------------------------------*/ 21 | class AzTET_Eval_Dflt : /* implements */ public virtual AzTET_Eval { 22 | protected: 23 | /*--- targets ---*/ 24 | const AzDvect *v_y; 25 | 26 | /*--- to output evaluation results ---*/ 27 | AzBytArr s_perf_fn; 28 | AzPerfType perf_type; 29 | 30 | AzLossType loss_type; 31 | AzBytArr s_config; 32 | 33 | AzOfs ofs; 34 | AzOut perf_out; 35 | bool doAppend; 36 | 37 | public: 38 | AzTET_Eval_Dflt() : v_y(NULL), 39 | loss_type(AzLoss_None), doAppend(false) {} 40 | ~AzTET_Eval_Dflt() { 41 | end(); 42 | } 43 | inline virtual bool isActive() const { 44 | if (v_y != NULL) return true; 45 | return false; 46 | } 47 | 48 | virtual void reset() { 49 | v_y= NULL; 50 | s_perf_fn.reset(); 51 | s_config.reset(); 52 | if (ofs.is_open()) { 53 | ofs.close(); 54 | } 55 | } 56 | void reset(const AzDvect *inp_v_y, 57 | const char *perf_fn, 58 | bool inp_doAppend) 59 | { 60 | v_y = inp_v_y; 61 | s_perf_fn.reset(perf_fn); 62 | doAppend = inp_doAppend; 63 | } 64 | virtual void resetConfig(const char *config) { 65 | s_config.reset(config); 66 | _clean(&s_config); 67 | } 68 | 69 | virtual void begin(const char *config, 70 | AzLossType inp_loss_type) { 71 | if (!isActive()) return; 72 | _begin(config, inp_loss_type); 73 | _clean(&s_config); 74 | } 75 | virtual void end() { 76 | if (ofs.is_open()) { 77 | ofs.close(); 78 | } 79 | } 80 | 81 | virtual void evaluate(const AzDvect *v_p, 82 | const AzTE_ModelInfo *info, 83 | const char *user_str=NULL) { 84 | if (!isActive()) return; 85 | AzPerfResult result=AzTaskTools::eval(v_p, v_y, loss_type); 86 | 87 | /*--- signature and configuration ---*/ 88 | AzBytArr s_sign_config(info->s_sign); 89 | s_sign_config.concat(":"); 90 | concat_config(info, &s_sign_config); 91 | 92 | /*--- print ---*/ 93 | AzPrint o(perf_out); 94 | o.printBegin("", ",", ","); 95 | o.print("#tree", info->tree_num); 96 | o.print("#leaf", info->leaf_num); 97 | o.print("acc", result.acc, 4); 98 | o.print("rmse", result.rmse, 4); 99 | o.print("sqerr", result.rmse*result.rmse, 6); 100 | o.print(loss_str[loss_type]); 101 | o.print("loss", result.loss, 6); 102 | o.print("#test", v_p->rowNum()); 103 | o.print("cfg"); /* for compatibility */ 104 | o.print(s_sign_config); 105 | if (user_str != NULL) { 106 | o.print(user_str); 107 | } 108 | o.printEnd(); 109 | } 110 | 111 | protected: 112 | virtual void _begin(const char *config, AzLossType inp_loss_type) { 113 | s_config.reset(config); 114 | loss_type = inp_loss_type; 115 | 116 | if (ofs.is_open()) { 117 | ofs.close(); 118 | } 119 | const char *perf_fn = s_perf_fn.c_str(); 120 | if (AzTools::isSpecified(perf_fn)) { 121 | ios_base::openmode mode = ios_base::out; 122 | if (doAppend) { 123 | mode = ios_base::app | ios_base::out; 124 | } 125 | ofs.open(perf_fn, mode); 126 | ofs.set_to(perf_out); 127 | } 128 | else { 129 | perf_out.setStdout(); 130 | } 131 | } 132 | 133 | virtual void _clean(AzBytArr *s) const { 134 | /*-- replace comma with : for convenience later --*/ 135 | s->replace(',', ';'); 136 | } 137 | 138 | virtual void concat_config(const AzTE_ModelInfo *info, AzBytArr *s) const { 139 | if (s_config.length() > 0) { 140 | s->concat(&s_config); 141 | } 142 | else { 143 | AzBytArr s_cfg(&info->s_config); 144 | _clean(&s_cfg); 145 | s->concat(&s_cfg); 146 | } 147 | } 148 | }; 149 | #endif 150 | -------------------------------------------------------------------------------- /RGF/src/tet/AzTETselector.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTETselector.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TET_SELECTOR_HPP_ 10 | #define _AZ_TET_SELECTOR_HPP_ 11 | 12 | #include "AzTETrainer.hpp" 13 | 14 | class AzTETselector { 15 | public: 16 | //! Return trainer 17 | virtual AzTETrainer *select(const char *alg_name, //! algorithm name 18 | //! if true, don't throw exception on error 19 | bool dontThrow=false 20 | ) const = 0; 21 | 22 | virtual const char *dflt_name() const = 0; 23 | virtual const char *another_name() const = 0; 24 | virtual const AzStrArray *names() const = 0; 25 | virtual bool isRGFfamily(const char *name) const { 26 | AzBytArr s(name); 27 | return s.beginsWith("RGF"); 28 | } 29 | virtual bool isGBfamily(const char *name) const { 30 | AzBytArr s(name); 31 | return s.beginsWith("GB"); 32 | } 33 | 34 | //! Return algorithm names. 35 | virtual void printOptions(const char *dlm, //! delimiter between algorithm names. 36 | AzBytArr *s) //!< output: algorithm names separated by dlm. 37 | const = 0; 38 | 39 | //! Help 40 | virtual void printHelp(AzHelp &h) const = 0; 41 | }; 42 | 43 | #endif 44 | 45 | -------------------------------------------------------------------------------- /RGF/src/tet/AzTE_ModelInfo.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTE_ModelInfo.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TE_MODEL_INFO_HPP_ 10 | #define _AZ_TE_MODEL_INFO_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | 14 | /*-------------------------------------------------------*/ 15 | /*! Tree Ensemble model info */ 16 | class AzTE_ModelInfo { 17 | public: 18 | int tree_num; //!< number of trees 19 | int leaf_num; //!< number of features/leaf nodes 20 | int f_num; //!< number of features including removed ones. 21 | int nz_f_num; //!< number of non-zero-weight features after consolidation 22 | AzBytArr s_sign; 23 | AzBytArr s_config; //!< configuration 24 | const char *sign; //!< signature of trainer 25 | AzTE_ModelInfo() : f_num(-1),leaf_num(-1),nz_f_num(-1),tree_num(-1) {} 26 | 27 | void reset() { 28 | tree_num = leaf_num = f_num = nz_f_num = -1; 29 | s_sign.reset(); 30 | s_config.reset(); 31 | } 32 | }; 33 | #endif 34 | -------------------------------------------------------------------------------- /RGF/src/tet/AzTrTreeEnsemble_ReadOnly.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTrTreeEnsemble_ReadOnly.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TR_TREE_ENSEMBLE_READONLY_HPP_ 10 | #define _AZ_TR_TREE_ENSEMBLE_READONLY_HPP_ 11 | 12 | #include "AzTrTree_ReadOnly.hpp" 13 | #include "AzTreeEnsemble.hpp" 14 | #include "AzSvFeatInfo.hpp" 15 | 16 | //! Abstract class: interface for read-only access to trainalbe tree ensemble. 17 | class AzTrTreeEnsemble_ReadOnly { 18 | public: 19 | virtual bool usingTempFile() const { return false; } 20 | virtual const AzTrTree_ReadOnly *tree(int tx) const = 0; 21 | virtual int leafNum() const = 0; 22 | virtual int leafNum(int tx0, int tx1) const = 0; 23 | virtual int size() const = 0; 24 | virtual int max_size() const = 0; 25 | virtual int lastIndex() const = 0; 26 | virtual void copy_to(AzTreeEnsemble *out_ens, 27 | const char *config, const char *sign) const = 0; 28 | virtual void show(const AzSvFeatInfo *feat, 29 | const AzOut &out, const char *header="") const = 0; 30 | virtual double constant() const = 0; 31 | virtual int orgdim() const = 0; 32 | virtual const char *param_c_str() const = 0; 33 | }; 34 | #endif 35 | -------------------------------------------------------------------------------- /RGF/src/tet/AzTrTreeNode.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTrTreeNode.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson, 2018 RGF-team 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TR_TREE_NODE_HPP_ 10 | #define _AZ_TR_TREE_NODE_HPP_ 11 | 12 | #include "AzTreeNodes.hpp" 13 | 14 | class AzTrTree; 15 | 16 | /*---------------------------------------------*/ 17 | /*! used only for training */ 18 | class AzTrTreeNode : /* extends */ public virtual AzTreeNode { 19 | protected: 20 | const int *dxs; /* data indexes belonging to this node */ 21 | 22 | public: 23 | int dxs_offset; /* position in the data indexes at the root */ 24 | int dxs_num; 25 | int depth; //!< node depth 26 | 27 | AzTrTreeNode() : depth(-1), dxs(NULL), dxs_offset(-1), dxs_num(-1) {} 28 | void reset() { 29 | AzTreeNode::reset(); 30 | depth = dxs_offset = dxs_num = -1; 31 | dxs = NULL; 32 | } 33 | void transfer_from(AzTrTreeNode *inp) { 34 | AzTreeNode::transfer_from(inp); 35 | dxs = inp->dxs; 36 | dxs_offset = inp->dxs_offset; 37 | dxs_num = inp->dxs_num; 38 | depth = inp->depth; 39 | gain = inp->gain; 40 | } 41 | 42 | inline const int *data_indexes() const { 43 | if (dxs_num > 0 && dxs == NULL) { 44 | throw new AzException("AzTrTreeNode::data_indexes", 45 | "data indexes are unavailable"); 46 | } 47 | return dxs; 48 | } 49 | inline void reset_data_indexes(const int *ptr) { 50 | dxs = ptr; 51 | } 52 | 53 | friend class AzTrTree; 54 | }; 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /RGF/src/tet/AzTrTree_ReadOnly.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTrTree_ReadOnly.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TR_TREE_READONLY_HPP_ 10 | #define _AZ_TR_TREE_READONLY_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzDataForTrTree.hpp" 14 | #include "AzTreeRule.hpp" 15 | #include "AzSvFeatInfo.hpp" 16 | #include "AzTreeNodes.hpp" 17 | #include "AzTrTreeNode.hpp" 18 | #include "AzSortedFeat.hpp" 19 | 20 | //! Abstract class: interface for read-only (information-seeking) access to trainable tree. 21 | /*------------------------------------------*/ 22 | /* Trainable tree; read only */ 23 | class AzTrTree_ReadOnly : /* implements */ public virtual AzTreeNodes 24 | { 25 | public: 26 | /*--- information seeking ... ---*/ 27 | virtual int nodeNum() const = 0; 28 | virtual int countLeafNum() const = 0; 29 | virtual int maxDepth() const = 0; 30 | virtual void show(const AzSvFeatInfo *feat, const AzOut &out) const = 0; 31 | virtual void concat_stat(AzBytArr *o) const = 0; 32 | virtual double getRule(int inp_nx, AzTreeRule *rule) const = 0; 33 | virtual void concatDesc(const AzSvFeatInfo *feat, int nx, 34 | AzBytArr *str_desc, /* output */ 35 | int max_len=-1) const = 0; 36 | virtual void isActiveNode(bool doAllowZeroWeightLeaf, 37 | AzIntArr *ia_isDecisionNode) const = 0; /* output */ 38 | virtual bool usingInternalNodes() const = 0; 39 | 40 | virtual const AzSortedFeatArr *sorted_array(int nx, 41 | const AzDataForTrTree *data) const = 0; 42 | /*--- (NOTE) this is const but changes sorted_arr[nx] ---*/ 43 | 44 | virtual const AzIntArr *root_dx() const = 0; 45 | 46 | /*--- apply ... ---*/ 47 | virtual double apply(const AzDataForTrTree *data, int dx, 48 | AzIntArr *ia_nx=NULL) const /* node path */ 49 | = 0; 50 | 51 | virtual const AzTrTreeNode *node(int nx) const = 0; 52 | }; 53 | #endif 54 | -------------------------------------------------------------------------------- /RGF/src/tet/AzTrTsplit.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTrTsplit.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson, 2018 RGF-team 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TRT_SPLIT_HPP_ 10 | #define _AZ_TRT_SPLIT_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzTools.hpp" 14 | 15 | //! Node split information. 16 | class AzTrTsplit { 17 | public: 18 | int fx; 19 | double border_val; 20 | double gain; 21 | double bestP[2]; /* le gt */ 22 | int weighted_n_samples[2]; 23 | AzBytArr str_desc; 24 | 25 | int tx, nx; /* set only by Rgf; not used by Std */ 26 | 27 | AzTrTsplit() : fx(-1), border_val(0), gain(0), tx(-1), nx(-1) { 28 | bestP[0] = bestP[1] = 0; 29 | weighted_n_samples[0] = weighted_n_samples[1] = 0; 30 | } 31 | 32 | virtual void print(const char *header) { 33 | #if 0 34 | printf("%s, fx=%d,border_val=%e,gain=%e,bestP[0]=%e,bestP[1]=%e,tx=%d,nx=%d\n", 35 | header, fx, border_val, gain, bestP[0], bestP[1], tx, nx); 36 | #endif 37 | } 38 | 39 | virtual 40 | void reset() { 41 | fx = -1; 42 | border_val = 0; 43 | bestP[0] = bestP[1] = 0; 44 | weighted_n_samples[0] = weighted_n_samples[1] = 0; 45 | gain = 0; 46 | str_desc.reset(); 47 | tx = nx = -1; 48 | } 49 | AzTrTsplit(int fx, double border_val, 50 | double gain, 51 | double bestP_L, double bestP_G) { 52 | reset_values(fx, border_val, gain, bestP_L, bestP_G); 53 | } 54 | AzTrTsplit(const AzTrTsplit *inp) { /* copy */ 55 | copy(inp); 56 | } 57 | virtual 58 | inline bool isEmpty() const { 59 | if (fx < 0) return true; 60 | return false; 61 | } 62 | 63 | virtual 64 | inline void reset(const AzTrTsplit *inp) { 65 | copy(inp); 66 | } 67 | virtual 68 | inline void reset(const AzTrTsplit *inp, int inp_tx, int inp_nx) { 69 | reset(inp); 70 | tx = inp_tx; 71 | nx = inp_nx; 72 | } 73 | virtual 74 | void copy(const AzTrTsplit *inp) { 75 | if (inp == NULL) return; 76 | fx = inp->fx; 77 | border_val = inp->border_val; 78 | gain = inp->gain; 79 | str_desc.clear(); 80 | str_desc.concat(&inp->str_desc); 81 | bestP[0] = inp->bestP[0]; 82 | bestP[1] = inp->bestP[1]; 83 | tx = inp->tx; 84 | nx = inp->nx; 85 | } 86 | 87 | virtual 88 | void reset_values(int inp_fx, double inp_border_val, 89 | double inp_gain, 90 | double bestP_L, double bestP_G) 91 | { 92 | fx = inp_fx; 93 | border_val = inp_border_val; 94 | gain = inp_gain; 95 | bestP[0] = bestP_L; 96 | bestP[1] = bestP_G; 97 | tx = nx = -1; 98 | } 99 | virtual 100 | void release() { 101 | str_desc.reset(); 102 | } 103 | }; 104 | #endif 105 | -------------------------------------------------------------------------------- /RGF/src/tet/AzTrTtarget.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTrTtarget.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TRT_TARGET_HPP_ 10 | #define _AZ_TRT_TARGET_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzDmat.hpp" 14 | 15 | //! Targets and data point weights for node split search. 16 | /*--------------------------------------------------------*/ 17 | class AzTrTtarget { 18 | protected: 19 | AzDvect v_tar_dw; // v_y - v_p(red) 20 | AzDvect v_dw; 21 | AzDvect v_y; 22 | AzDvect v_fixed_dw; /* data point weights assigned by users */ 23 | double fixed_dw_sum; 24 | 25 | public: 26 | AzTrTtarget() : fixed_dw_sum(-1) {} 27 | AzTrTtarget(const AzDvect *inp_v_y, 28 | const AzDvect *inp_v_fixed_dw=NULL) { 29 | reset(inp_v_y, inp_v_fixed_dw); 30 | } 31 | void reset(const AzDvect *inp_v_y, 32 | const AzDvect *inp_v_fixed_dw=NULL) { 33 | v_dw.reform(inp_v_y->rowNum()); 34 | v_dw.set(1); 35 | v_tar_dw.set(inp_v_y); 36 | v_y.set(inp_v_y); 37 | fixed_dw_sum = -1; 38 | 39 | v_fixed_dw.reset(); 40 | if (!AzDvect::isNull(inp_v_fixed_dw)) { 41 | v_fixed_dw.set(inp_v_fixed_dw); 42 | if (v_fixed_dw.rowNum() != v_y.rowNum()) { 43 | throw new AzException(AzInputError, "AzTrTtarget::reset", 44 | "conlict in dimensionality: y and data point weights"); 45 | } 46 | fixed_dw_sum = v_fixed_dw.sum(); 47 | } 48 | } 49 | inline bool isWeighted() const { 50 | return !AzDvect::isNull(&v_fixed_dw); 51 | } 52 | inline double sum_fixed_dw() const { 53 | return fixed_dw_sum; 54 | } 55 | inline void weight_tarDw() { 56 | v_tar_dw.scale(&v_fixed_dw); 57 | } 58 | inline void weight_dw() { 59 | v_dw.scale(&v_fixed_dw); 60 | } 61 | 62 | AzTrTtarget(const AzTrTtarget *inp) { 63 | reset(inp); 64 | } 65 | 66 | void reset(const AzTrTtarget *inp) { 67 | if (inp != NULL) { 68 | v_tar_dw.set(&inp->v_tar_dw); 69 | v_dw.set(&inp->v_dw); 70 | v_y.set(&inp->v_y); 71 | v_fixed_dw.set(&inp->v_fixed_dw); 72 | fixed_dw_sum = inp->fixed_dw_sum; 73 | } 74 | } 75 | 76 | void resetTarDw_residual(const AzDvect *v_p) { /* only for LS */ 77 | v_tar_dw.set(&v_y); 78 | v_tar_dw.add(v_p, -1); 79 | } 80 | inline const double *dw_arr() const { 81 | return v_dw.point(); 82 | } 83 | inline const double *tarDw_arr() const { 84 | return v_tar_dw.point(); 85 | } 86 | inline const AzDvect *y() const { 87 | return &v_y; 88 | } 89 | inline int dataNum() const { 90 | return v_tar_dw.rowNum(); 91 | } 92 | inline AzDvect *tarDw_forUpdate() { 93 | return &v_tar_dw; 94 | } 95 | inline const AzDvect *tarDw() const { 96 | return &v_tar_dw; 97 | } 98 | inline AzDvect *dw_forUpdate() { 99 | return &v_dw; 100 | } 101 | inline const AzDvect *dw() { 102 | return &v_dw; 103 | } 104 | inline double getTarDwSum(const int *dxs, int dxs_num) const { 105 | return v_tar_dw.sum(dxs, dxs_num); 106 | } 107 | inline double getDwSum(const int *dxs, int dxs_num) const { 108 | return v_dw.sum(dxs, dxs_num); 109 | } 110 | inline double getTarDwSum(const AzIntArr *ia_dx=NULL) const { 111 | return v_tar_dw.sum(ia_dx); 112 | } 113 | inline double getDwSum(const AzIntArr *ia_dx=NULL) const { 114 | return v_dw.sum(ia_dx); 115 | } 116 | 117 | int dim() const { 118 | return v_tar_dw.rowNum(); 119 | } 120 | }; 121 | #endif 122 | -------------------------------------------------------------------------------- /RGF/src/tet/AzTree.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTree.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson, 2018 RGF-team 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TREE_HPP_ 10 | #define _AZ_TREE_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzSmat.hpp" 14 | #include "AzSvFeatInfo.hpp" 15 | #include "AzTreeNodes.hpp" 16 | 17 | //! Untrainalbe regression tree. 18 | /*------------------------------------------*/ 19 | class AzTree : /* implements */ public virtual AzTreeNodes { 20 | protected: 21 | int root_nx; 22 | int nodes_used; 23 | AzBaseArray a_nodes; 24 | 25 | inline void _checkNode(int nx, const char *eyec) const { 26 | if (nodes == NULL || nx < 0 || nx >= nodes_used) { 27 | throw new AzException(eyec, "nx is out of range"); 28 | } 29 | } 30 | 31 | public: 32 | AzTree() : root_nx(-1), nodes_used(0), nodes(NULL) {} 33 | AzTree(AzFile *file) 34 | : root_nx(-1), nodes_used(0), nodes(NULL) { 35 | _read(file); 36 | } 37 | AzTree(const AzTreeNodes *inp) : root_nx(-1), nodes_used(0), nodes(NULL) { 38 | if (inp != NULL) { 39 | copy_from(inp); 40 | } 41 | } 42 | ~AzTree() {} 43 | 44 | AzTreeNode *nodes; 45 | 46 | void copy_from(const AzTreeNodes *tree_nodes); 47 | 48 | inline void transfer_from(AzTree *) { 49 | throw new AzException("AzTree::transfer_from", "no support"); 50 | } 51 | inline AzTree & operator =(const AzTree &inp) { 52 | if (this == &inp) return *this; 53 | throw new AzException("AzTree:=", "Don't use ="); 54 | } 55 | void reset() { 56 | _release(); 57 | } 58 | 59 | void write(AzFile *file); 60 | void read(AzFile *file); 61 | 62 | inline int nodeNum() const { 63 | return nodes_used; 64 | } 65 | 66 | static double apply(const AzReadOnlyVector *v_data, 67 | const AzTreeNodes *nodes, 68 | AzIntArr *ia_node=NULL); 69 | 70 | double apply(const AzReadOnlyVector *v_data, 71 | AzIntArr *ia_node=NULL) const { 72 | checkNodes("apply"); 73 | return apply(v_data, this, ia_node); 74 | } 75 | 76 | void show(const AzSvFeatInfo *feat, const AzOut &out, 77 | const char *header="") const; 78 | int leafNum() const; 79 | void clean_up(); 80 | 81 | inline const AzTreeNode *node(int nx) const { 82 | checkNode(nx, "point"); 83 | return &nodes[nx]; 84 | } 85 | inline int root() const { 86 | return root_nx; 87 | } 88 | void finfo(AzIFarr *ifa_fx_count, 89 | AzIFarr *ifa_fx_sum) const; /* appended */ 90 | void finfo(AzIntArr *ia_fxs) const; /* appended */ 91 | 92 | void cooccurrences(AzIIFarr *iifa_fx1_fx2_count) const; 93 | 94 | virtual void genDesc(const AzSvFeatInfo *feat, 95 | int nx, 96 | AzBytArr *s) /* output */ 97 | const; 98 | 99 | protected: 100 | /*--- functions ---*/ 101 | void _read(AzFile *file); 102 | 103 | inline void checkNode(int nx, const char *eyec) const { 104 | if (nodes == NULL || nx < 0 || nx >= nodes_used) { 105 | throw new AzException(eyec, "AzTree, nx is out of range"); 106 | } 107 | } 108 | inline void checkNodes(const char *msg) const { 109 | if (nodes == NULL && nodes_used > 0) { 110 | throw new AzException("AzTree, no nodes", msg); 111 | } 112 | } 113 | void _show(const AzSvFeatInfo *feat, 114 | int nx, int depth, 115 | const AzOut &out) const; 116 | 117 | void _release(); 118 | virtual void _genDesc(const AzSvFeatInfo *feat, 119 | int nx, 120 | AzBytArr *s) /* output */ 121 | const; 122 | }; 123 | 124 | #endif 125 | -------------------------------------------------------------------------------- /RGF/src/tet/AzTreeEnsemble.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTreeEnsemble.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TREE_ENSEMBLE_HPP_ 10 | #define _AZ_TREE_ENSEMBLE_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | #include "AzTree.hpp" 14 | #include "AzDmat.hpp" 15 | #include "AzTE_ModelInfo.hpp" 16 | 17 | //! Untrainable tree ensemble. Like applier. Generated as a result of training. 18 | class AzTreeEnsemble 19 | { 20 | protected: 21 | AzObjPtrArray a_tree; 22 | AzTree **t; 23 | int t_num; 24 | AzTree empty_tree; 25 | double const_val; 26 | 27 | AzBytArr s_config, s_sign; 28 | int org_dim; /* dimension of original features */ 29 | static const int kReservedLength = 256; 30 | 31 | public: 32 | AzTreeEnsemble() : t(NULL), t_num(0), const_val(0), org_dim(-1) {} 33 | ~AzTreeEnsemble() {} 34 | 35 | AzTreeEnsemble(const char *fn) 36 | : t(NULL), t_num(0), const_val(0), org_dim(-1) { 37 | read(fn); 38 | } 39 | AzTreeEnsemble(AzFile *file) 40 | : t(NULL), t_num(0), const_val(0), org_dim(-1) { 41 | _read(file); 42 | } 43 | 44 | void transfer_from(AzTree *inp_tree[], /* destroys input */ 45 | int inp_tree_num, 46 | double const_val, 47 | int orgdim, 48 | const char *config, 49 | const char *sign); 50 | 51 | void read(const char *fn); 52 | void read(AzFile *file) { 53 | _release(); 54 | _read(file); 55 | } 56 | void write(const char *fn); 57 | void write(AzFile *file); 58 | 59 | inline void destroy() { 60 | _release(); 61 | } 62 | 63 | inline const AzTree *tree(int tx) const { 64 | checkIndex(tx, "tree"); 65 | if (t[tx] == NULL) { 66 | return &empty_tree; 67 | } 68 | return t[tx]; 69 | } 70 | 71 | int leafNum() const { 72 | return leafNum(0, t_num); 73 | } 74 | int leafNum(int tx0, int tx1) const; 75 | inline int size() const { return t_num; } 76 | 77 | void apply(const AzSmat *m_data, 78 | AzDvect *v_pred) /* output */ 79 | const; 80 | double apply(const AzSvect *v_data) const; 81 | 82 | inline double constant() const { return const_val; } 83 | inline int orgdim() const { return org_dim; } 84 | const char *signature() const { return s_sign.c_str(); } 85 | const char *configuration() const { return s_config.c_str(); } 86 | 87 | /*---*/ 88 | void info(AzTE_ModelInfo *out_info) const; 89 | 90 | void show(const AzSvFeatInfo *feat, //!< may be NULL 91 | const AzOut &out, const char *header="") const; 92 | void finfo(AzIFarr *ifa_fx_count, 93 | AzIFarr *ifa_fx_sum) const { 94 | finfo(0, t_num, ifa_fx_count, ifa_fx_sum); 95 | } 96 | void finfo(int tx0, int tx1, 97 | AzIFarr *ifa_fx_count, 98 | AzIFarr *ifa_fx_sum) const; 99 | void finfo(AzIntArr *ia_fx2tx) const; 100 | void cooccurrences(AzIIFarr *iifa_fx1_fx2_count) const; 101 | 102 | void show_weights(const AzOut &out, AzSvFeatInfo *fi) const; 103 | 104 | protected: 105 | void _read(AzFile *file); 106 | inline void _release() { 107 | a_tree.free(&t); t_num = 0; 108 | s_config.reset(); 109 | s_sign.reset(); 110 | const_val = 0; 111 | org_dim = -1; 112 | } 113 | inline void checkIndex(int tx, const char *msg) const { 114 | if (tx < 0 || tx >= t_num) { 115 | throw new AzException("AzTreeEnsemble::checkIndex", msg); 116 | } 117 | } 118 | void clean_up(); 119 | }; 120 | #endif 121 | 122 | -------------------------------------------------------------------------------- /RGF/src/tet/AzTreeNodes.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTreeNodes.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson, 2018 RGF-team 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TREE_NODES_HPP_ 10 | #define _AZ_TREE_NODES_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | 14 | /*! Tree node */ 15 | class AzTreeNode { 16 | public: 17 | int fx; //!< feature id 18 | double border_val; 19 | int le_nx; //!< x[fx] <= border_val 20 | int gt_nx; //!< x[fx] > border_val 21 | int parent_nx; //!< pointing parent node 22 | double weight; //!< weight 23 | double gain; //!< impurity for calc feature importances 24 | 25 | /*--- ---*/ 26 | AzTreeNode() { 27 | reset(); 28 | } 29 | void reset() { 30 | border_val = 0; 31 | weight = 0; 32 | gain = 0; 33 | fx = le_nx = gt_nx = parent_nx = -1; 34 | } 35 | AzTreeNode(AzFile *file) { 36 | read(file); 37 | } 38 | inline bool isLeaf() const { 39 | if (le_nx < 0) return true; 40 | return false; 41 | } 42 | void write(AzFile *file); 43 | void read(AzFile *file); 44 | 45 | void transfer_from(AzTreeNode *inp) { 46 | *this = *inp; 47 | } 48 | }; 49 | 50 | class AzTreeNodes { 51 | public: 52 | virtual const AzTreeNode *node(int nx) const = 0; 53 | virtual int nodeNum() const = 0; 54 | virtual int root() const = 0; 55 | }; 56 | #endif 57 | -------------------------------------------------------------------------------- /RGF/src/tet/AzTreeRule.hpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * AzTreeRule.hpp 3 | * Copyright (C) 2011, 2012 Rie Johnson 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #ifndef _AZ_TREE_RULE_HPP_ 10 | #define _AZ_TREE_RULE_HPP_ 11 | 12 | #include "AzUtil.hpp" 13 | 14 | class AzTreeRule { 15 | protected: 16 | AzBytArr ba; 17 | 18 | public: 19 | inline void reset() { 20 | ba.clear(); 21 | } 22 | inline const AzBytArr *bytarr() { 23 | return &ba; 24 | } 25 | inline void reset(const AzTreeRule *inp) { 26 | ba.reset(); 27 | if (inp == NULL) return; 28 | ba.reset(&inp->ba); 29 | } 30 | inline void append(int fx, 31 | bool isLE, 32 | double border_val) 33 | { 34 | /*--- feat#, isLE, border_val ---*/ 35 | ba.concat((AzByte *)(&fx), sizeof(fx)); 36 | ba.concat((AzByte *)(&isLE), sizeof(isLE)); 37 | ba.concat((AzByte *)(&border_val), sizeof(border_val)); 38 | } 39 | inline void append(const AzTreeRule *inp) { 40 | if (inp != NULL) { 41 | ba.concat(&inp->ba); 42 | } 43 | } 44 | inline void finalize() { 45 | if (ba.getLen() == 0) { 46 | ba.concat('_'); /* root node (CONST) */ 47 | } 48 | } 49 | inline const AzBytArr *byteArr() { 50 | return &ba; 51 | } 52 | }; 53 | #endif 54 | -------------------------------------------------------------------------------- /RGF/src/tet/driv_rgf.cpp: -------------------------------------------------------------------------------- 1 | /* * * * * 2 | * driv_rgf.cpp 3 | * Copyright (C) 2011, 2012 Rie Johnson, 2018 RGF-team 4 | * 5 | * This software may be modified and distributed under the terms 6 | * of the MIT license. See the COPYING file for details. 7 | * * * * */ 8 | 9 | #define _AZ_MAIN_ 10 | #include "AzUtil.hpp" 11 | #include "AzTETmain.hpp" 12 | #include "AzRgfTrainerSel.hpp" 13 | #include "AzTET_Eval_Dflt.hpp" 14 | #include "AzHelp.hpp" 15 | 16 | /*-----------------------------------------------------------------*/ 17 | void help(int argc, const char *argv[]) 18 | { 19 | cout << "Arguments: action parameters" <getMessage() << endl; 96 | return -1; 97 | } 98 | 99 | return 0; 100 | } 101 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: off 2 | 3 | fixes: 4 | - "::R-package/" 5 | -------------------------------------------------------------------------------- /python-package/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 RGF-team and other contributors 4 | 5 | Each contributor holds copyright over their respective contributions. 6 | The project versioning (Git) records all such contribution source information. 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | -------------------------------------------------------------------------------- /python-package/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include Readme.rst _IS_SOURCE_PACKAGE 2 | recursive-include rgf VERSION 3 | 4 | recursive-include compile/RGF CMakeLists.txt 5 | recursive-include compile/RGF/build makefile 6 | recursive-include compile/RGF/Windows/rgf * 7 | recursive-include compile/RGF/src * 8 | 9 | recursive-include compile/FastRGF CMakeLists.txt 10 | recursive-include compile/FastRGF/src * 11 | recursive-include compile/FastRGF/include * 12 | 13 | global-exclude *.pyo 14 | global-exclude *.pyc 15 | -------------------------------------------------------------------------------- /python-package/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:16.04 2 | 3 | # apt dependency 4 | RUN apt-get update && \ 5 | apt-get install -y cmake build-essential gcc g++ git wget && \ 6 | 7 | # python-package 8 | # miniconda 9 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 10 | /bin/bash Miniconda3-latest-Linux-x86_64.sh -f -b -p /opt/conda && \ 11 | export PATH="/opt/conda/bin:$PATH" && \ 12 | # rgf_python 13 | conda install -y -q numpy joblib scipy scikit-learn pandas && \ 14 | git clone https://github.com/RGF-team/rgf.git && \ 15 | cd rgf/python-package && python setup.py install && \ 16 | 17 | # clean 18 | apt-get autoremove -y && apt-get clean && \ 19 | conda clean -i -l -t -y && \ 20 | rm -rf /usr/local/src/* 21 | 22 | ENV PATH /opt/conda/bin:$PATH 23 | -------------------------------------------------------------------------------- /python-package/examples/FastRGF/FastRGF_classifier_on_iris_dataset.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from sklearn import datasets 4 | from sklearn.utils.validation import check_random_state 5 | from sklearn.ensemble import GradientBoostingClassifier 6 | from rgf.sklearn import RGFClassifier, FastRGFClassifier 7 | 8 | iris = datasets.load_iris() 9 | rng = check_random_state(0) 10 | perm = rng.permutation(iris.target.size) 11 | iris.data = iris.data[perm] 12 | iris.target = iris.target[perm] 13 | 14 | start = time.time() 15 | clf = RGFClassifier() 16 | clf.fit(iris.data, iris.target) 17 | score = clf.score(iris.data, iris.target) 18 | end = time.time() 19 | print("RGF: {} sec".format(end - start)) 20 | print("score: {}".format(score)) 21 | 22 | start = time.time() 23 | clf = FastRGFClassifier() 24 | clf.fit(iris.data, iris.target) 25 | score = clf.score(iris.data, iris.target) 26 | end = time.time() 27 | print("FastRGF: {} sec".format(end - start)) 28 | print("score: {}".format(score)) 29 | 30 | start = time.time() 31 | clf = GradientBoostingClassifier() 32 | clf.fit(iris.data, iris.target) 33 | score = clf.score(iris.data, iris.target) 34 | end = time.time() 35 | print("Gradient Boosting: {} sec".format(end - start)) 36 | print("score: {}".format(score)) 37 | -------------------------------------------------------------------------------- /python-package/examples/FastRGF/FastRGF_regressor_on_boston_dataset.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from sklearn.datasets import load_diabetes 4 | from sklearn.utils.validation import check_random_state 5 | from sklearn.ensemble import RandomForestRegressor 6 | from rgf.sklearn import FastRGFRegressor, RGFRegressor 7 | 8 | diabetes = load_diabetes() 9 | rng = check_random_state(42) 10 | perm = rng.permutation(diabetes.target.size) 11 | diabetes.data = diabetes.data[perm] 12 | diabetes.target = diabetes.target[perm] 13 | 14 | train_x = diabetes.data[:300] 15 | test_x = diabetes.data[300:] 16 | train_y = diabetes.target[:300] 17 | test_y = diabetes.target[300:] 18 | 19 | start = time.time() 20 | reg = RGFRegressor() 21 | reg.fit(train_x, train_y) 22 | score = reg.score(test_x, test_y) 23 | end = time.time() 24 | print("RGF: {} sec".format(end - start)) 25 | print("score: {}".format(score)) 26 | 27 | start = time.time() 28 | reg = FastRGFRegressor() 29 | reg.fit(train_x, train_y) 30 | score = reg.score(test_x, test_y) 31 | end = time.time() 32 | print("FastRGF: {} sec".format(end - start)) 33 | print("score: {}".format(score)) 34 | 35 | start = time.time() 36 | reg = RandomForestRegressor(n_estimators=100) 37 | reg.fit(train_x, train_y) 38 | score = reg.score(test_x, test_y) 39 | end = time.time() 40 | print("Random Forest: {} sec".format(end - start)) 41 | print("score: {}".format(score)) 42 | -------------------------------------------------------------------------------- /python-package/examples/RGF/comparison_RGF_and_GBM_classifiers_on_iris_dataset.py: -------------------------------------------------------------------------------- 1 | from sklearn import datasets 2 | from sklearn.utils.validation import check_random_state 3 | from sklearn.model_selection import StratifiedKFold, cross_val_score 4 | from sklearn.ensemble import GradientBoostingClassifier 5 | from rgf.sklearn import RGFClassifier 6 | 7 | iris = datasets.load_iris() 8 | rng = check_random_state(0) 9 | perm = rng.permutation(iris.target.size) 10 | iris.data = iris.data[perm] 11 | iris.target = iris.target[perm] 12 | 13 | rgf = RGFClassifier(max_leaf=400, 14 | algorithm="RGF_Sib", 15 | test_interval=100, 16 | verbose=True) 17 | gb = GradientBoostingClassifier(n_estimators=20, 18 | learning_rate=0.01, 19 | subsample=0.6, 20 | random_state=rng) 21 | 22 | n_folds = 3 23 | 24 | rgf_scores = cross_val_score(rgf, 25 | iris.data, 26 | iris.target, 27 | cv=StratifiedKFold(n_folds)) 28 | 29 | gb_scores = cross_val_score(gb, 30 | iris.data, 31 | iris.target, 32 | cv=StratifiedKFold(n_folds)) 33 | 34 | rgf_score = sum(rgf_scores)/n_folds 35 | print('RGF Classifier score: {0:.5f}'.format(rgf_score)) 36 | # >>>RGF Classifier score: 0.95997 37 | 38 | gb_score = sum(gb_scores)/n_folds 39 | print('Gradient Boosting Classifier score: {0:.5f}'.format(gb_score)) 40 | # >>>Gradient Boosting Classifier score: 0.95997 41 | -------------------------------------------------------------------------------- /python-package/examples/RGF/comparison_RGF_and_RF_regressors_on_diabetes_dataset.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_diabetes 2 | from sklearn.utils.validation import check_random_state 3 | from sklearn.model_selection import cross_val_score 4 | from sklearn.metrics import make_scorer, mean_squared_error 5 | from sklearn.ensemble import RandomForestRegressor 6 | from rgf.sklearn import RGFRegressor 7 | 8 | diabetes = load_diabetes() 9 | rng = check_random_state(42) 10 | perm = rng.permutation(diabetes.target.size) 11 | diabetes.data = diabetes.data[perm] 12 | diabetes.target = diabetes.target[perm] 13 | 14 | rgf = RGFRegressor(max_leaf=30, 15 | n_iter=5, 16 | learning_rate=0.2, 17 | algorithm="RGF", 18 | test_interval=100, 19 | loss="LS", 20 | verbose=False) 21 | rf = RandomForestRegressor(n_estimators=600, 22 | min_samples_leaf=3, 23 | max_depth=10, 24 | random_state=42) 25 | 26 | n_folds = 3 27 | 28 | rgf_scores = cross_val_score(rgf, 29 | diabetes.data, 30 | diabetes.target, 31 | scoring=make_scorer(mean_squared_error), 32 | cv=n_folds) 33 | rf_scores = cross_val_score(rf, 34 | diabetes.data, 35 | diabetes.target, 36 | scoring=make_scorer(mean_squared_error), 37 | cv=n_folds) 38 | 39 | rgf_score = sum(rgf_scores)/n_folds 40 | print('RGF Regressor MSE: {0:.5f}'.format(rgf_score)) 41 | # >>> RGF Regressor MSE: 3377.46076 42 | 43 | rf_score = sum(rf_scores)/n_folds 44 | print('Random Forest Regressor MSE: {0:.5f}'.format(rf_score)) 45 | # >>> Random Forest Regressor MSE: 3441.01988 46 | -------------------------------------------------------------------------------- /python-package/rgf/VERSION: -------------------------------------------------------------------------------- 1 | 3.12.0 2 | -------------------------------------------------------------------------------- /python-package/rgf/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ('RGFClassifier', 'RGFRegressor', 2 | 'FastRGFClassifier', 'FastRGFRegressor') 3 | 4 | 5 | import os 6 | 7 | from rgf.rgf_model import RGFRegressor, RGFClassifier 8 | from rgf.fastrgf_model import FastRGFRegressor, FastRGFClassifier 9 | 10 | 11 | with open(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'VERSION')) as _f: 12 | __version__ = _f.read().strip() 13 | -------------------------------------------------------------------------------- /python-package/rgf/sklearn.py: -------------------------------------------------------------------------------- 1 | __all__ = ('RGFClassifier', 'RGFRegressor', 2 | 'FastRGFClassifier', 'FastRGFRegressor') 3 | 4 | from rgf.rgf_model import RGFRegressor, RGFClassifier 5 | from rgf.fastrgf_model import FastRGFRegressor, FastRGFClassifier 6 | -------------------------------------------------------------------------------- /python-package/tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import os 3 | import unittest 4 | 5 | 6 | def find_files(directory, pattern='*.py'): 7 | for root, _, files in os.walk(directory): 8 | for filename in files: 9 | if fnmatch.fnmatch(filename, pattern): 10 | filename = os.path.abspath(os.path.join(root, filename)) 11 | yield filename 12 | 13 | 14 | class TestExamples(unittest.TestCase): 15 | def test_examples(self): 16 | for filename in find_files(os.path.join(os.path.abspath(os.path.dirname(__file__)), os.path.pardir, 'examples')): 17 | exec(open(filename).read(), globals()) 18 | --------------------------------------------------------------------------------