├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── R ├── README.md └── gbm.R ├── README.md ├── dataset ├── machine.conf ├── test_dataset.txt └── test_dataset.txt.group ├── docs ├── Makefile ├── README.md ├── _static │ ├── default.css │ ├── lang-logo-tgbm.png │ ├── overall.png │ └── tgbm-logo.png ├── conf.py ├── faq.md ├── how-to.md ├── index.md ├── make.bat ├── parameters.md └── requirements.txt ├── include └── thundergbm │ ├── booster.h │ ├── builder │ ├── exact_tree_builder.h │ ├── function_builder.h │ ├── hist_tree_builder.h │ ├── hist_tree_builder_single.h │ ├── shard.h │ └── tree_builder.h │ ├── common.h │ ├── config.h.in │ ├── dataset.h │ ├── hist_cut.h │ ├── ins_stat.h │ ├── metric │ ├── metric.h │ ├── multiclass_metric.h │ ├── pointwise_metric.h │ └── ranking_metric.h │ ├── objective │ ├── multiclass_obj.h │ ├── objective_function.h │ ├── ranking_obj.h │ └── regression_obj.h │ ├── parser.h │ ├── predictor.h │ ├── quantile_sketch.h │ ├── row_sampler.h │ ├── sparse_columns.h │ ├── syncarray.h │ ├── syncmem.h │ ├── trainer.h │ ├── tree.h │ └── util │ ├── cub_wrapper.h │ ├── device_lambda.cuh │ ├── log.h │ └── multi_device.h ├── python ├── LICENSE ├── README.md ├── benchmarks │ ├── README.md │ ├── convert_dataset_plk.py │ ├── experiments.py │ ├── model │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── catboost_model.py │ │ ├── datasets.py │ │ ├── lightgbm_model.py │ │ ├── thundergbm_model.py │ │ └── xgboost_model.py │ ├── run_exp.sh │ └── utils │ │ ├── __init__.py │ │ ├── data_utils.py │ │ └── file_utils.py ├── dist │ ├── thundergbm-0.3.12-py2-none-win_amd64.whl │ ├── thundergbm-0.3.12-py3-none-win_amd64.whl │ ├── thundergbm-0.3.16-py3-none-any.whl │ └── thundergbm-0.3.4-py3-none-win_amd64.whl ├── examples │ ├── classification_demo.py │ ├── ranking_demo.py │ └── regression_demo.py ├── requirements.txt ├── setup.py └── thundergbm │ ├── __init__.py │ └── thundergbm.py ├── src ├── test │ ├── CMakeLists.txt │ ├── test_csr2csc.cpp │ ├── test_cub_wrapper.cu │ ├── test_dataset.cpp │ ├── test_for_refactor.cpp │ ├── test_gbdt.cpp │ ├── test_get_cut_point.cpp │ ├── test_gradient.cu │ ├── test_main.cpp │ ├── test_metrics.cpp │ ├── test_parser.cpp │ ├── test_synarray.cpp │ ├── test_synmem.cpp │ └── test_tree.cpp └── thundergbm │ ├── CMakeLists.txt │ ├── builder │ ├── exact_tree_builder.cu │ ├── function_builder.cu │ ├── hist_tree_builder.cu │ ├── hist_tree_builder_single.cu │ ├── shard.cu │ └── tree_builder.cu │ ├── dataset.cpp │ ├── gbm_R_interface.cpp │ ├── hist_cut.cu │ ├── metric │ ├── metric.cu │ ├── multiclass_metric.cu │ ├── pointwise_metric.cu │ └── rank_metric.cpp │ ├── objective │ ├── multiclass_obj.cu │ ├── objective_function.cu │ └── ranking_obj.cpp │ ├── parser.cpp │ ├── predictor.cu │ ├── quantile_sketch.cpp │ ├── row_sampler.cu │ ├── scikit_tgbm.cpp │ ├── sparse_columns.cu │ ├── syncmem.cpp │ ├── thundergbm_predict.cpp │ ├── thundergbm_train.cpp │ ├── trainer.cu │ ├── tree.cu │ └── util │ ├── common.cpp │ └── log.cpp └── thundergbm-full.pdf /.gitignore: -------------------------------------------------------------------------------- 1 | .project 2 | *logs* 3 | *.pyc 4 | *.model 5 | bin 6 | *.binary 7 | *.out 8 | *.idea 9 | tags 10 | .cproject 11 | .settings 12 | Debug 13 | Release 14 | dataset* 15 | *~ 16 | core 17 | *.DS_Store 18 | *.o 19 | *.swp 20 | *build* 21 | !builder 22 | !*build*.cu 23 | !*build*.h 24 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "cub"] 2 | path = cub 3 | url = https://github.com/NVlabs/cub.git 4 | [submodule "src/test/googletest"] 5 | path = src/test/googletest 6 | url = https://github.com/google/googletest.git 7 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(MSVC) 2 | cmake_minimum_required(VERSION 3.4) 3 | else() 4 | cmake_minimum_required(VERSION 2.8) 5 | endif() 6 | project(thundergbm) 7 | if(MSVC) 8 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) 9 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) 10 | endif() 11 | 12 | find_package(CUDA REQUIRED QUIET) 13 | find_package(OpenMP REQUIRED QUIET) 14 | 15 | if (NOT CMAKE_BUILD_TYPE) 16 | set(CMAKE_BUILD_TYPE Release) 17 | endif () 18 | 19 | # CUDA 11 20 | if(CUDA_VERSION VERSION_LESS "11.0") 21 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++11 -lineinfo --expt-extended-lambda --default-stream per-thread") 22 | if (CMAKE_VERSION VERSION_LESS "3.1") 23 | add_compile_options("-std=c++11") 24 | else () 25 | set(CMAKE_CXX_STANDARD 11) 26 | endif () 27 | else() 28 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++14 -lineinfo --expt-extended-lambda --default-stream per-thread") 29 | endif() 30 | 31 | if (OPENMP_FOUND) 32 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 33 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 34 | endif () 35 | 36 | # for easylogging++ configuration 37 | add_definitions("-DELPP_FEATURE_PERFORMANCE_TRACKING") 38 | add_definitions("-DELPP_THREAD_SAFE") 39 | add_definitions("-DELPP_STL_LOGGING") 40 | add_definitions("-DELPP_NO_LOG_TO_FILE") 41 | # includes 42 | set(COMMON_INCLUDES ${PROJECT_SOURCE_DIR}/include ${CMAKE_CURRENT_BINARY_DIR} ${PROJECT_SOURCE_DIR}/cub) 43 | include_directories(${COMMON_INCLUDES}) 44 | 45 | add_subdirectory(src/thundergbm) 46 | 47 | set(BUILD_TESTS OFF CACHE BOOL "Build Tests") 48 | if (BUILD_TESTS) 49 | add_subdirectory(src/test) 50 | endif () 51 | 52 | # configuration file 53 | set(DATASET_DIR ${PROJECT_SOURCE_DIR}/datasets/) 54 | configure_file(include/thundergbm/config.h.in config.h) 55 | 56 | -------------------------------------------------------------------------------- /R/README.md: -------------------------------------------------------------------------------- 1 | # R interface for thundergbm. 2 | Before you use the R interface, you must build ThunderGBM. Please refer to [Installation](https://thundergbm.readthedocs.io/en/latest/how-to.html) for building ThunderGBM. 3 | 4 | ## Methods 5 | By default, the directory for storing the training data and results is the working directory. 6 | 7 | *gbm_train_R(depth = 6, n_trees = 40, n_gpus = 1, verbose = 1, 8 | profiling = 0, data = 'None', max_num_bin = 255, column_sampling_rate = 1, 9 | bagging = 0, n_parallel_trees = 1, learning_rate = 1, objective = 'reg:linear', 10 | num_class = 1, min_child_weight = 1, lambda_tgbm = 1, gamma = 1, 11 | tree_method = 'auto', model_out = 'tgbm.model')* 12 | 13 | *gbm_predict_R(test_data = 'None', model_in = 'tgbm.model', verbose = 1)* 14 | 15 | ## Example 16 | * Step 1: go to the R interface. 17 | ```bash 18 | # in ThunderGBM root directory 19 | cd R 20 | ``` 21 | * Step 2: start R in the terminal by typing ```R``` and press the Enter key. 22 | * Step 3: execute the following code in R. 23 | ```R 24 | source("gbm.R") 25 | gbm_train_R(data = "../dataset/test_dataset.txt") 26 | gbm_predict_R(test_data = "../dataset/test_dataset.txt") 27 | ``` 28 | 29 | ## Parameters 30 | Please refer to [Parameters](https://github.com/Xtra-Computing/thundergbm/blob/master/docs/parameters.md). -------------------------------------------------------------------------------- /R/gbm.R: -------------------------------------------------------------------------------- 1 | # Created by: Qinbin 2 | # Created on: 3/15/20 3 | 4 | check_location <- function(){ 5 | if(Sys.info()['sysname'] == 'Windows'){ 6 | if(!file.exists("../build/bin/Debug/thundergbm.dll")){ 7 | stop("Please build the library first (or check you called this while your workspace is set to the thundergbm/R/ directory)!") 8 | } 9 | dyn.load("../build/bin/Debug/thundergbm.dll") 10 | } else if(Sys.info()['sysname'] == 'Linux'){ 11 | if(!file.exists("../build/lib/libthundergbm.so")){ 12 | stop("Please build the library first (or check you called this while your workspace is set to the thundergbm/R/ directory)!") 13 | } 14 | dyn.load("../build/lib/libthundergbm.so") 15 | } else if(Sys.info()['sysname'] == 'Darwin'){ 16 | if(!file.exists("../build/lib/libthundergbm.dylib")){ 17 | stop("Please build the library first (or check you called this while your workspace is set to the thundergbm/R/ directory)!") 18 | } 19 | dyn.load("../build/lib/libthundergbm.dylib") 20 | } else{ 21 | stop("OS not supported!") 22 | } 23 | } 24 | check_location() # Run this when the file is sourced 25 | 26 | gbm_train_R <- 27 | function( 28 | depth = 6, n_trees = 40, n_gpus = 1, verbose = 1, 29 | profiling = 0, data = 'None', max_num_bin = 255, column_sampling_rate = 1, 30 | bagging = 0, n_parallel_trees = 1, learning_rate = 1, objective = 'reg:linear', 31 | num_class = 1, min_child_weight = 1, lambda_tgbm = 1, gamma = 1, 32 | tree_method = 'auto', model_out = 'tgbm.model' 33 | ) 34 | { 35 | check_location() 36 | if(!file.exists(data)){stop("The file containing the training dataset provided as an argument in 'data' does not exist")} 37 | res <- .C("train_R", as.integer(depth), as.integer(n_trees), as.integer(n_gpus), as.integer(verbose), 38 | as.integer(profiling), as.character(data), as.integer(max_num_bin), as.double(column_sampling_rate), 39 | as.integer(bagging), as.integer(n_parallel_trees), as.double(learning_rate), as.character(objective), 40 | as.integer(num_class), as.integer(min_child_weight), as.double(lambda_tgbm), as.double(gamma), 41 | as.character(tree_method), as.character(model_out)) 42 | } 43 | 44 | gbm_predict_R <- 45 | function( 46 | test_data = 'None', model_in = 'tgbm.model', verbose = 1 47 | ) 48 | { 49 | check_location() 50 | if(!file.exists(test_data)){stop("The file containing the training dataset provided as an argument in 'data' does not exist")} 51 | if(!file.exists(model_in)){stop("The file containing the model provided as an argument in 'model_in' does not exist")} 52 | res <- .C("predict_R", as.character(test_data), as.character(model_in), as.integer(verbose)) 53 | } 54 | 55 | -------------------------------------------------------------------------------- /dataset/machine.conf: -------------------------------------------------------------------------------- 1 | data=../dataset/test_dataset.txt -------------------------------------------------------------------------------- /dataset/test_dataset.txt.group: -------------------------------------------------------------------------------- 1 | 100 2 | 300 3 | 400 4 | 200 5 | 605 -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = ThunderGBM 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | ## Documentation of ThunderGBM 2 | The documentations of ThunderGBM are written using Markdown, and generated by sphinx with recommonmark. 3 | 4 | ### Generate html files 5 | 6 | * go to this ```docs``` directory 7 | * type make html 8 | -------------------------------------------------------------------------------- /docs/_static/default.css: -------------------------------------------------------------------------------- 1 | .section #basic-2-flip-flop-synchronizer{ 2 | text-align:justify; 3 | } 4 | -------------------------------------------------------------------------------- /docs/_static/lang-logo-tgbm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/docs/_static/lang-logo-tgbm.png -------------------------------------------------------------------------------- /docs/_static/overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/docs/_static/overall.png -------------------------------------------------------------------------------- /docs/_static/tgbm-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/docs/_static/tgbm-logo.png -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # ThunderSVM documentation build configuration file, created by 4 | # sphinx-quickstart on Sat Oct 28 23:38:46 2017. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | # If extensions (or modules to document with autodoc) are in another directory, 16 | # add these directories to sys.path here. If the directory is relative to the 17 | # documentation root, use os.path.abspath to make it absolute, like shown here. 18 | # 19 | # sys.path.insert(0, os.path.abspath('.')) 20 | 21 | import recommonmark 22 | from recommonmark import transform 23 | AutoStructify = transform.AutoStructify 24 | 25 | # -- General configuration ------------------------------------------------ 26 | 27 | # If your documentation needs a minimal Sphinx version, state it here. 28 | # 29 | # needs_sphinx = '1.0' 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = ['sphinx.ext.mathjax'] 35 | 36 | # Add any paths that contain templates here, relative to this directory. 37 | templates_path = ['_templates'] 38 | 39 | # The suffix(es) of source filenames. 40 | # You can specify multiple suffix as a list of string: 41 | # 42 | 43 | source_parsers = { 44 | '.md': 'recommonmark.parser.CommonMarkParser', 45 | } 46 | 47 | source_suffix = ['.rst', '.md'] 48 | 49 | # The master toctree document. 50 | master_doc = 'index' 51 | 52 | # General information about the project. 53 | project = u'ThunderGBM' 54 | copyright = u'2019, ThunderGBM Developers' 55 | author = u'ThunderGBM Developers' 56 | 57 | # The version info for the project you're documenting, acts as replacement for 58 | # |version| and |release|, also used in various other places throughout the 59 | # built documents. 60 | # 61 | # The short X.Y version. 62 | version = u'0.1' 63 | # The full version, including alpha/beta/rc tags. 64 | release = u'0.1' 65 | 66 | # The language for content autogenerated by Sphinx. Refer to documentation 67 | # for a list of supported languages. 68 | # 69 | # This is also used if you do content translation via gettext catalogs. 70 | # Usually you set "language" from the command line for these cases. 71 | language = None 72 | 73 | # List of patterns, relative to source directory, that match files and 74 | # directories to ignore when looking for source files. 75 | # This patterns also effect to html_static_path and html_extra_path 76 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 77 | 78 | # The name of the Pygments (syntax highlighting) style to use. 79 | pygments_style = 'sphinx' 80 | 81 | # If true, `todo` and `todoList` produce output, else they produce nothing. 82 | todo_include_todos = False 83 | 84 | 85 | # -- Options for HTML output ---------------------------------------------- 86 | 87 | # The theme to use for HTML and HTML Help pages. See the documentation for 88 | # a list of builtin themes. 89 | # 90 | html_theme = 'sphinx_rtd_theme' 91 | 92 | # Theme options are theme-specific and customize the look and feel of a theme 93 | # further. For a list of options available for each theme, see the 94 | # documentation. 95 | # 96 | # html_theme_options = {} 97 | 98 | # Add any paths that contain custom static files (such as style sheets) here, 99 | # relative to this directory. They are copied after the builtin static files, 100 | # so a file named "default.css" will overwrite the builtin "default.css". 101 | html_static_path = ['_static'] 102 | 103 | # Custom sidebar templates, must be a dictionary that maps document names 104 | # to template names. 105 | # 106 | # This is required for the alabaster theme 107 | # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars 108 | html_sidebars = { 109 | '**': [ 110 | 'relations.html', # needs 'show_related': True theme option to display 111 | 'searchbox.html', 112 | ] 113 | } 114 | 115 | 116 | # -- Options for HTMLHelp output ------------------------------------------ 117 | 118 | # Output file base name for HTML help builder. 119 | htmlhelp_basename = 'ThunderGBMdoc' 120 | 121 | 122 | # -- Options for LaTeX output --------------------------------------------- 123 | 124 | latex_elements = { 125 | # The paper size ('letterpaper' or 'a4paper'). 126 | # 127 | # 'papersize': 'letterpaper', 128 | 129 | # The font size ('10pt', '11pt' or '12pt'). 130 | # 131 | # 'pointsize': '10pt', 132 | 133 | # Additional stuff for the LaTeX preamble. 134 | # 135 | # 'preamble': '', 136 | 137 | # Latex figure (float) alignment 138 | # 139 | # 'figure_align': 'htbp', 140 | } 141 | 142 | # Grouping the document tree into LaTeX files. List of tuples 143 | # (source start file, target name, title, 144 | # author, documentclass [howto, manual, or own class]). 145 | latex_documents = [ 146 | (master_doc, 'ThunderGBM.tex', u'ThunderGBM Documentation', 147 | u'Zeyi Wen', 'manual'), 148 | ] 149 | 150 | 151 | # -- Options for manual page output --------------------------------------- 152 | 153 | # One entry per manual page. List of tuples 154 | # (source start file, name, description, authors, manual section). 155 | man_pages = [ 156 | (master_doc, 'thundergbm', u'ThunderGBM Documentation', 157 | [author], 1) 158 | ] 159 | 160 | 161 | # -- Options for Texinfo output ------------------------------------------- 162 | 163 | # Grouping the document tree into Texinfo files. List of tuples 164 | # (source start file, target name, title, author, 165 | # dir menu entry, description, category) 166 | texinfo_documents = [ 167 | (master_doc, 'ThunderGBM', u'ThunderGBM Documentation', 168 | author, 'ThunderGBM', 'One line description of project.', 169 | 'Miscellaneous'), 170 | ] 171 | 172 | github_doc_root = 'https://github.com/rtfd/recommonmark/tree/master/doc/' 173 | def setup(app): 174 | app.add_config_value('recommonmark_config', { 175 | 'url_resolver': lambda url: github_doc_root + url, 176 | 'enable_eval_rst': True, 177 | }, True) 178 | app.add_transform(AutoStructify) 179 | 180 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | Frequently Asked Questions (FAQs) 2 | ====== 3 | This page is dedicated to summarizing some frequently asked questions about ThunderGBM. 4 | 5 | ## FAQs of users 6 | 7 | * **What is the data format of the input file?** 8 | ThunderGBM uses the LibSVM format. You can also use ThunderGBM using the scikit-learn interface. 9 | 10 | * **Can ThunderGBM run on CPUs?** 11 | No. ThunderGBM is specifically optimized for GBDTs and Random Forests on GPUs. 12 | 13 | * **Can ThunderGBM run on multiple GPUs?** 14 | Yes. You can use the ``n_gpus`` options to specify how many GPUs you want to use. Please refer to [Parameters](parameters.md) for more information. -------------------------------------------------------------------------------- /docs/how-to.md: -------------------------------------------------------------------------------- 1 | ThunderGBM How To 2 | ====== 3 | This page is for key instructions of intalling, using and contributing to ThunderGBM. Everyone in the community can contribute to ThunderGBM to make it better. 4 | 5 | ## How to install ThunderGBM 6 | First of all, you need to install the prerequisite libraries and tools. Then you can download and install ThunderGBM. 7 | ### Prerequisites 8 | * cmake 2.8 or above 9 | * gcc 4.8 or above for Linux | [CUDA](https://developer.nvidia.com/cuda-downloads) 8 or above 10 | * Visual C++ for Windows | CUDA 10 11 | 12 | 13 | ### Download 14 | ```bash 15 | git clone https://github.com/zeyiwen/thundergbm.git 16 | cd thundergbm 17 | #under the directory of thundergbm 18 | git submodule init cub && git submodule update 19 | ``` 20 | ### Build on Linux 21 | ```bash 22 | #under the directory of "thundergbm" 23 | mkdir build && cd build && cmake .. && make -j 24 | ``` 25 | 26 | ### Quick Start 27 | ```bash 28 | ./bin/thundergbm-train ../dataset/machine.conf 29 | ./bin/thundergbm-predict ../dataset/machine.conf 30 | ``` 31 | You will see `RMSE = 0.489562` after successful running. 32 | 33 | ### Build on Windows 34 | You can build the ThunderGBM library as follows: 35 | ```bash 36 | cd thundergbm 37 | mkdir build 38 | cd build 39 | cmake .. -DCMAKE_WINDOWS_EXPORT_ALL_SYMBOLS=TRUE -DBUILD_SHARED_LIBS=TRUE -G "Visual Studio 15 2017 Win64" 40 | ``` 41 | You need to change the Visual Studio version if you are using a different version of Visual Studio. Visual Studio can be downloaded from [this link](https://www.visualstudio.com/vs/). The above commands generate some Visual Studio project files, open the Visual Studio project to build ThunderGBM. Please note that CMake should be 3.4 or above for Windows. 42 | 43 | ## How to use ThunderGBM using command line 44 | First of all, please refer to the above instruction for installing ThunderGBM. Then, you can run the demo by the following command. 45 | ```bash 46 | ./bin/thundergbm-train ../dataset/machine.conf 47 | ./bin/thundergbm-predict ../dataset/machine.conf 48 | ``` 49 | 50 | If you like to know more about the detailed options of running the binary, please use the `-help` option as follows. 51 | ```bash 52 | ./bin/thundergbm-train -help 53 | ``` 54 | 55 | In ThunderGBM, the command line options can be added in the `machine.conf` file under the `dataset` folder. All the options are listed in the [Parameters](parameters.md) page. 56 | 57 | ## How to improve documentations 58 | Most of the documents can be viewed on GitHub. The documents can also be viewed in Read the Doc. The HTML files of our documents are generated by [Sphinx](http://www.sphinx-doc.org/en/stable/), and the source files of the documents are written using [Markdown](http://commonmark.org/). In the following, we describe how to setup the Sphinx environment. 59 | 60 | * Install sphinx 61 | ```bash 62 | pip install sphinx 63 | ``` 64 | 65 | * Install Makedown Parser 66 | ```bash 67 | pip install recommonmark 68 | ``` 69 | Note that ```recommonmark``` has a bug when working with Sphinx in some platforms, so you may need to hack into transform.py to fix the problem by yourself. You can find the instruction of hacking in [this link](https://github.com/sphinx-doc/sphinx/issues/3800). 70 | 71 | * Install Sphinx theme 72 | ```bash 73 | pip install sphinx_rtd_theme 74 | ``` 75 | 76 | * Generate HTML 77 | 78 | Go to the "docs" directory of ThunderGBM and run: 79 | ```bash 80 | make html 81 | ``` 82 | 83 | At this point, make sure you have generated the documents of ThunderGBM. You can build the documents in your machine to see the outcome. 84 | 85 | ## Contribute to ThunderGBM 86 | You need to fetch the latest version of ThunderGBM before submitting a pull request. 87 | ```bash 88 | git remote add upstream https://github.com/Xtra-Computing/thundergbm.git 89 | git fetch upstream 90 | git rebase upstream/master 91 | ``` 92 | 93 | ## How to build test for ThunderGBM 94 | For building test cases, you also need to obtain ``googletest`` using the following command. 95 | ```bash 96 | #under the thundergbm directory 97 | git submodule update --init src/test/googletest 98 | ``` 99 | After obtaining the ``googletest`` submodule, you can build the test cases by the following commands. 100 | ```bash 101 | cd thundergbm 102 | mkdir build && cd build && cmake -DBUILD_TESTS=ON .. && make -j 103 | ``` 104 | 105 | ## How to use ThunderGBM for ranking 106 | 107 | There are two key steps to use ThunderGBM for ranking. 108 | * First, you need to choose ``rank:pairwise`` or ``rank:ndcg`` to set the ``objective`` of ThunderGBM. 109 | * Second, you need to have a file called ``[train_file_name].group`` to specify the number of instances in each query. 110 | 111 | The remaining part is the same as classification and regression. Please refer to [Parameters](parameters.md) for more information about setting the parameters. 112 | 113 | ## How to build the Python wheel file for Linux 114 | You have to ensure the repository is identical to the latest one. 115 | * Clone ThunderGBM repository 116 | ```bash 117 | git clone https://github.com/zeyiwen/thundergbm.git 118 | cd thundergbm 119 | #under the directory of thundergbm 120 | git submodule init cub && git submodule update 121 | ``` 122 | * Build the binary 123 | ```base 124 | mkdir build && cd build && cmake .. && make -j 125 | ``` 126 | * Build the python wheel file 127 | - change directory to python by `cd ../python` 128 | - update the version you are going to release in [setup.py](https://github.com/Xtra-Computing/thundergbm/blob/c89d6da6008f945c09aae521c95cfe5b8bdd8db5/python/setup.py#L20) 129 | - you may need to install the `wheel` dependency by ``pip3 install wheel`` 130 | ```bash 131 | python3 setup.py bdist_wheel 132 | ``` 133 | ## How to build the Python wheel file for Windows 134 | You have to ensure the repository is identical to the latest one. 135 | * Requirements 136 | - Visual Studio 137 | - CUDA 10.0 or above 138 | - python3.x 139 | * Clone ThunderGBM repository 140 | ```bash 141 | git clone https://github.com/zeyiwen/thundergbm.git 142 | cd thundergbm 143 | #under the directory of thundergbm 144 | git submodule init && git submodule update 145 | ``` 146 | * Cmake using Visual Studio Developer Command Prompt 147 | ```base 148 | mkdir build && cd build 149 | cmake .. -DCMAKE_WINDOWS_EXPORT_ALL_SYMBOLS=TRUE -DBUILD_SHARED_LIBS=TRUE -G "Visual Studio 15 2017 Win64" 150 | ``` 151 | You may need to change the version of Visual Studio if you are using a different version of Visual Studio. 152 | * Build binary file using Visual Studio 153 | - Open the file in path 'thundergbm/build/thundergbm.sln' with Visual Studio 154 | - Click 'Build all' in Visual Studio 155 | * Build the python wheel file 156 | - change directory to python by `cd ../python` 157 | - update the version you are going to release in [setup.py](https://github.com/Xtra-Computing/thundergbm/blob/c89d6da6008f945c09aae521c95cfe5b8bdd8db5/python/setup.py#L20) 158 | - you may need to install the `wheel` dependency by ``pip3 install wheel`` 159 | ```bash 160 | python3 setup.py bdist_wheel 161 | ``` 162 | * Upload the wheel file to [Pypi.org](https://pypi.org) 163 | - you may need to install the `twine` dependency by ``pip3 install twine`` 164 | - you need to use ``python3 -m twine upload dist/*`` if ``twine`` is not included in ``PATH`` 165 | ```sybase 166 | twine upload dist/* --verbose 167 | ``` 168 | * [Recommended] Draw a new release on [Release](https://github.com/Xtra-Computing/thundergbm/releases) 169 | * state the bug fixed or new functions. 170 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ThunderGBM: Fast GBDTs and Random Forests on GPUs 2 | ====================================== 3 | ThunderGBM is dedicated to helping users apply GBDTs and Random Forests to solve problems efficiently and easily using GPUs. Key features of ThunderGBM are as follows. 4 | * Support regression, classification and ranking. 5 | * Use same command line options as XGBoost, and support Python (scikit-learn) interface. 6 | * Supported Operating Systems: Linux and Windows. 7 | * ThunderGBM is often 10 times faster than XGBoost, LightGBM and CatBoost. It has excellent performance on handling high dimensional and sparse problems. 8 | 9 |
10 | 11 | 12 | 13 |
14 | 15 | ## More information about using ThunderGBM 16 | * [Parameters](parameters.md) | [How To](how-to.md) | [FAQ](faq.md) 17 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=ThunderGBM 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/parameters.md: -------------------------------------------------------------------------------- 1 | ThunderGBM Parameters 2 | ===================== 3 | This page is for parameter specification in ThunderGBM. The parameters used in ThunderGBM are identical to XGBoost (except some newly introduced parameters), so existing XGBoost users can easily get used to ThunderGBM. 4 | 5 | ### Key arameters for both *python* and *c++* command line 6 | * ``verbose`` [default=1] 7 | 8 | - Printing information: 0 for silence, 1 for key information and 2 for more information. 9 | 10 | * ``depth`` [default=6] 11 | 12 | - The maximum depth of the decision trees. Shallow trees tend to have better generality, and deep trees are more likely to overfit the training data. 13 | 14 | * ``n_trees`` [default=40] 15 | 16 | - The number of training iterations. ``n_trees`` equals to the number of trees in GBDTs. 17 | 18 | * ``n_gpus`` [default=1] 19 | 20 | - The number of GPUs to be used in the training. 21 | 22 | * ``max_num_bin`` [default=255] 23 | 24 | - The maximum number of bins in a histogram. 25 | 26 | * ``column_sampling_rate`` [default=1] 27 | 28 | - The sampling ratio of subsampling columns (i.e., features) 29 | 30 | * ``bagging`` [default=0] 31 | 32 | - This option is for training random forests. Setting it to 1 to perform bagging. 33 | 34 | * ``n_parallel_trees`` [default=1] 35 | 36 | - This option is used for random forests to specify how many trees per iteration. 37 | 38 | * ``learning_rate`` [default=1, alias(only for c++): ``eta``] 39 | 40 | - valid domain: [0,1]. This option is to set the weight of newly trained tree. Use ``eta < 1`` to mitigate overfitting. 41 | 42 | * ``objective`` [default="reg:linear"] 43 | 44 | - valid options include ``reg:linear``, ``reg:logistic``, ``multi:softprob``, ``multi:softmax``, ``rank:pairwise`` and ``rank:ndcg``. 45 | - ``reg:linear`` is for regression, ``reg:logistic`` and ``binary:logistic`` are for binary classification. 46 | - ``multi:softprob`` and ``multi:softmax`` are for multi-class classification. ``multi:softprob`` outputs probability for each class, and ``multi:softmax`` outputs the label only. 47 | - ``rank:pairwise`` and ``rank:ndcg`` are for ranking problems. 48 | 49 | * ``num_class`` [default=1] 50 | - set the number of classes in the multi-class classification. This option is not compulsory. 51 | 52 | * ``min_child_weight`` [default=1] 53 | 54 | - The minimum sum of instance weight (measured by the second order derivative) needed in a child node. 55 | 56 | * ``lambda_tgbm`` [default=1, alias(only for c++): ``lambda`` or ``reg_lambda``] 57 | 58 | - L2 regularization term on weights. 59 | 60 | * ``gamma`` [default=1, alias(only for c++): ``min_split_loss``] 61 | 62 | - The minimum loss reduction required to make a further split on a leaf node of the tree. ``gamma`` is used in the pruning stage. 63 | * ``tree_method`` [default="auto"] 64 | 65 | - "auto": select the approach of finding best splits using the builtin heuristics. 66 | 67 | - "exact": find the best split using enumeration on all the possible feature values. 68 | 69 | - "hist": find the best split using histogram based approach. 70 | 71 | ### Parameters only for *c++* command line: 72 | * ``data`` [default="../dataset/test_dataset.txt"] 73 | 74 | - The path to the training data set 75 | 76 | * ``model_out`` [default="tgbm.model"] 77 | 78 | - The file name of the output model. This option is used in training. 79 | 80 | * ``model_in`` [default="tgbm.model"] 81 | 82 | - The file name of the input model. This option is used in prediction. 83 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==1.5.6 2 | commonmark 3 | mock 4 | 5 | -------------------------------------------------------------------------------- /include/thundergbm/booster.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-17. 3 | // 4 | 5 | #ifndef THUNDERGBM_BOOSTER_H 6 | #define THUNDERGBM_BOOSTER_H 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include "thundergbm/common.h" 13 | #include "syncarray.h" 14 | #include "tree.h" 15 | #include "row_sampler.h" 16 | 17 | std::mutex mtx; 18 | 19 | class Booster { 20 | public: 21 | void init(const DataSet &dataSet, GBMParam ¶m); 22 | 23 | void boost(vector> &boosted_model, int epoch,int total_epoch); 24 | 25 | private: 26 | MSyncArray gradients; 27 | std::unique_ptr obj; 28 | std::unique_ptr metric; 29 | MSyncArray y; 30 | std::unique_ptr fbuilder; 31 | RowSampler rowSampler; 32 | GBMParam param; 33 | int n_devices; 34 | }; 35 | 36 | void Booster::init(const DataSet &dataSet, GBMParam ¶m) { 37 | int n_available_device; 38 | cudaGetDeviceCount(&n_available_device); 39 | CHECK_GE(n_available_device, param.n_device) << "only " << n_available_device 40 | << " GPUs available; please set correct number of GPUs to use"; 41 | this->param = param; 42 | //fbuilder.reset(FunctionBuilder::create(param.tree_method)); 43 | //if method is hist, and n_available_device is 1 44 | if(param.n_device==1 && param.tree_method == "hist"){ 45 | fbuilder.reset(FunctionBuilder::create("hist_single")); 46 | } 47 | else{ 48 | fbuilder.reset(FunctionBuilder::create(param.tree_method)); 49 | } 50 | 51 | fbuilder->init(dataSet, param); 52 | obj.reset(ObjectiveFunction::create(param.objective)); 53 | obj->configure(param, dataSet); 54 | metric.reset(Metric::create(obj->default_metric_name())); 55 | metric->configure(param, dataSet); 56 | 57 | n_devices = param.n_device; 58 | int n_outputs = param.num_class * dataSet.n_instances(); 59 | gradients = MSyncArray(n_devices, n_outputs); 60 | y = MSyncArray(n_devices, dataSet.n_instances()); 61 | 62 | DO_ON_MULTI_DEVICES(n_devices, [&](int device_id) { 63 | y[device_id].copy_from(dataSet.y.data(), dataSet.n_instances()); 64 | }); 65 | 66 | //init base score 67 | //only support histogram-based method and single device now 68 | //TODO support exact and multi-device 69 | if(param.n_device && param.tree_method == "hist"){ 70 | DO_ON_MULTI_DEVICES(n_devices, [&](int device_id){ 71 | param.base_score = obj->init_base_score(y[device_id], fbuilder->get_raw_y_predict()[device_id], gradients[device_id]); 72 | }); 73 | } 74 | } 75 | 76 | void Booster::boost(vector> &boosted_model,int epoch,int total_epoch) { 77 | TIMED_FUNC(timerObj); 78 | std::unique_lock lock(mtx); 79 | 80 | //update gradients 81 | DO_ON_MULTI_DEVICES(n_devices, [&](int device_id) { 82 | obj->get_gradient(y[device_id], fbuilder->get_y_predict()[device_id], gradients[device_id]); 83 | }); 84 | if (param.bagging) rowSampler.do_bagging(gradients); 85 | PERFORMANCE_CHECKPOINT(timerObj); 86 | //build new model/approximate function 87 | boosted_model.push_back(fbuilder->build_approximate(gradients)); 88 | 89 | PERFORMANCE_CHECKPOINT(timerObj); 90 | //show metric on training set 91 | auto res = metric->get_score(fbuilder->get_y_predict().front()); 92 | LOG(INFO) <<"["<get_name() << " = " < 9 | #include "thundergbm/common.h" 10 | #include "thundergbm/sparse_columns.h" 11 | 12 | class FunctionBuilder { 13 | public: 14 | virtual vector build_approximate(const MSyncArray &gradients) = 0; 15 | 16 | virtual void init(const DataSet &dataset, const GBMParam ¶m) { 17 | this->param = param; 18 | }; 19 | 20 | virtual const MSyncArray &get_y_predict(){ return y_predict; }; 21 | MSyncArray &get_raw_y_predict(){ return y_predict; }; 22 | 23 | virtual ~FunctionBuilder(){}; 24 | 25 | static FunctionBuilder *create(std::string name); 26 | 27 | protected: 28 | MSyncArray y_predict; 29 | GBMParam param; 30 | }; 31 | 32 | #endif //THUNDERGBM_FUNCTION_BUILDER_H 33 | -------------------------------------------------------------------------------- /include/thundergbm/builder/hist_tree_builder.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-17. 3 | // 4 | 5 | #ifndef THUNDERGBM_HIST_TREE_BUILDER_H 6 | #define THUNDERGBM_HIST_TREE_BUILDER_H 7 | 8 | #include 9 | #include "thundergbm/common.h" 10 | #include "shard.h" 11 | #include "tree_builder.h" 12 | 13 | 14 | class HistTreeBuilder : public TreeBuilder { 15 | public: 16 | 17 | void init(const DataSet &dataset, const GBMParam ¶m) override; 18 | 19 | void get_bin_ids(); 20 | 21 | void find_split(int level, int device_id) override; 22 | 23 | virtual ~HistTreeBuilder(){}; 24 | 25 | void update_ins2node_id() override; 26 | 27 | 28 | private: 29 | vector cut; 30 | // MSyncArray char_dense_bin_id; 31 | MSyncArray dense_bin_id; 32 | MSyncArray last_hist; 33 | 34 | double build_hist_used_time=0; 35 | int build_n_hist = 0; 36 | int total_hist_num = 0; 37 | double total_dp_time = 0; 38 | double total_copy_time = 0; 39 | }; 40 | 41 | 42 | #endif //THUNDERGBM_HIST_TREE_BUILDER_H 43 | -------------------------------------------------------------------------------- /include/thundergbm/builder/hist_tree_builder_single.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-17. 3 | // 4 | 5 | #ifndef THUNDERGBM_HIST_TREE_BUILDER_H_single 6 | #define THUNDERGBM_HIST_TREE_BUILDER_H_single 7 | 8 | #include 9 | #include "thundergbm/common.h" 10 | #include "shard.h" 11 | #include "tree_builder.h" 12 | 13 | 14 | class HistTreeBuilder_single : public TreeBuilder { 15 | public: 16 | 17 | void init(const DataSet &dataset, const GBMParam ¶m) override; 18 | 19 | void get_bin_ids(); 20 | 21 | void find_split(int level, int device_id) override; 22 | 23 | virtual ~HistTreeBuilder_single(){}; 24 | 25 | void update_ins2node_id() override; 26 | 27 | void update_tree() override; 28 | 29 | 30 | private: 31 | vector cut; 32 | // MSyncArray char_dense_bin_id; 33 | MSyncArray dense_bin_id; 34 | MSyncArray last_hist; 35 | 36 | //store csr dense_bin_id 37 | MSyncArray csr_bin_id; 38 | MSyncArray bin_id_origin; 39 | MSyncArray csr_row_ptr; 40 | MSyncArray csr_col_idx; 41 | 42 | double build_hist_used_time=0; 43 | int build_n_hist = 0; 44 | int total_hist_num = 0; 45 | double total_dp_time = 0; 46 | double total_copy_time = 0; 47 | bool use_gpu = 1; 48 | }; 49 | 50 | 51 | #endif //THUNDERGBM_HIST_TREE_BUILDER_H 52 | -------------------------------------------------------------------------------- /include/thundergbm/builder/shard.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-2. 3 | // 4 | 5 | #ifndef THUNDERGBM_SHARD_H 6 | #define THUNDERGBM_SHARD_H 7 | 8 | 9 | #include "thundergbm/sparse_columns.h" 10 | #include "thundergbm/tree.h" 11 | 12 | class SplitPoint; 13 | 14 | struct Shard { 15 | SparseColumns columns;//a subset of columns (or features) 16 | SyncArray ignored_set;//for column sampling 17 | 18 | void column_sampling(float rate); 19 | }; 20 | 21 | 22 | 23 | class SplitPoint { 24 | public: 25 | float_type gain; 26 | GHPair fea_missing_gh;//missing gh in this segment 27 | GHPair rch_sum_gh;//right child total gh (missing gh included if default2right) 28 | GHPair lch_sum_gh;//light child total gh (missing gh included if default2light) 29 | bool default_right; 30 | int nid; 31 | 32 | //split condition 33 | int split_fea_id; 34 | float_type fval;//split on this feature value (for exact) 35 | unsigned char split_bid;//split on this bin id (for hist) 36 | 37 | SplitPoint() { 38 | nid = -1; 39 | split_fea_id = -1; 40 | gain = 0; 41 | default_right = true; 42 | } 43 | 44 | friend std::ostream &operator<<(std::ostream &output, const SplitPoint &sp) { 45 | output << sp.gain << "/" << sp.split_fea_id << "/" << sp.nid << "/" << sp.lch_sum_gh; 46 | return output; 47 | } 48 | }; 49 | 50 | #endif //THUNDERGBM_SHARD_H 51 | -------------------------------------------------------------------------------- /include/thundergbm/builder/tree_builder.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jiashuai on 19-1-23. 3 | // 4 | 5 | #ifndef THUNDERGBM_TREEBUILDER_H 6 | #define THUNDERGBM_TREEBUILDER_H 7 | 8 | #include 9 | #include "thundergbm/common.h" 10 | #include "shard.h" 11 | #include "function_builder.h" 12 | 13 | 14 | class TreeBuilder : public FunctionBuilder { 15 | public: 16 | virtual void find_split(int level, int device_id) = 0; 17 | 18 | virtual void update_ins2node_id() = 0; 19 | 20 | vector build_approximate(const MSyncArray &gradients) override; 21 | 22 | void init(const DataSet &dataset, const GBMParam ¶m) override; 23 | 24 | virtual void update_tree(); 25 | 26 | void predict_in_training(int k); 27 | 28 | virtual void split_point_all_reduce(int depth); 29 | 30 | virtual void ins2node_id_all_reduce(int depth); 31 | 32 | virtual ~TreeBuilder(){}; 33 | 34 | protected: 35 | vector shards; 36 | int n_instances; 37 | vector trees; 38 | MSyncArray ins2node_id; 39 | MSyncArray sp; 40 | MSyncArray gradients; 41 | vector has_split; 42 | }; 43 | 44 | 45 | #endif //THUNDERGBM_TREEBUILDER_H 46 | -------------------------------------------------------------------------------- /include/thundergbm/common.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jiashuai on 18-1-16. 3 | // 4 | 5 | #ifndef THUNDERGBM_COMMON_H 6 | #define THUNDERGBM_COMMON_H 7 | 8 | #include "thundergbm/util/log.h" 9 | #include "cuda_runtime_api.h" 10 | #include "cstdlib" 11 | #include "config.h" 12 | #include "thrust/tuple.h" 13 | 14 | using std::vector; 15 | using std::string; 16 | 17 | //CUDA macro 18 | #define USE_CUDA 19 | #define NO_GPU \ 20 | LOG(FATAL)<<"Cannot use GPU when compiling without GPU" 21 | #define CUDA_CHECK(condition) \ 22 | /* Code block avoids redefinition of cudaError_t error */ \ 23 | do { \ 24 | cudaError_t error = condition; \ 25 | CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ 26 | } while (false) 27 | 28 | //https://stackoverflow.com/questions/2342162/stdstring-formatting-like-sprintf 29 | template 30 | std::string string_format(const std::string &format, Args ... args) { 31 | size_t size = snprintf(nullptr, 0, format.c_str(), args ...) + 1; // Extra space for '\0' 32 | std::unique_ptr buf(new char[size]); 33 | snprintf(buf.get(), size, format.c_str(), args ...); 34 | return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside 35 | } 36 | 37 | //data types 38 | typedef float float_type; 39 | //typedef double float_type; 40 | 41 | #define HOST_DEVICE __host__ __device__ 42 | 43 | struct GHPair { 44 | float_type g; 45 | float_type h; 46 | 47 | HOST_DEVICE GHPair operator+(const GHPair &rhs) const { 48 | GHPair res; 49 | res.g = this->g + rhs.g; 50 | res.h = this->h + rhs.h; 51 | return res; 52 | } 53 | 54 | HOST_DEVICE const GHPair operator-(const GHPair &rhs) const { 55 | GHPair res; 56 | res.g = this->g - rhs.g; 57 | res.h = this->h - rhs.h; 58 | return res; 59 | } 60 | 61 | HOST_DEVICE bool operator==(const GHPair &rhs) const { 62 | return this->g == rhs.g && this->h == rhs.h; 63 | } 64 | 65 | HOST_DEVICE bool operator!=(const GHPair &rhs) const { 66 | return !(*this == rhs); 67 | } 68 | 69 | HOST_DEVICE GHPair() : g(0), h(0) {}; 70 | 71 | HOST_DEVICE GHPair(float_type v) : g(v), h(v) {}; 72 | 73 | HOST_DEVICE GHPair(float_type g, float_type h) : g(g), h(h) {}; 74 | 75 | friend std::ostream &operator<<(std::ostream &os, 76 | const GHPair &p) { 77 | os << string_format("%f/%f", p.g, p.h); 78 | return os; 79 | } 80 | }; 81 | 82 | typedef thrust::tuple int_float; 83 | 84 | std::ostream &operator<<(std::ostream &os, const int_float &rhs); 85 | 86 | struct GBMParam { 87 | int depth; 88 | int n_trees; 89 | float_type min_child_weight; 90 | float_type lambda; 91 | float_type gamma; 92 | float_type rt_eps; 93 | float column_sampling_rate; 94 | std::string path; 95 | int verbose; 96 | bool profiling; 97 | bool bagging; 98 | int n_parallel_trees; 99 | float learning_rate; 100 | std::string objective; 101 | int num_class; 102 | int tree_per_rounds; // #tree of each round, depends on #class 103 | 104 | //for histogram 105 | int max_num_bin; 106 | float base_score=0; 107 | 108 | int n_device; 109 | 110 | std::string tree_method; 111 | }; 112 | #endif //THUNDERGBM_COMMON_H 113 | -------------------------------------------------------------------------------- /include/thundergbm/config.h.in: -------------------------------------------------------------------------------- 1 | #cmakedefine DATASET_DIR "@DATASET_DIR@" 2 | -------------------------------------------------------------------------------- /include/thundergbm/dataset.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jiashuai on 18-1-17. 3 | // 4 | 5 | #ifndef THUNDERGBM_DATASET_H 6 | #define THUNDERGBM_DATASET_H 7 | 8 | #include "common.h" 9 | #include "syncarray.h" 10 | 11 | class DataSet { 12 | public: 13 | ///load dataset from file 14 | void load_from_sparse(int n_instances, float *csr_val, int *csr_row_ptr, int *csr_col_idx, float *y, 15 | int *group, int num_group, GBMParam ¶m); 16 | void load_from_file(string file_name, GBMParam ¶m); 17 | void load_csc_from_file(string file_name, GBMParam ¶m, int const nfeatures=500); 18 | void load_group_file(string file_name); 19 | void group_label(); 20 | 21 | size_t n_features() const; 22 | 23 | size_t n_instances() const; 24 | 25 | vector csr_val; 26 | vector csr_row_ptr; 27 | vector csr_col_idx; 28 | vector y; 29 | size_t n_features_; 30 | vector group; 31 | vector label; 32 | 33 | 34 | // csc variables 35 | vector csc_val; 36 | vector csc_row_idx; 37 | vector csc_col_ptr; 38 | 39 | // whether the dataset is to big 40 | int use_cpu = false; 41 | }; 42 | 43 | #endif //THUNDERGBM_DATASET_H 44 | -------------------------------------------------------------------------------- /include/thundergbm/hist_cut.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by qinbin on 2018/5/9. 3 | // 4 | 5 | #ifndef THUNDERGBM_HIST_CUT_H 6 | #define THUNDERGBM_HIST_CUT_H 7 | 8 | #include "common.h" 9 | #include "sparse_columns.h" 10 | #include "thundergbm/dataset.h" 11 | #include "thundergbm/tree.h" 12 | #include "ins_stat.h" 13 | 14 | class HistCut { 15 | public: 16 | //split_point[i] stores the split points of feature i 17 | //std::vector> split_points; 18 | vector cut_points; 19 | vector row_ptr; 20 | //for gpu 21 | SyncArray cut_points_val; 22 | SyncArray cut_row_ptr; 23 | SyncArray cut_fid; 24 | 25 | HistCut() = default; 26 | 27 | HistCut(const HistCut &cut) { 28 | cut_points = cut.cut_points; 29 | row_ptr = cut.row_ptr; 30 | cut_points_val.copy_from(cut.cut_points_val); 31 | cut_row_ptr.copy_from(cut.cut_row_ptr); 32 | } 33 | 34 | void get_cut_points2(SparseColumns &columns, int max_num_bins, int n_instances); 35 | void get_cut_points3(SparseColumns &columns, int max_num_bins, int n_instances); 36 | 37 | //for hist on single device 38 | void get_cut_points_single(SparseColumns &columns, int max_num_bins, int n_instances); 39 | }; 40 | 41 | #endif //THUNDERGBM_HIST_CUT_H 42 | -------------------------------------------------------------------------------- /include/thundergbm/ins_stat.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by shijiashuai on 5/7/18. 3 | // 4 | 5 | #ifndef THUNDERGBM_INS_STAT_H 6 | #define THUNDERGBM_INS_STAT_H 7 | 8 | #include "syncarray.h" 9 | #include "thundergbm/objective/objective_function.h" 10 | 11 | class InsStat { 12 | public: 13 | 14 | ///gradient and hessian 15 | SyncArray gh_pair; 16 | 17 | ///backup for bagging 18 | SyncArray gh_pair_backup; 19 | ///node id 20 | SyncArray nid; 21 | ///target value 22 | SyncArray y; 23 | ///predict value 24 | SyncArray y_predict; 25 | 26 | std::unique_ptr obj; 27 | 28 | int n_instances; 29 | 30 | InsStat() = default; 31 | 32 | explicit InsStat(size_t n_instances) { 33 | resize(n_instances); 34 | } 35 | 36 | void resize(size_t n_instances); 37 | 38 | void update_gradient(); 39 | 40 | void reset_nid(); 41 | 42 | void do_bagging(); 43 | }; 44 | 45 | #endif //THUNDERGBM_INS_STAT_H 46 | -------------------------------------------------------------------------------- /include/thundergbm/metric/metric.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-13. 3 | // 4 | 5 | #ifndef THUNDERGBM_METRIC_H 6 | #define THUNDERGBM_METRIC_H 7 | 8 | #include 9 | #include 10 | 11 | class Metric { 12 | public: 13 | virtual float_type get_score(const SyncArray &y_p) const = 0; 14 | 15 | virtual void configure(const GBMParam ¶m, const DataSet &dataset); 16 | 17 | static Metric *create(string name); 18 | 19 | virtual string get_name() const = 0; 20 | 21 | protected: 22 | SyncArray y; 23 | }; 24 | 25 | 26 | #endif //THUNDERGBM_METRIC_H 27 | -------------------------------------------------------------------------------- /include/thundergbm/metric/multiclass_metric.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-15. 3 | // 4 | 5 | #ifndef THUNDERGBM_MULTICLASS_METRIC_H 6 | #define THUNDERGBM_MULTICLASS_METRIC_H 7 | 8 | #include "thundergbm/common.h" 9 | #include "metric.h" 10 | 11 | class MulticlassMetric: public Metric { 12 | public: 13 | void configure(const GBMParam ¶m, const DataSet &dataset) override { 14 | Metric::configure(param, dataset); 15 | num_class = param.num_class; 16 | CHECK_EQ(num_class, dataset.label.size()); 17 | label.resize(num_class); 18 | label.copy_from(dataset.label.data(), num_class); 19 | } 20 | 21 | protected: 22 | int num_class; 23 | SyncArray label; 24 | }; 25 | 26 | class MulticlassAccuracy: public MulticlassMetric { 27 | public: 28 | float_type get_score(const SyncArray &y_p) const override; 29 | 30 | string get_name() const override { return "multi-class accuracy"; } 31 | }; 32 | 33 | class BinaryClassMetric: public MulticlassAccuracy{ 34 | public: 35 | float_type get_score(const SyncArray &y_p) const override; 36 | string get_name() const override { return "test error";} 37 | }; 38 | 39 | 40 | #endif //THUNDERGBM_MULTICLASS_METRIC_H 41 | -------------------------------------------------------------------------------- /include/thundergbm/metric/pointwise_metric.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-13. 3 | // 4 | 5 | #ifndef THUNDERGBM_POINTWISE_METRIC_H 6 | #define THUNDERGBM_POINTWISE_METRIC_H 7 | 8 | #include "metric.h" 9 | 10 | class RMSE : public Metric { 11 | public: 12 | float_type get_score(const SyncArray &y_p) const override; 13 | 14 | string get_name() const override { return "RMSE"; } 15 | }; 16 | 17 | #endif //THUNDERGBM_POINTWISE_METRIC_H 18 | -------------------------------------------------------------------------------- /include/thundergbm/metric/ranking_metric.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-13. 3 | // 4 | 5 | #ifndef THUNDERGBM_RANKING_METRIC_H 6 | #define THUNDERGBM_RANKING_METRIC_H 7 | 8 | #include "metric.h" 9 | 10 | class RankListMetric : public Metric { 11 | public: 12 | float_type get_score(const SyncArray &y_p) const override; 13 | 14 | void configure(const GBMParam ¶m, const DataSet &dataset) override; 15 | 16 | static void configure_gptr(const vector &group, vector &gptr); 17 | 18 | protected: 19 | virtual float_type eval_query_group(vector &y, vector &y_p, int group_id) const = 0; 20 | 21 | vector gptr; 22 | int n_group; 23 | int topn; 24 | }; 25 | 26 | 27 | class MAP : public RankListMetric { 28 | public: 29 | string get_name() const override { return "MAP"; } 30 | 31 | protected: 32 | float_type eval_query_group(vector &y, vector &y_p, int group_id) const override; 33 | }; 34 | 35 | class NDCG : public RankListMetric { 36 | public: 37 | string get_name() const override { return "NDCG"; }; 38 | 39 | void configure(const GBMParam ¶m, const DataSet &dataset) override; 40 | 41 | inline HOST_DEVICE static float_type discounted_gain(int label, int rank) { 42 | return ((1 << label) - 1) / log2f(rank + 1 + 1); 43 | } 44 | 45 | static void get_IDCG(const vector &gptr, const vector &y, vector &idcg); 46 | 47 | protected: 48 | float_type eval_query_group(vector &y, vector &y_p, int group_id) const override; 49 | 50 | private: 51 | vector idcg; 52 | }; 53 | 54 | 55 | #endif //THUNDERGBM_RANKING_METRIC_H 56 | -------------------------------------------------------------------------------- /include/thundergbm/objective/multiclass_obj.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-3. 3 | // 4 | 5 | #ifndef THUNDERGBM_MULTICLASS_OBJ_H 6 | #define THUNDERGBM_MULTICLASS_OBJ_H 7 | 8 | #include "objective_function.h" 9 | #include "thundergbm/util/device_lambda.cuh" 10 | 11 | class Softmax : public ObjectiveFunction { 12 | public: 13 | void get_gradient(const SyncArray &y, const SyncArray &y_p, 14 | SyncArray &gh_pair) override; 15 | 16 | void predict_transform(SyncArray &y) override; 17 | 18 | void configure(GBMParam param, const DataSet &dataset) override; 19 | 20 | string default_metric_name() override { return "macc"; } 21 | 22 | virtual ~Softmax() override = default; 23 | 24 | protected: 25 | int num_class; 26 | SyncArray label; 27 | }; 28 | 29 | 30 | class SoftmaxProb : public Softmax { 31 | public: 32 | void predict_transform(SyncArray &y) override; 33 | 34 | ~SoftmaxProb() override = default; 35 | 36 | }; 37 | 38 | 39 | 40 | #endif //THUNDERGBM_MULTICLASS_OBJ_H 41 | -------------------------------------------------------------------------------- /include/thundergbm/objective/objective_function.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-1. 3 | // 4 | 5 | #ifndef THUNDERGBM_OBJECTIVE_FUNCTION_H 6 | #define THUNDERGBM_OBJECTIVE_FUNCTION_H 7 | 8 | #include 9 | #include 10 | 11 | class ObjectiveFunction { 12 | public: 13 | //todo different target type 14 | virtual void 15 | get_gradient(const SyncArray &y, const SyncArray &y_p, SyncArray &gh_pair) = 0; 16 | virtual void 17 | predict_transform(SyncArray &y){}; 18 | 19 | //init base score 20 | //SyncArray &gh_pair for tmporary storage 21 | virtual float init_base_score(const SyncArray &y,SyncArray &y_p, SyncArray &gh_pair){ 22 | LOG(INFO)<<"not implement base score function!!!!!!"; 23 | return 0; 24 | } 25 | 26 | virtual void configure(GBMParam param, const DataSet &dataset) = 0; 27 | virtual string default_metric_name() = 0; 28 | 29 | static ObjectiveFunction* create(string name); 30 | 31 | //a file containing the number of instances per query; similar to XGBoost 32 | static bool need_load_group_file(string name); 33 | static bool need_group_label(string name); 34 | virtual ~ObjectiveFunction() = default; 35 | }; 36 | 37 | #endif //THUNDERGBM_OBJECTIVE_FUNCTION_H 38 | -------------------------------------------------------------------------------- /include/thundergbm/objective/ranking_obj.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-10. 3 | // 4 | 5 | #ifndef THUNDERGBM_RANKING_OBJ_H 6 | #define THUNDERGBM_RANKING_OBJ_H 7 | 8 | #include "objective_function.h" 9 | 10 | /** 11 | * 12 | * https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/MSR-TR-2010-82.pdf 13 | */ 14 | class LambdaRank : public ObjectiveFunction { 15 | public: 16 | void get_gradient(const SyncArray &y, const SyncArray &y_p, 17 | SyncArray &gh_pair) override; 18 | 19 | void configure(GBMParam param, const DataSet &dataset) override; 20 | 21 | string default_metric_name() override; 22 | 23 | virtual ~LambdaRank() override = default; 24 | 25 | protected: 26 | virtual inline float_type get_delta_z(float_type labelI, float_type labelJ, int rankI, int rankJ, int group_id) { return 1; }; 27 | 28 | vector gptr;//group start position 29 | int n_group; 30 | 31 | float_type sigma; 32 | }; 33 | 34 | class LambdaRankNDCG : public LambdaRank { 35 | public: 36 | void configure(GBMParam param, const DataSet &dataset) override; 37 | 38 | string default_metric_name() override; 39 | 40 | protected: 41 | float_type get_delta_z(float_type labelI, float_type labelJ, int rankI, int rankJ, int group_id) override; 42 | 43 | private: 44 | vector idcg; 45 | }; 46 | 47 | 48 | #endif //THUNDERGBM_RANKING_OBJ_H 49 | -------------------------------------------------------------------------------- /include/thundergbm/objective/regression_obj.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-1. 3 | // 4 | 5 | #ifndef THUNDERGBM_REGRESSION_OBJ_H 6 | #define THUNDERGBM_REGRESSION_OBJ_H 7 | 8 | #include "objective_function.h" 9 | #include "thundergbm/util/device_lambda.cuh" 10 | #include "thrust/reduce.h" 11 | 12 | template class Loss> 13 | class RegressionObj : public ObjectiveFunction { 14 | public: 15 | void get_gradient(const SyncArray &y, const SyncArray &y_p, 16 | SyncArray &gh_pair) override { 17 | CHECK_EQ(y.size(), y_p.size())< class Loss> 65 | class LogClsObj: public RegressionObj{ 66 | public: 67 | void get_gradient(const SyncArray &y, const SyncArray &y_p, 68 | SyncArray &gh_pair) override { 69 | auto y_data = y.device_data(); 70 | auto y_p_data = y_p.device_data(); 71 | auto gh_pair_data = gh_pair.device_data(); 72 | device_loop(y.size(), [=]__device__(int i) { 73 | gh_pair_data[i] = Loss::gradient(y_data[i], y_p_data[i]); 74 | }); 75 | } 76 | void predict_transform(SyncArray &y) { 77 | //this method transform y(#class * #instances) into y(#instances) 78 | auto yp_data = y.device_data(); 79 | auto label_data = label.device_data(); 80 | // int num_class = this->num_class; 81 | int n_instances = y.size(); 82 | device_loop(n_instances, [=]__device__(int i) { 83 | //yp_data[i] = Loss::predict_transform(yp_data[i]); 84 | int max_k = (yp_data[i] > 0) ? 1 : 0; 85 | yp_data[i] = label_data[max_k]; 86 | }); 87 | //TODO not to make a temp_y? 88 | SyncArray < float_type > temp_y(n_instances); 89 | temp_y.copy_from(y.device_data(), n_instances); 90 | y.resize(n_instances); 91 | y.copy_from(temp_y); 92 | } 93 | 94 | //base score 95 | float init_base_score(const SyncArray &y,SyncArray &y_p, SyncArray &gh_pair){ 96 | 97 | //get gradients first, SyncArray &gh_pair for temporal storage 98 | get_gradient(y,y_p,gh_pair); 99 | 100 | //get sum gh_pair 101 | GHPair sum_gh = thrust::reduce(thrust::cuda::par, gh_pair.device_data(), gh_pair.device_end()); 102 | 103 | //get weight 104 | float weight = -sum_gh.g / fmax(sum_gh.h, (double)(1e-6)); 105 | //sigmod transform 106 | weight = 1 / (1 + expf(-weight)); 107 | float base_score = -logf(1.0f / weight - 1.0f); 108 | LOG(INFO)<<"base_score "<> &boosted_model, DataSet &dataSet); 15 | void save_model(string model_path, GBMParam &model_param, vector> &boosted_model, DataSet &dataSet); 16 | }; 17 | 18 | #endif //THUNDERGBM_PARAM_PARSER_H 19 | -------------------------------------------------------------------------------- /include/thundergbm/predictor.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by zeyi on 1/12/19. 3 | // 4 | 5 | #ifndef THUNDERGBM_PREDICTOR_H 6 | #define THUNDERGBM_PREDICTOR_H 7 | 8 | #include "thundergbm/tree.h" 9 | #include 10 | 11 | class Predictor{ 12 | public: 13 | vector predict(const GBMParam &model_param, const vector> &boosted_model, const DataSet &dataSet); 14 | void predict_raw(const GBMParam &model_param, const vector> &boosted_model, 15 | const DataSet &dataSet, SyncArray &y_predict); 16 | }; 17 | 18 | #endif //THUNDERGBM_PREDICTOR_H 19 | -------------------------------------------------------------------------------- /include/thundergbm/quantile_sketch.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by qinbin on 2018/5/9. 3 | // 4 | #ifndef THUNDERGBM_QUANTILE_SKETCH_H 5 | #define THUNDERGBM_QUANTILE_SKETCH_H 6 | 7 | #include "common.h" 8 | #include 9 | #include 10 | 11 | using std::pair; 12 | using std::tuple; 13 | using std::vector; 14 | 15 | 16 | class entry{ 17 | public: 18 | float_type val;//a cut point candidate 19 | float_type rmin;//total weights of feature values less than val 20 | float_type rmax;//total weights of feature values less than or equal to val 21 | float_type w; 22 | entry() {}; 23 | entry(float_type val, float_type rmin, float_type rmax, float_type w) : val(val), rmin(rmin), rmax(rmax), w(w) {}; 24 | }; 25 | 26 | class summary{ 27 | public: 28 | int entry_size; 29 | int entry_reserve_size; 30 | vector entries; 31 | summary(): entry_size(0),entry_reserve_size(0) { 32 | //entries.clear(); 33 | }; 34 | summary(int entry_size, int reserve_size): entry_size(entry_size), entry_reserve_size(reserve_size) {entries.resize(reserve_size);}; 35 | void Reserve(int size);//reserve memory for the summary 36 | void Prune(summary& src,int size);//reduce the number of cut point candidates of the summary 37 | void Merge(summary& src1, summary& src2);//merge two summaries 38 | void Copy(summary& src); 39 | 40 | }; 41 | 42 | /** 43 | * @brief: store the pairs before constructing a summary 44 | */ 45 | class Qitem{ 46 | public: 47 | int tail; 48 | vector> data; 49 | Qitem(): tail(0) { 50 | //data.clear(); 51 | }; 52 | void GetSummary(summary& ret); 53 | }; 54 | 55 | 56 | class quanSketch{ 57 | public: 58 | int numOfLevel;//the summary has multiple levels 59 | int summarySize;//max size of the first level summary 60 | Qitem Qentry; 61 | vector summaries; 62 | summary t_summary; //for process_nodes 63 | void Init(int maxn, float_type eps); 64 | void Add(float_type, float_type); 65 | void GetSummary(summary& dest); 66 | quanSketch(): numOfLevel(0), summarySize(0) { 67 | //summaries.clear(); 68 | }; 69 | 70 | }; 71 | #endif //THUNDERGBM_QUANTILE_SKETCH_H 72 | -------------------------------------------------------------------------------- /include/thundergbm/row_sampler.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by shijiashuai on 2019-02-15. 3 | // 4 | 5 | #ifndef THUNDERGBM_ROW_SAMPLER_H 6 | #define THUNDERGBM_ROW_SAMPLER_H 7 | 8 | #include "thundergbm/common.h" 9 | #include "syncarray.h" 10 | 11 | class RowSampler { 12 | public: 13 | void do_bagging(MSyncArray &gradients); 14 | }; 15 | #endif //THUNDERGBM_ROW_SAMPLER_H 16 | -------------------------------------------------------------------------------- /include/thundergbm/sparse_columns.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by shijiashuai on 5/7/18. 3 | // 4 | 5 | #ifndef THUNDERGBM_SPARSE_COLUMNS_H 6 | #define THUNDERGBM_SPARSE_COLUMNS_H 7 | 8 | #include "syncarray.h" 9 | #include "dataset.h" 10 | 11 | class SparseColumns {//one feature corresponding to one column 12 | public: 13 | SyncArray csc_val; 14 | SyncArray csc_row_idx; 15 | SyncArray csc_col_ptr; 16 | 17 | //original order without sort 18 | SyncArray csc_val_origin; 19 | SyncArray csc_row_idx_origin; 20 | SyncArray csc_col_ptr_origin; 21 | 22 | //csr data 23 | SyncArray csr_val; 24 | SyncArray csr_row_ptr; 25 | SyncArray csr_col_idx; 26 | 27 | int max_trick_depth = -1; 28 | int max_trick_nodes = -1; 29 | 30 | int n_column; 31 | int n_row; 32 | int column_offset; 33 | size_t nnz; 34 | 35 | void csr2csc_gpu(const DataSet &dataSet, vector> &); 36 | 37 | //function for data transfer to gpu , this function is only for single GPU device 38 | void to_gpu(const DataSet &dataSet, vector> &); 39 | 40 | void csr2csc_cpu(const DataSet &dataset, vector> &); 41 | void csc_by_default(const DataSet &dataset, vector> &v_columns); 42 | void to_multi_devices(vector> &) const; 43 | 44 | }; 45 | #endif //THUNDERGBM_SPARSE_COLUMNS_H 46 | -------------------------------------------------------------------------------- /include/thundergbm/syncarray.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jiashuai on 17-9-17. 3 | // 4 | 5 | #ifndef THUNDERGBM_SYNCDATA_H 6 | #define THUNDERGBM_SYNCDATA_H 7 | 8 | #include "thundergbm/util/log.h" 9 | #include "syncmem.h" 10 | 11 | /** 12 | * @brief Wrapper of SyncMem with a type 13 | * @tparam T type of element 14 | */ 15 | template 16 | class SyncArray : public el::Loggable { 17 | public: 18 | /** 19 | * initialize class that can store given count of elements 20 | * @param count the given count 21 | */ 22 | explicit SyncArray(size_t count) : mem(new SyncMem(sizeof(T) * count)), size_(count) { 23 | } 24 | 25 | SyncArray() : mem(nullptr), size_(0) {} 26 | 27 | ~SyncArray() { delete mem; }; 28 | 29 | const T *host_data() const { 30 | to_host(); 31 | return static_cast(mem->host_data()); 32 | }; 33 | 34 | const T *device_data() const { 35 | to_device(); 36 | return static_cast(mem->device_data()); 37 | }; 38 | 39 | T *host_data() { 40 | to_host(); 41 | return static_cast(mem->host_data()); 42 | }; 43 | 44 | T *device_data() { 45 | to_device(); 46 | return static_cast(mem->device_data()); 47 | }; 48 | 49 | T *device_end() { 50 | return device_data() + size(); 51 | }; 52 | 53 | const T *device_end() const { 54 | return device_data() + size(); 55 | }; 56 | 57 | void set_host_data(T *host_ptr) { 58 | mem->set_host_data(host_ptr); 59 | } 60 | 61 | void set_device_data(T *device_ptr) { 62 | mem->set_device_data(device_ptr); 63 | } 64 | 65 | void to_host() const { 66 | CHECK_GT(size_, 0); 67 | mem->to_host(); 68 | } 69 | 70 | void to_device() const { 71 | CHECK_GT(size_, 0); 72 | mem->to_device(); 73 | } 74 | 75 | /** 76 | * copy device data. This will call to_device() implicitly. 77 | * @param source source data pointer (data can be on host or device) 78 | * @param count the count of elements 79 | */ 80 | void copy_from(const T *source, size_t count) { 81 | 82 | #ifdef USE_CUDA 83 | thunder::device_mem_copy(mem->device_data(), source, sizeof(T) * count); 84 | #else 85 | memcpy(mem->host_data(), source, sizeof(T) * count); 86 | #endif 87 | }; 88 | 89 | void copy_from(const SyncArray &source) { 90 | 91 | CHECK_EQ(size(), source.size()) << "destination and source count doesn't match"; 92 | #ifdef USE_CUDA 93 | if (get_owner_id() == source.get_owner_id()) 94 | copy_from(source.device_data(), source.size()); 95 | else 96 | CUDA_CHECK(cudaMemcpyPeer(mem->device_data(), get_owner_id(), source.device_data(), source.get_owner_id(), 97 | source.mem_size())); 98 | #else 99 | copy_from(source.host_data(), source.size()); 100 | #endif 101 | }; 102 | 103 | /** 104 | * resize to a new size. This will also clear all data. 105 | * @param count 106 | */ 107 | void resize(size_t count) { 108 | if(mem != nullptr || mem != NULL) { 109 | delete mem; 110 | } 111 | mem = new SyncMem(sizeof(T) * count); 112 | this->size_ = count; 113 | }; 114 | 115 | /* 116 | * resize to a new size. This will not clear the origin data. 117 | * @param count 118 | */ 119 | void resize_without_delete(size_t count) { 120 | // delete mem; 121 | mem = new SyncMem(sizeof(T) * count); 122 | this->size_ = count; 123 | }; 124 | 125 | 126 | size_t mem_size() const {//number of bytes 127 | return mem->size(); 128 | } 129 | 130 | size_t size() const {//number of values 131 | return size_; 132 | } 133 | 134 | SyncMem::HEAD head() const { 135 | return mem->head(); 136 | } 137 | 138 | void log(el::base::type::ostream_t &ostream) const override { 139 | int i; 140 | ostream << "["; 141 | const T *data = host_data(); 142 | for (i = 0; i < size() - 1 && i < el::base::consts::kMaxLogPerContainer - 1; ++i) { 143 | // for (i = 0; i < size() - 1; ++i) { 144 | ostream << data[i] << ","; 145 | } 146 | ostream << host_data()[i]; 147 | if (size() <= el::base::consts::kMaxLogPerContainer) { 148 | ostream << "]"; 149 | } else { 150 | ostream << ", ...(" << size() - el::base::consts::kMaxLogPerContainer << " more)"; 151 | } 152 | }; 153 | 154 | int get_owner_id() const { 155 | return mem->get_owner_id(); 156 | } 157 | 158 | //move constructor 159 | SyncArray(SyncArray &&rhs) noexcept : mem(rhs.mem), size_(rhs.size_) { 160 | rhs.mem = nullptr; 161 | rhs.size_ = 0; 162 | } 163 | 164 | //move assign 165 | SyncArray &operator=(SyncArray &&rhs) noexcept { 166 | delete mem; 167 | mem = rhs.mem; 168 | size_ = rhs.size_; 169 | 170 | rhs.mem = nullptr; 171 | rhs.size_ = 0; 172 | return *this; 173 | } 174 | 175 | SyncArray(const SyncArray &) = delete; 176 | 177 | SyncArray &operator=(const SyncArray &) = delete; 178 | 179 | 180 | //new function clear gpu mem 181 | void clear_device(){ 182 | to_host(); 183 | mem->free_device(); 184 | } 185 | 186 | private: 187 | SyncMem *mem; 188 | size_t size_; 189 | }; 190 | 191 | //SyncArray for multiple devices 192 | template 193 | class MSyncArray : public vector> { 194 | public: 195 | explicit MSyncArray(size_t n_device) : base_class(n_device) {}; 196 | 197 | explicit MSyncArray(size_t n_device, size_t size) : base_class(n_device) { 198 | for (int i = 0; i < n_device; ++i) { 199 | this->at(i) = SyncArray(size); 200 | } 201 | }; 202 | 203 | MSyncArray() : base_class() {}; 204 | 205 | //move constructor and assign 206 | MSyncArray(MSyncArray &&) = default; 207 | 208 | MSyncArray &operator=(MSyncArray &&) = default; 209 | 210 | MSyncArray(const MSyncArray &) = delete; 211 | 212 | MSyncArray &operator=(const MSyncArray &) = delete; 213 | 214 | private: 215 | typedef vector> base_class; 216 | }; 217 | 218 | #endif //THUNDERGBM_SYNCDATA_H 219 | -------------------------------------------------------------------------------- /include/thundergbm/syncmem.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jiashuai on 17-9-16. 3 | // 4 | 5 | #ifndef THUNDERGBM_SYNCMEM_H 6 | #define THUNDERGBM_SYNCMEM_H 7 | 8 | #include "common.h" 9 | #include "cub/util_allocator.cuh" 10 | 11 | using namespace cub; 12 | namespace thunder { 13 | 14 | inline void device_mem_copy(void *dst, const void *src, size_t size) { 15 | #ifdef USE_CUDA 16 | CUDA_CHECK(cudaMemcpy(dst, src, size, cudaMemcpyDefault)); 17 | #else 18 | NO_GPU; 19 | #endif 20 | } 21 | 22 | class DeviceAllocator : public CachingDeviceAllocator { 23 | public: 24 | DeviceAllocator(unsigned int bin_growth, unsigned int min_bin, unsigned int max_bin, size_t max_cached_bytes, 25 | bool skip_cleanup, bool debug) : CachingDeviceAllocator(bin_growth, min_bin, max_bin, 26 | max_cached_bytes, skip_cleanup, 27 | debug) {}; 28 | 29 | cudaError_t DeviceAllocate(int device, void **d_ptr, size_t bytes, cudaStream_t active_stream = 0); 30 | 31 | cudaError_t DeviceAllocate(void **d_ptr, size_t bytes, cudaStream_t active_stream = 0); 32 | }; 33 | 34 | class HostAllocator : public CachingDeviceAllocator { 35 | public: 36 | HostAllocator(unsigned int bin_growth, unsigned int min_bin, unsigned int max_bin, size_t max_cached_bytes, 37 | bool skip_cleanup, bool debug) : CachingDeviceAllocator(bin_growth, min_bin, max_bin, 38 | max_cached_bytes, skip_cleanup, debug) {}; 39 | 40 | cudaError_t DeviceAllocate(int device, void **d_ptr, size_t bytes, cudaStream_t active_stream = 0); 41 | 42 | cudaError_t DeviceAllocate(void **d_ptr, size_t bytes, cudaStream_t active_stream = 0); 43 | 44 | cudaError_t DeviceFree(int device, void *d_ptr); 45 | 46 | cudaError_t DeviceFree(void *d_ptr); 47 | 48 | cudaError_t FreeAllCached(); 49 | 50 | ~HostAllocator() override; 51 | }; 52 | 53 | /** 54 | * @brief Auto-synced memory for CPU and GPU 55 | */ 56 | class SyncMem { 57 | public: 58 | SyncMem(); 59 | 60 | /** 61 | * create a piece of synced memory with given size. The GPU/CPU memory will not be allocated immediately, but 62 | * allocated when it is used at first time. 63 | * @param size the size of memory (in Bytes) 64 | */ 65 | explicit SyncMem(size_t size); 66 | 67 | ~SyncMem(); 68 | 69 | ///return raw host pointer 70 | void *host_data(); 71 | 72 | ///return raw device pointer 73 | void *device_data(); 74 | 75 | /** 76 | * set host data pointer to another host pointer, and its memory will not be managed by this class 77 | * @param data another host pointer 78 | */ 79 | void set_host_data(void *data); 80 | 81 | /** 82 | * set device data pointer to another device pointer, and its memory will not be managed by this class 83 | * @param data another device pointer 84 | */ 85 | void set_device_data(void *data); 86 | 87 | ///transfer data to host 88 | void to_host(); 89 | 90 | ///transfer data to device 91 | void to_device(); 92 | 93 | ///return the size of memory 94 | size_t size() const; 95 | 96 | ///to determine the where the newest data locates in 97 | enum HEAD { 98 | HOST, DEVICE, UNINITIALIZED 99 | }; 100 | 101 | HEAD head() const; 102 | 103 | int get_owner_id() const { 104 | return device_id; 105 | } 106 | 107 | static void clear_cache() { 108 | device_allocator.FreeAllCached(); 109 | host_allocator.FreeAllCached(); 110 | }; 111 | 112 | //new func to clear gpu mem 113 | void free_device(){ 114 | device_allocator.DeviceFree(device_ptr); 115 | } 116 | 117 | private: 118 | void *device_ptr; 119 | void *host_ptr; 120 | bool own_device_data; 121 | bool own_host_data; 122 | size_t size_; 123 | HEAD head_; 124 | int device_id; 125 | static DeviceAllocator device_allocator; 126 | static HostAllocator host_allocator; 127 | 128 | inline void malloc_host(void **ptr, size_t size); 129 | 130 | inline void free_host(void *ptr); 131 | 132 | }; 133 | } 134 | using thunder::SyncMem; 135 | #endif //THUNDERGBM_SYNCMEM_H 136 | -------------------------------------------------------------------------------- /include/thundergbm/trainer.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by zeyi on 1/9/19. 3 | // 4 | 5 | #ifndef THUNDERGBM_TRAINER_H 6 | #define THUNDERGBM_TRAINER_H 7 | 8 | #include "common.h" 9 | #include "tree.h" 10 | #include "dataset.h" 11 | 12 | class TreeTrainer{ 13 | public: 14 | vector > train(GBMParam ¶m, const DataSet &dataset); 15 | // float_type train(GBMParam ¶m); 16 | // float_type train_exact(GBMParam ¶m); 17 | // float_type train_hist(GBMParam ¶m); 18 | 19 | }; 20 | 21 | #endif //THUNDERGBM_TRAINER_H 22 | -------------------------------------------------------------------------------- /include/thundergbm/tree.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jiashuai on 18-1-18. 3 | // 4 | 5 | #ifndef THUNDERGBM_TREE_H 6 | #define THUNDERGBM_TREE_H 7 | 8 | #include "syncarray.h" 9 | #include "sstream" 10 | 11 | 12 | class Tree { 13 | public: 14 | struct TreeNode { 15 | int final_id;// node id after pruning, may not equal to node index 16 | int lch_index;// index of left child 17 | int rch_index;// index of right child 18 | int parent_index;// index of parent node 19 | float_type gain;// gain of splitting this node 20 | float_type base_weight; 21 | int split_feature_id; 22 | float_type split_value; 23 | unsigned char split_bid; 24 | bool default_right; 25 | bool is_leaf; 26 | bool is_valid;// non-valid nodes are those that are "children" of leaf nodes 27 | bool is_pruned;// pruned after pruning 28 | 29 | GHPair sum_gh_pair; 30 | 31 | friend std::ostream &operator<<(std::ostream &os, 32 | const TreeNode &node); 33 | 34 | HOST_DEVICE void calc_weight(float_type lambda) { 35 | this->base_weight = -sum_gh_pair.g / (sum_gh_pair.h + lambda); 36 | } 37 | 38 | HOST_DEVICE bool splittable() const { 39 | return !is_leaf && is_valid; 40 | } 41 | 42 | }; 43 | 44 | Tree() = default; 45 | 46 | Tree(const Tree &tree) { 47 | nodes.resize(tree.nodes.size()); 48 | nodes.copy_from(tree.nodes); 49 | } 50 | 51 | Tree &operator=(const Tree &tree) { 52 | nodes.resize(tree.nodes.size()); 53 | nodes.copy_from(tree.nodes); 54 | return *this; 55 | } 56 | 57 | void init2(const SyncArray &gradients, const GBMParam ¶m); 58 | 59 | string dump(int depth) const; 60 | 61 | SyncArray nodes; 62 | 63 | void prune_self(float_type gamma); 64 | 65 | private: 66 | void preorder_traversal(int nid, int max_depth, int depth, string &s) const; 67 | 68 | int try_prune_leaf(int nid, int np, float_type gamma, vector &leaf_child_count); 69 | 70 | void reorder_nid(); 71 | }; 72 | 73 | #endif //THUNDERGBM_TREE_H 74 | -------------------------------------------------------------------------------- /include/thundergbm/util/device_lambda.cuh: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jiashuai on 18-1-19. 3 | // 4 | 5 | #ifndef THUNDERGBM_DEVICE_LAMBDA_H 6 | #define THUNDERGBM_DEVICE_LAMBDA_H 7 | 8 | #include "thundergbm/common.h" 9 | 10 | template 11 | __global__ void lambda_kernel(size_t len, L lambda) { 12 | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += blockDim.x * gridDim.x) { 13 | lambda(i); 14 | } 15 | } 16 | 17 | template 18 | __global__ void anonymous_kernel_k(L lambda) { 19 | lambda(); 20 | } 21 | 22 | template 23 | __global__ void lambda_2d_sparse_kernel(const int *len2, L lambda) { 24 | int i = blockIdx.x; 25 | int begin = len2[i]; 26 | int end = len2[i + 1]; 27 | for (int j = begin + blockIdx.y * blockDim.x + threadIdx.x; j < end; j += blockDim.x * gridDim.y) { 28 | lambda(i, j); 29 | } 30 | } 31 | 32 | template 33 | __global__ void lambda_2d_maximum_sparse_kernel(const int *len2, const int maximum, L lambda) { 34 | int i = blockIdx.x; 35 | int begin = len2[i]; 36 | int end = len2[i + 1]; 37 | int interval = (end - begin) / maximum; 38 | for (int j = begin + blockIdx.y * blockDim.x + threadIdx.x; j < end; j += blockDim.x * gridDim.y) { 39 | lambda(i, j, interval); 40 | } 41 | } 42 | 43 | ///p100 has 56 MPs, using 32*56 thread blocks, one MP supports 2048 threads 44 | //a6000 has 84 mps 45 | template 46 | inline void device_loop(size_t len, L lambda) { 47 | if (len > 0) { 48 | lambda_kernel << < NUM_BLOCK, BLOCK_SIZE >> > (len, lambda); 49 | cudaDeviceSynchronize(); 50 | /*cudaError_t error = cudaPeekAtLastError();*/ 51 | if(cudaPeekAtLastError() == cudaErrorInvalidResourceHandle){ 52 | cudaGetLastError(); 53 | LOG(INFO) << "warning: cuda invalid resource handle, potential issue of driver version and cuda version mismatch"; 54 | } else { 55 | CUDA_CHECK(cudaPeekAtLastError()); 56 | } 57 | } 58 | } 59 | 60 | template 61 | inline void anonymous_kernel(L lambda, size_t num_fv, size_t smem_size = 0, int NUM_BLOCK = 32 * 84, int BLOCK_SIZE = 256) { 62 | size_t tmp_num_block = num_fv / (BLOCK_SIZE * 8); 63 | NUM_BLOCK = std::min(NUM_BLOCK, (int)std::max(tmp_num_block, (size_t)32)); 64 | anonymous_kernel_k<< < NUM_BLOCK, BLOCK_SIZE, smem_size >> > (lambda); 65 | cudaDeviceSynchronize(); 66 | if(cudaPeekAtLastError() == cudaErrorInvalidResourceHandle){ 67 | cudaGetLastError(); 68 | LOG(INFO) << "warning: cuda invalid resource handle, potential issue of driver version and cuda version mismatch"; 69 | } else { 70 | CUDA_CHECK(cudaPeekAtLastError()); 71 | } 72 | } 73 | 74 | /** 75 | * @brief: (len1 x NUM_BLOCK) is the total number of blocks; len2 is an array of lengths. 76 | */ 77 | template 78 | void device_loop_2d(int len1, const int *len2, L lambda, unsigned int NUM_BLOCK = 4 * 84, 79 | unsigned int BLOCK_SIZE = 256) { 80 | if (len1 > 0) { 81 | lambda_2d_sparse_kernel << < dim3(len1, NUM_BLOCK), BLOCK_SIZE >> > (len2, lambda); 82 | cudaDeviceSynchronize(); 83 | CUDA_CHECK(cudaPeekAtLastError()); 84 | } 85 | } 86 | 87 | /** 88 | * @brief: (len1 x NUM_BLOCK) is the total number of blocks; len2 is an array of lengths. 89 | */ 90 | template 91 | void device_loop_2d_with_maximum(int len1, const int *len2, const int maximum, L lambda, 92 | unsigned int NUM_BLOCK = 4 * 84, 93 | unsigned int BLOCK_SIZE = 256) { 94 | if (len1 > 0) { 95 | lambda_2d_maximum_sparse_kernel << < dim3(len1, NUM_BLOCK), BLOCK_SIZE >> > (len2, maximum, lambda); 96 | cudaDeviceSynchronize(); 97 | CUDA_CHECK(cudaPeekAtLastError()); 98 | } 99 | } 100 | 101 | //for sparse formate bin id 102 | //func for new sparse loop 103 | template 104 | __global__ void lambda_hist_csr_root_kernel(const int *csr_row_ptr, L lambda) { 105 | 106 | int i = blockIdx.x; 107 | int begin = csr_row_ptr[i]; 108 | int end = csr_row_ptr[i + 1]; 109 | for (int j = begin + blockIdx.y * blockDim.x + threadIdx.x; j < end; j += blockDim.x * gridDim.y) { 110 | //i for instance 111 | //j for feature 112 | lambda(i, j ); 113 | } 114 | } 115 | 116 | 117 | 118 | template 119 | void device_loop_hist_csr_root(int n_instances, const int *csr_row_ptr, L lambda , unsigned int NUM_BLOCK = 4 * 84, 120 | unsigned int BLOCK_SIZE = 256) { 121 | if (n_instances > 0) { 122 | lambda_hist_csr_root_kernel << < dim3(n_instances, NUM_BLOCK), BLOCK_SIZE >> > (csr_row_ptr, lambda); 123 | cudaDeviceSynchronize(); 124 | CUDA_CHECK(cudaPeekAtLastError()); 125 | } 126 | } 127 | 128 | template 129 | __global__ void lambda_hist_csr_node_kernel(L lambda) { 130 | int i = blockIdx.x; 131 | int current_pos = blockIdx.y * blockDim.x + threadIdx.x; 132 | int stride = blockDim.x * gridDim.y; 133 | lambda(i,current_pos,stride); 134 | } 135 | 136 | template 137 | void device_loop_hist_csr_node(int n_instances, const int *csr_row_ptr, L lambda , unsigned int NUM_BLOCK = 4 * 84, 138 | unsigned int BLOCK_SIZE = 256) { 139 | if (n_instances > 0) { 140 | lambda_hist_csr_node_kernel << < dim3(n_instances, NUM_BLOCK), BLOCK_SIZE >> > (lambda); 141 | cudaDeviceSynchronize(); 142 | CUDA_CHECK(cudaPeekAtLastError()); 143 | } 144 | } 145 | 146 | 147 | #endif //THUNDERGBM_DEVICE_LAMBDA_H 148 | -------------------------------------------------------------------------------- /include/thundergbm/util/multi_device.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 18-6-18. 3 | // 4 | 5 | #ifndef THUNDERGBM_MULTI_DEVICE_H 6 | #define THUNDERGBM_MULTI_DEVICE_H 7 | 8 | #include "thundergbm/common.h" 9 | 10 | //switch to specific device and do something, then switch back to the original device 11 | //FIXME make this macro into a function? 12 | #define DO_ON_DEVICE(device_id, something) \ 13 | do { \ 14 | int org_device_id = 0; \ 15 | CUDA_CHECK(cudaGetDevice(&org_device_id)); \ 16 | CUDA_CHECK(cudaSetDevice(device_id)); \ 17 | something; \ 18 | CUDA_CHECK(cudaSetDevice(org_device_id)); \ 19 | } while (false) 20 | 21 | /** 22 | * Do something on multiple devices, then switch back to the original device 23 | * 24 | * 25 | * example: 26 | * 27 | * DO_ON_MULTI_DEVICES(n_devices, [&](int device_id){ 28 | * //do_something_on_device(device_id); 29 | * }); 30 | */ 31 | 32 | template 33 | void DO_ON_MULTI_DEVICES(int n_devices, L do_something) { 34 | int org_device_id = 0; 35 | CUDA_CHECK(cudaGetDevice(&org_device_id)); 36 | #pragma omp parallel for num_threads(n_devices) 37 | for (int device_id = 0; device_id < n_devices; device_id++) { 38 | CUDA_CHECK(cudaSetDevice(device_id)); 39 | do_something(device_id); 40 | } 41 | CUDA_CHECK(cudaSetDevice(org_device_id)); 42 | 43 | } 44 | 45 | #endif //THUNDERGBM_MULTI_DEVICE_H 46 | -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | We provide a scikit-learn wrapper interface. Before you use the Python interface, you must build ThunderGBM. Note that both dense and sparse matrices are supported in ThunderGBM. The dense matrices are converted into csr format which the the parallel training algorithm on GPUs is based on. 2 | 3 | ## Instructions for building ThunderGBM 4 | * Please refer to [Installation](http://thundergbm.readthedocs.io/en/latest/how-to.html) for building ThunderGBM. 5 | 6 | * Then, if you want to install the Python package from source, go to the project root directory and run: 7 | ```bash 8 | cd python && python setup.py install 9 | ``` 10 | Or you can install via pip 11 | ```bash 12 | pip3 install thundergbm 13 | ``` 14 | * After you have successfully installed ThunderGBM, you can import TGBMModel: 15 | ```python 16 | from thundergbm import TGBMClassifier 17 | from sklearn import datasets 18 | clf = TGBMClassifier() 19 | X, y = datasets.load_digits(return_X_y=True) 20 | clf.fit(X, y) 21 | clf.save_model(model_path) 22 | ``` 23 | * Load model 24 | ```python 25 | from thundergbm import TGBMClassifier 26 | clf = TGBMClassifier(objective="your-objective") 27 | # You should specific objective here as in training stage 28 | clf.load_model(model_path) 29 | y_pred = clf.predict(X) 30 | ``` 31 | ## Prerequisites 32 | * numpy | scipy | sklearn 33 | 34 | ## Example 35 | 36 | * Step 1: Create a file called ```tgbm_test.py``` which has the following content. 37 | ```python 38 | from thundergbm import * 39 | from sklearn.datasets import * 40 | from sklearn.metrics import mean_squared_error 41 | from math import sqrt 42 | 43 | x,y = load_svmlight_file("../dataset/test_dataset.txt") 44 | clf = TGBMRegressor() 45 | clf.fit(x,y) 46 | 47 | x2,y2=load_svmlight_file("../dataset/test_dataset.txt") 48 | y_predict=clf.predict(x2) 49 | 50 | rms = sqrt(mean_squared_error(y2, y_predict)) 51 | print(rms) 52 | 53 | ``` 54 | * Step 2: Run the python script. 55 | ```bash 56 | python tgbm_test.py 57 | ``` 58 | 59 | ## ThunderGBM class 60 | *class TGBMModel(depth = 6, num_round = 40, n_device = 1, min_child_weight = 1.0, lambda_tgbm = 1.0, gamma = 1.0, max_num_bin = 255, verbose = 0, column_sampling_rate = 1.0, bagging = 0, n_parallel_trees = 1, learning_rate = 1.0, objective = "reg:linear", num_class = 1, path = "../dataset/test_dataset.txt"))* 61 | 62 | Please note that ``TGBMClassifier`` and ``TGBMRegressor`` are wrappers of ``TGBMModel``, and their constructors have the same parameters to ``TGBMModel``. 63 | 64 | ### Parametes 65 | Please refer to [Parameters](https://github.com/zeyiwen/thundergbm/blob/master/docs/parameters.md) in ThunderGBM documentations. 66 | 67 | ### Methods 68 | *fit(X, y)*:\ 69 | Fit the GBM model according to the given training data. 70 | 71 | *predict(X)*:\ 72 | Perform prediction on samples in X. 73 | 74 | *save_model(path)*:\ 75 | Save the model to the file path. 76 | 77 | *load_model(path)*:\ 78 | Load the model from the file path. 79 | 80 | *cv(X, y, folds=None, nfold=5, shuffle=True,seed=0)*:\ 81 | * folds: A length *n* list of tuples. Each tuple is (in,out) where *in* is a list of indices to be used as the training samples for the *n* th fold and *out* 82 | is a list of indices to be used as the testing samples for the *n* th fold. 83 | * shuffle (bool): Whether to shuffle data before creating folds. 84 | * seed (int): Seed used to generate the folds. 85 | 86 | *get_shap_trees()*:\ 87 | Return the model that used for the SHAP explainer. 88 | -------------------------------------------------------------------------------- /python/benchmarks/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Introduction 3 | 4 | This is some python scripts that can help you reproduce our experimental results. We would like to thank the author of [GBM-Benchmarks](https://github.com/RAMitchell/GBM-Benchmarks) which our scripts are built on top of. 5 | 6 | ## Requirements 7 | _python3.x_     _numpy_     _sklearn_    _xgboost_      _lightgbm_      _catboost_     _thundergbm_ 8 | 9 | ## How to run 10 | 11 | You can use command `python3 experiments.py [model_name] [device_type] [dataset_name]` to run the scripts. And the candidate values of each parameter are as follows: 12 | 13 | - model_name 14 | - xgboost 15 | - lightgbm 16 | - catboost 17 | - thundergbm 18 | - device_type 19 | - cpu 20 | - gpu 21 | - dataset_name 22 | - news20 23 | - higgs 24 | - log1p 25 | - cifar 26 | 27 | ## Files descriptions 28 | - model 29 | - base_model.py 30 | - datasets. py 31 | - catboost_model.py 32 | - lightgb_model.py 33 | - xgboost_model.py 34 | - thundergbm_model.py 35 | - utils 36 | - data_utils,py 37 | - file_utils.py 38 | - experiments. py 39 | - convert_dataset_plk.py 40 | 41 | Floder **_model_** contains the model file of each libraries which inherit from `BaseModel` in `base_model.py.` Floder **_utils_** contains a few tools including `data heleper` and `file I/O helper`. `convert_dataset_plk.py` is used for converting normal `libsvm` file to Python pickle file. This is because the datasets we used for experiments sometime have large size which lead to time-consuming data loading step. By using pickle, we can sharply reduce consuming time in data loading step. `experiments.py` is the main entrance of our scripts. 42 | 43 | ## How to add more datasets 44 | 45 | As the optional datasets of our scripts are limited, you can add the datasets you want. You can achieve this by modifying the file `utils/data_utils.py`. There are some dataset template in that script which may help you add your own dataset easily. 46 | -------------------------------------------------------------------------------- /python/benchmarks/convert_dataset_plk.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from sklearn.datasets import load_svmlight_file 3 | import os 4 | from sklearn.datasets import dump_svmlight_file 5 | 6 | 7 | base_dir = './raw_dataset' 8 | plk_dir = './datasets' 9 | for dataset_name in os.listdir(base_dir): 10 | print(dataset_name) 11 | # if 'binary' not in dataset_name: 12 | # continue 13 | X, y = load_svmlight_file(os.path.join(base_dir, dataset_name)) 14 | print(X.shape) 15 | print(y.shape) 16 | # y2 = [1 if x % 2 == 0 else 0 for x in y] 17 | # dump_svmlight_file(X, y2, open('SVNH.2classes', 'wb')) 18 | pickle.dump((X, y), open(os.path.join(plk_dir, dataset_name+'.plk'), 'wb'), protocol=4) 19 | # X, y = load_svmlight_file('datasets/YearPredictMSD.train') -------------------------------------------------------------------------------- /python/benchmarks/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/python/benchmarks/model/__init__.py -------------------------------------------------------------------------------- /python/benchmarks/model/base_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import mean_squared_error, accuracy_score 3 | 4 | 5 | class BaseModel(object): 6 | """ 7 | Base model to run the test 8 | """ 9 | 10 | def __init__(self): 11 | self.max_depth = 6 12 | self.learning_rate = 1 13 | self.min_split_loss = 1 14 | self.min_weight = 1 15 | self.L1_reg = 1 16 | self.L2_reg = 1 17 | self.num_rounds = 40 18 | self.max_bin = 255 19 | self.use_gpu = True 20 | self.params = {} 21 | 22 | self.model = None # self.model is different with different libraries 23 | 24 | def _config_model(self, data): 25 | """ 26 | To config the model with different params 27 | """ 28 | pass 29 | 30 | def _train_model(self, data): 31 | """ 32 | To train model 33 | :param data: 34 | :return: 35 | """ 36 | pass 37 | 38 | def _predict(self, data): 39 | pass 40 | 41 | def eval(self, data, pred): 42 | """ 43 | To eval the predict results with specified metric 44 | :param data: 45 | :param pred: 46 | :return: 47 | """ 48 | if data.metric == "RMSE": 49 | with open('pred', 'w') as f: 50 | for x in pred: 51 | f.write(str(x) + '\n') 52 | return np.sqrt(mean_squared_error(data.y_test, pred)) 53 | elif data.metric == "Accuracy": 54 | # Threshold prediction if binary classification 55 | if data.task == "Classification": 56 | pred = pred > 0.5 57 | elif data.task == "Multiclass classification": 58 | if pred.ndim > 1: 59 | pred = np.argmax(pred, axis=1) 60 | return accuracy_score(data.y_test, pred) 61 | else: 62 | raise ValueError("Unknown metric: " + data.metric) 63 | 64 | def run_model(self, data): 65 | """ 66 | To run model 67 | :param data: 68 | :return: 69 | """ 70 | self._config_model(data) 71 | elapsed = self._train_model(data) 72 | # metric = 0 73 | metric = self._predict(data) 74 | print("##### Elapsed time: %.5f #####" % (elapsed)) 75 | print("##### Predict %s: %.4f #####" % (data.metric, metric)) 76 | 77 | return elapsed, metric 78 | 79 | def model_name(self): 80 | pass 81 | 82 | -------------------------------------------------------------------------------- /python/benchmarks/model/catboost_model.py: -------------------------------------------------------------------------------- 1 | from model.base_model import BaseModel 2 | import numpy as np 3 | import catboost as cb 4 | import time 5 | import utils.data_utils as du 6 | from model.datasets import Dataset 7 | 8 | 9 | class CatboostModel(BaseModel): 10 | 11 | def __init__(self): 12 | BaseModel.__init__(self) 13 | 14 | def _config_model(self, data): 15 | self.params['eta'] = self.learning_rate 16 | self.params['depth'] = self.max_depth 17 | self.params['l2_leaf_reg'] = self.L2_reg 18 | self.params['devices'] = "0" 19 | self.params['max_bin'] = self.max_bin 20 | self.params['thread_count'] = 20 21 | self.params['iterations'] = self.num_rounds 22 | 23 | if self.use_gpu: 24 | self.params['task_type'] = 'GPU' 25 | else: 26 | self.params['task_type'] = 'CPU' 27 | if data.task == "Multiclass classification": 28 | self.params['loss_function'] = 'MultiClass' 29 | self.params["classes_count"] = int(np.max(data.y_test) + 1) 30 | self.params["eval_metric"] = 'MultiClass' 31 | if data.task == "Classification": 32 | self.params['loss_function'] = 'Logloss' 33 | if data.task == "Ranking": 34 | self.params['loss_function'] = 'YetiRank' 35 | self.params["eval_metric"] = 'NDCG' 36 | 37 | def _train_model(self, data): 38 | print(self.params) 39 | dtrain = cb.Pool(data.X_train, data.y_train) 40 | if data.task == 'Ranking': 41 | dtrain.set_group_id(data.groups) 42 | 43 | start = time.time() 44 | self.model = cb.train(pool=dtrain, params=self.params, ) 45 | elapsed = time.time() - start 46 | 47 | return elapsed 48 | 49 | 50 | def _predict(self, data): 51 | # test dataset for catboost 52 | cb_test = cb.Pool(data.X_test, data.y_test) 53 | if data.task == 'Ranking': 54 | cb_test.set_group_id(data.groups) 55 | preds = self.model.predict(cb_test) 56 | metric = self.eval(data, preds) 57 | 58 | return metric 59 | 60 | def model_name(self): 61 | name = "catboost_" 62 | use_cpu = "gpu_" if self.use_gpu else "cpu_" 63 | nr = str(self.num_rounds) + "_" 64 | return name + use_cpu + nr + str(self.max_depth) 65 | 66 | 67 | if __name__ == "__main__": 68 | X, y, groups = du.get_yahoo() 69 | dataset = Dataset(name='yahoo', task='Ranking', metric='NDCG', get_func=du.get_yahoo) 70 | print(dataset.X_train.shape) 71 | print(dataset.y_test.shape) 72 | 73 | t_start = time.time() 74 | xgbModel = CatboostModel() 75 | xgbModel.use_gpu = False 76 | xgbModel.run_model(data=dataset) 77 | 78 | eplased = time.time() - t_start 79 | print("--------->> " + str(eplased)) -------------------------------------------------------------------------------- /python/benchmarks/model/datasets.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import train_test_split 2 | import utils.data_utils as du 3 | 4 | class Dataset(object): 5 | 6 | def __init__(self, name, task, metric, X=None, y=None, get_func=None): 7 | """ 8 | Please notice that the training set and test set are the same here. 9 | """ 10 | group = None 11 | self.name = name 12 | self.task = task 13 | self.metric = metric 14 | if task == 'Ranking': 15 | if get_func is not None: 16 | X, y, group = get_func() 17 | else: 18 | if get_func is not None: 19 | X, y = get_func() 20 | self.X_train = X 21 | self.X_test = X 22 | self.y_train = y 23 | self.y_test = y 24 | self.groups = group 25 | 26 | def split_dataset(self, test_size=0.1): 27 | """ 28 | Split the dataset in a certain proportion 29 | :param test_size: the proportion of test set 30 | :return: 31 | """ 32 | self.X_train, self.X_test, self.y_train, self.y_test = \ 33 | train_test_split(self.X_train, self.X_test, test_size=test_size) 34 | 35 | 36 | if __name__ == "__main__": 37 | X, y = du.get_higgs() 38 | dataset = Dataset(name='higgs', task='Regression', metric='RMSE', get_func=du.get_higgs) 39 | print(dataset.X_train.shape) 40 | print(dataset.y_test.shape) 41 | 42 | dataset.split_dataset(test_size=0.5) 43 | print(dataset.X_train.shape) 44 | print(dataset.y_test.shape) 45 | -------------------------------------------------------------------------------- /python/benchmarks/model/lightgbm_model.py: -------------------------------------------------------------------------------- 1 | from model.base_model import BaseModel 2 | import numpy as np 3 | import lightgbm as lgb 4 | import time 5 | import utils.data_utils as du 6 | from model.datasets import Dataset 7 | 8 | 9 | class LightGBMModel(BaseModel): 10 | 11 | def __init__(self): 12 | BaseModel.__init__(self) 13 | 14 | def _config_model(self, data): 15 | self.params['task'] = 'train' 16 | self.params['boosting_type'] = 'gbdt' 17 | self.params['max_depth'] = 6 18 | self.params['num_leaves'] = 2 ** self.params['max_depth'] # for max_depth is 6 19 | # self.params['min_sum_hessian+in_leaf'] = 1 20 | self.params['min_split_gain'] = self.min_split_loss 21 | self.params['min_child_weight'] = self.min_weight 22 | self.params['lambda_l1'] = self.L1_reg 23 | self.params['lambda_l2'] = self.L2_reg 24 | self.params['max_bin'] = self.max_bin 25 | self.params['num_threads'] = 20 26 | 27 | if self.use_gpu: 28 | self.params['device'] = 'gpu' 29 | else: 30 | self.params['device'] = 'cpu' 31 | if data.task == "Regression": 32 | self.params["objective"] = "regression" 33 | elif data.task == "Multiclass classification": 34 | self.params["objective"] = "multiclass" 35 | self.params["num_class"] = int(np.max(data.y_test) + 1) 36 | elif data.task == "Classification": 37 | self.params["objective"] = "binary" 38 | elif data.task == "Ranking": 39 | self.params["objective"] = "lambdarank" 40 | else: 41 | raise ValueError("Unknown task: " + data.task) 42 | 43 | 44 | def _train_model(self, data): 45 | print(self.params) 46 | lgb_train = lgb.Dataset(data.X_train, data.y_train) 47 | if data.task == 'Ranking': 48 | lgb_train.set_group(data.groups) 49 | 50 | start = time.time() 51 | self.model = lgb.train(self.params, 52 | lgb_train, 53 | num_boost_round=self.num_rounds) 54 | elapsed = time.time() - start 55 | 56 | return elapsed 57 | 58 | def _predict(self, data): 59 | pred = self.model.predict(data.X_test) 60 | metric = self.eval(data, pred) 61 | 62 | return metric 63 | 64 | def model_name(self): 65 | name = "lightgbm_" 66 | use_cpu = "gpu_" if self.use_gpu else "cpu_" 67 | nr = str(self.num_rounds) + "_" 68 | return name + use_cpu + nr + str(self.max_depth) 69 | 70 | 71 | if __name__ == "__main__": 72 | X, y, groups = du.get_yahoo() 73 | dataset = Dataset(name='yahoo', task='Ranking', metric='NDCG', get_func=du.get_yahoo) 74 | print(dataset.X_train.shape) 75 | print(dataset.y_test.shape) 76 | 77 | t_start = time.time() 78 | xgbModel = LightGBMModel() 79 | xgbModel.use_gpu = False 80 | xgbModel.run_model(data=dataset) 81 | 82 | eplased = time.time() - t_start 83 | print("--------->> " + str(eplased)) 84 | -------------------------------------------------------------------------------- /python/benchmarks/model/thundergbm_model.py: -------------------------------------------------------------------------------- 1 | from model.base_model import BaseModel 2 | import thundergbm as tgb 3 | import time 4 | import numpy as np 5 | import utils.data_utils as du 6 | from model.datasets import Dataset 7 | 8 | 9 | class ThunderGBMModel(BaseModel): 10 | 11 | def __init__(self, depth=6, n_device=1, n_parallel_trees=1, 12 | verbose=0, column_sampling_rate=1.0, bagging=0, tree_method='auto'): 13 | BaseModel.__init__(self) 14 | self.verbose = verbose 15 | self.n_device = n_device 16 | self.column_sampling_rate = column_sampling_rate 17 | self.bagging = bagging 18 | self.n_parallel_trees = n_parallel_trees 19 | self.tree_method = tree_method 20 | self.objective = "" 21 | self.num_class = 1 22 | 23 | def _config_model(self, data): 24 | if data.task == "Regression": 25 | self.objective = "reg:linear" 26 | elif data.task == "Multiclass classification": 27 | self.objective = "multi:softmax" 28 | self.num_class = int(np.max(data.y_test) + 1) 29 | elif data.task == "Classification": 30 | self.objective = "binary:logistic" 31 | elif data.task == "Ranking": 32 | self.objective = "rank:ndcg" 33 | else: 34 | raise ValueError("Unknown task: " + data.task) 35 | 36 | 37 | def _train_model(self, data): 38 | if data.task is 'Regression': 39 | self.model = tgb.TGBMRegressor(tree_method=self.tree_method, depth = self.max_depth, n_trees = 40, n_gpus = 1, \ 40 | min_child_weight = 1.0, lambda_tgbm = 1.0, gamma = 1.0,\ 41 | max_num_bin = 255, verbose = 0, column_sampling_rate = 1.0,\ 42 | bagging = 0, n_parallel_trees = 1, learning_rate = 1.0, \ 43 | objective = "reg:linear", num_class = 1) 44 | else: 45 | self.model = tgb.TGBMClassifier(bagging=1, lambda_tgbm=1, learning_rate=0.07, min_child_weight=1.2, n_gpus=1, verbose=0, 46 | n_parallel_trees=40, gamma=0.2, depth=self.max_depth, n_trees=40, tree_method=self.tree_method, objective='multi:softprob') 47 | start = time.time() 48 | self.model.fit(data.X_train, data.y_train) 49 | elapsed = time.time() - start 50 | print("##################!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! %.5f" % elapsed) 51 | 52 | return elapsed 53 | 54 | def _predict(self, data): 55 | pred = self.model.predict(data.X_test) 56 | metric = self.eval(data, pred) 57 | return metric 58 | 59 | 60 | def model_name(self): 61 | name = "thundergbm_" 62 | use_cpu = "gpu_" if self.use_gpu else "cpu_" 63 | nr = str(self.num_rounds) + "_" 64 | return name + use_cpu + nr + str(self.max_depth) 65 | 66 | 67 | if __name__ == "__main__": 68 | # X, y = du.get_higgs() 69 | dataset = Dataset(name='higgs', task='Regression', metric='RMSE', get_func=du.get_realsim) 70 | print(dataset.X_train.shape) 71 | print(dataset.y_test.shape) 72 | 73 | t_start = time.time() 74 | 75 | tgmModel = ThunderGBMModel() 76 | tgmModel.tree_method = 'hist' 77 | tgmModel.run_model(data=dataset) 78 | 79 | eplased = time.time() - t_start 80 | 81 | print("--------->> " + str(eplased)) -------------------------------------------------------------------------------- /python/benchmarks/model/xgboost_model.py: -------------------------------------------------------------------------------- 1 | from model.base_model import BaseModel 2 | import numpy as np 3 | import xgboost as xgb 4 | import time 5 | import utils.data_utils as du 6 | from model.datasets import Dataset 7 | 8 | 9 | class XGboostModel(BaseModel): 10 | 11 | def __init__(self, use_exact=False, debug_verose=1): 12 | BaseModel.__init__(self) 13 | self.use_exact = use_exact 14 | self.debug_verose = debug_verose 15 | 16 | def _config_model(self, data): 17 | self.params['max_depth'] = self.max_depth 18 | self.params['learning_rate'] = self.learning_rate 19 | self.params['min_split_loss'] = self.min_split_loss 20 | self.params['min_child_weight'] = self.min_weight 21 | self.params['alpha'] = self.L1_reg 22 | self.params['lambda'] = self.L2_reg 23 | self.params['debug_verbose'] = self.debug_verose 24 | self.params['max_bin'] = self.max_bin 25 | 26 | if self.use_gpu: 27 | self.params['tree_method'] = ('gpu_exact' if self.use_exact 28 | else 'gpu_hist') 29 | self.params['n_gpus'] = 1 30 | else: 31 | self.params['nthread'] = 20 32 | self.params['tree_method'] = ('exact' if self.use_exact else 'hist') 33 | 34 | self.params["predictor"] = "gpu_predictor" 35 | if data.task == "Regression": 36 | self.params["objective"] = "reg:squarederror" 37 | elif data.task == "Multiclass classification": 38 | self.params["objective"] = "multi:softprob" 39 | self.params["num_class"] = int(np.max(data.y_test) + 1) 40 | elif data.task == "Classification": 41 | self.params["objective"] = "binary:logistic" 42 | elif data.task == "Ranking": 43 | self.params["objective"] = "rank:ndcg" 44 | else: 45 | raise ValueError("Unknown task: " + data.task) 46 | 47 | def _train_model(self, data): 48 | print(self.params) 49 | dtrain = xgb.DMatrix(data.X_train, data.y_train) 50 | if data.task == 'Ranking': 51 | dtrain.set_group(data.groups) 52 | t_start = time.time() 53 | self.model = xgb.train(self.params, dtrain, self.num_rounds, [(dtrain, "train")]) 54 | elapsed_time = time.time() - t_start 55 | 56 | return elapsed_time 57 | 58 | def _predict(self, data): 59 | dtest = xgb.DMatrix(data.X_test, data.y_test) 60 | if data.task == 'Ranking': 61 | dtest.set_group(data.groups) 62 | pred = self.model.predict(dtest) 63 | metric = self.eval(data=data, pred=pred) 64 | 65 | return metric 66 | 67 | def model_name(self): 68 | name = "xgboost_" 69 | use_cpu = "gpu_" if self.use_gpu else "cpu_" 70 | nr = str(self.num_rounds) + "_" 71 | return name + use_cpu + nr + str(self.max_depth) 72 | 73 | 74 | 75 | 76 | if __name__ == "__main__": 77 | X, y, groups = du.get_yahoo() 78 | dataset = Dataset(name='yahoo', task='Ranking', metric='NDCG', get_func=du.get_yahoo) 79 | print(dataset.X_train.shape) 80 | print(dataset.y_test.shape) 81 | 82 | t_start = time.time() 83 | xgbModel = XGboostModel() 84 | xgbModel.use_gpu = False 85 | xgbModel.run_model(data=dataset) 86 | 87 | eplased = time.time() - t_start 88 | print("--------->> " + str(eplased)) -------------------------------------------------------------------------------- /python/benchmarks/run_exp.sh: -------------------------------------------------------------------------------- 1 | # lgb 2 | #python3 experiments.py lightgbm cpu higgs 3 | #python3 experiments.py lightgbm cpu log1p 4 | #python3 experiments.py lightgbm cpu cifar 5 | python3 experiments.py lightgbm cpu news20 6 | 7 | # lgb gpu 8 | #python3 experiments.py lightgbm gpu higgs 9 | #python3 experiments.py lightgbm gpu log1p 10 | #python3 experiments.py lightgbm gpu cifar 11 | python3 experiments.py lightgbm gpu news20 12 | 13 | # xgboost 14 | #python3 experiments.py xgboost cpu higgs 15 | #python3 experiments.py xgboost cpu log1p 16 | #python3 experiments.py xgboost cpu cifar 17 | python3 experiments.py xgboost cpu news20 18 | 19 | # xgboost gpu 20 | #python3 experiments.py xgboost gpu higgs 21 | #python3 experiments.py xgboost gpu log1p 22 | #python3 experiments.py xgboost gpu cifar 23 | python3 experiments.py xgboost gpu news20 24 | 25 | # catboost 26 | #python3 experiments.py catboost cpu higgs 27 | #python3 experiments.py catboost cpu log1p 28 | #python3 experiments.py catboost cpu cifar 29 | python3 experiments.py catboost cpu news20 30 | 31 | # catboost gpu 32 | #python3 experiments.py catboost gpu higgs 33 | #python3 experiments.py catboost gpu log1p 34 | #python3 experiments.py catboost gpu cifar 35 | python3 experiments.py catboost gpu news20 36 | -------------------------------------------------------------------------------- /python/benchmarks/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/python/benchmarks/utils/__init__.py -------------------------------------------------------------------------------- /python/benchmarks/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pandas as pd 4 | from sklearn import datasets 5 | import pickle 6 | import numpy as np 7 | 8 | if sys.version_info[0] >= 3: 9 | from urllib.request import urlretrieve # pylint: disable=import-error,no-name-in-module 10 | else: 11 | from urllib import urlretrieve # pylint: disable=import-error,no-name-in-module 12 | 13 | # to make a base directory of datasets 14 | BASE_DATASET_DIR = 'datasets' 15 | if not os.path.exists(BASE_DATASET_DIR): 16 | print("Made dir: %s" % BASE_DATASET_DIR) 17 | os.mkdir(BASE_DATASET_DIR) 18 | 19 | 20 | def convert_to_plk(filename, plk_filename, data_url): 21 | if not os.path.isfile(plk_filename): 22 | print('Downloading dataset, please wait...') 23 | urlretrieve(data_url, filename) 24 | X, y = datasets.load_svmlight_file(filename) 25 | pickle.dump((X, y), open(plk_filename, 'wb')) 26 | os.remove(filename) 27 | print("----------- loading dataset %s -----------" % filename) 28 | X, y = pickle.load(open(plk_filename, 'rb')) 29 | print("----------- Finished loading dataset %s -----------" % filename) 30 | return X, y 31 | 32 | 33 | # ------------------------------------------------------------------------------------------------- 34 | get_higgs_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/HIGGS.bz2' # pylint: disable=line-too-long 35 | def get_higgs(num_rows=None): 36 | filename = os.path.join(BASE_DATASET_DIR, 'HIGGS.bz2') 37 | plk_filename = filename + '.plk' 38 | X, y = convert_to_plk(filename, plk_filename, data_url=get_higgs_url) 39 | 40 | return X.toarray(), y 41 | 42 | 43 | # ------------------------------------------------------------------------------------------------- 44 | def get_covtype(num_rows=None): 45 | data = datasets.fetch_covtype() 46 | X = data.data 47 | y = data.target 48 | if num_rows is not None: 49 | X = X[0:num_rows] 50 | y = y[0:num_rows] 51 | 52 | return X, y 53 | 54 | 55 | # ------------------------------------------------------------------------------------------------- 56 | get_e2006_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/E2006.train.bz2' # pylint: disable=line-too-long 57 | def get_e2006(num_rows=None): 58 | filename = os.path.join(BASE_DATASET_DIR, 'E2006.train') 59 | plk_filename = filename + '.plk' 60 | X, y = convert_to_plk(filename, plk_filename, data_url=get_e2006_url) 61 | 62 | return X.toarray(), y 63 | 64 | 65 | # ------------------------------------------------------------------------------------------------- 66 | get_lop1p_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/log1p.E2006.train.bz2' 67 | def get_log1p(num_rows=None): 68 | filename = os.path.join(BASE_DATASET_DIR, 'log1p.train.bz2') 69 | plk_filename = filename + '.plk' 70 | X, y = convert_to_plk(filename, plk_filename, data_url=get_lop1p_url) 71 | 72 | return X, y 73 | 74 | 75 | # ------------------------------------------------------------------------------------------------- 76 | get_news20_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/news20.binary.bz2' 77 | def get_news20(num_rows=None): 78 | filename = os.path.join(BASE_DATASET_DIR, 'news20.binary.bz2') 79 | plk_filename = filename + '.plk' 80 | X, y = convert_to_plk(filename, plk_filename, data_url=get_news20_url) 81 | 82 | return X.toarray(), y 83 | 84 | 85 | # ------------------------------------------------------------------------------------------------- 86 | get_realsim_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/real-sim.bz2' 87 | def get_realsim(num_rows=None): 88 | filename = os.path.join(BASE_DATASET_DIR, 'real-sim.bz2') 89 | plk_filename = filename + '.plk' 90 | X, y = convert_to_plk(filename, plk_filename, data_url=get_realsim_url) 91 | 92 | return X.toarray(), y 93 | 94 | 95 | # ------------------------------------------------------------------------------------------------- 96 | get_susy_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/SUSY.bz2' 97 | def get_susy(num_rows=None): 98 | filename = os.path.join(BASE_DATASET_DIR, 'susy.bz2') 99 | plk_filename = filename + '.plk' 100 | X, y = convert_to_plk(filename, plk_filename, data_url=get_susy_url) 101 | 102 | return X.toarray(), y 103 | 104 | 105 | # ------------------------------------------------------------------------------------------------- 106 | get_epslion_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/epsilon_normalized.bz2' 107 | def get_epsilon(num_rows=None): 108 | filename = os.path.join(BASE_DATASET_DIR, 'epsilon_normalized') 109 | plk_filename = filename + '.plk' 110 | X, y = convert_to_plk(filename, plk_filename, data_url=get_epslion_url) 111 | 112 | return X.toarray(), y 113 | 114 | 115 | # ------------------------------------------------------------------------------------------------- 116 | get_cifar_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/cifar10.bz2' 117 | def get_cifar(num_rows=None): 118 | filename = os.path.join(BASE_DATASET_DIR, 'cifar') 119 | plk_filename = filename + '.plk' 120 | X, y = convert_to_plk(filename, plk_filename, data_url=get_cifar_url) 121 | 122 | return X.toarray(), y 123 | 124 | 125 | # ------------------------------------------------------------------------------------------------- 126 | get_ins_url = '' 127 | def get_ins(num_rows=None): 128 | filename = BASE_DATASET_DIR + '/' + 'ins' 129 | plk_filename = filename + '.plk' 130 | X, y = convert_to_plk(filename, plk_filename, data_url=get_ins_url) 131 | 132 | return X.toarray(), y 133 | 134 | 135 | get_cifar_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/cifar10.bz2' 136 | def get_cifar10(num_rows=None): 137 | filename = BASE_DATASET_DIR + '/' + 'cifar10' 138 | plk_filename = filename + '.plk' 139 | X, y = convert_to_plk(filename, plk_filename, data_url=get_cifar_url) 140 | 141 | return X.toarray(), y 142 | 143 | 144 | 145 | get_news20_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/news20.bz2' 146 | def get_news20(num_rows=None): 147 | filename = BASE_DATASET_DIR + '/' + 'news20.bz2' 148 | plk_filename = filename + '.plk' 149 | X, y = convert_to_plk(filename, plk_filename, data_url=get_news20_url) 150 | 151 | return X.toarray(), y 152 | 153 | 154 | def load_groups(filename): 155 | groups = [] 156 | with open(filename) as f: 157 | for line in f: 158 | if line.strip() != '': 159 | groups.append(int(line.strip())) 160 | 161 | return np.asarray(groups) 162 | 163 | 164 | get_yahoo_url = '' 165 | def get_yahoo(): 166 | filename = BASE_DATASET_DIR + "/" + 'yahoo-ltr-libsvm' 167 | group_filename = BASE_DATASET_DIR + "/" + 'yahoo-ltr-libsvm.group' 168 | plk_filename = filename + '.plk' 169 | X, y = convert_to_plk(filename, plk_filename, data_url=get_yahoo_url) 170 | groups = load_groups(group_filename) 171 | 172 | return X.toarray(), y, groups 173 | 174 | if __name__ == "__main__": 175 | X, y, groups = get_yahoo(); 176 | print(X.shape) 177 | print(y.shape) -------------------------------------------------------------------------------- /python/benchmarks/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | def add_data(df, algorithm, data, elapsed, metric): 5 | time_col = (data.name, 'Time(s)') 6 | metric_col = (data.name, data.metric) 7 | try: 8 | df.insert(len(df.columns), time_col, '-') 9 | df.insert(len(df.columns), metric_col, '-') 10 | except: 11 | pass 12 | 13 | df.at[algorithm, time_col] = elapsed 14 | df.at[algorithm, metric_col] = metric 15 | 16 | 17 | def write_results(df, filename, format): 18 | if format == "latex": 19 | tmp_df = df.copy() 20 | tmp_df.columns = pd.MultiIndex.from_tuples(tmp_df.columns) 21 | with open(filename, "a") as file: 22 | file.write(tmp_df.to_latex()) 23 | elif format == "csv": 24 | with open(filename, "a") as file: 25 | file.write(df.to_csv()) 26 | else: 27 | raise ValueError("Unknown format: " + format) 28 | 29 | print(format + " results written to: " + filename) -------------------------------------------------------------------------------- /python/dist/thundergbm-0.3.12-py2-none-win_amd64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/python/dist/thundergbm-0.3.12-py2-none-win_amd64.whl -------------------------------------------------------------------------------- /python/dist/thundergbm-0.3.12-py3-none-win_amd64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/python/dist/thundergbm-0.3.12-py3-none-win_amd64.whl -------------------------------------------------------------------------------- /python/dist/thundergbm-0.3.16-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/python/dist/thundergbm-0.3.16-py3-none-any.whl -------------------------------------------------------------------------------- /python/dist/thundergbm-0.3.4-py3-none-win_amd64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/python/dist/thundergbm-0.3.4-py3-none-win_amd64.whl -------------------------------------------------------------------------------- /python/examples/classification_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../") 3 | 4 | from thundergbm import TGBMClassifier 5 | from sklearn.datasets import load_digits 6 | from sklearn.metrics import accuracy_score 7 | 8 | if __name__ == '__main__': 9 | x, y = load_digits(return_X_y=True) 10 | clf = TGBMClassifier() 11 | clf.fit(x, y) 12 | y_pred = clf.predict(x) 13 | accuracy = accuracy_score(y, y_pred) 14 | print(accuracy) -------------------------------------------------------------------------------- /python/examples/ranking_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../") 3 | 4 | import thundergbm 5 | from sklearn.datasets import load_svmlight_file 6 | 7 | def load_test_data(file_path): 8 | X, y = load_svmlight_file(file_path) 9 | 10 | return X, y 11 | 12 | def load_groups(file_path): 13 | groups = [] 14 | with open(file_path, 'r') as f: 15 | for line in f.readlines(): 16 | if line.strip() != '': 17 | groups.append(int(line.strip())) 18 | 19 | return groups 20 | 21 | if __name__ == "__main__": 22 | X, y = load_test_data('../../dataset/test_dataset.txt') 23 | groups = load_groups('../../dataset/test_dataset.txt.group') 24 | # tgbm_model = thundergbm.TGBMClassifier(depth=6, n_trees=40, objective='rank:ndcg') 25 | tgbm_model = thundergbm.TGBMRanker(depth=6, n_trees=40, objective='rank:ndcg') 26 | tgbm_model.fit(X, y, groups) 27 | pred_result = tgbm_model.predict(X, groups) 28 | print(pred_result) 29 | 30 | -------------------------------------------------------------------------------- /python/examples/regression_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../") 3 | 4 | from thundergbm import TGBMRegressor 5 | from sklearn.datasets import load_boston 6 | from sklearn.metrics import mean_squared_error 7 | from math import sqrt 8 | 9 | if __name__ == '__main__': 10 | x, y = load_boston(return_X_y=True) 11 | clf = TGBMRegressor() 12 | clf.fit(x, y) 13 | y_pred = clf.predict(x) 14 | rmse = sqrt(mean_squared_error(y, y_pred)) 15 | print(rmse) -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | sklearn 3 | numpy 4 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import setuptools 3 | from shutil import copyfile 4 | from sys import platform 5 | 6 | dirname = path.dirname(path.abspath(__file__)) 7 | 8 | if platform == "linux" or platform == "linux2": 9 | lib_path = path.abspath(path.join(dirname, '../build/lib/libthundergbm.so')) 10 | elif platform == "win32": 11 | lib_path = path.abspath(path.join(dirname, '../build/bin/Debug/thundergbm.dll')) 12 | elif platform == "darwin": 13 | lib_path = path.abspath(path.join(dirname, '../build/lib/libthundergbm.dylib')) 14 | else: 15 | print("OS not supported!") 16 | exit() 17 | if not path.exists(path.join(dirname, "thundergbm", path.basename(lib_path))): 18 | copyfile(lib_path, path.join(dirname, "thundergbm", path.basename(lib_path))) 19 | setuptools.setup(name="thundergbm", 20 | version="0.3.16", 21 | packages=["thundergbm"], 22 | package_dir={"python": "thundergbm"}, 23 | description="A Fast GBM Library on GPUs and CPUs", 24 | long_description="""The mission of ThunderGBM is to help users easily and efficiently apply GBDTs and Random Forests to solve problems. ThunderGBM exploits GPUs and multi-core CPUs to achieve high efficiency""", 25 | long_description_content_type="text/plain", 26 | license='Apache-2.0', 27 | author='Xtra Computing Group', 28 | maintainer='thundergbm contributors', 29 | maintainer_email='wenzy@comp.nus.edu.sg', 30 | url="https://github.com/zeyiwen/thundergbm", 31 | package_data={"thundergbm": [path.basename(lib_path)]}, 32 | install_requires=['numpy', 'scipy', 'scikit-learn'], 33 | classifiers=[ 34 | "Programming Language :: Python :: 3", 35 | "License :: OSI Approved :: Apache Software License", 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /python/thundergbm/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*-coding:utf-8-*- 3 | """ 4 | * Name : __init__.py 5 | * Author : Locke 6 | * Version : 0.0.1 7 | * Description : 8 | """ 9 | name = "thundergbm" 10 | from .thundergbm import * 11 | -------------------------------------------------------------------------------- /src/test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(googletest) 2 | 3 | include_directories(googletest/googletest/include) 4 | include_directories(googletest/googlemock/include) 5 | 6 | file(GLOB TEST_SRC *) 7 | 8 | cuda_add_executable(${PROJECT_NAME}-test ${TEST_SRC} ${COMMON_INCLUDES}) 9 | target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME} gtest) 10 | 11 | -------------------------------------------------------------------------------- /src/test/test_cub_wrapper.cu: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "thundergbm/util/cub_wrapper.h" 3 | #include "thundergbm/syncarray.h" 4 | 5 | class CubWrapperTest: public::testing::Test { 6 | public: 7 | SyncArray values; 8 | SyncArray keys; 9 | 10 | protected: 11 | void SetUp() override { 12 | values.resize(4); 13 | keys.resize(4); 14 | auto values_data = values.host_data(); 15 | values_data[0] = 1; 16 | values_data[1] = 4; 17 | values_data[2] = -1; 18 | values_data[3] = 0; 19 | auto keys_data = keys.host_data(); 20 | keys_data[0] = 1; 21 | keys_data[1] = 3; 22 | keys_data[2] = 2; 23 | keys_data[3] = 4; 24 | } 25 | 26 | }; 27 | 28 | TEST_F(CubWrapperTest, test_cub_sort_by_key) { 29 | // keys from [1, 3, 2, 4] to [1, 2, 3, 4] 30 | // values from [1, 4, -1, 0] to [1, -1, 4, 0] 31 | cub_sort_by_key(keys, values); 32 | auto keys_data = keys.host_data(); 33 | auto values_data = values.host_data(); 34 | EXPECT_EQ(keys_data[0], 1); 35 | EXPECT_EQ(keys_data[1], 2); 36 | EXPECT_EQ(keys_data[2], 3); 37 | EXPECT_EQ(keys_data[3], 4); 38 | EXPECT_EQ(values_data[0], 1); 39 | EXPECT_EQ(values_data[1], -1); 40 | EXPECT_EQ(values_data[2], 4); 41 | EXPECT_EQ(values_data[3], 0); 42 | } 43 | 44 | TEST_F(CubWrapperTest, test_cub_seg_sort_by_key) { 45 | SyncArray seg_ptr(3); // 2 segments 46 | auto seg_ptr_data = seg_ptr.host_data(); 47 | seg_ptr_data[0] = 0; 48 | seg_ptr_data[1] = 2; 49 | seg_ptr_data[2] = 4; 50 | // keys from [1, 3, 2, 4] to [1, 3, 2, 4] 51 | // values from [1, 4, -1, 0] to [1, 4, -1, 0] 52 | /*cub_seg_sort_by_key(keys, values, seg_ptr);*/ 53 | auto keys_data = keys.host_data(); 54 | auto values_data = values.host_data(); 55 | EXPECT_EQ(keys_data[0], 1); 56 | EXPECT_EQ(keys_data[1], 3); 57 | EXPECT_EQ(keys_data[2], 2); 58 | EXPECT_EQ(keys_data[3], 4); 59 | EXPECT_EQ(values_data[0], 1); 60 | EXPECT_EQ(values_data[1], 4); 61 | EXPECT_EQ(values_data[2], -1); 62 | EXPECT_EQ(values_data[3], 0); 63 | } 64 | 65 | TEST_F(CubWrapperTest, test_sort_array) { 66 | sort_array(values); 67 | auto values_data = values.host_data(); 68 | EXPECT_EQ(values_data[0], -1); 69 | EXPECT_EQ(values_data[1], 0); 70 | EXPECT_EQ(values_data[2], 1); 71 | EXPECT_EQ(values_data[3], 4); 72 | } 73 | 74 | TEST_F(CubWrapperTest, test_max_elem) { 75 | int max_elem = max_elements(values); 76 | EXPECT_EQ(max_elem, 4); 77 | } 78 | 79 | -------------------------------------------------------------------------------- /src/test/test_dataset.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jiashuai on 17-9-17. 3 | // 4 | #include "gtest/gtest.h" 5 | #include "thundergbm/dataset.h" 6 | #include "omp.h" 7 | 8 | class DatasetTest : public ::testing::Test { 9 | public: 10 | GBMParam param; 11 | vector csr_val; 12 | vector csr_row_ptr; 13 | vector csr_col_idx; 14 | vector y; 15 | size_t n_features_; 16 | vector label; 17 | protected: 18 | void SetUp() override { 19 | param.depth = 6; 20 | param.n_trees = 40; 21 | param.n_device = 1; 22 | param.min_child_weight = 1; 23 | param.lambda = 1; 24 | param.gamma = 1; 25 | param.rt_eps = 1e-6; 26 | param.max_num_bin = 255; 27 | param.verbose = false; 28 | param.profiling = false; 29 | param.column_sampling_rate = 1; 30 | param.bagging = false; 31 | param.n_parallel_trees = 1; 32 | param.learning_rate = 1; 33 | param.objective = "reg:linear"; 34 | param.num_class = 1; 35 | param.path = "../dataset/test_dataset.txt"; 36 | param.tree_method = "auto"; 37 | if (!param.verbose) { 38 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 39 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 40 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "True"); 41 | } 42 | if (!param.profiling) { 43 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 44 | } 45 | } 46 | 47 | void load_from_file(string file_name, GBMParam ¶m) { 48 | LOG(INFO) << "loading LIBSVM dataset from file \"" << file_name << "\""; 49 | y.clear(); 50 | csr_row_ptr.resize(1, 0); 51 | csr_col_idx.clear(); 52 | csr_val.clear(); 53 | n_features_ = 0; 54 | 55 | std::ifstream ifs(file_name, std::ifstream::binary); 56 | CHECK(ifs.is_open()) << "file " << file_name << " not found"; 57 | 58 | int buffer_size = 4 << 20; 59 | char *buffer = (char *)malloc(buffer_size); 60 | //array may cause stack overflow in windows 61 | //std::array buffer{}; 62 | const int nthread = omp_get_max_threads(); 63 | 64 | auto find_last_line = [](char *ptr, const char *begin) { 65 | while (ptr != begin && *ptr != '\n' && *ptr != '\r' && *ptr != '\0') --ptr; 66 | return ptr; 67 | }; 68 | 69 | while (ifs) { 70 | ifs.read(buffer, buffer_size); 71 | char *head = buffer; 72 | //ifs.read(buffer.data(), buffer.size()); 73 | //char *head = buffer.data(); 74 | size_t size = ifs.gcount(); 75 | vector> y_(nthread); 76 | vector> col_idx_(nthread); 77 | vector> row_len_(nthread); 78 | vector> val_(nthread); 79 | 80 | vector max_feature(nthread, 0); 81 | bool is_zeor_base = false; 82 | 83 | #pragma omp parallel num_threads(nthread) 84 | { 85 | //get working area of this thread 86 | int tid = omp_get_thread_num(); 87 | size_t nstep = (size + nthread - 1) / nthread; 88 | size_t sbegin = (std::min)(tid * nstep, size - 1); 89 | size_t send = (std::min)((tid + 1) * nstep, size - 1); 90 | char *pbegin = find_last_line(head + sbegin, head); 91 | char *pend = find_last_line(head + send, head); 92 | 93 | //move stream start position to the end of last line 94 | if (tid == nthread - 1) { 95 | if (ifs.eof()) 96 | pend = head + send; 97 | else 98 | ifs.seekg(-(head + send - pend), std::ios_base::cur); 99 | } 100 | 101 | //read instances line by line 102 | //TODO optimize parse line 103 | char *lbegin = pbegin; 104 | char *lend = lbegin; 105 | while (lend != pend) { 106 | //get one line 107 | lend = lbegin + 1; 108 | while (lend != pend && *lend != '\n' && *lend != '\r' && *lend != '\0') { 109 | ++lend; 110 | } 111 | string line(lbegin, lend); 112 | if (line != "\n") { 113 | std::stringstream ss(line); 114 | 115 | //read label of an instance 116 | y_[tid].push_back(0); 117 | ss >> y_[tid].back(); 118 | 119 | row_len_[tid].push_back(0); 120 | string tuple; 121 | while (ss >> tuple) { 122 | int i; 123 | float v; 124 | CHECK_EQ(sscanf(tuple.c_str(), "%d:%f", &i, &v), 2) 125 | << "read error, using [index]:[value] format"; 126 | //TODO one-based and zero-based 127 | col_idx_[tid].push_back(i - 1);//one based 128 | if(i - 1 == -1){ 129 | is_zeor_base = true; 130 | } 131 | CHECK_GE(i - 1, -1) << "dataset format error"; 132 | val_[tid].push_back(v); 133 | if (i > max_feature[tid]) { 134 | max_feature[tid] = i; 135 | } 136 | row_len_[tid].back()++; 137 | } 138 | } 139 | //read next instance 140 | lbegin = lend; 141 | 142 | } 143 | } 144 | for (int i = 0; i < nthread; i++) { 145 | if (max_feature[i] > n_features_) 146 | n_features_ = max_feature[i]; 147 | } 148 | for (int tid = 0; tid < nthread; tid++) { 149 | csr_val.insert(csr_val.end(), val_[tid].begin(), val_[tid].end()); 150 | if(is_zeor_base){ 151 | for (int i = 0; i < col_idx_[tid].size(); ++i) { 152 | col_idx_[tid][i]++; 153 | } 154 | } 155 | csr_col_idx.insert(csr_col_idx.end(), col_idx_[tid].begin(), col_idx_[tid].end()); 156 | for (int row_len : row_len_[tid]) { 157 | csr_row_ptr.push_back(csr_row_ptr.back() + row_len); 158 | } 159 | } 160 | for (int i = 0; i < nthread; i++) { 161 | this->y.insert(y.end(), y_[i].begin(), y_[i].end()); 162 | this->label.insert(label.end(), y_[i].begin(), y_[i].end()); 163 | } 164 | } 165 | ifs.close(); 166 | free(buffer); 167 | } 168 | }; 169 | 170 | TEST_F(DatasetTest, load_dataset){ 171 | DataSet dataset; 172 | load_from_file(param.path, param); 173 | dataset.load_from_file(param.path, param); 174 | printf("### Dataset: %s, num_instances: %d, num_features: %d. ###\n", 175 | param.path.c_str(), 176 | dataset.n_instances(), 177 | dataset.n_features()); 178 | EXPECT_EQ(dataset.n_instances(), 1605); 179 | EXPECT_EQ(dataset.n_features_, 119); 180 | EXPECT_EQ(dataset.label[0], -1); 181 | EXPECT_EQ(dataset.csr_val[1], 1); 182 | 183 | for(int i = 0; i < csr_val.size(); i++) 184 | EXPECT_EQ(csr_val[i], dataset.csr_val[i]); 185 | for(int i = 0; i < csr_row_ptr.size(); i++) 186 | EXPECT_EQ(csr_row_ptr[i], dataset.csr_row_ptr[i]); 187 | for(int i = 0; i < csr_col_idx.size(); i++) 188 | EXPECT_EQ(csr_col_idx[i], dataset.csr_col_idx[i]); 189 | } 190 | 191 | -------------------------------------------------------------------------------- /src/test/test_for_refactor.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-14. 3 | // 4 | #include 5 | #include 6 | #include "gtest/gtest.h" 7 | #include "thundergbm/common.h" 8 | 9 | class TrainTest : public ::testing::Test { 10 | public: 11 | GBMParam param; 12 | TreeTrainer trainer; 13 | protected: 14 | void SetUp() override { 15 | param.depth = 6; 16 | param.n_trees = 40; 17 | param.n_device = 1; 18 | param.min_child_weight = 1; 19 | param.lambda = 1; 20 | param.gamma = 1; 21 | param.rt_eps = 1e-6; 22 | param.max_num_bin = 255; 23 | param.verbose = false; 24 | param.profiling = false; 25 | param.column_sampling_rate = 1; 26 | param.bagging = false; 27 | param.n_parallel_trees = 1; 28 | param.learning_rate = 1; 29 | param.objective = "reg:linear"; 30 | param.num_class = 1; 31 | param.path = "../dataset/test_dataset.txt"; 32 | param.tree_method = "auto"; 33 | if (!param.verbose) { 34 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 35 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 36 | } 37 | if (!param.profiling) { 38 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 39 | } 40 | } 41 | 42 | void TearDown() override { 43 | } 44 | }; 45 | 46 | TEST_F(TrainTest, news20) { 47 | param.path = DATASET_DIR "news20.scale"; 48 | // trainer.train(param, <#initializer#>); 49 | // float_type rmse = trainer.train(param); 50 | // EXPECT_NEAR(rmse, 2.55274, 1e-5); 51 | } 52 | 53 | TEST_F(TrainTest, covtype) { 54 | param.path = DATASET_DIR "covtype"; 55 | // float_type rmse = trainer.train2(param); 56 | // trainer.train(param, <#initializer#>); 57 | param.bagging = true; 58 | // trainer.train(param, <#initializer#>); 59 | param.column_sampling_rate = 0.5; 60 | // trainer.train(param, <#initializer#>); 61 | // EXPECT_NEAR(rmse, 0.730795, 1e-5); 62 | } 63 | 64 | TEST_F(TrainTest, covtype_multiclass) { 65 | param.path = DATASET_DIR "covtype"; 66 | param.num_class = 7; 67 | param.objective = "multi:softprob"; 68 | // trainer.train(param, <#initializer#>); 69 | } 70 | 71 | TEST_F(TrainTest, mnist_multiclass) { 72 | param.path = DATASET_DIR "mnist.scale"; 73 | param.objective = "multi:softprob"; 74 | param.num_class = 10; 75 | // trainer.train(param, <#initializer#>); 76 | } 77 | 78 | TEST_F(TrainTest, cifar10_multiclass) { 79 | param.path = DATASET_DIR "cifar10"; 80 | param.objective = "multi:softprob"; 81 | param.num_class = 10; 82 | // trainer.train(param, <#initializer#>); 83 | } 84 | 85 | TEST_F(TrainTest, sector_multiclass) { 86 | param.path = DATASET_DIR "sector.scale"; 87 | param.objective = "multi:softprob"; 88 | param.tree_method = "hist"; 89 | param.num_class = 105; 90 | // trainer.train(param, <#initializer#>); 91 | } 92 | 93 | TEST_F(TrainTest, news20_multiclass) { 94 | param.path = DATASET_DIR "news20.scale"; 95 | param.objective = "multi:softprob"; 96 | param.num_class = 20; 97 | // trainer.train(param, <#initializer#>); 98 | } 99 | 100 | TEST_F(TrainTest, rcv1_multiclass) { 101 | param.path = DATASET_DIR "rcv1_train.multiclass"; 102 | param.objective = "multi:softprob"; 103 | param.tree_method = "hist"; 104 | param.num_class = 51; 105 | // trainer.train(param, <#initializer#>); 106 | } 107 | 108 | TEST_F(TrainTest, yahoo_ranking) { 109 | // param.path = "dataset/rank.train"; 110 | param.path = DATASET_DIR "yahoo-ltr-libsvm"; 111 | param.objective = "rank:ndcg"; 112 | // trainer.train(param, <#initializer#>); 113 | // float_type rmse = trainer.train(param); 114 | } 115 | -------------------------------------------------------------------------------- /src/test/test_gbdt.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "thundergbm/trainer.h" 3 | #include "thundergbm/predictor.h" 4 | #include "thundergbm/dataset.h" 5 | #include "thundergbm/tree.h" 6 | 7 | class GBDTTest: public ::testing::Test { 8 | public: 9 | GBMParam param; 10 | protected: 11 | void SetUp() override { 12 | param.depth = 3; 13 | param.n_trees = 5; 14 | param.n_device = 1; 15 | param.min_child_weight = 1; 16 | param.lambda = 1; 17 | param.gamma = 1; 18 | param.rt_eps = 1e-6; 19 | param.max_num_bin = 255; 20 | param.verbose = false; 21 | param.profiling = false; 22 | param.column_sampling_rate = 1; 23 | param.bagging = false; 24 | param.n_parallel_trees = 1; 25 | param.learning_rate = 1; 26 | param.objective = "reg:linear"; 27 | param.num_class = 1; 28 | param.path = "../dataset/test_dataset.txt"; 29 | param.tree_method = "hist"; 30 | param.tree_per_rounds = 1; 31 | if (!param.verbose) { 32 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 33 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 34 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "True"); 35 | } 36 | if (!param.profiling) { 37 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 38 | } 39 | } 40 | }; 41 | 42 | // training GBDTs in hist manner 43 | TEST_F(GBDTTest, test_hist) { 44 | param.tree_method = "hist"; 45 | DataSet dataset; 46 | dataset.load_from_file(param.path, param); 47 | TreeTrainer trainer; 48 | trainer.train(param, dataset); 49 | } 50 | 51 | // training GBDTs in exact manner 52 | TEST_F(GBDTTest, test_exact) { 53 | param.tree_method = "exact"; 54 | DataSet dataset; 55 | dataset.load_from_file(param.path, param); 56 | TreeTrainer trainer; 57 | trainer.train(param, dataset); 58 | } 59 | 60 | // test bagging 61 | TEST_F(GBDTTest,test_bagging) { 62 | param.bagging = true; 63 | DataSet dataset; 64 | dataset.load_from_file(param.path, param); 65 | TreeTrainer trainer; 66 | trainer.train(param, dataset); 67 | } 68 | 69 | // test different number of parallel trees 70 | TEST_F(GBDTTest, test_n_parallel_trees) { 71 | param.n_parallel_trees = 2; 72 | DataSet dataset; 73 | dataset.load_from_file(param.path, param); 74 | TreeTrainer trainer; 75 | trainer.train(param, dataset); 76 | } 77 | 78 | // test predictor 79 | TEST_F(GBDTTest, test_predictor) { 80 | DataSet dataset; 81 | dataset.load_from_file(param.path, param); 82 | TreeTrainer trainer; 83 | vector> boosted_model; 84 | boosted_model = trainer.train(param, dataset); 85 | Predictor predictor; 86 | predictor.predict(param, boosted_model, dataset); 87 | } 88 | -------------------------------------------------------------------------------- /src/test/test_get_cut_point.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by hanfeng on 6/11/19. 3 | // 4 | 5 | #include "gtest/gtest.h" 6 | #include "thundergbm/common.h" 7 | #include "thundergbm/sparse_columns.h" 8 | #include "thundergbm/hist_cut.h" 9 | #include "thrust/unique.h" 10 | #include "thrust/execution_policy.h" 11 | #include "thundergbm/builder/shard.h" 12 | 13 | 14 | class GetCutPointTest : public ::testing::Test { 15 | public: 16 | GBMParam param; 17 | HistCut cut; 18 | 19 | vector cut_points; 20 | vector row_ptr; 21 | //for gpu 22 | SyncArray cut_points_val; 23 | SyncArray cut_row_ptr; 24 | SyncArray cut_fid; 25 | 26 | 27 | protected: 28 | void SetUp() override { 29 | param.max_num_bin = 255; 30 | param.depth = 6; 31 | param.n_trees = 40; 32 | param.n_device = 1; 33 | param.min_child_weight = 1; 34 | param.lambda = 1; 35 | param.gamma = 1; 36 | param.rt_eps = 1e-6; 37 | param.verbose = false; 38 | param.profiling = false; 39 | param.column_sampling_rate = 1; 40 | param.bagging = false; 41 | param.n_parallel_trees = 1; 42 | param.learning_rate = 1; 43 | param.objective = "reg:linear"; 44 | param.num_class = 1; 45 | param.path = "../dataset/test_dataset.txt"; 46 | param.tree_method = "auto"; 47 | if (!param.verbose) { 48 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 49 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 50 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "false"); 51 | } 52 | if (!param.profiling) { 53 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 54 | } 55 | } 56 | 57 | void get_cut_points2(SparseColumns &columns, int max_num_bins, int n_instances){ 58 | int n_column = columns.n_column; 59 | auto csc_val_data = columns.csc_val.host_data(); 60 | auto csc_col_ptr_data = columns.csc_col_ptr.host_data(); 61 | cut_points.clear(); 62 | row_ptr.clear(); 63 | row_ptr.resize(1, 0); 64 | 65 | //TODO do this on GPU 66 | for (int fid = 0; fid < n_column; ++fid) { 67 | int col_start = csc_col_ptr_data[fid]; 68 | int col_len = csc_col_ptr_data[fid + 1] - col_start; 69 | auto val_data = csc_val_data + col_start; 70 | vector unique_val(col_len); 71 | 72 | int unique_len = thrust::unique_copy(thrust::host, val_data, val_data + col_len, unique_val.data()) - unique_val.data(); 73 | if (unique_len <= max_num_bins) { 74 | row_ptr.push_back(unique_len + row_ptr.back()); 75 | for (int i = 0; i < unique_len; ++i) { 76 | cut_points.push_back(unique_val[i]); 77 | } 78 | } else { 79 | row_ptr.push_back(max_num_bins + row_ptr.back()); 80 | for (int i = 0; i < max_num_bins; ++i) { 81 | cut_points.push_back(unique_val[unique_len / max_num_bins * i]); 82 | } 83 | } 84 | } 85 | 86 | cut_points_val.resize(cut_points.size()); 87 | cut_points_val.copy_from(cut_points.data(), cut_points.size()); 88 | cut_row_ptr.resize(row_ptr.size()); 89 | cut_row_ptr.copy_from(row_ptr.data(), row_ptr.size()); 90 | cut_fid.resize(cut_points.size()); 91 | } 92 | }; 93 | 94 | 95 | TEST_F(GetCutPointTest, covtype) { 96 | //param.path = "../dataset/covtype"; 97 | DataSet dataset; 98 | dataset.load_from_file(param.path, param); 99 | SparseColumns columns; 100 | //vector> v_columns(1); 101 | //columns.csr2csc_gpu(dataset, v_columns); 102 | //cut.get_cut_points2(columns, param.max_num_bin, dataset.n_instances()); 103 | 104 | printf("### Dataset: %s, num_instances: %d, num_features: %d, get_cut_points finished. ###\n", 105 | param.path.c_str(), 106 | dataset.n_instances(), 107 | dataset.n_features()); 108 | 109 | //this->get_cut_points2(columns, param.max_num_bin,dataset.n_instances()); 110 | 111 | // --- test cut_points_val 112 | //auto gpu_cut_points_val = cut.cut_points_val.host_data(); 113 | //auto cpu_cut_points_val = this->cut_points_val.host_data(); 114 | //for(int i = 0; i < cut.cut_points_val.size(); i++) 115 | //EXPECT_EQ(gpu_cut_points_val[i], cpu_cut_points_val[i]); 116 | 117 | // --- test cut_row_ptr 118 | //auto gpu_cut_row_ptr = cut.cut_row_ptr.host_data(); 119 | //auto cpu_cut_row_ptr = this->cut_row_ptr.host_data(); 120 | //for(int i = 0; i < cut.cut_row_ptr.size(); i++) 121 | //EXPECT_EQ(gpu_cut_row_ptr[i], cpu_cut_row_ptr[i]); 122 | 123 | // --- test cut_fid 124 | //EXPECT_EQ(cut.cut_fid.size(), this->cut_fid.size()); 125 | } 126 | 127 | 128 | TEST_F(GetCutPointTest, real_sim) { 129 | DataSet dataset; 130 | dataset.load_from_file(param.path, param); 131 | SparseColumns columns; 132 | //vector> v_columns(1); 133 | //columns.csr2csc_gpu(dataset, v_columns); 134 | //cut.get_cut_points2(columns, param.max_num_bin, dataset.n_instances()); 135 | 136 | //printf("### Dataset: %s, num_instances: %d, num_features: %d, get_cut_points finished. ###\n", 137 | //param.path.c_str(), 138 | //dataset.n_instances(), 139 | //dataset.n_features()); 140 | 141 | //this->get_cut_points2(columns, param.max_num_bin,dataset.n_instances()); 142 | 143 | // --- test cut_points_val 144 | //auto gpu_cut_points_val = cut.cut_points_val.host_data(); 145 | //auto cpu_cut_points_val = this->cut_points_val.host_data(); 146 | //for(int i = 0; i < cut.cut_points_val.size(); i++) 147 | //EXPECT_EQ(gpu_cut_points_val[i], cpu_cut_points_val[i]); 148 | 149 | // --- test cut_row_ptr 150 | //auto gpu_cut_row_ptr = cut.cut_row_ptr.host_data(); 151 | //auto cpu_cut_row_ptr = this->cut_row_ptr.host_data(); 152 | //for(int i = 0; i < cut.cut_row_ptr.size(); i++) 153 | //EXPECT_EQ(gpu_cut_row_ptr[i], cpu_cut_row_ptr[i]); 154 | 155 | //// --- test cut_fid 156 | //EXPECT_EQ(cut.cut_fid.size(), this->cut_fid.size()); 157 | } 158 | 159 | TEST_F(GetCutPointTest, susy) { 160 | //param.path = "../dataset/SUSY"; 161 | DataSet dataset; 162 | dataset.load_from_file(param.path, param); 163 | SparseColumns columns; 164 | //vector> v_columns(1); 165 | //columns.csr2csc_gpu(dataset, v_columns); 166 | //cut.get_cut_points2(columns, param.max_num_bin, dataset.n_instances()); 167 | 168 | //printf("### Dataset: %s, num_instances: %d, num_features: %d, get_cut_points finished. ###\n", 169 | //param.path.c_str(), 170 | //dataset.n_instances(), 171 | //dataset.n_features()); 172 | 173 | //this->get_cut_points2(columns, param.max_num_bin,dataset.n_instances()); 174 | 175 | // --- test cut_points_val 176 | //auto gpu_cut_points_val = cut.cut_points_val.host_data(); 177 | //auto cpu_cut_points_val = this->cut_points_val.host_data(); 178 | //for(int i = 0; i < cut.cut_points_val.size(); i++) 179 | //EXPECT_EQ(gpu_cut_points_val[i], cpu_cut_points_val[i]); 180 | 181 | // --- test cut_row_ptr 182 | //auto gpu_cut_row_ptr = cut.cut_row_ptr.host_data(); 183 | //auto cpu_cut_row_ptr = this->cut_row_ptr.host_data(); 184 | //for(int i = 0; i < cut.cut_row_ptr.size(); i++) 185 | //EXPECT_EQ(gpu_cut_row_ptr[i], cpu_cut_row_ptr[i]); 186 | 187 | // --- test cut_fid 188 | //EXPECT_EQ(cut.cut_fid.size(), this->cut_fid.size()); 189 | } 190 | -------------------------------------------------------------------------------- /src/test/test_gradient.cu: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "thundergbm/objective/multiclass_obj.h" 3 | #include "thundergbm/objective/regression_obj.h" 4 | #include "thundergbm/objective/ranking_obj.h" 5 | #include "thundergbm/dataset.h" 6 | #include "thundergbm/parser.h" 7 | #include "thundergbm/syncarray.h" 8 | 9 | 10 | class GradientTest: public ::testing::Test { 11 | public: 12 | GBMParam param; 13 | protected: 14 | void SetUp() override { 15 | param.depth = 6; 16 | param.n_trees = 40; 17 | param.n_device = 1; 18 | param.min_child_weight = 1; 19 | param.lambda = 1; 20 | param.gamma = 1; 21 | param.rt_eps = 1e-6; 22 | param.max_num_bin = 255; 23 | param.verbose = false; 24 | param.profiling = false; 25 | param.column_sampling_rate = 1; 26 | param.bagging = false; 27 | param.n_parallel_trees = 1; 28 | param.learning_rate = 1; 29 | param.objective = "reg:linear"; 30 | param.num_class = 2; 31 | param.path = "../dataset/test_dataset.txt"; 32 | param.tree_method = "hist"; 33 | if (!param.verbose) { 34 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 35 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 36 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "True"); 37 | } 38 | if (!param.profiling) { 39 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 40 | } 41 | } 42 | }; 43 | 44 | TEST_F(GradientTest, test_softmax_obj) { 45 | DataSet dataset; 46 | dataset.load_from_file(param.path, param); 47 | param.num_class = 2; 48 | dataset.label.resize(2); 49 | dataset.label[0] = 0; 50 | dataset.label[1] = 1; 51 | Softmax softmax; 52 | softmax.configure(param, dataset); 53 | 54 | // check the metric name 55 | EXPECT_EQ(softmax.default_metric_name(), "macc"); 56 | SyncArray y_true(4); 57 | SyncArray y_pred(8); 58 | auto y_pred_data = y_pred.host_data(); 59 | for(int i = 0; i < 8; i++) 60 | y_pred_data[i] = -i; 61 | SyncArray gh_pair(8); 62 | softmax.get_gradient(y_true, y_pred, gh_pair); 63 | 64 | // test the transform function 65 | EXPECT_EQ(y_pred.size(), 8); 66 | softmax.predict_transform(y_pred); 67 | EXPECT_EQ(y_pred.size(), 4); 68 | } 69 | 70 | TEST_F(GradientTest, test_softmaxprob_obj) { 71 | DataSet dataset; 72 | dataset.load_from_file(param.path, param); 73 | param.num_class = 2; 74 | dataset.label.resize(2); 75 | dataset.label[0] = 0; 76 | dataset.label[1] = 1; 77 | SoftmaxProb smp; 78 | smp.configure(param, dataset); 79 | 80 | // check the metric name 81 | EXPECT_EQ(smp.default_metric_name(), "macc"); 82 | SyncArray y_true(4); 83 | SyncArray y_pred(8); 84 | auto y_pred_data = y_pred.host_data(); 85 | for(int i = 0; i < 8; i++) 86 | y_pred_data[i] = -i; 87 | SyncArray gh_pair(8); 88 | smp.get_gradient(y_true, y_pred, gh_pair); 89 | 90 | // test the transform function 91 | EXPECT_EQ(y_pred.size(), 8); 92 | smp.predict_transform(y_pred); 93 | EXPECT_EQ(y_pred.size(), 8); 94 | } 95 | 96 | TEST_F(GradientTest, test_regression_obj) { 97 | DataSet dataset; 98 | dataset.load_from_file(param.path, param); 99 | RegressionObj rmse; 100 | SyncArray y_true(4); 101 | SyncArray y_pred(4); 102 | auto y_pred_data = y_pred.host_data(); 103 | for(int i = 0; i < 4; i++) 104 | y_pred_data[i] = -i; 105 | SyncArray gh_pair(4); 106 | EXPECT_EQ(rmse.default_metric_name(), "rmse"); 107 | rmse.get_gradient(y_true, y_pred, gh_pair); 108 | } 109 | 110 | TEST_F(GradientTest, test_logcls_obj) { 111 | DataSet dataset; 112 | dataset.load_from_file(param.path, param); 113 | LogClsObj logcls; 114 | SyncArray y_true(4); 115 | SyncArray y_pred(4); 116 | auto y_pred_data = y_pred.host_data(); 117 | for(int i = 0; i < 4; i++) 118 | y_pred_data[i] = -i; 119 | SyncArray gh_pair(4); 120 | EXPECT_EQ(logcls.default_metric_name(), "error"); 121 | logcls.get_gradient(y_true, y_pred, gh_pair); 122 | } 123 | 124 | /*TEST_F(GradientTest, test_squareloss_obj) {*/ 125 | /*DataSet dataset;*/ 126 | /*dataset.load_from_file(param.path, param);*/ 127 | /*}*/ 128 | -------------------------------------------------------------------------------- /src/test/test_main.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jiashuai on 17-9-15. 3 | // 4 | #include 5 | #include "gtest/gtest.h" 6 | #ifdef _WIN32 7 | INITIALIZE_EASYLOGGINGPP 8 | #endif 9 | int main(int argc, char **argv) { 10 | ::testing::InitGoogleTest(&argc, argv); 11 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Format, "%datetime %level %fbase:%line : %msg"); 12 | el::Loggers::addFlag(el::LoggingFlag::ColoredTerminalOutput); 13 | el::Loggers::addFlag(el::LoggingFlag::FixedTimeFormat); 14 | return RUN_ALL_TESTS(); 15 | } 16 | -------------------------------------------------------------------------------- /src/test/test_metrics.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "thundergbm/metric/metric.h" 3 | #include "thundergbm/metric/multiclass_metric.h" 4 | #include "thundergbm/metric/pointwise_metric.h" 5 | #include "thundergbm/metric/ranking_metric.h" 6 | 7 | 8 | class MetricTest : public ::testing::Test { 9 | public: 10 | GBMParam param; 11 | vector csr_val; 12 | vector csr_row_ptr; 13 | vector csr_col_idx; 14 | vector y; 15 | size_t n_features_; 16 | vector label; 17 | protected: 18 | void SetUp() override { 19 | param.depth = 6; 20 | param.n_trees = 40; 21 | param.n_device = 1; 22 | param.min_child_weight = 1; 23 | param.lambda = 1; 24 | param.gamma = 1; 25 | param.rt_eps = 1e-6; 26 | param.max_num_bin = 255; 27 | param.verbose = false; 28 | param.profiling = false; 29 | param.column_sampling_rate = 1; 30 | param.bagging = false; 31 | param.n_parallel_trees = 1; 32 | param.learning_rate = 1; 33 | param.objective = "reg:linear"; 34 | param.num_class = 2; 35 | param.path = "../dataset/test_dataset.txt"; 36 | param.tree_method = "hist"; 37 | if (!param.verbose) { 38 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 39 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 40 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "True"); 41 | } 42 | if (!param.profiling) { 43 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 44 | } 45 | } 46 | }; 47 | 48 | 49 | TEST_F(MetricTest, test_multiclass_metric_config) { 50 | DataSet dataset; 51 | dataset.load_from_file(param.path, param); 52 | MulticlassAccuracy mmetric; 53 | dataset.y.resize(4); 54 | dataset.label.resize(2); 55 | mmetric.configure(param, dataset); 56 | EXPECT_EQ(mmetric.get_name(), "multi-class accuracy"); 57 | } 58 | 59 | TEST_F(MetricTest, test_multiclass_metric_score) { 60 | DataSet dataset; 61 | dataset.load_from_file(param.path, param); 62 | MulticlassAccuracy mmetric; 63 | dataset.y.resize(4); 64 | dataset.label.resize(2); 65 | mmetric.configure(param, dataset); 66 | 67 | SyncArray y_pred(8); 68 | auto y_pred_data = y_pred.host_data(); 69 | y_pred_data[0] = 1; 70 | EXPECT_EQ(mmetric.get_score(y_pred), 0) << mmetric.get_score(y_pred); 71 | } 72 | 73 | TEST_F(MetricTest, test_binaryclass_metric_score) { 74 | DataSet dataset; 75 | dataset.load_from_file(param.path, param); 76 | BinaryClassMetric mmetric; 77 | dataset.y.resize(4); 78 | dataset.label.resize(2); 79 | mmetric.configure(param, dataset); 80 | 81 | SyncArray y_pred(4); 82 | auto y_pred_data = y_pred.host_data(); 83 | y_pred_data[0] = 1; 84 | EXPECT_EQ(mmetric.get_score(y_pred), 1) << mmetric.get_score(y_pred); 85 | } 86 | 87 | 88 | TEST_F(MetricTest, test_pointwise_metric_score) { 89 | DataSet dataset; 90 | dataset.load_from_file(param.path, param); 91 | RMSE mmetric; 92 | dataset.y.resize(4); 93 | mmetric.configure(param, dataset); 94 | 95 | SyncArray y_pred(4); 96 | EXPECT_EQ(mmetric.get_score(y_pred), 1) << mmetric.get_score(y_pred); 97 | } 98 | -------------------------------------------------------------------------------- /src/test/test_parser.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "thundergbm/parser.h" 3 | #include "thundergbm/dataset.h" 4 | #include "thundergbm/tree.h" 5 | 6 | class ParserTest: public ::testing::Test { 7 | public: 8 | GBMParam param; 9 | vector csr_val; 10 | vector csr_row_ptr; 11 | vector csr_col_idx; 12 | vector y; 13 | size_t n_features_; 14 | vector label; 15 | protected: 16 | void SetUp() override { 17 | param.depth = 6; 18 | param.n_trees = 40; 19 | param.n_device = 1; 20 | param.min_child_weight = 1; 21 | param.lambda = 1; 22 | param.gamma = 1; 23 | param.rt_eps = 1e-6; 24 | param.max_num_bin = 255; 25 | param.verbose = false; 26 | param.profiling = false; 27 | param.column_sampling_rate = 1; 28 | param.bagging = false; 29 | param.n_parallel_trees = 1; 30 | param.learning_rate = 1; 31 | param.objective = "reg:linear"; 32 | param.num_class = 1; 33 | param.path = "../dataset/test_dataset.txt"; 34 | param.tree_method = "auto"; 35 | if (!param.verbose) { 36 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 37 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 38 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "True"); 39 | } 40 | if (!param.profiling) { 41 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 42 | } 43 | } 44 | }; 45 | 46 | TEST_F(ParserTest, test_parser){ 47 | EXPECT_EQ(param.depth, 6); 48 | EXPECT_EQ(param.gamma, 1); 49 | EXPECT_EQ(param.learning_rate, 1); 50 | EXPECT_EQ(param.num_class, 1); 51 | EXPECT_EQ(param.tree_method, "auto"); 52 | EXPECT_EQ(param.max_num_bin, 255); 53 | } 54 | 55 | TEST_F(ParserTest, test_save_model) { 56 | string model_path = "tgbm.model"; 57 | vector> boosted_model; 58 | DataSet dataset; 59 | dataset.load_from_file(param.path, param); 60 | Parser parser; 61 | parser.save_model(model_path, param, boosted_model, dataset); 62 | } 63 | 64 | TEST_F(ParserTest, test_load_model) { 65 | string model_path = "tgbm.model"; 66 | vector> boosted_model; 67 | DataSet dataset; 68 | dataset.load_from_file(param.path, param); 69 | Parser parser; 70 | parser.load_model(model_path, param, boosted_model, dataset); 71 | // the size of he boosted model should be zero 72 | EXPECT_EQ(boosted_model.size(), 0); 73 | } 74 | -------------------------------------------------------------------------------- /src/test/test_synarray.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jiashuai on 17-9-17. 3 | // 4 | #include "gtest/gtest.h" 5 | #include "thundergbm/syncarray.h" 6 | 7 | TEST(SyncDataTest, host_allocate){ 8 | SyncArray syncData(100); 9 | EXPECT_NE(syncData.host_data(), nullptr); 10 | EXPECT_EQ(syncData.head(), SyncMem::HEAD::HOST); 11 | EXPECT_EQ(syncData.mem_size(), sizeof(int) * 100); 12 | EXPECT_EQ(syncData.size(), 100); 13 | syncData.resize(20); 14 | EXPECT_EQ(syncData.head(), SyncMem::UNINITIALIZED); 15 | EXPECT_EQ(syncData.size(), 20); 16 | } 17 | 18 | #ifdef USE_CUDA 19 | TEST(SyncDataTest, device_allocate){ 20 | SyncArray syncData(100); 21 | EXPECT_NE(syncData.device_data(), nullptr); 22 | EXPECT_EQ(syncData.head(), SyncMem::HEAD::DEVICE); 23 | EXPECT_EQ(syncData.mem_size(), sizeof(int) * 100); 24 | EXPECT_EQ(syncData.size(), 100); 25 | syncData.resize(20); 26 | EXPECT_EQ(syncData.head(), SyncMem::UNINITIALIZED); 27 | EXPECT_EQ(syncData.size(), 20); 28 | } 29 | 30 | TEST(SyncDataTest, host_to_device){ 31 | SyncArray syncData(10); 32 | SyncArray syncData1(10); 33 | syncData1.set_host_data(syncData.host_data()); 34 | syncData1.set_device_data(syncData.device_data()); 35 | int *data = syncData.host_data(); 36 | for (int i = 0; i < 10; ++i) { 37 | data[i] = i; 38 | } 39 | syncData.to_device(); 40 | EXPECT_EQ(syncData.head(), SyncMem::HEAD::DEVICE); 41 | for (int i = 0; i < 10; ++i) { 42 | data[i] = -1; 43 | } 44 | syncData.to_host(); 45 | EXPECT_EQ(syncData.head(), SyncMem::HEAD::HOST); 46 | data = syncData.host_data(); 47 | for (int i = 0; i < 10; ++i) { 48 | EXPECT_EQ(data[i] , i); 49 | } 50 | for (int i = 0; i < 10; ++i) { 51 | EXPECT_EQ(syncData.host_data()[i], i); 52 | EXPECT_EQ(syncData1.host_data()[i], i); 53 | } 54 | } 55 | 56 | #endif -------------------------------------------------------------------------------- /src/test/test_synmem.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "thundergbm/syncmem.h" 3 | 4 | 5 | TEST(SyncmemTest, test_constructor) { 6 | SyncMem smem1(10); 7 | EXPECT_NE(smem1.host_data(), nullptr); 8 | EXPECT_EQ(smem1.size(), 10); 9 | EXPECT_EQ(smem1.head(), SyncMem::HEAD::HOST); 10 | 11 | SyncMem smem2(10); 12 | EXPECT_EQ(smem2.size(), 10); 13 | smem2.to_host(); 14 | EXPECT_EQ(smem2.size(), 10) << "The size of smem2 is " << smem2.size(); 15 | } 16 | 17 | 18 | TEST(SyncmemTest, test_set_host_data) { 19 | SyncMem smem1(10); 20 | SyncMem smem2(10); 21 | // copy the data from smem1 to smem2 22 | smem2.set_host_data(smem1.host_data()); 23 | EXPECT_EQ(smem2.size(), 10); 24 | } 25 | 26 | TEST(SyncmemTest, test_set_device_data) { 27 | SyncMem smem1(20); 28 | SyncMem smem2(20); 29 | // copy the data from smem1 to smem2 30 | smem2.set_device_data(smem2.device_data()); 31 | EXPECT_EQ(smem2.size(), 20); 32 | // the head flag of smem2 should become DEVICE 33 | EXPECT_EQ(smem2.head(), SyncMem::HEAD::DEVICE); 34 | } 35 | 36 | TEST(SyncmemTest, test_host_to_device) { 37 | SyncMem smem1(sizeof(int) * 10); 38 | int *data = static_cast(smem1.host_data()); 39 | for(int i = 0; i < 10; i++) 40 | data[i] = i; 41 | smem1.to_device(); 42 | // the head flag of smem1 should become DEVICE 43 | EXPECT_EQ(smem1.head(), SyncMem::HEAD::DEVICE); 44 | 45 | // change the data on the host 46 | for(int i = 0; i < 10; i++) 47 | data[i] = -10; 48 | smem1.to_host(); 49 | // the head flag of smem1 should become HOST 50 | EXPECT_EQ(smem1.head(), SyncMem::HEAD::HOST); 51 | 52 | // reset the data and check if the data has been changed 53 | data = static_cast(smem1.host_data()); 54 | for(int i = 0; i < 10; i++) 55 | EXPECT_EQ(data[i], i); 56 | } 57 | 58 | TEST(SyncMem, test_get_device_id) { 59 | SyncMem smem1(10); 60 | smem1.to_device(); 61 | EXPECT_EQ(smem1.get_owner_id(), 0) << "the default device id is 0"; 62 | } 63 | 64 | TEST(SyncMem, test_clear_cache) { 65 | SyncMem smem(sizeof(int) * 10); 66 | int *data = static_cast(smem.host_data()); 67 | for(int i = 0; i < 10; i++) 68 | data[i] = i; 69 | for(int i = 0; i < 10; i++) 70 | EXPECT_EQ(data[i], i); 71 | smem.clear_cache(); 72 | data = static_cast(smem.host_data()); 73 | for(int i = 0; i < 10; i++) 74 | EXPECT_EQ(data[i], i); 75 | } 76 | -------------------------------------------------------------------------------- /src/test/test_tree.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "thundergbm/tree.h" 3 | #include "thundergbm/dataset.h" 4 | #include "thundergbm/booster.h" 5 | #include "thundergbm/syncarray.h" 6 | 7 | class TreeTest: public ::testing::Test { 8 | public: 9 | GBMParam param; 10 | vector csr_val; 11 | vector csr_row_ptr; 12 | vector csr_col_idx; 13 | vector y; 14 | size_t n_features_; 15 | vector label; 16 | protected: 17 | void SetUp() override { 18 | param.depth = 6; 19 | param.n_trees = 40; 20 | param.n_device = 1; 21 | param.min_child_weight = 1; 22 | param.lambda = 1; 23 | param.gamma = 1; 24 | param.rt_eps = 1e-6; 25 | param.max_num_bin = 255; 26 | param.verbose = false; 27 | param.profiling = false; 28 | param.column_sampling_rate = 1; 29 | param.bagging = false; 30 | param.n_parallel_trees = 1; 31 | param.learning_rate = 1; 32 | param.objective = "reg:linear"; 33 | param.num_class = 1; 34 | param.path = "../dataset/test_dataset.txt"; 35 | param.tree_method = "hist"; 36 | if (!param.verbose) { 37 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 38 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 39 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "True"); 40 | } 41 | if (!param.profiling) { 42 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 43 | } 44 | } 45 | }; 46 | 47 | TEST_F(TreeTest, treenode){ 48 | int max_nodes = 8; 49 | SyncArray nodes; 50 | nodes = SyncArray(max_nodes); 51 | auto node_data = nodes.host_data(); 52 | for(int i =0; i < max_nodes; i++) { 53 | node_data[i].final_id = i; 54 | node_data[i].split_feature_id = -1; 55 | } 56 | 57 | EXPECT_EQ(nodes.size(), 8); 58 | EXPECT_EQ(node_data[5].final_id, 5); 59 | EXPECT_EQ(node_data[6].split_feature_id, -1); 60 | } 61 | 62 | TEST_F(TreeTest, tree_init){ 63 | SyncArray gradients(10); 64 | Tree tree; 65 | tree.init2(gradients, param); 66 | 67 | // check the amount of tree nodes 68 | EXPECT_EQ(tree.nodes.size(), 127); 69 | 70 | // check the value of nodes' attributes 71 | auto nodes_data = tree.nodes.host_data(); 72 | EXPECT_EQ(nodes_data[5].final_id, 5); 73 | EXPECT_EQ(nodes_data[1].split_feature_id, -1); 74 | } 75 | 76 | TEST_F(TreeTest, tree_prune) { 77 | SyncArray gradients(10); 78 | Tree tree; 79 | tree.init2(gradients, param); 80 | tree.prune_self(0.5); 81 | } 82 | -------------------------------------------------------------------------------- /src/thundergbm/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 2 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 3 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 4 | 5 | file(GLOB SRC util/*.c* builder/*.c* objective/*.c* metric/*.c* *.c*) 6 | list(REMOVE_ITEM SRC "${CMAKE_CURRENT_LIST_DIR}/thundergbm_train.cpp") 7 | list(REMOVE_ITEM SRC "${CMAKE_CURRENT_LIST_DIR}/thundergbm_predict.cpp") 8 | 9 | cuda_add_library(${PROJECT_NAME} SHARED ${SRC}) 10 | target_link_libraries(${PROJECT_NAME} ${CUDA_cusparse_LIBRARY}) 11 | 12 | cuda_add_executable(${PROJECT_NAME}-train thundergbm_train.cpp) 13 | target_link_libraries(${PROJECT_NAME}-train ${PROJECT_NAME}) 14 | 15 | cuda_add_executable(${PROJECT_NAME}-predict thundergbm_predict.cpp) 16 | target_link_libraries(${PROJECT_NAME}-predict ${PROJECT_NAME}) 17 | 18 | -------------------------------------------------------------------------------- /src/thundergbm/builder/function_builder.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-19. 3 | // 4 | 5 | #include 6 | #include "thundergbm/builder/exact_tree_builder.h" 7 | #include "thundergbm/builder/hist_tree_builder.h" 8 | #include "thundergbm/builder/hist_tree_builder_single.h" 9 | 10 | FunctionBuilder *FunctionBuilder::create(std::string name) { 11 | if (name == "exact") return new ExactTreeBuilder; 12 | if (name == "hist") return new HistTreeBuilder; 13 | if (name == "hist_single") return new HistTreeBuilder_single; 14 | LOG(FATAL) << "unknown builder " << name; 15 | return nullptr; 16 | } 17 | 18 | -------------------------------------------------------------------------------- /src/thundergbm/builder/shard.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by shijiashuai on 2019-03-08. 3 | // 4 | #include "thundergbm/builder/shard.h" 5 | #include "thrust/sequence.h" 6 | #include "thundergbm/util/device_lambda.cuh" 7 | 8 | void Shard::column_sampling(float rate) { 9 | if (rate < 1) { 10 | CHECK_GT(rate, 0); 11 | int n_column = columns.n_column; 12 | SyncArray idx(n_column); 13 | thrust::sequence(thrust::cuda::par, idx.device_data(), idx.device_end(), 0); 14 | std::random_shuffle(idx.host_data(), idx.host_data() + n_column); 15 | int sample_count = max(1, int(n_column * rate)); 16 | ignored_set.resize(n_column); 17 | auto idx_data = idx.device_data(); 18 | auto ignored_set_data = ignored_set.device_data(); 19 | device_loop(sample_count, [=]__device__(int i) { 20 | ignored_set_data[idx_data[i]] = true; 21 | }); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/thundergbm/gbm_R_interface.cpp: -------------------------------------------------------------------------------- 1 | // Created by Qinbin on 3/15/2020 2 | 3 | #include 4 | #include "thundergbm/parser.h" 5 | #include "thundergbm/predictor.h" 6 | using std::fstream; 7 | using std::stringstream; 8 | 9 | extern "C" { 10 | void train_R(int* depth, int* n_trees, int* n_gpus, int* verbose, 11 | int* profiling, char** data, int* max_num_bin, double* column_sampling_rate, 12 | int* bagging, int* n_parallel_trees, double* learning_rate, char** objective, 13 | int* num_class, int* min_child_weight, double* lambda_tgbm, double* gamma, 14 | char** tree_method, char** model_out) { 15 | 16 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Format, "%datetime %level %fbase:%line : %msg"); 17 | el::Loggers::addFlag(el::LoggingFlag::ColoredTerminalOutput); 18 | el::Loggers::addFlag(el::LoggingFlag::FixedTimeFormat); 19 | 20 | if (*verbose == 0) { 21 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 22 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 23 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "false"); 24 | } else if (*verbose == 1) { 25 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 26 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 27 | } 28 | 29 | if (!(*profiling)) { 30 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 31 | } 32 | 33 | GBMParam model_param; 34 | model_param.depth = *depth; 35 | model_param.n_trees = *n_trees; 36 | model_param.n_device = *n_gpus; 37 | model_param.verbose = *verbose; 38 | model_param.profiling = *profiling; 39 | model_param.path = data[0]; 40 | model_param.max_num_bin = *max_num_bin; 41 | model_param.column_sampling_rate = (float_type) *column_sampling_rate; 42 | model_param.bagging = *bagging; 43 | model_param.n_parallel_trees = *n_parallel_trees; 44 | model_param.learning_rate = (float_type) * learning_rate; 45 | model_param.objective = objective[0]; 46 | model_param.num_class = *num_class; 47 | model_param.min_child_weight = *min_child_weight; 48 | model_param.lambda = (float_type) *lambda_tgbm; 49 | model_param.gamma = (float_type) *gamma; 50 | model_param.tree_method = tree_method[0]; 51 | model_param.rt_eps = 1e-6; 52 | model_param.tree_per_rounds = 1; 53 | 54 | DataSet dataset; 55 | Parser parser; 56 | vector> boosted_model; 57 | dataset.load_from_file(model_param.path, model_param); 58 | TreeTrainer trainer; 59 | boosted_model = trainer.train(model_param, dataset); 60 | parser.save_model(model_out[0], model_param, boosted_model, dataset); 61 | } 62 | 63 | void predict_R(char** data, char** model_in, int* verbose){ 64 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Format, "%datetime %level %fbase:%line : %msg"); 65 | el::Loggers::addFlag(el::LoggingFlag::ColoredTerminalOutput); 66 | el::Loggers::addFlag(el::LoggingFlag::FixedTimeFormat); 67 | 68 | if (*verbose == 0) { 69 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 70 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 71 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "false"); 72 | } else if (*verbose == 1) { 73 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 74 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 75 | } 76 | 77 | GBMParam model_param; 78 | Parser parser; 79 | DataSet dataSet; 80 | vector> boosted_model; 81 | parser.load_model(model_in[0], model_param, boosted_model, dataSet); 82 | dataSet.load_from_file(data[0], model_param); 83 | //predict 84 | Predictor pred; 85 | vector y_pred_vec = pred.predict(model_param, boosted_model, dataSet); 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/thundergbm/metric/metric.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-13. 3 | // 4 | #include "thundergbm/metric/metric.h" 5 | #include "thundergbm/metric/pointwise_metric.h" 6 | #include "thundergbm/metric/ranking_metric.h" 7 | #include "thundergbm/metric/multiclass_metric.h" 8 | 9 | Metric *Metric::create(string name) { 10 | if (name == "map") return new MAP; 11 | if (name == "rmse") return new RMSE; 12 | if (name == "ndcg") return new NDCG; 13 | if (name == "macc") return new MulticlassAccuracy; 14 | if (name == "error") return new BinaryClassMetric; 15 | LOG(FATAL) << "unknown metric " << name; 16 | return nullptr; 17 | } 18 | 19 | void Metric::configure(const GBMParam ¶m, const DataSet &dataset) { 20 | y.resize(dataset.y.size()); 21 | y.copy_from(dataset.y.data(), dataset.n_instances()); 22 | } 23 | -------------------------------------------------------------------------------- /src/thundergbm/metric/multiclass_metric.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-15. 3 | // 4 | #include "thundergbm/metric/multiclass_metric.h" 5 | #include "thundergbm/util/device_lambda.cuh" 6 | #include "thrust/reduce.h" 7 | 8 | 9 | float_type MulticlassAccuracy::get_score(const SyncArray &y_p) const { 10 | CHECK_EQ(num_class * y.size(), y_p.size()) << num_class << " * " << y.size() << " != " << y_p.size(); 11 | int n_instances = y.size(); 12 | auto y_data = y.device_data(); 13 | auto yp_data = y_p.device_data(); 14 | SyncArray is_true(n_instances); 15 | auto is_true_data = is_true.device_data(); 16 | int num_class = this->num_class; 17 | device_loop(n_instances, [=] __device__(int i){ 18 | int max_k = 0; 19 | float_type max_p = yp_data[i]; 20 | for (int k = 1; k < num_class; ++k) { 21 | if (yp_data[k * n_instances + i] > max_p) { 22 | max_p = yp_data[k * n_instances + i]; 23 | max_k = k; 24 | } 25 | } 26 | is_true_data[i] = max_k == y_data[i]; 27 | }); 28 | 29 | float acc = thrust::reduce(thrust::cuda::par, is_true_data, is_true_data + n_instances) / (float) n_instances; 30 | return acc; 31 | } 32 | 33 | float_type BinaryClassMetric::get_score(const SyncArray &y_p) const { 34 | int n_instances = y.size(); 35 | auto y_data = y.device_data(); 36 | auto yp_data = y_p.device_data(); 37 | SyncArray is_true(n_instances); 38 | auto is_true_data = is_true.device_data(); 39 | device_loop(n_instances, [=] __device__(int i){ 40 | int max_k = (1 / (1 + exp(-yp_data[i])) > 0.5) ? 1 : 0; 41 | is_true_data[i] = max_k == y_data[i]; 42 | }); 43 | 44 | float acc = thrust::reduce(thrust::cuda::par, is_true_data, is_true_data + n_instances) / (float) n_instances; 45 | return 1 - acc; 46 | } 47 | -------------------------------------------------------------------------------- /src/thundergbm/metric/pointwise_metric.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-14. 3 | // 4 | #include "thrust/reduce.h" 5 | #include "thundergbm/util/device_lambda.cuh" 6 | #include "thundergbm/metric/pointwise_metric.h" 7 | 8 | float_type RMSE::get_score(const SyncArray &y_p) const { 9 | CHECK_EQ(y_p.size(), y.size()); 10 | int n_instances = y_p.size(); 11 | SyncArray sq_err(n_instances); 12 | auto sq_err_data = sq_err.device_data(); 13 | const float_type *y_data = y.device_data(); 14 | const float_type *y_predict_data = y_p.device_data(); 15 | device_loop(n_instances, [=] __device__(int i) { 16 | float_type e = y_predict_data[i] - y_data[i]; 17 | sq_err_data[i] = e * e; 18 | }); 19 | float_type rmse = 20 | sqrtf(thrust::reduce(thrust::cuda::par, sq_err.device_data(), sq_err.device_end()) / n_instances); 21 | return rmse; 22 | } 23 | 24 | -------------------------------------------------------------------------------- /src/thundergbm/metric/rank_metric.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-14. 3 | // 4 | #include 5 | #ifndef _WIN32 6 | #include "parallel/algorithm" 7 | #endif 8 | 9 | float_type RankListMetric::get_score(const SyncArray &y_p) const { 10 | TIMED_FUNC(obj); 11 | float_type sum_score = 0; 12 | auto y_data0 = y.host_data(); 13 | auto yp_data0 = y_p.host_data(); 14 | #pragma omp parallel for schedule(static) reduction(+:sum_score) 15 | for (int k = 0; k < n_group; ++k) { 16 | int group_start = gptr[k]; 17 | int len = gptr[k + 1] - group_start; 18 | vector query_y(len); 19 | vector query_yp(len); 20 | memcpy(query_y.data(), y_data0 + group_start, len * sizeof(float_type)); 21 | memcpy(query_yp.data(), yp_data0 + group_start, len * sizeof(float_type)); 22 | sum_score += this->eval_query_group(query_y, query_yp, k); 23 | } 24 | return sum_score / n_group; 25 | } 26 | 27 | void RankListMetric::configure(const GBMParam ¶m, const DataSet &dataset) { 28 | Metric::configure(param, dataset); 29 | 30 | //init gptr 31 | n_group = dataset.group.size(); 32 | configure_gptr(dataset.group, gptr); 33 | 34 | //TODO parse from param 35 | topn = (std::numeric_limits::max)(); 36 | } 37 | 38 | void RankListMetric::configure_gptr(const vector &group, vector &gptr) { 39 | gptr = vector(group.size() + 1, 0); 40 | for (int i = 1; i < gptr.size(); ++i) { 41 | gptr[i] = gptr[i - 1] + group[i - 1]; 42 | } 43 | } 44 | 45 | float_type MAP::eval_query_group(vector &y, vector &y_p, int group_id) const { 46 | auto y_data = y.data(); 47 | auto yp_data = y_p.data(); 48 | int len = y.size(); 49 | vector idx(len); 50 | for (int i = 0; i < len; ++i) { 51 | idx[i] = i; 52 | } 53 | #ifdef _WIN32 54 | std::sort(idx.begin(), idx.end(), [=](int a, int b) { return yp_data[a] > yp_data[b]; }); 55 | #else 56 | __gnu_parallel::sort(idx.begin(), idx.end(), [=](int a, int b) { return yp_data[a] > yp_data[b]; }); 57 | #endif 58 | int nhits = 0; 59 | double sum_ap = 0; 60 | for (int i = 0; i < len; ++i) { 61 | if (y_data[idx[i]] != 0) { 62 | nhits++; 63 | if (i < topn) { 64 | sum_ap += (double) nhits / (i + 1); 65 | } 66 | } 67 | } 68 | 69 | if (nhits != 0) 70 | return sum_ap / nhits; 71 | else return 1; 72 | } 73 | 74 | void NDCG::configure(const GBMParam ¶m, const DataSet &dataset) { 75 | RankListMetric::configure(param, dataset); 76 | get_IDCG(gptr, dataset.y, idcg); 77 | } 78 | 79 | float_type NDCG::eval_query_group(vector &y, vector &y_p, int group_id) const { 80 | CHECK_EQ(y.size(), y_p.size()); 81 | if (idcg[group_id] == 0) return 1; 82 | int len = y.size(); 83 | vector idx(len); 84 | for (int i = 0; i < len; ++i) { 85 | idx[i] = i; 86 | } 87 | auto label = y.data(); 88 | auto score = y_p.data(); 89 | #ifdef _WIN32 90 | std::sort(idx.begin(), idx.end(), [=](int a, int b) { return score[a] > score[b]; }); 91 | #else 92 | __gnu_parallel::sort(idx.begin(), idx.end(), [=](int a, int b) { return score[a] > score[b]; }); 93 | #endif 94 | 95 | float_type dcg = 0; 96 | for (int i = 0; i < len; ++i) { 97 | dcg += discounted_gain(static_cast(label[idx[i]]), i); 98 | } 99 | return dcg / idcg[group_id]; 100 | } 101 | 102 | void NDCG::get_IDCG(const vector &gptr, const vector &y, vector &idcg) { 103 | int n_group = gptr.size() - 1; 104 | idcg.clear(); 105 | idcg.resize(n_group); 106 | //calculate IDCG 107 | #pragma omp parallel for schedule(static) 108 | for (int k = 0; k < n_group; ++k) { 109 | int group_start = gptr[k]; 110 | int len = gptr[k + 1] - group_start; 111 | vector sorted_label(len); 112 | memcpy(sorted_label.data(), y.data() + group_start, len * sizeof(float_type)); 113 | #ifdef _WIN32 114 | std::sort(sorted_label.begin(), sorted_label.end(), std::greater()); 115 | #else 116 | __gnu_parallel::sort(sorted_label.begin(), sorted_label.end(), std::greater()); 117 | #endif 118 | for (int i = 0; i < sorted_label.size(); ++i) { 119 | //assume labels are int 120 | idcg[k] += NDCG::discounted_gain(static_cast(sorted_label[i]), i); 121 | } 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /src/thundergbm/objective/multiclass_obj.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-15. 3 | // 4 | #include "thundergbm/objective/multiclass_obj.h" 5 | 6 | void 7 | Softmax::get_gradient(const SyncArray &y, const SyncArray &y_p, SyncArray &gh_pair) { 8 | CHECK_EQ(y.size(), y_p.size() / num_class); 9 | CHECK_EQ(y_p.size(), gh_pair.size()); 10 | auto y_data = y.device_data(); 11 | auto yp_data = y_p.device_data(); 12 | auto gh_data = gh_pair.device_data(); 13 | int num_class = this->num_class; 14 | int n_instances = y_p.size() / num_class; 15 | device_loop(n_instances, [=]__device__(int i) { 16 | float_type max = yp_data[i]; 17 | for (int k = 1; k < num_class; ++k) { 18 | max = fmaxf(max, yp_data[k * n_instances + i]); 19 | } 20 | float_type sum = 0; 21 | for (int k = 0; k < num_class; ++k) { 22 | //-max to avoid numerical issue 23 | sum += expf(yp_data[k * n_instances + i] - max); 24 | } 25 | for (int k = 0; k < num_class; ++k) { 26 | float_type p = expf(yp_data[k * n_instances + i] - max) / sum; 27 | //gradient = p_i - y_i 28 | //approximate hessian = 2 * p_i * (1 - p_i) 29 | //https://github.com/dmlc/xgboost/issues/2485 30 | float_type g = k == y_data[i] ? (p - 1) : (p - 0); 31 | float_type h = fmaxf(2 * p * (1 - p), 1e-16f); 32 | gh_data[k * n_instances + i] = GHPair(g, h); 33 | } 34 | }); 35 | } 36 | 37 | void Softmax::configure(GBMParam param, const DataSet &dataset) { 38 | num_class = param.num_class; 39 | label.resize(num_class); 40 | CHECK_EQ(dataset.label.size(), num_class)< &y) { 71 | auto yp_data = y.device_data(); 72 | int num_class = this->num_class; 73 | int n_instances = y.size() / num_class; 74 | device_loop(n_instances, [=]__device__(int i) { 75 | float_type max = yp_data[i]; 76 | for (int k = 1; k < num_class; ++k) { 77 | max = fmaxf(max, yp_data[k * n_instances + i]); 78 | } 79 | float_type sum = 0; 80 | for (int k = 0; k < num_class; ++k) { 81 | //-max to avoid numerical issue 82 | yp_data[k * n_instances + i] = expf(yp_data[k * n_instances + i] - max); 83 | sum += yp_data[k * n_instances + i]; 84 | } 85 | for (int k = 0; k < num_class; ++k) { 86 | yp_data[k * n_instances + i] /= sum; 87 | } 88 | }); 89 | } 90 | -------------------------------------------------------------------------------- /src/thundergbm/objective/objective_function.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-1. 3 | // 4 | #include 5 | #include "thundergbm/objective/regression_obj.h" 6 | #include "thundergbm/objective/multiclass_obj.h" 7 | #include "thundergbm/objective/ranking_obj.h" 8 | 9 | ObjectiveFunction *ObjectiveFunction::create(string name) { 10 | if (name == "reg:linear") return new RegressionObj; 11 | if (name == "reg:logistic") return new RegressionObj; 12 | if (name == "binary:logistic") return new LogClsObj; 13 | if (name == "multi:softprob") return new SoftmaxProb; 14 | if (name == "multi:softmax") return new Softmax; 15 | if (name == "rank:pairwise") return new LambdaRank; 16 | if (name == "rank:ndcg") return new LambdaRankNDCG; 17 | LOG(FATAL) << "undefined objective " << name; 18 | return nullptr; 19 | } 20 | 21 | bool ObjectiveFunction::need_load_group_file(string name) { 22 | return name == "rank:ndcg" || name == "rank:pairwise"; 23 | } 24 | 25 | bool ObjectiveFunction::need_group_label(string name) { 26 | return name == "multi:softprob" || name == "multi:softmax" || name == "binary:logistic"; 27 | } 28 | -------------------------------------------------------------------------------- /src/thundergbm/objective/ranking_obj.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by ss on 19-1-12. 3 | // 4 | #include 5 | #include "thundergbm/metric/ranking_metric.h" 6 | #ifndef _WIN32 7 | #include 8 | #endif 9 | #include 10 | 11 | void LambdaRank::configure(GBMParam param, const DataSet &dataset) { 12 | sigma = 1; 13 | 14 | //init gptr 15 | n_group = dataset.group.size(); 16 | RankListMetric::configure_gptr(dataset.group, gptr); 17 | CHECK_EQ(gptr.back(), dataset.n_instances()); 18 | } 19 | 20 | void 21 | LambdaRank::get_gradient(const SyncArray &y, const SyncArray &y_p, SyncArray &gh_pair) { 22 | TIMED_FUNC(obj); 23 | { 24 | auto gh_data = gh_pair.host_data(); 25 | #pragma omp parallel for schedule(static) 26 | for (int i = 0; i < gh_pair.size(); ++i) { 27 | gh_data[i] = 0; 28 | } 29 | } 30 | GHPair *gh_data0 = gh_pair.host_data(); 31 | const float_type *score0 = y_p.host_data(); 32 | const float_type *label_data0 = y.host_data(); 33 | PERFORMANCE_CHECKPOINT_WITH_ID(obj, "copy and init"); 34 | #pragma omp parallel for schedule(static) 35 | for (int k = 0; k < n_group; ++k) { 36 | int group_start = gptr[k]; 37 | int len = gptr[k + 1] - group_start; 38 | GHPair *gh_data = gh_data0 + group_start; 39 | const float_type *score = score0 + group_start; 40 | const float_type *label_data = label_data0 + group_start; 41 | vector idx(len); 42 | for (int i = 0; i < len; ++i) { idx[i] = i; } 43 | std::sort(idx.begin(), idx.end(), [=](int a, int b) { return score[a] > score[b]; }); 44 | vector> label_idx(len); 45 | for (int i = 0; i < len; ++i) { 46 | label_idx[i].first = label_data[idx[i]]; 47 | label_idx[i].second = idx[i]; 48 | } 49 | //sort by label ascending 50 | std::sort(label_idx.begin(), label_idx.end(), 51 | [](std::pair a, std::pair b) { return a.first > b.first; }); 52 | 53 | std::mt19937 gen(std::rand()); 54 | for (int i = 0; i < len; ++i) { 55 | int j = i + 1; 56 | while (j < len && label_idx[i].first == label_idx[j].first) j++; 57 | int nleft = i; 58 | int nright = len - j; 59 | //if not all are same label 60 | if (nleft + nright != 0) { 61 | // bucket in [i,j), get a sample outside bucket 62 | std::uniform_int_distribution<> dis(0, nleft + nright - 1); 63 | for (int l = i; l < j; ++l) { 64 | int m = dis(gen); 65 | int high = 0; 66 | int low = 0; 67 | if (m < nleft) { 68 | high = m; 69 | low = l; 70 | } else { 71 | high = l; 72 | low = m + j - i; 73 | } 74 | float_type high_label = label_idx[high].first; 75 | float_type low_label = label_idx[low].first; 76 | int high_idx = label_idx[high].second; 77 | int low_idx = label_idx[low].second; 78 | 79 | float_type abs_delta_z = fabsf(get_delta_z(high_label, low_label, high, low, k)); 80 | float_type rhoIJ = 1 / (1 + expf((score[high_idx] - score[low_idx]))); 81 | float_type lambda = -abs_delta_z * rhoIJ; 82 | float_type hessian = 2 * fmaxf(abs_delta_z * rhoIJ * (1 - rhoIJ), 1e-16f); 83 | gh_data[high_idx] = gh_data[high_idx] + GHPair(lambda, hessian); 84 | gh_data[low_idx] = gh_data[low_idx] + GHPair(-lambda, hessian); 85 | } 86 | } 87 | i = j; 88 | } 89 | } 90 | y_p.to_device(); 91 | } 92 | 93 | string LambdaRank::default_metric_name() { return "map"; } 94 | 95 | //inline functions should be defined in the header file 96 | //inline float_type 97 | //LambdaRank::get_delta_z(float_type labelI, float_type labelJ, int rankI, int rankJ, int group_id) { return 1; } 98 | 99 | 100 | void LambdaRankNDCG::configure(GBMParam param, const DataSet &dataset) { 101 | LambdaRank::configure(param, dataset); 102 | NDCG::get_IDCG(gptr, dataset.y, idcg); 103 | } 104 | 105 | float_type 106 | LambdaRankNDCG::get_delta_z(float_type labelI, float_type labelJ, int rankI, int rankJ, int group_id) { 107 | if (idcg[group_id] == 0) return 0; 108 | float_type dgI1 = NDCG::discounted_gain(static_cast(labelI), rankI); 109 | float_type dgJ1 = NDCG::discounted_gain(static_cast(labelJ), rankJ); 110 | float_type dgI2 = NDCG::discounted_gain(static_cast(labelI), rankJ); 111 | float_type dgJ2 = NDCG::discounted_gain(static_cast(labelJ), rankI); 112 | return (dgI1 + dgJ1 - dgI2 - dgJ2) / idcg[group_id]; 113 | } 114 | 115 | string LambdaRankNDCG::default_metric_name() { return "ndcg"; } 116 | -------------------------------------------------------------------------------- /src/thundergbm/parser.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by zeyi on 1/10/19. 3 | // 4 | 5 | #include "thundergbm/parser.h" 6 | #include 7 | using namespace std; 8 | 9 | void Parser::parse_param(GBMParam &model_param, int argc, char **argv){ 10 | model_param.depth = 6; 11 | model_param.n_trees = 40; 12 | model_param.n_device = 1; 13 | model_param.min_child_weight = 1; 14 | model_param.lambda = 1; 15 | model_param.gamma = 1; 16 | model_param.rt_eps = 1e-6; 17 | model_param.max_num_bin = 255; 18 | model_param.verbose = 1; 19 | model_param.profiling = false; 20 | model_param.column_sampling_rate = 1; 21 | model_param.bagging = false; 22 | model_param.n_parallel_trees = 1; 23 | model_param.learning_rate = 1; 24 | model_param.objective = "reg:linear"; 25 | model_param.num_class = 1; 26 | model_param.path = "../dataset/test_dataset.txt"; 27 | model_param.tree_method = "auto"; 28 | model_param.tree_per_rounds = 1; // # tree of each round, depends on # class 29 | 30 | if (argc < 2) { 31 | printf("Usage: \n"); 32 | exit(0); 33 | } 34 | 35 | //parsing parameter values from configuration file or command line 36 | auto parse_value = [&](const char *name_val){ 37 | char name[256], val[256]; 38 | if (sscanf(name_val, "%[^=]=%s", name, val) == 2) { 39 | string str_name(name); 40 | if((str_name.compare("max_depth") == 0) || (str_name.compare("depth") == 0)) 41 | model_param.depth = atoi(val); 42 | else if((str_name.compare("num_round") == 0) || (str_name.compare("n_trees") == 0)) 43 | model_param.n_trees = atoi(val); 44 | else if(str_name.compare("n_gpus") == 0) 45 | model_param.n_device = atoi(val); 46 | else if((str_name.compare("verbosity") == 0) || (str_name.compare("verbose") == 0)) 47 | model_param.verbose = atoi(val); 48 | else if(str_name.compare("profiling") == 0) 49 | model_param.profiling = atoi(val); 50 | else if(str_name.compare("data") == 0) 51 | model_param.path = val; 52 | else if((str_name.compare("max_bin") == 0) || (str_name.compare("max_num_bin") == 0)) 53 | model_param.max_num_bin = atoi(val); 54 | else if((str_name.compare("colsample") == 0) || (str_name.compare("column_sampling_rate") == 0)) 55 | model_param.column_sampling_rate = atof(val); 56 | else if(str_name.compare("bagging") == 0) 57 | model_param.bagging = atoi(val); 58 | else if((str_name.compare("num_parallel_tree") == 0) || (str_name.compare("n_parallel_trees") == 0)) 59 | model_param.n_parallel_trees = atoi(val); 60 | else if(str_name.compare("eta") == 0 || str_name.compare("learning_rate") == 0) 61 | model_param.learning_rate = atof(val); 62 | else if(str_name.compare("objective") == 0) 63 | model_param.objective = val; 64 | else if(str_name.compare("num_class") == 0) 65 | model_param.num_class = atoi(val); 66 | else if(str_name.compare("min_child_weight") == 0) 67 | model_param.min_child_weight = atoi(val); 68 | else if(str_name.compare("lambda") == 0 || str_name.compare("lambda_tgbm") == 0) 69 | model_param.lambda = atof(val); 70 | else if(str_name.compare("gamma") == 0 || str_name.compare("min_split_loss") == 0) 71 | model_param.gamma = atof(val); 72 | else if(str_name.compare("tree_method") == 0) 73 | model_param.tree_method = val; 74 | else 75 | LOG(INFO) << "\"" << name << "\" is unknown option!"; 76 | } 77 | else{ 78 | string str_name(name); 79 | if(str_name.compare("-help") == 0){ 80 | printf("please refer to \"docs/parameters.md\" in the GitHub repository for more information about setting the options\n"); 81 | exit(0); 82 | } 83 | } 84 | 85 | }; 86 | 87 | //read configuration file 88 | std::ifstream conf_file(argv[1]); 89 | std::string line; 90 | while (std::getline(conf_file, line)) 91 | { 92 | //LOG(INFO) << line; 93 | parse_value(line.c_str()); 94 | } 95 | 96 | //TODO: confirm handling spaces around "=" 97 | for (int i = 0; i < argc; ++i) { 98 | parse_value(argv[i]); 99 | }//end parsing parameters 100 | } 101 | 102 | void Parser::load_model(string model_path, GBMParam &model_param, vector> &boosted_model, DataSet & dataset) { 103 | std::ifstream ifs(model_path, ios::binary); 104 | CHECK_EQ(ifs.is_open(), true); 105 | int length; 106 | ifs.read((char*)&length, sizeof(length)); 107 | char * temp = new char[length+1]; 108 | temp[length] = '\0'; 109 | // read param.objective 110 | ifs.read(temp, length); 111 | string str(temp); 112 | model_param.objective = str; 113 | ifs.read((char*)&model_param.learning_rate, sizeof(model_param.learning_rate)); 114 | ifs.read((char*)&model_param.num_class, sizeof(model_param.num_class)); 115 | ifs.read((char*)&model_param.n_trees, sizeof(model_param.n_trees)); 116 | ifs.read((char*)&model_param.base_score, sizeof(model_param.base_score)); 117 | int label_size; 118 | ifs.read((char*)&label_size, sizeof(label_size)); 119 | float_type f; 120 | dataset.label.clear(); 121 | for (int i = 0; i < label_size; ++i) { 122 | ifs.read((char*)&f, sizeof(float_type)); 123 | dataset.label.push_back(f); 124 | } 125 | int boosted_model_size; 126 | ifs.read((char*)&boosted_model_size, sizeof(boosted_model_size)); 127 | Tree t; 128 | vector v; 129 | for (int i = 0; i < boosted_model_size; ++i) { 130 | int boost_model_i_size; 131 | ifs.read((char*)&boost_model_i_size, sizeof(boost_model_i_size)); 132 | for (int j = 0; j < boost_model_i_size; ++j) { 133 | size_t syn_node_size; 134 | ifs.read((char*)&syn_node_size, sizeof(syn_node_size)); 135 | SyncArray tmp(syn_node_size); 136 | ifs.read((char*)tmp.host_data(), sizeof(Tree::TreeNode) * syn_node_size); 137 | t.nodes.resize(tmp.size()); 138 | t.nodes.copy_from(tmp); 139 | v.push_back(t); 140 | } 141 | boosted_model.push_back(v); 142 | v.clear(); 143 | } 144 | ifs.close(); 145 | } 146 | void Parser::save_model(string model_path, GBMParam &model_param, vector> &boosted_model, DataSet &dataset) { 147 | ofstream out_model_file(model_path, ios::binary); 148 | CHECK_EQ(out_model_file.is_open(), true); 149 | int length = model_param.objective.length(); 150 | out_model_file.write((char*)&length, sizeof(length)); 151 | out_model_file.write(model_param.objective.c_str(), model_param.objective.length()); 152 | out_model_file.write((char*)&model_param.learning_rate, sizeof(model_param.learning_rate)); 153 | out_model_file.write((char*)&model_param.num_class, sizeof(model_param.num_class)); 154 | out_model_file.write((char*)&model_param.n_trees, sizeof(model_param.n_trees)); 155 | out_model_file.write((char*)&model_param.base_score, sizeof(model_param.base_score)); 156 | int label_size = dataset.label.size(); 157 | out_model_file.write((char*)&label_size, sizeof(label_size)); 158 | out_model_file.write((char*)&dataset.label[0], dataset.label.size() * sizeof(float_type)); 159 | int boosted_model_size = boosted_model.size(); 160 | out_model_file.write((char*)&boosted_model_size, sizeof(boosted_model_size)); 161 | for(int j = 0; j < boosted_model.size(); ++j) { 162 | int boosted_model_j_size = boosted_model[j].size(); 163 | out_model_file.write((char*)&boosted_model_j_size, sizeof(boosted_model_j_size)); 164 | for (int i = 0; i < boosted_model_j_size; ++i) { 165 | size_t syn_node_size = boosted_model[j][i].nodes.size(); 166 | out_model_file.write((char*)&syn_node_size, sizeof(syn_node_size)); 167 | out_model_file.write((char*)boosted_model[j][i].nodes.host_data(), syn_node_size * sizeof(Tree::TreeNode)); 168 | } 169 | } 170 | out_model_file.close(); 171 | } 172 | -------------------------------------------------------------------------------- /src/thundergbm/row_sampler.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by shijiashuai on 2019-02-15. 3 | // 4 | #include 5 | #include "thundergbm/util/multi_device.h" 6 | #include "thundergbm/util/device_lambda.cuh" 7 | #include "thrust/random.h" 8 | 9 | void RowSampler::do_bagging(MSyncArray &gradients) { 10 | LOG(TRACE) << "do bagging"; 11 | using namespace thrust; 12 | int n_instances = gradients.front().size(); 13 | SyncArray idx(n_instances); 14 | auto idx_data = idx.device_data(); 15 | int seed = std::rand();//TODO add a global random generator class 16 | device_loop(n_instances, [=]__device__(int i) { 17 | default_random_engine rng(seed); 18 | uniform_int_distribution uniform_dist(0, n_instances - 1); 19 | rng.discard(i); 20 | idx_data[i] = uniform_dist(rng); 21 | }); 22 | SyncArray ins_count(n_instances); 23 | auto ins_count_data = ins_count.device_data(); 24 | device_loop(n_instances, [=]__device__(int i) { 25 | int ins_id = idx_data[i]; 26 | atomicAdd(ins_count_data + ins_id, 1); 27 | }); 28 | DO_ON_MULTI_DEVICES(gradients.size(), [&](int device_id){ 29 | auto gh_data = gradients[device_id].device_data(); 30 | device_loop(n_instances, [=]__device__(int i) { 31 | gh_data[i].g = gh_data[i].g * ins_count_data[i]; 32 | gh_data[i].h = gh_data[i].h * ins_count_data[i]; 33 | }); 34 | }); 35 | } 36 | 37 | -------------------------------------------------------------------------------- /src/thundergbm/thundergbm_predict.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by zeyi on 1/10/19. 3 | // 4 | 5 | #include "thundergbm/parser.h" 6 | #include 7 | #include "thundergbm/predictor.h" 8 | #ifdef _WIN32 9 | INITIALIZE_EASYLOGGINGPP 10 | #endif 11 | int main(int argc, char **argv) { 12 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Format, "%datetime %level %fbase:%line : %msg"); 13 | el::Loggers::addFlag(el::LoggingFlag::ColoredTerminalOutput); 14 | el::Loggers::addFlag(el::LoggingFlag::FixedTimeFormat); 15 | 16 | GBMParam model_param; 17 | Parser parser; 18 | parser.parse_param(model_param, argc, argv); 19 | 20 | if(model_param.verbose == 0) { 21 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 22 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 23 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "false"); 24 | } 25 | else if (model_param.verbose == 1) { 26 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 27 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 28 | } 29 | 30 | //load data set 31 | DataSet dataSet; 32 | //load model 33 | vector> boosted_model; 34 | parser.load_model("tgbm.model", model_param, boosted_model, dataSet); 35 | dataSet.load_from_file(model_param.path, model_param); 36 | //predict 37 | Predictor pred; 38 | vector y_pred_vec = pred.predict(model_param, boosted_model, dataSet); 39 | 40 | //users can use y_pred_vec for their own purpose. 41 | } 42 | -------------------------------------------------------------------------------- /src/thundergbm/thundergbm_train.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by zeyi on 1/9/19. 3 | // 4 | 5 | #include 6 | #include "thundergbm/parser.h" 7 | #ifdef _WIN32 8 | INITIALIZE_EASYLOGGINGPP 9 | #endif 10 | int main(int argc, char **argv) { 11 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Format, "%datetime %level %fbase:%line : %msg"); 12 | el::Loggers::addFlag(el::LoggingFlag::ColoredTerminalOutput); 13 | el::Loggers::addFlag(el::LoggingFlag::FixedTimeFormat); 14 | 15 | GBMParam model_param; 16 | Parser parser; 17 | parser.parse_param(model_param, argc, argv); 18 | if(model_param.verbose == 0) { 19 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 20 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 21 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "false"); 22 | } 23 | else if (model_param.verbose == 1) { 24 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 25 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 26 | } 27 | 28 | if (!model_param.profiling) { 29 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 30 | } 31 | 32 | DataSet dataset; 33 | vector> boosted_model; 34 | 35 | // dataset.load_csc_from_file(model_param.path, model_param); 36 | dataset.load_from_file(model_param.path, model_param); 37 | TreeTrainer trainer; 38 | boosted_model = trainer.train(model_param, dataset); 39 | parser.save_model("tgbm.model", model_param, boosted_model, dataset); 40 | } -------------------------------------------------------------------------------- /src/thundergbm/trainer.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by zeyi on 1/9/19. 3 | // 4 | #include 5 | #include "cuda_runtime_api.h" 6 | 7 | #include 8 | #include 9 | #include 10 | #include "thundergbm/util/device_lambda.cuh" 11 | #include "thrust/reduce.h" 12 | #include "time.h" 13 | #include "thundergbm/booster.h" 14 | #include "chrono" 15 | #include 16 | using namespace std; 17 | 18 | vector> TreeTrainer::train(GBMParam ¶m, const DataSet &dataset) { 19 | if (param.tree_method == "auto") 20 | if (dataset.n_features() > 20000) 21 | param.tree_method = "exact"; 22 | else 23 | param.tree_method = "hist"; 24 | 25 | //correct the number of classes 26 | if(param.objective.find("multi:") != std::string::npos || param.objective.find("binary:") != std::string::npos) { 27 | int num_class = dataset.label.size(); 28 | if (param.num_class != num_class) { 29 | LOG(INFO) << "updating number of classes from " << param.num_class << " to " << num_class; 30 | param.num_class = num_class; 31 | } 32 | if(param.num_class > 2) 33 | param.tree_per_rounds = param.num_class; 34 | } 35 | else if(param.objective.find("reg:") != std::string::npos){ 36 | param.num_class = 1; 37 | } 38 | 39 | vector> boosted_model; 40 | Booster booster; 41 | booster.init(dataset, param); 42 | std::chrono::high_resolution_clock timer; 43 | auto start = timer.now(); 44 | for (int i = 0; i < param.n_trees; ++i) { 45 | //one iteration may produce multiple trees, depending on objectives 46 | booster.boost(boosted_model,i+1,param.n_trees); 47 | } 48 | auto stop = timer.now(); 49 | std::chrono::duration training_time = stop - start; 50 | LOG(INFO) << "training time = " << training_time.count(); 51 | 52 | std::atexit([]() { 53 | SyncMem::clear_cache(); 54 | }); 55 | // SyncMem::clear_cache(); 56 | return boosted_model; 57 | } 58 | -------------------------------------------------------------------------------- /src/thundergbm/tree.cu: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jiashuai on 18-1-18. 3 | // 4 | #include "thundergbm/tree.h" 5 | #include "thundergbm/util/device_lambda.cuh" 6 | #include "thrust/reduce.h" 7 | 8 | void Tree::init2(const SyncArray &gradients, const GBMParam ¶m) { 9 | TIMED_FUNC(timerObj); 10 | int n_max_nodes = static_cast(pow(2, param.depth + 1) - 1); 11 | nodes = SyncArray(n_max_nodes); 12 | auto node_data = nodes.device_data(); 13 | device_loop(n_max_nodes, [=]__device__(int i) { 14 | node_data[i].final_id = i; 15 | node_data[i].split_feature_id = -1; 16 | node_data[i].is_valid = false; 17 | node_data[i].parent_index = i == 0 ? -1 : (i - 1) / 2; 18 | if (i < n_max_nodes / 2) { 19 | node_data[i].is_leaf = false; 20 | node_data[i].lch_index = i * 2 + 1; 21 | node_data[i].rch_index = i * 2 + 2; 22 | } else { 23 | //leaf nodes 24 | node_data[i].is_leaf = true; 25 | node_data[i].lch_index = -1; 26 | node_data[i].rch_index = -1; 27 | } 28 | }); 29 | 30 | //init root node 31 | GHPair sum_gh = thrust::reduce(thrust::cuda::par, gradients.device_data(), gradients.device_end()); 32 | float_type lambda = param.lambda; 33 | device_loop<1, 1>(1, [=]__device__(int i) { 34 | Tree::TreeNode &root_node = node_data[0]; 35 | root_node.sum_gh_pair = sum_gh; 36 | root_node.is_valid = true; 37 | root_node.calc_weight(lambda); 38 | }); 39 | } 40 | 41 | string Tree::dump(int depth) const { 42 | string s("\n"); 43 | preorder_traversal(0, depth, 0, s); 44 | return s; 45 | } 46 | 47 | void Tree::preorder_traversal(int nid, int max_depth, int depth, string &s) const { 48 | if(nid == -1)//child of leaf node 49 | return; 50 | const TreeNode &node = nodes.host_data()[nid]; 51 | const TreeNode *node_data = nodes.host_data(); 52 | if (node.is_valid && !node.is_pruned) { 53 | s = s + string(static_cast(depth), '\t'); 54 | 55 | if(node.is_leaf){ 56 | s = s + string_format("%d:leaf=%.6g\n", node.final_id, node.base_weight); 57 | } 58 | else { 59 | int lch_final_id = node_data[node.lch_index].final_id; 60 | int rch_final_id = node_data[node.rch_index].final_id; 61 | string str_inter_node = string_format("%d:[f%d<%.6g] yes=%d,no=%d,missing=%d\n", node.final_id, 62 | node.split_feature_id + 1, 63 | node.split_value, lch_final_id, rch_final_id, 64 | node.default_right == 0 ? lch_final_id : rch_final_id); 65 | s = s + str_inter_node; 66 | } 67 | // string_format("%d:[f%d<%.6g], weight=%f, gain=%f, dr=%d\n", node.final_id, node.split_feature_id + 1, 68 | // node.split_value, 69 | // node.base_weight, node.gain, node.default_right)); 70 | } 71 | if (depth < max_depth) { 72 | preorder_traversal(node.lch_index, max_depth, depth + 1, s); 73 | preorder_traversal(node.rch_index, max_depth, depth + 1, s); 74 | } 75 | } 76 | 77 | std::ostream &operator<<(std::ostream &os, const Tree::TreeNode &node) { 78 | os << string_format("\nnid:%d,l:%d,v:%d,split_feature_id:%d,f:%f,gain:%f,r:%d,w:%f,", node.final_id, node.is_leaf, 79 | node.is_valid, 80 | node.split_feature_id, node.split_value, node.gain, node.default_right, node.base_weight); 81 | os << "g/h:" << node.sum_gh_pair; 82 | return os; 83 | } 84 | 85 | void Tree::reorder_nid() { 86 | int nid = 0; 87 | Tree::TreeNode *nodes_data = nodes.host_data(); 88 | for (int i = 0; i < nodes.size(); ++i) { 89 | if (nodes_data[i].is_valid && !nodes_data[i].is_pruned) { 90 | nodes_data[i].final_id = nid; 91 | nid++; 92 | } 93 | } 94 | } 95 | 96 | int Tree::try_prune_leaf(int nid, int np, float_type gamma, vector &leaf_child_count) { 97 | Tree::TreeNode *nodes_data = nodes.host_data(); 98 | int p_nid = nodes_data[nid].parent_index; 99 | if (p_nid == -1) return np;// is root 100 | Tree::TreeNode &p_node = nodes_data[p_nid]; 101 | Tree::TreeNode &lch = nodes_data[p_node.lch_index]; 102 | Tree::TreeNode &rch = nodes_data[p_node.rch_index]; 103 | leaf_child_count[p_nid]++; 104 | if (leaf_child_count[p_nid] >= 2 && p_node.gain < gamma) { 105 | //do pruning 106 | //delete two children 107 | CHECK(lch.is_leaf); 108 | CHECK(rch.is_leaf); 109 | lch.is_pruned = true; 110 | rch.is_pruned = true; 111 | //make parent to leaf 112 | p_node.is_leaf = true; 113 | return try_prune_leaf(p_nid, np + 2, gamma, leaf_child_count); 114 | } else return np; 115 | } 116 | 117 | void Tree::prune_self(float_type gamma) { 118 | vector leaf_child_count(nodes.size(), 0); 119 | Tree::TreeNode *nodes_data = nodes.host_data(); 120 | int n_pruned = 0; 121 | for (int i = 0; i < nodes.size(); ++i) { 122 | if (nodes_data[i].is_leaf && nodes_data[i].is_valid) { 123 | n_pruned = try_prune_leaf(i, n_pruned, gamma, leaf_child_count); 124 | } 125 | } 126 | LOG(DEBUG) << string_format("%d nodes are pruned", n_pruned); 127 | reorder_nid(); 128 | } 129 | -------------------------------------------------------------------------------- /src/thundergbm/util/common.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jiashuai on 18-1-16. 3 | // 4 | #include "thundergbm/common.h" 5 | INITIALIZE_EASYLOGGINGPP 6 | 7 | std::ostream &operator<<(std::ostream &os, const int_float &rhs) { 8 | os << string_format("%d/%f", thrust::get<0>(rhs), thrust::get<1>(rhs)); 9 | return os; 10 | } 11 | -------------------------------------------------------------------------------- /thundergbm-full.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/thundergbm-full.pdf --------------------------------------------------------------------------------