├── .gitignore
├── .gitmodules
├── CMakeLists.txt
├── LICENSE
├── R
├── README.md
└── gbm.R
├── README.md
├── dataset
├── machine.conf
├── test_dataset.txt
└── test_dataset.txt.group
├── docs
├── Makefile
├── README.md
├── _static
│ ├── default.css
│ ├── lang-logo-tgbm.png
│ ├── overall.png
│ └── tgbm-logo.png
├── conf.py
├── faq.md
├── how-to.md
├── index.md
├── make.bat
├── parameters.md
└── requirements.txt
├── include
└── thundergbm
│ ├── booster.h
│ ├── builder
│ ├── exact_tree_builder.h
│ ├── function_builder.h
│ ├── hist_tree_builder.h
│ ├── hist_tree_builder_single.h
│ ├── shard.h
│ └── tree_builder.h
│ ├── common.h
│ ├── config.h.in
│ ├── dataset.h
│ ├── hist_cut.h
│ ├── ins_stat.h
│ ├── metric
│ ├── metric.h
│ ├── multiclass_metric.h
│ ├── pointwise_metric.h
│ └── ranking_metric.h
│ ├── objective
│ ├── multiclass_obj.h
│ ├── objective_function.h
│ ├── ranking_obj.h
│ └── regression_obj.h
│ ├── parser.h
│ ├── predictor.h
│ ├── quantile_sketch.h
│ ├── row_sampler.h
│ ├── sparse_columns.h
│ ├── syncarray.h
│ ├── syncmem.h
│ ├── trainer.h
│ ├── tree.h
│ └── util
│ ├── cub_wrapper.h
│ ├── device_lambda.cuh
│ ├── log.h
│ └── multi_device.h
├── python
├── LICENSE
├── README.md
├── benchmarks
│ ├── README.md
│ ├── convert_dataset_plk.py
│ ├── experiments.py
│ ├── model
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── catboost_model.py
│ │ ├── datasets.py
│ │ ├── lightgbm_model.py
│ │ ├── thundergbm_model.py
│ │ └── xgboost_model.py
│ ├── run_exp.sh
│ └── utils
│ │ ├── __init__.py
│ │ ├── data_utils.py
│ │ └── file_utils.py
├── dist
│ ├── thundergbm-0.3.12-py2-none-win_amd64.whl
│ ├── thundergbm-0.3.12-py3-none-win_amd64.whl
│ ├── thundergbm-0.3.16-py3-none-any.whl
│ └── thundergbm-0.3.4-py3-none-win_amd64.whl
├── examples
│ ├── classification_demo.py
│ ├── ranking_demo.py
│ └── regression_demo.py
├── requirements.txt
├── setup.py
└── thundergbm
│ ├── __init__.py
│ └── thundergbm.py
├── src
├── test
│ ├── CMakeLists.txt
│ ├── test_csr2csc.cpp
│ ├── test_cub_wrapper.cu
│ ├── test_dataset.cpp
│ ├── test_for_refactor.cpp
│ ├── test_gbdt.cpp
│ ├── test_get_cut_point.cpp
│ ├── test_gradient.cu
│ ├── test_main.cpp
│ ├── test_metrics.cpp
│ ├── test_parser.cpp
│ ├── test_synarray.cpp
│ ├── test_synmem.cpp
│ └── test_tree.cpp
└── thundergbm
│ ├── CMakeLists.txt
│ ├── builder
│ ├── exact_tree_builder.cu
│ ├── function_builder.cu
│ ├── hist_tree_builder.cu
│ ├── hist_tree_builder_single.cu
│ ├── shard.cu
│ └── tree_builder.cu
│ ├── dataset.cpp
│ ├── gbm_R_interface.cpp
│ ├── hist_cut.cu
│ ├── metric
│ ├── metric.cu
│ ├── multiclass_metric.cu
│ ├── pointwise_metric.cu
│ └── rank_metric.cpp
│ ├── objective
│ ├── multiclass_obj.cu
│ ├── objective_function.cu
│ └── ranking_obj.cpp
│ ├── parser.cpp
│ ├── predictor.cu
│ ├── quantile_sketch.cpp
│ ├── row_sampler.cu
│ ├── scikit_tgbm.cpp
│ ├── sparse_columns.cu
│ ├── syncmem.cpp
│ ├── thundergbm_predict.cpp
│ ├── thundergbm_train.cpp
│ ├── trainer.cu
│ ├── tree.cu
│ └── util
│ ├── common.cpp
│ └── log.cpp
└── thundergbm-full.pdf
/.gitignore:
--------------------------------------------------------------------------------
1 | .project
2 | *logs*
3 | *.pyc
4 | *.model
5 | bin
6 | *.binary
7 | *.out
8 | *.idea
9 | tags
10 | .cproject
11 | .settings
12 | Debug
13 | Release
14 | dataset*
15 | *~
16 | core
17 | *.DS_Store
18 | *.o
19 | *.swp
20 | *build*
21 | !builder
22 | !*build*.cu
23 | !*build*.h
24 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "cub"]
2 | path = cub
3 | url = https://github.com/NVlabs/cub.git
4 | [submodule "src/test/googletest"]
5 | path = src/test/googletest
6 | url = https://github.com/google/googletest.git
7 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | if(MSVC)
2 | cmake_minimum_required(VERSION 3.4)
3 | else()
4 | cmake_minimum_required(VERSION 2.8)
5 | endif()
6 | project(thundergbm)
7 | if(MSVC)
8 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin)
9 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin)
10 | endif()
11 |
12 | find_package(CUDA REQUIRED QUIET)
13 | find_package(OpenMP REQUIRED QUIET)
14 |
15 | if (NOT CMAKE_BUILD_TYPE)
16 | set(CMAKE_BUILD_TYPE Release)
17 | endif ()
18 |
19 | # CUDA 11
20 | if(CUDA_VERSION VERSION_LESS "11.0")
21 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++11 -lineinfo --expt-extended-lambda --default-stream per-thread")
22 | if (CMAKE_VERSION VERSION_LESS "3.1")
23 | add_compile_options("-std=c++11")
24 | else ()
25 | set(CMAKE_CXX_STANDARD 11)
26 | endif ()
27 | else()
28 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++14 -lineinfo --expt-extended-lambda --default-stream per-thread")
29 | endif()
30 |
31 | if (OPENMP_FOUND)
32 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
33 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
34 | endif ()
35 |
36 | # for easylogging++ configuration
37 | add_definitions("-DELPP_FEATURE_PERFORMANCE_TRACKING")
38 | add_definitions("-DELPP_THREAD_SAFE")
39 | add_definitions("-DELPP_STL_LOGGING")
40 | add_definitions("-DELPP_NO_LOG_TO_FILE")
41 | # includes
42 | set(COMMON_INCLUDES ${PROJECT_SOURCE_DIR}/include ${CMAKE_CURRENT_BINARY_DIR} ${PROJECT_SOURCE_DIR}/cub)
43 | include_directories(${COMMON_INCLUDES})
44 |
45 | add_subdirectory(src/thundergbm)
46 |
47 | set(BUILD_TESTS OFF CACHE BOOL "Build Tests")
48 | if (BUILD_TESTS)
49 | add_subdirectory(src/test)
50 | endif ()
51 |
52 | # configuration file
53 | set(DATASET_DIR ${PROJECT_SOURCE_DIR}/datasets/)
54 | configure_file(include/thundergbm/config.h.in config.h)
55 |
56 |
--------------------------------------------------------------------------------
/R/README.md:
--------------------------------------------------------------------------------
1 | # R interface for thundergbm.
2 | Before you use the R interface, you must build ThunderGBM. Please refer to [Installation](https://thundergbm.readthedocs.io/en/latest/how-to.html) for building ThunderGBM.
3 |
4 | ## Methods
5 | By default, the directory for storing the training data and results is the working directory.
6 |
7 | *gbm_train_R(depth = 6, n_trees = 40, n_gpus = 1, verbose = 1,
8 | profiling = 0, data = 'None', max_num_bin = 255, column_sampling_rate = 1,
9 | bagging = 0, n_parallel_trees = 1, learning_rate = 1, objective = 'reg:linear',
10 | num_class = 1, min_child_weight = 1, lambda_tgbm = 1, gamma = 1,
11 | tree_method = 'auto', model_out = 'tgbm.model')*
12 |
13 | *gbm_predict_R(test_data = 'None', model_in = 'tgbm.model', verbose = 1)*
14 |
15 | ## Example
16 | * Step 1: go to the R interface.
17 | ```bash
18 | # in ThunderGBM root directory
19 | cd R
20 | ```
21 | * Step 2: start R in the terminal by typing ```R``` and press the Enter key.
22 | * Step 3: execute the following code in R.
23 | ```R
24 | source("gbm.R")
25 | gbm_train_R(data = "../dataset/test_dataset.txt")
26 | gbm_predict_R(test_data = "../dataset/test_dataset.txt")
27 | ```
28 |
29 | ## Parameters
30 | Please refer to [Parameters](https://github.com/Xtra-Computing/thundergbm/blob/master/docs/parameters.md).
--------------------------------------------------------------------------------
/R/gbm.R:
--------------------------------------------------------------------------------
1 | # Created by: Qinbin
2 | # Created on: 3/15/20
3 |
4 | check_location <- function(){
5 | if(Sys.info()['sysname'] == 'Windows'){
6 | if(!file.exists("../build/bin/Debug/thundergbm.dll")){
7 | stop("Please build the library first (or check you called this while your workspace is set to the thundergbm/R/ directory)!")
8 | }
9 | dyn.load("../build/bin/Debug/thundergbm.dll")
10 | } else if(Sys.info()['sysname'] == 'Linux'){
11 | if(!file.exists("../build/lib/libthundergbm.so")){
12 | stop("Please build the library first (or check you called this while your workspace is set to the thundergbm/R/ directory)!")
13 | }
14 | dyn.load("../build/lib/libthundergbm.so")
15 | } else if(Sys.info()['sysname'] == 'Darwin'){
16 | if(!file.exists("../build/lib/libthundergbm.dylib")){
17 | stop("Please build the library first (or check you called this while your workspace is set to the thundergbm/R/ directory)!")
18 | }
19 | dyn.load("../build/lib/libthundergbm.dylib")
20 | } else{
21 | stop("OS not supported!")
22 | }
23 | }
24 | check_location() # Run this when the file is sourced
25 |
26 | gbm_train_R <-
27 | function(
28 | depth = 6, n_trees = 40, n_gpus = 1, verbose = 1,
29 | profiling = 0, data = 'None', max_num_bin = 255, column_sampling_rate = 1,
30 | bagging = 0, n_parallel_trees = 1, learning_rate = 1, objective = 'reg:linear',
31 | num_class = 1, min_child_weight = 1, lambda_tgbm = 1, gamma = 1,
32 | tree_method = 'auto', model_out = 'tgbm.model'
33 | )
34 | {
35 | check_location()
36 | if(!file.exists(data)){stop("The file containing the training dataset provided as an argument in 'data' does not exist")}
37 | res <- .C("train_R", as.integer(depth), as.integer(n_trees), as.integer(n_gpus), as.integer(verbose),
38 | as.integer(profiling), as.character(data), as.integer(max_num_bin), as.double(column_sampling_rate),
39 | as.integer(bagging), as.integer(n_parallel_trees), as.double(learning_rate), as.character(objective),
40 | as.integer(num_class), as.integer(min_child_weight), as.double(lambda_tgbm), as.double(gamma),
41 | as.character(tree_method), as.character(model_out))
42 | }
43 |
44 | gbm_predict_R <-
45 | function(
46 | test_data = 'None', model_in = 'tgbm.model', verbose = 1
47 | )
48 | {
49 | check_location()
50 | if(!file.exists(test_data)){stop("The file containing the training dataset provided as an argument in 'data' does not exist")}
51 | if(!file.exists(model_in)){stop("The file containing the model provided as an argument in 'model_in' does not exist")}
52 | res <- .C("predict_R", as.character(test_data), as.character(model_in), as.integer(verbose))
53 | }
54 |
55 |
--------------------------------------------------------------------------------
/dataset/machine.conf:
--------------------------------------------------------------------------------
1 | data=../dataset/test_dataset.txt
--------------------------------------------------------------------------------
/dataset/test_dataset.txt.group:
--------------------------------------------------------------------------------
1 | 100
2 | 300
3 | 400
4 | 200
5 | 605
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = ThunderGBM
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | ## Documentation of ThunderGBM
2 | The documentations of ThunderGBM are written using Markdown, and generated by sphinx with recommonmark.
3 |
4 | ### Generate html files
5 |
6 | * go to this ```docs``` directory
7 | * type make html
8 |
--------------------------------------------------------------------------------
/docs/_static/default.css:
--------------------------------------------------------------------------------
1 | .section #basic-2-flip-flop-synchronizer{
2 | text-align:justify;
3 | }
4 |
--------------------------------------------------------------------------------
/docs/_static/lang-logo-tgbm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/docs/_static/lang-logo-tgbm.png
--------------------------------------------------------------------------------
/docs/_static/overall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/docs/_static/overall.png
--------------------------------------------------------------------------------
/docs/_static/tgbm-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/docs/_static/tgbm-logo.png
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # ThunderSVM documentation build configuration file, created by
4 | # sphinx-quickstart on Sat Oct 28 23:38:46 2017.
5 | #
6 | # This file is execfile()d with the current directory set to its
7 | # containing dir.
8 | #
9 | # Note that not all possible configuration values are present in this
10 | # autogenerated file.
11 | #
12 | # All configuration values have a default; values that are commented out
13 | # serve to show the default.
14 |
15 | # If extensions (or modules to document with autodoc) are in another directory,
16 | # add these directories to sys.path here. If the directory is relative to the
17 | # documentation root, use os.path.abspath to make it absolute, like shown here.
18 | #
19 | # sys.path.insert(0, os.path.abspath('.'))
20 |
21 | import recommonmark
22 | from recommonmark import transform
23 | AutoStructify = transform.AutoStructify
24 |
25 | # -- General configuration ------------------------------------------------
26 |
27 | # If your documentation needs a minimal Sphinx version, state it here.
28 | #
29 | # needs_sphinx = '1.0'
30 |
31 | # Add any Sphinx extension module names here, as strings. They can be
32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
33 | # ones.
34 | extensions = ['sphinx.ext.mathjax']
35 |
36 | # Add any paths that contain templates here, relative to this directory.
37 | templates_path = ['_templates']
38 |
39 | # The suffix(es) of source filenames.
40 | # You can specify multiple suffix as a list of string:
41 | #
42 |
43 | source_parsers = {
44 | '.md': 'recommonmark.parser.CommonMarkParser',
45 | }
46 |
47 | source_suffix = ['.rst', '.md']
48 |
49 | # The master toctree document.
50 | master_doc = 'index'
51 |
52 | # General information about the project.
53 | project = u'ThunderGBM'
54 | copyright = u'2019, ThunderGBM Developers'
55 | author = u'ThunderGBM Developers'
56 |
57 | # The version info for the project you're documenting, acts as replacement for
58 | # |version| and |release|, also used in various other places throughout the
59 | # built documents.
60 | #
61 | # The short X.Y version.
62 | version = u'0.1'
63 | # The full version, including alpha/beta/rc tags.
64 | release = u'0.1'
65 |
66 | # The language for content autogenerated by Sphinx. Refer to documentation
67 | # for a list of supported languages.
68 | #
69 | # This is also used if you do content translation via gettext catalogs.
70 | # Usually you set "language" from the command line for these cases.
71 | language = None
72 |
73 | # List of patterns, relative to source directory, that match files and
74 | # directories to ignore when looking for source files.
75 | # This patterns also effect to html_static_path and html_extra_path
76 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
77 |
78 | # The name of the Pygments (syntax highlighting) style to use.
79 | pygments_style = 'sphinx'
80 |
81 | # If true, `todo` and `todoList` produce output, else they produce nothing.
82 | todo_include_todos = False
83 |
84 |
85 | # -- Options for HTML output ----------------------------------------------
86 |
87 | # The theme to use for HTML and HTML Help pages. See the documentation for
88 | # a list of builtin themes.
89 | #
90 | html_theme = 'sphinx_rtd_theme'
91 |
92 | # Theme options are theme-specific and customize the look and feel of a theme
93 | # further. For a list of options available for each theme, see the
94 | # documentation.
95 | #
96 | # html_theme_options = {}
97 |
98 | # Add any paths that contain custom static files (such as style sheets) here,
99 | # relative to this directory. They are copied after the builtin static files,
100 | # so a file named "default.css" will overwrite the builtin "default.css".
101 | html_static_path = ['_static']
102 |
103 | # Custom sidebar templates, must be a dictionary that maps document names
104 | # to template names.
105 | #
106 | # This is required for the alabaster theme
107 | # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
108 | html_sidebars = {
109 | '**': [
110 | 'relations.html', # needs 'show_related': True theme option to display
111 | 'searchbox.html',
112 | ]
113 | }
114 |
115 |
116 | # -- Options for HTMLHelp output ------------------------------------------
117 |
118 | # Output file base name for HTML help builder.
119 | htmlhelp_basename = 'ThunderGBMdoc'
120 |
121 |
122 | # -- Options for LaTeX output ---------------------------------------------
123 |
124 | latex_elements = {
125 | # The paper size ('letterpaper' or 'a4paper').
126 | #
127 | # 'papersize': 'letterpaper',
128 |
129 | # The font size ('10pt', '11pt' or '12pt').
130 | #
131 | # 'pointsize': '10pt',
132 |
133 | # Additional stuff for the LaTeX preamble.
134 | #
135 | # 'preamble': '',
136 |
137 | # Latex figure (float) alignment
138 | #
139 | # 'figure_align': 'htbp',
140 | }
141 |
142 | # Grouping the document tree into LaTeX files. List of tuples
143 | # (source start file, target name, title,
144 | # author, documentclass [howto, manual, or own class]).
145 | latex_documents = [
146 | (master_doc, 'ThunderGBM.tex', u'ThunderGBM Documentation',
147 | u'Zeyi Wen', 'manual'),
148 | ]
149 |
150 |
151 | # -- Options for manual page output ---------------------------------------
152 |
153 | # One entry per manual page. List of tuples
154 | # (source start file, name, description, authors, manual section).
155 | man_pages = [
156 | (master_doc, 'thundergbm', u'ThunderGBM Documentation',
157 | [author], 1)
158 | ]
159 |
160 |
161 | # -- Options for Texinfo output -------------------------------------------
162 |
163 | # Grouping the document tree into Texinfo files. List of tuples
164 | # (source start file, target name, title, author,
165 | # dir menu entry, description, category)
166 | texinfo_documents = [
167 | (master_doc, 'ThunderGBM', u'ThunderGBM Documentation',
168 | author, 'ThunderGBM', 'One line description of project.',
169 | 'Miscellaneous'),
170 | ]
171 |
172 | github_doc_root = 'https://github.com/rtfd/recommonmark/tree/master/doc/'
173 | def setup(app):
174 | app.add_config_value('recommonmark_config', {
175 | 'url_resolver': lambda url: github_doc_root + url,
176 | 'enable_eval_rst': True,
177 | }, True)
178 | app.add_transform(AutoStructify)
179 |
180 |
--------------------------------------------------------------------------------
/docs/faq.md:
--------------------------------------------------------------------------------
1 | Frequently Asked Questions (FAQs)
2 | ======
3 | This page is dedicated to summarizing some frequently asked questions about ThunderGBM.
4 |
5 | ## FAQs of users
6 |
7 | * **What is the data format of the input file?**
8 | ThunderGBM uses the LibSVM format. You can also use ThunderGBM using the scikit-learn interface.
9 |
10 | * **Can ThunderGBM run on CPUs?**
11 | No. ThunderGBM is specifically optimized for GBDTs and Random Forests on GPUs.
12 |
13 | * **Can ThunderGBM run on multiple GPUs?**
14 | Yes. You can use the ``n_gpus`` options to specify how many GPUs you want to use. Please refer to [Parameters](parameters.md) for more information.
--------------------------------------------------------------------------------
/docs/how-to.md:
--------------------------------------------------------------------------------
1 | ThunderGBM How To
2 | ======
3 | This page is for key instructions of intalling, using and contributing to ThunderGBM. Everyone in the community can contribute to ThunderGBM to make it better.
4 |
5 | ## How to install ThunderGBM
6 | First of all, you need to install the prerequisite libraries and tools. Then you can download and install ThunderGBM.
7 | ### Prerequisites
8 | * cmake 2.8 or above
9 | * gcc 4.8 or above for Linux | [CUDA](https://developer.nvidia.com/cuda-downloads) 8 or above
10 | * Visual C++ for Windows | CUDA 10
11 |
12 |
13 | ### Download
14 | ```bash
15 | git clone https://github.com/zeyiwen/thundergbm.git
16 | cd thundergbm
17 | #under the directory of thundergbm
18 | git submodule init cub && git submodule update
19 | ```
20 | ### Build on Linux
21 | ```bash
22 | #under the directory of "thundergbm"
23 | mkdir build && cd build && cmake .. && make -j
24 | ```
25 |
26 | ### Quick Start
27 | ```bash
28 | ./bin/thundergbm-train ../dataset/machine.conf
29 | ./bin/thundergbm-predict ../dataset/machine.conf
30 | ```
31 | You will see `RMSE = 0.489562` after successful running.
32 |
33 | ### Build on Windows
34 | You can build the ThunderGBM library as follows:
35 | ```bash
36 | cd thundergbm
37 | mkdir build
38 | cd build
39 | cmake .. -DCMAKE_WINDOWS_EXPORT_ALL_SYMBOLS=TRUE -DBUILD_SHARED_LIBS=TRUE -G "Visual Studio 15 2017 Win64"
40 | ```
41 | You need to change the Visual Studio version if you are using a different version of Visual Studio. Visual Studio can be downloaded from [this link](https://www.visualstudio.com/vs/). The above commands generate some Visual Studio project files, open the Visual Studio project to build ThunderGBM. Please note that CMake should be 3.4 or above for Windows.
42 |
43 | ## How to use ThunderGBM using command line
44 | First of all, please refer to the above instruction for installing ThunderGBM. Then, you can run the demo by the following command.
45 | ```bash
46 | ./bin/thundergbm-train ../dataset/machine.conf
47 | ./bin/thundergbm-predict ../dataset/machine.conf
48 | ```
49 |
50 | If you like to know more about the detailed options of running the binary, please use the `-help` option as follows.
51 | ```bash
52 | ./bin/thundergbm-train -help
53 | ```
54 |
55 | In ThunderGBM, the command line options can be added in the `machine.conf` file under the `dataset` folder. All the options are listed in the [Parameters](parameters.md) page.
56 |
57 | ## How to improve documentations
58 | Most of the documents can be viewed on GitHub. The documents can also be viewed in Read the Doc. The HTML files of our documents are generated by [Sphinx](http://www.sphinx-doc.org/en/stable/), and the source files of the documents are written using [Markdown](http://commonmark.org/). In the following, we describe how to setup the Sphinx environment.
59 |
60 | * Install sphinx
61 | ```bash
62 | pip install sphinx
63 | ```
64 |
65 | * Install Makedown Parser
66 | ```bash
67 | pip install recommonmark
68 | ```
69 | Note that ```recommonmark``` has a bug when working with Sphinx in some platforms, so you may need to hack into transform.py to fix the problem by yourself. You can find the instruction of hacking in [this link](https://github.com/sphinx-doc/sphinx/issues/3800).
70 |
71 | * Install Sphinx theme
72 | ```bash
73 | pip install sphinx_rtd_theme
74 | ```
75 |
76 | * Generate HTML
77 |
78 | Go to the "docs" directory of ThunderGBM and run:
79 | ```bash
80 | make html
81 | ```
82 |
83 | At this point, make sure you have generated the documents of ThunderGBM. You can build the documents in your machine to see the outcome.
84 |
85 | ## Contribute to ThunderGBM
86 | You need to fetch the latest version of ThunderGBM before submitting a pull request.
87 | ```bash
88 | git remote add upstream https://github.com/Xtra-Computing/thundergbm.git
89 | git fetch upstream
90 | git rebase upstream/master
91 | ```
92 |
93 | ## How to build test for ThunderGBM
94 | For building test cases, you also need to obtain ``googletest`` using the following command.
95 | ```bash
96 | #under the thundergbm directory
97 | git submodule update --init src/test/googletest
98 | ```
99 | After obtaining the ``googletest`` submodule, you can build the test cases by the following commands.
100 | ```bash
101 | cd thundergbm
102 | mkdir build && cd build && cmake -DBUILD_TESTS=ON .. && make -j
103 | ```
104 |
105 | ## How to use ThunderGBM for ranking
106 |
107 | There are two key steps to use ThunderGBM for ranking.
108 | * First, you need to choose ``rank:pairwise`` or ``rank:ndcg`` to set the ``objective`` of ThunderGBM.
109 | * Second, you need to have a file called ``[train_file_name].group`` to specify the number of instances in each query.
110 |
111 | The remaining part is the same as classification and regression. Please refer to [Parameters](parameters.md) for more information about setting the parameters.
112 |
113 | ## How to build the Python wheel file for Linux
114 | You have to ensure the repository is identical to the latest one.
115 | * Clone ThunderGBM repository
116 | ```bash
117 | git clone https://github.com/zeyiwen/thundergbm.git
118 | cd thundergbm
119 | #under the directory of thundergbm
120 | git submodule init cub && git submodule update
121 | ```
122 | * Build the binary
123 | ```base
124 | mkdir build && cd build && cmake .. && make -j
125 | ```
126 | * Build the python wheel file
127 | - change directory to python by `cd ../python`
128 | - update the version you are going to release in [setup.py](https://github.com/Xtra-Computing/thundergbm/blob/c89d6da6008f945c09aae521c95cfe5b8bdd8db5/python/setup.py#L20)
129 | - you may need to install the `wheel` dependency by ``pip3 install wheel``
130 | ```bash
131 | python3 setup.py bdist_wheel
132 | ```
133 | ## How to build the Python wheel file for Windows
134 | You have to ensure the repository is identical to the latest one.
135 | * Requirements
136 | - Visual Studio
137 | - CUDA 10.0 or above
138 | - python3.x
139 | * Clone ThunderGBM repository
140 | ```bash
141 | git clone https://github.com/zeyiwen/thundergbm.git
142 | cd thundergbm
143 | #under the directory of thundergbm
144 | git submodule init && git submodule update
145 | ```
146 | * Cmake using Visual Studio Developer Command Prompt
147 | ```base
148 | mkdir build && cd build
149 | cmake .. -DCMAKE_WINDOWS_EXPORT_ALL_SYMBOLS=TRUE -DBUILD_SHARED_LIBS=TRUE -G "Visual Studio 15 2017 Win64"
150 | ```
151 | You may need to change the version of Visual Studio if you are using a different version of Visual Studio.
152 | * Build binary file using Visual Studio
153 | - Open the file in path 'thundergbm/build/thundergbm.sln' with Visual Studio
154 | - Click 'Build all' in Visual Studio
155 | * Build the python wheel file
156 | - change directory to python by `cd ../python`
157 | - update the version you are going to release in [setup.py](https://github.com/Xtra-Computing/thundergbm/blob/c89d6da6008f945c09aae521c95cfe5b8bdd8db5/python/setup.py#L20)
158 | - you may need to install the `wheel` dependency by ``pip3 install wheel``
159 | ```bash
160 | python3 setup.py bdist_wheel
161 | ```
162 | * Upload the wheel file to [Pypi.org](https://pypi.org)
163 | - you may need to install the `twine` dependency by ``pip3 install twine``
164 | - you need to use ``python3 -m twine upload dist/*`` if ``twine`` is not included in ``PATH``
165 | ```sybase
166 | twine upload dist/* --verbose
167 | ```
168 | * [Recommended] Draw a new release on [Release](https://github.com/Xtra-Computing/thundergbm/releases)
169 | * state the bug fixed or new functions.
170 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ThunderGBM: Fast GBDTs and Random Forests on GPUs
2 | ======================================
3 | ThunderGBM is dedicated to helping users apply GBDTs and Random Forests to solve problems efficiently and easily using GPUs. Key features of ThunderGBM are as follows.
4 | * Support regression, classification and ranking.
5 | * Use same command line options as XGBoost, and support Python (scikit-learn) interface.
6 | * Supported Operating Systems: Linux and Windows.
7 | * ThunderGBM is often 10 times faster than XGBoost, LightGBM and CatBoost. It has excellent performance on handling high dimensional and sparse problems.
8 |
9 |
14 |
15 | ## More information about using ThunderGBM
16 | * [Parameters](parameters.md) | [How To](how-to.md) | [FAQ](faq.md)
17 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=ThunderGBM
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
20 | echo.installed, then set the SPHINXBUILD environment variable to point
21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
22 | echo.may add the Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/parameters.md:
--------------------------------------------------------------------------------
1 | ThunderGBM Parameters
2 | =====================
3 | This page is for parameter specification in ThunderGBM. The parameters used in ThunderGBM are identical to XGBoost (except some newly introduced parameters), so existing XGBoost users can easily get used to ThunderGBM.
4 |
5 | ### Key arameters for both *python* and *c++* command line
6 | * ``verbose`` [default=1]
7 |
8 | - Printing information: 0 for silence, 1 for key information and 2 for more information.
9 |
10 | * ``depth`` [default=6]
11 |
12 | - The maximum depth of the decision trees. Shallow trees tend to have better generality, and deep trees are more likely to overfit the training data.
13 |
14 | * ``n_trees`` [default=40]
15 |
16 | - The number of training iterations. ``n_trees`` equals to the number of trees in GBDTs.
17 |
18 | * ``n_gpus`` [default=1]
19 |
20 | - The number of GPUs to be used in the training.
21 |
22 | * ``max_num_bin`` [default=255]
23 |
24 | - The maximum number of bins in a histogram.
25 |
26 | * ``column_sampling_rate`` [default=1]
27 |
28 | - The sampling ratio of subsampling columns (i.e., features)
29 |
30 | * ``bagging`` [default=0]
31 |
32 | - This option is for training random forests. Setting it to 1 to perform bagging.
33 |
34 | * ``n_parallel_trees`` [default=1]
35 |
36 | - This option is used for random forests to specify how many trees per iteration.
37 |
38 | * ``learning_rate`` [default=1, alias(only for c++): ``eta``]
39 |
40 | - valid domain: [0,1]. This option is to set the weight of newly trained tree. Use ``eta < 1`` to mitigate overfitting.
41 |
42 | * ``objective`` [default="reg:linear"]
43 |
44 | - valid options include ``reg:linear``, ``reg:logistic``, ``multi:softprob``, ``multi:softmax``, ``rank:pairwise`` and ``rank:ndcg``.
45 | - ``reg:linear`` is for regression, ``reg:logistic`` and ``binary:logistic`` are for binary classification.
46 | - ``multi:softprob`` and ``multi:softmax`` are for multi-class classification. ``multi:softprob`` outputs probability for each class, and ``multi:softmax`` outputs the label only.
47 | - ``rank:pairwise`` and ``rank:ndcg`` are for ranking problems.
48 |
49 | * ``num_class`` [default=1]
50 | - set the number of classes in the multi-class classification. This option is not compulsory.
51 |
52 | * ``min_child_weight`` [default=1]
53 |
54 | - The minimum sum of instance weight (measured by the second order derivative) needed in a child node.
55 |
56 | * ``lambda_tgbm`` [default=1, alias(only for c++): ``lambda`` or ``reg_lambda``]
57 |
58 | - L2 regularization term on weights.
59 |
60 | * ``gamma`` [default=1, alias(only for c++): ``min_split_loss``]
61 |
62 | - The minimum loss reduction required to make a further split on a leaf node of the tree. ``gamma`` is used in the pruning stage.
63 | * ``tree_method`` [default="auto"]
64 |
65 | - "auto": select the approach of finding best splits using the builtin heuristics.
66 |
67 | - "exact": find the best split using enumeration on all the possible feature values.
68 |
69 | - "hist": find the best split using histogram based approach.
70 |
71 | ### Parameters only for *c++* command line:
72 | * ``data`` [default="../dataset/test_dataset.txt"]
73 |
74 | - The path to the training data set
75 |
76 | * ``model_out`` [default="tgbm.model"]
77 |
78 | - The file name of the output model. This option is used in training.
79 |
80 | * ``model_in`` [default="tgbm.model"]
81 |
82 | - The file name of the input model. This option is used in prediction.
83 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx==1.5.6
2 | commonmark
3 | mock
4 |
5 |
--------------------------------------------------------------------------------
/include/thundergbm/booster.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 19-1-17.
3 | //
4 |
5 | #ifndef THUNDERGBM_BOOSTER_H
6 | #define THUNDERGBM_BOOSTER_H
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include "thundergbm/common.h"
13 | #include "syncarray.h"
14 | #include "tree.h"
15 | #include "row_sampler.h"
16 |
17 | std::mutex mtx;
18 |
19 | class Booster {
20 | public:
21 | void init(const DataSet &dataSet, GBMParam ¶m);
22 |
23 | void boost(vector> &boosted_model, int epoch,int total_epoch);
24 |
25 | private:
26 | MSyncArray gradients;
27 | std::unique_ptr obj;
28 | std::unique_ptr metric;
29 | MSyncArray y;
30 | std::unique_ptr fbuilder;
31 | RowSampler rowSampler;
32 | GBMParam param;
33 | int n_devices;
34 | };
35 |
36 | void Booster::init(const DataSet &dataSet, GBMParam ¶m) {
37 | int n_available_device;
38 | cudaGetDeviceCount(&n_available_device);
39 | CHECK_GE(n_available_device, param.n_device) << "only " << n_available_device
40 | << " GPUs available; please set correct number of GPUs to use";
41 | this->param = param;
42 | //fbuilder.reset(FunctionBuilder::create(param.tree_method));
43 | //if method is hist, and n_available_device is 1
44 | if(param.n_device==1 && param.tree_method == "hist"){
45 | fbuilder.reset(FunctionBuilder::create("hist_single"));
46 | }
47 | else{
48 | fbuilder.reset(FunctionBuilder::create(param.tree_method));
49 | }
50 |
51 | fbuilder->init(dataSet, param);
52 | obj.reset(ObjectiveFunction::create(param.objective));
53 | obj->configure(param, dataSet);
54 | metric.reset(Metric::create(obj->default_metric_name()));
55 | metric->configure(param, dataSet);
56 |
57 | n_devices = param.n_device;
58 | int n_outputs = param.num_class * dataSet.n_instances();
59 | gradients = MSyncArray(n_devices, n_outputs);
60 | y = MSyncArray(n_devices, dataSet.n_instances());
61 |
62 | DO_ON_MULTI_DEVICES(n_devices, [&](int device_id) {
63 | y[device_id].copy_from(dataSet.y.data(), dataSet.n_instances());
64 | });
65 |
66 | //init base score
67 | //only support histogram-based method and single device now
68 | //TODO support exact and multi-device
69 | if(param.n_device && param.tree_method == "hist"){
70 | DO_ON_MULTI_DEVICES(n_devices, [&](int device_id){
71 | param.base_score = obj->init_base_score(y[device_id], fbuilder->get_raw_y_predict()[device_id], gradients[device_id]);
72 | });
73 | }
74 | }
75 |
76 | void Booster::boost(vector> &boosted_model,int epoch,int total_epoch) {
77 | TIMED_FUNC(timerObj);
78 | std::unique_lock lock(mtx);
79 |
80 | //update gradients
81 | DO_ON_MULTI_DEVICES(n_devices, [&](int device_id) {
82 | obj->get_gradient(y[device_id], fbuilder->get_y_predict()[device_id], gradients[device_id]);
83 | });
84 | if (param.bagging) rowSampler.do_bagging(gradients);
85 | PERFORMANCE_CHECKPOINT(timerObj);
86 | //build new model/approximate function
87 | boosted_model.push_back(fbuilder->build_approximate(gradients));
88 |
89 | PERFORMANCE_CHECKPOINT(timerObj);
90 | //show metric on training set
91 | auto res = metric->get_score(fbuilder->get_y_predict().front());
92 | LOG(INFO) <<"["<get_name() << " = " <
9 | #include "thundergbm/common.h"
10 | #include "thundergbm/sparse_columns.h"
11 |
12 | class FunctionBuilder {
13 | public:
14 | virtual vector build_approximate(const MSyncArray &gradients) = 0;
15 |
16 | virtual void init(const DataSet &dataset, const GBMParam ¶m) {
17 | this->param = param;
18 | };
19 |
20 | virtual const MSyncArray &get_y_predict(){ return y_predict; };
21 | MSyncArray &get_raw_y_predict(){ return y_predict; };
22 |
23 | virtual ~FunctionBuilder(){};
24 |
25 | static FunctionBuilder *create(std::string name);
26 |
27 | protected:
28 | MSyncArray y_predict;
29 | GBMParam param;
30 | };
31 |
32 | #endif //THUNDERGBM_FUNCTION_BUILDER_H
33 |
--------------------------------------------------------------------------------
/include/thundergbm/builder/hist_tree_builder.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 19-1-17.
3 | //
4 |
5 | #ifndef THUNDERGBM_HIST_TREE_BUILDER_H
6 | #define THUNDERGBM_HIST_TREE_BUILDER_H
7 |
8 | #include
9 | #include "thundergbm/common.h"
10 | #include "shard.h"
11 | #include "tree_builder.h"
12 |
13 |
14 | class HistTreeBuilder : public TreeBuilder {
15 | public:
16 |
17 | void init(const DataSet &dataset, const GBMParam ¶m) override;
18 |
19 | void get_bin_ids();
20 |
21 | void find_split(int level, int device_id) override;
22 |
23 | virtual ~HistTreeBuilder(){};
24 |
25 | void update_ins2node_id() override;
26 |
27 |
28 | private:
29 | vector cut;
30 | // MSyncArray char_dense_bin_id;
31 | MSyncArray dense_bin_id;
32 | MSyncArray last_hist;
33 |
34 | double build_hist_used_time=0;
35 | int build_n_hist = 0;
36 | int total_hist_num = 0;
37 | double total_dp_time = 0;
38 | double total_copy_time = 0;
39 | };
40 |
41 |
42 | #endif //THUNDERGBM_HIST_TREE_BUILDER_H
43 |
--------------------------------------------------------------------------------
/include/thundergbm/builder/hist_tree_builder_single.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 19-1-17.
3 | //
4 |
5 | #ifndef THUNDERGBM_HIST_TREE_BUILDER_H_single
6 | #define THUNDERGBM_HIST_TREE_BUILDER_H_single
7 |
8 | #include
9 | #include "thundergbm/common.h"
10 | #include "shard.h"
11 | #include "tree_builder.h"
12 |
13 |
14 | class HistTreeBuilder_single : public TreeBuilder {
15 | public:
16 |
17 | void init(const DataSet &dataset, const GBMParam ¶m) override;
18 |
19 | void get_bin_ids();
20 |
21 | void find_split(int level, int device_id) override;
22 |
23 | virtual ~HistTreeBuilder_single(){};
24 |
25 | void update_ins2node_id() override;
26 |
27 | void update_tree() override;
28 |
29 |
30 | private:
31 | vector cut;
32 | // MSyncArray char_dense_bin_id;
33 | MSyncArray dense_bin_id;
34 | MSyncArray last_hist;
35 |
36 | //store csr dense_bin_id
37 | MSyncArray csr_bin_id;
38 | MSyncArray bin_id_origin;
39 | MSyncArray csr_row_ptr;
40 | MSyncArray csr_col_idx;
41 |
42 | double build_hist_used_time=0;
43 | int build_n_hist = 0;
44 | int total_hist_num = 0;
45 | double total_dp_time = 0;
46 | double total_copy_time = 0;
47 | bool use_gpu = 1;
48 | };
49 |
50 |
51 | #endif //THUNDERGBM_HIST_TREE_BUILDER_H
52 |
--------------------------------------------------------------------------------
/include/thundergbm/builder/shard.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 19-1-2.
3 | //
4 |
5 | #ifndef THUNDERGBM_SHARD_H
6 | #define THUNDERGBM_SHARD_H
7 |
8 |
9 | #include "thundergbm/sparse_columns.h"
10 | #include "thundergbm/tree.h"
11 |
12 | class SplitPoint;
13 |
14 | struct Shard {
15 | SparseColumns columns;//a subset of columns (or features)
16 | SyncArray ignored_set;//for column sampling
17 |
18 | void column_sampling(float rate);
19 | };
20 |
21 |
22 |
23 | class SplitPoint {
24 | public:
25 | float_type gain;
26 | GHPair fea_missing_gh;//missing gh in this segment
27 | GHPair rch_sum_gh;//right child total gh (missing gh included if default2right)
28 | GHPair lch_sum_gh;//light child total gh (missing gh included if default2light)
29 | bool default_right;
30 | int nid;
31 |
32 | //split condition
33 | int split_fea_id;
34 | float_type fval;//split on this feature value (for exact)
35 | unsigned char split_bid;//split on this bin id (for hist)
36 |
37 | SplitPoint() {
38 | nid = -1;
39 | split_fea_id = -1;
40 | gain = 0;
41 | default_right = true;
42 | }
43 |
44 | friend std::ostream &operator<<(std::ostream &output, const SplitPoint &sp) {
45 | output << sp.gain << "/" << sp.split_fea_id << "/" << sp.nid << "/" << sp.lch_sum_gh;
46 | return output;
47 | }
48 | };
49 |
50 | #endif //THUNDERGBM_SHARD_H
51 |
--------------------------------------------------------------------------------
/include/thundergbm/builder/tree_builder.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by jiashuai on 19-1-23.
3 | //
4 |
5 | #ifndef THUNDERGBM_TREEBUILDER_H
6 | #define THUNDERGBM_TREEBUILDER_H
7 |
8 | #include
9 | #include "thundergbm/common.h"
10 | #include "shard.h"
11 | #include "function_builder.h"
12 |
13 |
14 | class TreeBuilder : public FunctionBuilder {
15 | public:
16 | virtual void find_split(int level, int device_id) = 0;
17 |
18 | virtual void update_ins2node_id() = 0;
19 |
20 | vector build_approximate(const MSyncArray &gradients) override;
21 |
22 | void init(const DataSet &dataset, const GBMParam ¶m) override;
23 |
24 | virtual void update_tree();
25 |
26 | void predict_in_training(int k);
27 |
28 | virtual void split_point_all_reduce(int depth);
29 |
30 | virtual void ins2node_id_all_reduce(int depth);
31 |
32 | virtual ~TreeBuilder(){};
33 |
34 | protected:
35 | vector shards;
36 | int n_instances;
37 | vector trees;
38 | MSyncArray ins2node_id;
39 | MSyncArray sp;
40 | MSyncArray gradients;
41 | vector has_split;
42 | };
43 |
44 |
45 | #endif //THUNDERGBM_TREEBUILDER_H
46 |
--------------------------------------------------------------------------------
/include/thundergbm/common.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by jiashuai on 18-1-16.
3 | //
4 |
5 | #ifndef THUNDERGBM_COMMON_H
6 | #define THUNDERGBM_COMMON_H
7 |
8 | #include "thundergbm/util/log.h"
9 | #include "cuda_runtime_api.h"
10 | #include "cstdlib"
11 | #include "config.h"
12 | #include "thrust/tuple.h"
13 |
14 | using std::vector;
15 | using std::string;
16 |
17 | //CUDA macro
18 | #define USE_CUDA
19 | #define NO_GPU \
20 | LOG(FATAL)<<"Cannot use GPU when compiling without GPU"
21 | #define CUDA_CHECK(condition) \
22 | /* Code block avoids redefinition of cudaError_t error */ \
23 | do { \
24 | cudaError_t error = condition; \
25 | CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
26 | } while (false)
27 |
28 | //https://stackoverflow.com/questions/2342162/stdstring-formatting-like-sprintf
29 | template
30 | std::string string_format(const std::string &format, Args ... args) {
31 | size_t size = snprintf(nullptr, 0, format.c_str(), args ...) + 1; // Extra space for '\0'
32 | std::unique_ptr buf(new char[size]);
33 | snprintf(buf.get(), size, format.c_str(), args ...);
34 | return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside
35 | }
36 |
37 | //data types
38 | typedef float float_type;
39 | //typedef double float_type;
40 |
41 | #define HOST_DEVICE __host__ __device__
42 |
43 | struct GHPair {
44 | float_type g;
45 | float_type h;
46 |
47 | HOST_DEVICE GHPair operator+(const GHPair &rhs) const {
48 | GHPair res;
49 | res.g = this->g + rhs.g;
50 | res.h = this->h + rhs.h;
51 | return res;
52 | }
53 |
54 | HOST_DEVICE const GHPair operator-(const GHPair &rhs) const {
55 | GHPair res;
56 | res.g = this->g - rhs.g;
57 | res.h = this->h - rhs.h;
58 | return res;
59 | }
60 |
61 | HOST_DEVICE bool operator==(const GHPair &rhs) const {
62 | return this->g == rhs.g && this->h == rhs.h;
63 | }
64 |
65 | HOST_DEVICE bool operator!=(const GHPair &rhs) const {
66 | return !(*this == rhs);
67 | }
68 |
69 | HOST_DEVICE GHPair() : g(0), h(0) {};
70 |
71 | HOST_DEVICE GHPair(float_type v) : g(v), h(v) {};
72 |
73 | HOST_DEVICE GHPair(float_type g, float_type h) : g(g), h(h) {};
74 |
75 | friend std::ostream &operator<<(std::ostream &os,
76 | const GHPair &p) {
77 | os << string_format("%f/%f", p.g, p.h);
78 | return os;
79 | }
80 | };
81 |
82 | typedef thrust::tuple int_float;
83 |
84 | std::ostream &operator<<(std::ostream &os, const int_float &rhs);
85 |
86 | struct GBMParam {
87 | int depth;
88 | int n_trees;
89 | float_type min_child_weight;
90 | float_type lambda;
91 | float_type gamma;
92 | float_type rt_eps;
93 | float column_sampling_rate;
94 | std::string path;
95 | int verbose;
96 | bool profiling;
97 | bool bagging;
98 | int n_parallel_trees;
99 | float learning_rate;
100 | std::string objective;
101 | int num_class;
102 | int tree_per_rounds; // #tree of each round, depends on #class
103 |
104 | //for histogram
105 | int max_num_bin;
106 | float base_score=0;
107 |
108 | int n_device;
109 |
110 | std::string tree_method;
111 | };
112 | #endif //THUNDERGBM_COMMON_H
113 |
--------------------------------------------------------------------------------
/include/thundergbm/config.h.in:
--------------------------------------------------------------------------------
1 | #cmakedefine DATASET_DIR "@DATASET_DIR@"
2 |
--------------------------------------------------------------------------------
/include/thundergbm/dataset.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by jiashuai on 18-1-17.
3 | //
4 |
5 | #ifndef THUNDERGBM_DATASET_H
6 | #define THUNDERGBM_DATASET_H
7 |
8 | #include "common.h"
9 | #include "syncarray.h"
10 |
11 | class DataSet {
12 | public:
13 | ///load dataset from file
14 | void load_from_sparse(int n_instances, float *csr_val, int *csr_row_ptr, int *csr_col_idx, float *y,
15 | int *group, int num_group, GBMParam ¶m);
16 | void load_from_file(string file_name, GBMParam ¶m);
17 | void load_csc_from_file(string file_name, GBMParam ¶m, int const nfeatures=500);
18 | void load_group_file(string file_name);
19 | void group_label();
20 |
21 | size_t n_features() const;
22 |
23 | size_t n_instances() const;
24 |
25 | vector csr_val;
26 | vector csr_row_ptr;
27 | vector csr_col_idx;
28 | vector y;
29 | size_t n_features_;
30 | vector group;
31 | vector label;
32 |
33 |
34 | // csc variables
35 | vector csc_val;
36 | vector csc_row_idx;
37 | vector csc_col_ptr;
38 |
39 | // whether the dataset is to big
40 | int use_cpu = false;
41 | };
42 |
43 | #endif //THUNDERGBM_DATASET_H
44 |
--------------------------------------------------------------------------------
/include/thundergbm/hist_cut.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by qinbin on 2018/5/9.
3 | //
4 |
5 | #ifndef THUNDERGBM_HIST_CUT_H
6 | #define THUNDERGBM_HIST_CUT_H
7 |
8 | #include "common.h"
9 | #include "sparse_columns.h"
10 | #include "thundergbm/dataset.h"
11 | #include "thundergbm/tree.h"
12 | #include "ins_stat.h"
13 |
14 | class HistCut {
15 | public:
16 | //split_point[i] stores the split points of feature i
17 | //std::vector> split_points;
18 | vector cut_points;
19 | vector row_ptr;
20 | //for gpu
21 | SyncArray cut_points_val;
22 | SyncArray cut_row_ptr;
23 | SyncArray cut_fid;
24 |
25 | HistCut() = default;
26 |
27 | HistCut(const HistCut &cut) {
28 | cut_points = cut.cut_points;
29 | row_ptr = cut.row_ptr;
30 | cut_points_val.copy_from(cut.cut_points_val);
31 | cut_row_ptr.copy_from(cut.cut_row_ptr);
32 | }
33 |
34 | void get_cut_points2(SparseColumns &columns, int max_num_bins, int n_instances);
35 | void get_cut_points3(SparseColumns &columns, int max_num_bins, int n_instances);
36 |
37 | //for hist on single device
38 | void get_cut_points_single(SparseColumns &columns, int max_num_bins, int n_instances);
39 | };
40 |
41 | #endif //THUNDERGBM_HIST_CUT_H
42 |
--------------------------------------------------------------------------------
/include/thundergbm/ins_stat.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by shijiashuai on 5/7/18.
3 | //
4 |
5 | #ifndef THUNDERGBM_INS_STAT_H
6 | #define THUNDERGBM_INS_STAT_H
7 |
8 | #include "syncarray.h"
9 | #include "thundergbm/objective/objective_function.h"
10 |
11 | class InsStat {
12 | public:
13 |
14 | ///gradient and hessian
15 | SyncArray gh_pair;
16 |
17 | ///backup for bagging
18 | SyncArray gh_pair_backup;
19 | ///node id
20 | SyncArray nid;
21 | ///target value
22 | SyncArray y;
23 | ///predict value
24 | SyncArray y_predict;
25 |
26 | std::unique_ptr obj;
27 |
28 | int n_instances;
29 |
30 | InsStat() = default;
31 |
32 | explicit InsStat(size_t n_instances) {
33 | resize(n_instances);
34 | }
35 |
36 | void resize(size_t n_instances);
37 |
38 | void update_gradient();
39 |
40 | void reset_nid();
41 |
42 | void do_bagging();
43 | };
44 |
45 | #endif //THUNDERGBM_INS_STAT_H
46 |
--------------------------------------------------------------------------------
/include/thundergbm/metric/metric.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 19-1-13.
3 | //
4 |
5 | #ifndef THUNDERGBM_METRIC_H
6 | #define THUNDERGBM_METRIC_H
7 |
8 | #include
9 | #include
10 |
11 | class Metric {
12 | public:
13 | virtual float_type get_score(const SyncArray &y_p) const = 0;
14 |
15 | virtual void configure(const GBMParam ¶m, const DataSet &dataset);
16 |
17 | static Metric *create(string name);
18 |
19 | virtual string get_name() const = 0;
20 |
21 | protected:
22 | SyncArray y;
23 | };
24 |
25 |
26 | #endif //THUNDERGBM_METRIC_H
27 |
--------------------------------------------------------------------------------
/include/thundergbm/metric/multiclass_metric.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 19-1-15.
3 | //
4 |
5 | #ifndef THUNDERGBM_MULTICLASS_METRIC_H
6 | #define THUNDERGBM_MULTICLASS_METRIC_H
7 |
8 | #include "thundergbm/common.h"
9 | #include "metric.h"
10 |
11 | class MulticlassMetric: public Metric {
12 | public:
13 | void configure(const GBMParam ¶m, const DataSet &dataset) override {
14 | Metric::configure(param, dataset);
15 | num_class = param.num_class;
16 | CHECK_EQ(num_class, dataset.label.size());
17 | label.resize(num_class);
18 | label.copy_from(dataset.label.data(), num_class);
19 | }
20 |
21 | protected:
22 | int num_class;
23 | SyncArray label;
24 | };
25 |
26 | class MulticlassAccuracy: public MulticlassMetric {
27 | public:
28 | float_type get_score(const SyncArray &y_p) const override;
29 |
30 | string get_name() const override { return "multi-class accuracy"; }
31 | };
32 |
33 | class BinaryClassMetric: public MulticlassAccuracy{
34 | public:
35 | float_type get_score(const SyncArray &y_p) const override;
36 | string get_name() const override { return "test error";}
37 | };
38 |
39 |
40 | #endif //THUNDERGBM_MULTICLASS_METRIC_H
41 |
--------------------------------------------------------------------------------
/include/thundergbm/metric/pointwise_metric.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 19-1-13.
3 | //
4 |
5 | #ifndef THUNDERGBM_POINTWISE_METRIC_H
6 | #define THUNDERGBM_POINTWISE_METRIC_H
7 |
8 | #include "metric.h"
9 |
10 | class RMSE : public Metric {
11 | public:
12 | float_type get_score(const SyncArray &y_p) const override;
13 |
14 | string get_name() const override { return "RMSE"; }
15 | };
16 |
17 | #endif //THUNDERGBM_POINTWISE_METRIC_H
18 |
--------------------------------------------------------------------------------
/include/thundergbm/metric/ranking_metric.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 19-1-13.
3 | //
4 |
5 | #ifndef THUNDERGBM_RANKING_METRIC_H
6 | #define THUNDERGBM_RANKING_METRIC_H
7 |
8 | #include "metric.h"
9 |
10 | class RankListMetric : public Metric {
11 | public:
12 | float_type get_score(const SyncArray &y_p) const override;
13 |
14 | void configure(const GBMParam ¶m, const DataSet &dataset) override;
15 |
16 | static void configure_gptr(const vector &group, vector &gptr);
17 |
18 | protected:
19 | virtual float_type eval_query_group(vector &y, vector &y_p, int group_id) const = 0;
20 |
21 | vector gptr;
22 | int n_group;
23 | int topn;
24 | };
25 |
26 |
27 | class MAP : public RankListMetric {
28 | public:
29 | string get_name() const override { return "MAP"; }
30 |
31 | protected:
32 | float_type eval_query_group(vector &y, vector &y_p, int group_id) const override;
33 | };
34 |
35 | class NDCG : public RankListMetric {
36 | public:
37 | string get_name() const override { return "NDCG"; };
38 |
39 | void configure(const GBMParam ¶m, const DataSet &dataset) override;
40 |
41 | inline HOST_DEVICE static float_type discounted_gain(int label, int rank) {
42 | return ((1 << label) - 1) / log2f(rank + 1 + 1);
43 | }
44 |
45 | static void get_IDCG(const vector &gptr, const vector &y, vector &idcg);
46 |
47 | protected:
48 | float_type eval_query_group(vector &y, vector &y_p, int group_id) const override;
49 |
50 | private:
51 | vector idcg;
52 | };
53 |
54 |
55 | #endif //THUNDERGBM_RANKING_METRIC_H
56 |
--------------------------------------------------------------------------------
/include/thundergbm/objective/multiclass_obj.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 19-1-3.
3 | //
4 |
5 | #ifndef THUNDERGBM_MULTICLASS_OBJ_H
6 | #define THUNDERGBM_MULTICLASS_OBJ_H
7 |
8 | #include "objective_function.h"
9 | #include "thundergbm/util/device_lambda.cuh"
10 |
11 | class Softmax : public ObjectiveFunction {
12 | public:
13 | void get_gradient(const SyncArray &y, const SyncArray &y_p,
14 | SyncArray &gh_pair) override;
15 |
16 | void predict_transform(SyncArray &y) override;
17 |
18 | void configure(GBMParam param, const DataSet &dataset) override;
19 |
20 | string default_metric_name() override { return "macc"; }
21 |
22 | virtual ~Softmax() override = default;
23 |
24 | protected:
25 | int num_class;
26 | SyncArray label;
27 | };
28 |
29 |
30 | class SoftmaxProb : public Softmax {
31 | public:
32 | void predict_transform(SyncArray &y) override;
33 |
34 | ~SoftmaxProb() override = default;
35 |
36 | };
37 |
38 |
39 |
40 | #endif //THUNDERGBM_MULTICLASS_OBJ_H
41 |
--------------------------------------------------------------------------------
/include/thundergbm/objective/objective_function.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 19-1-1.
3 | //
4 |
5 | #ifndef THUNDERGBM_OBJECTIVE_FUNCTION_H
6 | #define THUNDERGBM_OBJECTIVE_FUNCTION_H
7 |
8 | #include
9 | #include
10 |
11 | class ObjectiveFunction {
12 | public:
13 | //todo different target type
14 | virtual void
15 | get_gradient(const SyncArray &y, const SyncArray &y_p, SyncArray &gh_pair) = 0;
16 | virtual void
17 | predict_transform(SyncArray &y){};
18 |
19 | //init base score
20 | //SyncArray &gh_pair for tmporary storage
21 | virtual float init_base_score(const SyncArray &y,SyncArray &y_p, SyncArray &gh_pair){
22 | LOG(INFO)<<"not implement base score function!!!!!!";
23 | return 0;
24 | }
25 |
26 | virtual void configure(GBMParam param, const DataSet &dataset) = 0;
27 | virtual string default_metric_name() = 0;
28 |
29 | static ObjectiveFunction* create(string name);
30 |
31 | //a file containing the number of instances per query; similar to XGBoost
32 | static bool need_load_group_file(string name);
33 | static bool need_group_label(string name);
34 | virtual ~ObjectiveFunction() = default;
35 | };
36 |
37 | #endif //THUNDERGBM_OBJECTIVE_FUNCTION_H
38 |
--------------------------------------------------------------------------------
/include/thundergbm/objective/ranking_obj.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 19-1-10.
3 | //
4 |
5 | #ifndef THUNDERGBM_RANKING_OBJ_H
6 | #define THUNDERGBM_RANKING_OBJ_H
7 |
8 | #include "objective_function.h"
9 |
10 | /**
11 | *
12 | * https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/MSR-TR-2010-82.pdf
13 | */
14 | class LambdaRank : public ObjectiveFunction {
15 | public:
16 | void get_gradient(const SyncArray &y, const SyncArray &y_p,
17 | SyncArray &gh_pair) override;
18 |
19 | void configure(GBMParam param, const DataSet &dataset) override;
20 |
21 | string default_metric_name() override;
22 |
23 | virtual ~LambdaRank() override = default;
24 |
25 | protected:
26 | virtual inline float_type get_delta_z(float_type labelI, float_type labelJ, int rankI, int rankJ, int group_id) { return 1; };
27 |
28 | vector gptr;//group start position
29 | int n_group;
30 |
31 | float_type sigma;
32 | };
33 |
34 | class LambdaRankNDCG : public LambdaRank {
35 | public:
36 | void configure(GBMParam param, const DataSet &dataset) override;
37 |
38 | string default_metric_name() override;
39 |
40 | protected:
41 | float_type get_delta_z(float_type labelI, float_type labelJ, int rankI, int rankJ, int group_id) override;
42 |
43 | private:
44 | vector idcg;
45 | };
46 |
47 |
48 | #endif //THUNDERGBM_RANKING_OBJ_H
49 |
--------------------------------------------------------------------------------
/include/thundergbm/objective/regression_obj.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 19-1-1.
3 | //
4 |
5 | #ifndef THUNDERGBM_REGRESSION_OBJ_H
6 | #define THUNDERGBM_REGRESSION_OBJ_H
7 |
8 | #include "objective_function.h"
9 | #include "thundergbm/util/device_lambda.cuh"
10 | #include "thrust/reduce.h"
11 |
12 | template class Loss>
13 | class RegressionObj : public ObjectiveFunction {
14 | public:
15 | void get_gradient(const SyncArray &y, const SyncArray &y_p,
16 | SyncArray &gh_pair) override {
17 | CHECK_EQ(y.size(), y_p.size())<::gradient(y_data[i], y_p_data[i]);
24 | });
25 | }
26 |
27 | void predict_transform(SyncArray &y) override {
28 | auto y_data = y.device_data();
29 | device_loop(y.size(), [=]__device__(int i) {
30 | y_data[i] = Loss::predict_transform(y_data[i]);
31 | });
32 | }
33 |
34 | //base score
35 | float init_base_score(const SyncArray &y,SyncArray &y_p, SyncArray &gh_pair){
36 |
37 | //get gradients first, SyncArray &gh_pair for temporal storage
38 | get_gradient(y,y_p,gh_pair);
39 |
40 | //get sum gh_pair
41 | GHPair sum_gh = thrust::reduce(thrust::cuda::par, gh_pair.device_data(), gh_pair.device_end());
42 |
43 | //get weight
44 | float weight = -sum_gh.g / fmax(sum_gh.h, (double)(1e-6));
45 |
46 | float base_score = weight;
47 | LOG(INFO)<<"base_score "< class Loss>
65 | class LogClsObj: public RegressionObj{
66 | public:
67 | void get_gradient(const SyncArray &y, const SyncArray &y_p,
68 | SyncArray &gh_pair) override {
69 | auto y_data = y.device_data();
70 | auto y_p_data = y_p.device_data();
71 | auto gh_pair_data = gh_pair.device_data();
72 | device_loop(y.size(), [=]__device__(int i) {
73 | gh_pair_data[i] = Loss::gradient(y_data[i], y_p_data[i]);
74 | });
75 | }
76 | void predict_transform(SyncArray &y) {
77 | //this method transform y(#class * #instances) into y(#instances)
78 | auto yp_data = y.device_data();
79 | auto label_data = label.device_data();
80 | // int num_class = this->num_class;
81 | int n_instances = y.size();
82 | device_loop(n_instances, [=]__device__(int i) {
83 | //yp_data[i] = Loss::predict_transform(yp_data[i]);
84 | int max_k = (yp_data[i] > 0) ? 1 : 0;
85 | yp_data[i] = label_data[max_k];
86 | });
87 | //TODO not to make a temp_y?
88 | SyncArray < float_type > temp_y(n_instances);
89 | temp_y.copy_from(y.device_data(), n_instances);
90 | y.resize(n_instances);
91 | y.copy_from(temp_y);
92 | }
93 |
94 | //base score
95 | float init_base_score(const SyncArray &y,SyncArray &y_p, SyncArray &gh_pair){
96 |
97 | //get gradients first, SyncArray &gh_pair for temporal storage
98 | get_gradient(y,y_p,gh_pair);
99 |
100 | //get sum gh_pair
101 | GHPair sum_gh = thrust::reduce(thrust::cuda::par, gh_pair.device_data(), gh_pair.device_end());
102 |
103 | //get weight
104 | float weight = -sum_gh.g / fmax(sum_gh.h, (double)(1e-6));
105 | //sigmod transform
106 | weight = 1 / (1 + expf(-weight));
107 | float base_score = -logf(1.0f / weight - 1.0f);
108 | LOG(INFO)<<"base_score "< label;
127 | };
128 |
129 | template
130 | struct SquareLoss {
131 | HOST_DEVICE static GHPair gradient(T y, T y_p) { return GHPair(y_p - y, 1); }
132 |
133 | HOST_DEVICE static T predict_transform(T x) { return x; }
134 | };
135 |
136 | //for probability regression
137 | template
138 | struct LogisticLoss {
139 | HOST_DEVICE static GHPair gradient(T y, T y_p);
140 |
141 | HOST_DEVICE static T predict_transform(T x);
142 | };
143 |
144 | template<>
145 | struct LogisticLoss {
146 | HOST_DEVICE static GHPair gradient(float y, float y_p) {
147 | float p = sigmoid(y_p);
148 | return GHPair(p - y, fmaxf(p * (1 - p), 1e-16f));
149 | }
150 |
151 | HOST_DEVICE static float predict_transform(float y) { return sigmoid(y); }
152 |
153 | HOST_DEVICE static float sigmoid(float x) {return 1 / (1 + expf(-x));}
154 | };
155 |
156 | template<>
157 | struct LogisticLoss {
158 | HOST_DEVICE static GHPair gradient(double y, double y_p) {
159 | double p = sigmoid(y_p);
160 | return GHPair(p - y, fmax(p * (1 - p), 1e-16));
161 | }
162 |
163 | HOST_DEVICE static double predict_transform(double x) { return 1 / (1 + exp(-x)); }
164 |
165 | HOST_DEVICE static double sigmoid(double x) {return 1 / (1 + exp(-x));}
166 | };
167 |
168 | #endif //THUNDERGBM_REGRESSION_OBJ_H
169 |
--------------------------------------------------------------------------------
/include/thundergbm/parser.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by zeyi on 1/10/19.
3 | //
4 |
5 | #ifndef THUNDERGBM_PARAM_PARSER_H
6 | #define THUNDERGBM_PARAM_PARSER_H
7 |
8 | #include "tree.h"
9 | #include "dataset.h"
10 |
11 | class Parser{
12 | public:
13 | void parse_param(GBMParam &model_param, int argc, char **argv);
14 | void load_model(string model_path, GBMParam &model_param, vector> &boosted_model, DataSet &dataSet);
15 | void save_model(string model_path, GBMParam &model_param, vector> &boosted_model, DataSet &dataSet);
16 | };
17 |
18 | #endif //THUNDERGBM_PARAM_PARSER_H
19 |
--------------------------------------------------------------------------------
/include/thundergbm/predictor.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by zeyi on 1/12/19.
3 | //
4 |
5 | #ifndef THUNDERGBM_PREDICTOR_H
6 | #define THUNDERGBM_PREDICTOR_H
7 |
8 | #include "thundergbm/tree.h"
9 | #include
10 |
11 | class Predictor{
12 | public:
13 | vector predict(const GBMParam &model_param, const vector> &boosted_model, const DataSet &dataSet);
14 | void predict_raw(const GBMParam &model_param, const vector> &boosted_model,
15 | const DataSet &dataSet, SyncArray &y_predict);
16 | };
17 |
18 | #endif //THUNDERGBM_PREDICTOR_H
19 |
--------------------------------------------------------------------------------
/include/thundergbm/quantile_sketch.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by qinbin on 2018/5/9.
3 | //
4 | #ifndef THUNDERGBM_QUANTILE_SKETCH_H
5 | #define THUNDERGBM_QUANTILE_SKETCH_H
6 |
7 | #include "common.h"
8 | #include
9 | #include
10 |
11 | using std::pair;
12 | using std::tuple;
13 | using std::vector;
14 |
15 |
16 | class entry{
17 | public:
18 | float_type val;//a cut point candidate
19 | float_type rmin;//total weights of feature values less than val
20 | float_type rmax;//total weights of feature values less than or equal to val
21 | float_type w;
22 | entry() {};
23 | entry(float_type val, float_type rmin, float_type rmax, float_type w) : val(val), rmin(rmin), rmax(rmax), w(w) {};
24 | };
25 |
26 | class summary{
27 | public:
28 | int entry_size;
29 | int entry_reserve_size;
30 | vector entries;
31 | summary(): entry_size(0),entry_reserve_size(0) {
32 | //entries.clear();
33 | };
34 | summary(int entry_size, int reserve_size): entry_size(entry_size), entry_reserve_size(reserve_size) {entries.resize(reserve_size);};
35 | void Reserve(int size);//reserve memory for the summary
36 | void Prune(summary& src,int size);//reduce the number of cut point candidates of the summary
37 | void Merge(summary& src1, summary& src2);//merge two summaries
38 | void Copy(summary& src);
39 |
40 | };
41 |
42 | /**
43 | * @brief: store the pairs before constructing a summary
44 | */
45 | class Qitem{
46 | public:
47 | int tail;
48 | vector> data;
49 | Qitem(): tail(0) {
50 | //data.clear();
51 | };
52 | void GetSummary(summary& ret);
53 | };
54 |
55 |
56 | class quanSketch{
57 | public:
58 | int numOfLevel;//the summary has multiple levels
59 | int summarySize;//max size of the first level summary
60 | Qitem Qentry;
61 | vector summaries;
62 | summary t_summary; //for process_nodes
63 | void Init(int maxn, float_type eps);
64 | void Add(float_type, float_type);
65 | void GetSummary(summary& dest);
66 | quanSketch(): numOfLevel(0), summarySize(0) {
67 | //summaries.clear();
68 | };
69 |
70 | };
71 | #endif //THUNDERGBM_QUANTILE_SKETCH_H
72 |
--------------------------------------------------------------------------------
/include/thundergbm/row_sampler.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by shijiashuai on 2019-02-15.
3 | //
4 |
5 | #ifndef THUNDERGBM_ROW_SAMPLER_H
6 | #define THUNDERGBM_ROW_SAMPLER_H
7 |
8 | #include "thundergbm/common.h"
9 | #include "syncarray.h"
10 |
11 | class RowSampler {
12 | public:
13 | void do_bagging(MSyncArray &gradients);
14 | };
15 | #endif //THUNDERGBM_ROW_SAMPLER_H
16 |
--------------------------------------------------------------------------------
/include/thundergbm/sparse_columns.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by shijiashuai on 5/7/18.
3 | //
4 |
5 | #ifndef THUNDERGBM_SPARSE_COLUMNS_H
6 | #define THUNDERGBM_SPARSE_COLUMNS_H
7 |
8 | #include "syncarray.h"
9 | #include "dataset.h"
10 |
11 | class SparseColumns {//one feature corresponding to one column
12 | public:
13 | SyncArray csc_val;
14 | SyncArray csc_row_idx;
15 | SyncArray csc_col_ptr;
16 |
17 | //original order without sort
18 | SyncArray csc_val_origin;
19 | SyncArray csc_row_idx_origin;
20 | SyncArray csc_col_ptr_origin;
21 |
22 | //csr data
23 | SyncArray csr_val;
24 | SyncArray csr_row_ptr;
25 | SyncArray csr_col_idx;
26 |
27 | int max_trick_depth = -1;
28 | int max_trick_nodes = -1;
29 |
30 | int n_column;
31 | int n_row;
32 | int column_offset;
33 | size_t nnz;
34 |
35 | void csr2csc_gpu(const DataSet &dataSet, vector> &);
36 |
37 | //function for data transfer to gpu , this function is only for single GPU device
38 | void to_gpu(const DataSet &dataSet, vector> &);
39 |
40 | void csr2csc_cpu(const DataSet &dataset, vector> &);
41 | void csc_by_default(const DataSet &dataset, vector> &v_columns);
42 | void to_multi_devices(vector> &) const;
43 |
44 | };
45 | #endif //THUNDERGBM_SPARSE_COLUMNS_H
46 |
--------------------------------------------------------------------------------
/include/thundergbm/syncarray.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by jiashuai on 17-9-17.
3 | //
4 |
5 | #ifndef THUNDERGBM_SYNCDATA_H
6 | #define THUNDERGBM_SYNCDATA_H
7 |
8 | #include "thundergbm/util/log.h"
9 | #include "syncmem.h"
10 |
11 | /**
12 | * @brief Wrapper of SyncMem with a type
13 | * @tparam T type of element
14 | */
15 | template
16 | class SyncArray : public el::Loggable {
17 | public:
18 | /**
19 | * initialize class that can store given count of elements
20 | * @param count the given count
21 | */
22 | explicit SyncArray(size_t count) : mem(new SyncMem(sizeof(T) * count)), size_(count) {
23 | }
24 |
25 | SyncArray() : mem(nullptr), size_(0) {}
26 |
27 | ~SyncArray() { delete mem; };
28 |
29 | const T *host_data() const {
30 | to_host();
31 | return static_cast(mem->host_data());
32 | };
33 |
34 | const T *device_data() const {
35 | to_device();
36 | return static_cast(mem->device_data());
37 | };
38 |
39 | T *host_data() {
40 | to_host();
41 | return static_cast(mem->host_data());
42 | };
43 |
44 | T *device_data() {
45 | to_device();
46 | return static_cast(mem->device_data());
47 | };
48 |
49 | T *device_end() {
50 | return device_data() + size();
51 | };
52 |
53 | const T *device_end() const {
54 | return device_data() + size();
55 | };
56 |
57 | void set_host_data(T *host_ptr) {
58 | mem->set_host_data(host_ptr);
59 | }
60 |
61 | void set_device_data(T *device_ptr) {
62 | mem->set_device_data(device_ptr);
63 | }
64 |
65 | void to_host() const {
66 | CHECK_GT(size_, 0);
67 | mem->to_host();
68 | }
69 |
70 | void to_device() const {
71 | CHECK_GT(size_, 0);
72 | mem->to_device();
73 | }
74 |
75 | /**
76 | * copy device data. This will call to_device() implicitly.
77 | * @param source source data pointer (data can be on host or device)
78 | * @param count the count of elements
79 | */
80 | void copy_from(const T *source, size_t count) {
81 |
82 | #ifdef USE_CUDA
83 | thunder::device_mem_copy(mem->device_data(), source, sizeof(T) * count);
84 | #else
85 | memcpy(mem->host_data(), source, sizeof(T) * count);
86 | #endif
87 | };
88 |
89 | void copy_from(const SyncArray &source) {
90 |
91 | CHECK_EQ(size(), source.size()) << "destination and source count doesn't match";
92 | #ifdef USE_CUDA
93 | if (get_owner_id() == source.get_owner_id())
94 | copy_from(source.device_data(), source.size());
95 | else
96 | CUDA_CHECK(cudaMemcpyPeer(mem->device_data(), get_owner_id(), source.device_data(), source.get_owner_id(),
97 | source.mem_size()));
98 | #else
99 | copy_from(source.host_data(), source.size());
100 | #endif
101 | };
102 |
103 | /**
104 | * resize to a new size. This will also clear all data.
105 | * @param count
106 | */
107 | void resize(size_t count) {
108 | if(mem != nullptr || mem != NULL) {
109 | delete mem;
110 | }
111 | mem = new SyncMem(sizeof(T) * count);
112 | this->size_ = count;
113 | };
114 |
115 | /*
116 | * resize to a new size. This will not clear the origin data.
117 | * @param count
118 | */
119 | void resize_without_delete(size_t count) {
120 | // delete mem;
121 | mem = new SyncMem(sizeof(T) * count);
122 | this->size_ = count;
123 | };
124 |
125 |
126 | size_t mem_size() const {//number of bytes
127 | return mem->size();
128 | }
129 |
130 | size_t size() const {//number of values
131 | return size_;
132 | }
133 |
134 | SyncMem::HEAD head() const {
135 | return mem->head();
136 | }
137 |
138 | void log(el::base::type::ostream_t &ostream) const override {
139 | int i;
140 | ostream << "[";
141 | const T *data = host_data();
142 | for (i = 0; i < size() - 1 && i < el::base::consts::kMaxLogPerContainer - 1; ++i) {
143 | // for (i = 0; i < size() - 1; ++i) {
144 | ostream << data[i] << ",";
145 | }
146 | ostream << host_data()[i];
147 | if (size() <= el::base::consts::kMaxLogPerContainer) {
148 | ostream << "]";
149 | } else {
150 | ostream << ", ...(" << size() - el::base::consts::kMaxLogPerContainer << " more)";
151 | }
152 | };
153 |
154 | int get_owner_id() const {
155 | return mem->get_owner_id();
156 | }
157 |
158 | //move constructor
159 | SyncArray(SyncArray &&rhs) noexcept : mem(rhs.mem), size_(rhs.size_) {
160 | rhs.mem = nullptr;
161 | rhs.size_ = 0;
162 | }
163 |
164 | //move assign
165 | SyncArray &operator=(SyncArray &&rhs) noexcept {
166 | delete mem;
167 | mem = rhs.mem;
168 | size_ = rhs.size_;
169 |
170 | rhs.mem = nullptr;
171 | rhs.size_ = 0;
172 | return *this;
173 | }
174 |
175 | SyncArray(const SyncArray &) = delete;
176 |
177 | SyncArray &operator=(const SyncArray &) = delete;
178 |
179 |
180 | //new function clear gpu mem
181 | void clear_device(){
182 | to_host();
183 | mem->free_device();
184 | }
185 |
186 | private:
187 | SyncMem *mem;
188 | size_t size_;
189 | };
190 |
191 | //SyncArray for multiple devices
192 | template
193 | class MSyncArray : public vector> {
194 | public:
195 | explicit MSyncArray(size_t n_device) : base_class(n_device) {};
196 |
197 | explicit MSyncArray(size_t n_device, size_t size) : base_class(n_device) {
198 | for (int i = 0; i < n_device; ++i) {
199 | this->at(i) = SyncArray(size);
200 | }
201 | };
202 |
203 | MSyncArray() : base_class() {};
204 |
205 | //move constructor and assign
206 | MSyncArray(MSyncArray &&) = default;
207 |
208 | MSyncArray &operator=(MSyncArray &&) = default;
209 |
210 | MSyncArray(const MSyncArray &) = delete;
211 |
212 | MSyncArray &operator=(const MSyncArray &) = delete;
213 |
214 | private:
215 | typedef vector> base_class;
216 | };
217 |
218 | #endif //THUNDERGBM_SYNCDATA_H
219 |
--------------------------------------------------------------------------------
/include/thundergbm/syncmem.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by jiashuai on 17-9-16.
3 | //
4 |
5 | #ifndef THUNDERGBM_SYNCMEM_H
6 | #define THUNDERGBM_SYNCMEM_H
7 |
8 | #include "common.h"
9 | #include "cub/util_allocator.cuh"
10 |
11 | using namespace cub;
12 | namespace thunder {
13 |
14 | inline void device_mem_copy(void *dst, const void *src, size_t size) {
15 | #ifdef USE_CUDA
16 | CUDA_CHECK(cudaMemcpy(dst, src, size, cudaMemcpyDefault));
17 | #else
18 | NO_GPU;
19 | #endif
20 | }
21 |
22 | class DeviceAllocator : public CachingDeviceAllocator {
23 | public:
24 | DeviceAllocator(unsigned int bin_growth, unsigned int min_bin, unsigned int max_bin, size_t max_cached_bytes,
25 | bool skip_cleanup, bool debug) : CachingDeviceAllocator(bin_growth, min_bin, max_bin,
26 | max_cached_bytes, skip_cleanup,
27 | debug) {};
28 |
29 | cudaError_t DeviceAllocate(int device, void **d_ptr, size_t bytes, cudaStream_t active_stream = 0);
30 |
31 | cudaError_t DeviceAllocate(void **d_ptr, size_t bytes, cudaStream_t active_stream = 0);
32 | };
33 |
34 | class HostAllocator : public CachingDeviceAllocator {
35 | public:
36 | HostAllocator(unsigned int bin_growth, unsigned int min_bin, unsigned int max_bin, size_t max_cached_bytes,
37 | bool skip_cleanup, bool debug) : CachingDeviceAllocator(bin_growth, min_bin, max_bin,
38 | max_cached_bytes, skip_cleanup, debug) {};
39 |
40 | cudaError_t DeviceAllocate(int device, void **d_ptr, size_t bytes, cudaStream_t active_stream = 0);
41 |
42 | cudaError_t DeviceAllocate(void **d_ptr, size_t bytes, cudaStream_t active_stream = 0);
43 |
44 | cudaError_t DeviceFree(int device, void *d_ptr);
45 |
46 | cudaError_t DeviceFree(void *d_ptr);
47 |
48 | cudaError_t FreeAllCached();
49 |
50 | ~HostAllocator() override;
51 | };
52 |
53 | /**
54 | * @brief Auto-synced memory for CPU and GPU
55 | */
56 | class SyncMem {
57 | public:
58 | SyncMem();
59 |
60 | /**
61 | * create a piece of synced memory with given size. The GPU/CPU memory will not be allocated immediately, but
62 | * allocated when it is used at first time.
63 | * @param size the size of memory (in Bytes)
64 | */
65 | explicit SyncMem(size_t size);
66 |
67 | ~SyncMem();
68 |
69 | ///return raw host pointer
70 | void *host_data();
71 |
72 | ///return raw device pointer
73 | void *device_data();
74 |
75 | /**
76 | * set host data pointer to another host pointer, and its memory will not be managed by this class
77 | * @param data another host pointer
78 | */
79 | void set_host_data(void *data);
80 |
81 | /**
82 | * set device data pointer to another device pointer, and its memory will not be managed by this class
83 | * @param data another device pointer
84 | */
85 | void set_device_data(void *data);
86 |
87 | ///transfer data to host
88 | void to_host();
89 |
90 | ///transfer data to device
91 | void to_device();
92 |
93 | ///return the size of memory
94 | size_t size() const;
95 |
96 | ///to determine the where the newest data locates in
97 | enum HEAD {
98 | HOST, DEVICE, UNINITIALIZED
99 | };
100 |
101 | HEAD head() const;
102 |
103 | int get_owner_id() const {
104 | return device_id;
105 | }
106 |
107 | static void clear_cache() {
108 | device_allocator.FreeAllCached();
109 | host_allocator.FreeAllCached();
110 | };
111 |
112 | //new func to clear gpu mem
113 | void free_device(){
114 | device_allocator.DeviceFree(device_ptr);
115 | }
116 |
117 | private:
118 | void *device_ptr;
119 | void *host_ptr;
120 | bool own_device_data;
121 | bool own_host_data;
122 | size_t size_;
123 | HEAD head_;
124 | int device_id;
125 | static DeviceAllocator device_allocator;
126 | static HostAllocator host_allocator;
127 |
128 | inline void malloc_host(void **ptr, size_t size);
129 |
130 | inline void free_host(void *ptr);
131 |
132 | };
133 | }
134 | using thunder::SyncMem;
135 | #endif //THUNDERGBM_SYNCMEM_H
136 |
--------------------------------------------------------------------------------
/include/thundergbm/trainer.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by zeyi on 1/9/19.
3 | //
4 |
5 | #ifndef THUNDERGBM_TRAINER_H
6 | #define THUNDERGBM_TRAINER_H
7 |
8 | #include "common.h"
9 | #include "tree.h"
10 | #include "dataset.h"
11 |
12 | class TreeTrainer{
13 | public:
14 | vector > train(GBMParam ¶m, const DataSet &dataset);
15 | // float_type train(GBMParam ¶m);
16 | // float_type train_exact(GBMParam ¶m);
17 | // float_type train_hist(GBMParam ¶m);
18 |
19 | };
20 |
21 | #endif //THUNDERGBM_TRAINER_H
22 |
--------------------------------------------------------------------------------
/include/thundergbm/tree.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by jiashuai on 18-1-18.
3 | //
4 |
5 | #ifndef THUNDERGBM_TREE_H
6 | #define THUNDERGBM_TREE_H
7 |
8 | #include "syncarray.h"
9 | #include "sstream"
10 |
11 |
12 | class Tree {
13 | public:
14 | struct TreeNode {
15 | int final_id;// node id after pruning, may not equal to node index
16 | int lch_index;// index of left child
17 | int rch_index;// index of right child
18 | int parent_index;// index of parent node
19 | float_type gain;// gain of splitting this node
20 | float_type base_weight;
21 | int split_feature_id;
22 | float_type split_value;
23 | unsigned char split_bid;
24 | bool default_right;
25 | bool is_leaf;
26 | bool is_valid;// non-valid nodes are those that are "children" of leaf nodes
27 | bool is_pruned;// pruned after pruning
28 |
29 | GHPair sum_gh_pair;
30 |
31 | friend std::ostream &operator<<(std::ostream &os,
32 | const TreeNode &node);
33 |
34 | HOST_DEVICE void calc_weight(float_type lambda) {
35 | this->base_weight = -sum_gh_pair.g / (sum_gh_pair.h + lambda);
36 | }
37 |
38 | HOST_DEVICE bool splittable() const {
39 | return !is_leaf && is_valid;
40 | }
41 |
42 | };
43 |
44 | Tree() = default;
45 |
46 | Tree(const Tree &tree) {
47 | nodes.resize(tree.nodes.size());
48 | nodes.copy_from(tree.nodes);
49 | }
50 |
51 | Tree &operator=(const Tree &tree) {
52 | nodes.resize(tree.nodes.size());
53 | nodes.copy_from(tree.nodes);
54 | return *this;
55 | }
56 |
57 | void init2(const SyncArray &gradients, const GBMParam ¶m);
58 |
59 | string dump(int depth) const;
60 |
61 | SyncArray nodes;
62 |
63 | void prune_self(float_type gamma);
64 |
65 | private:
66 | void preorder_traversal(int nid, int max_depth, int depth, string &s) const;
67 |
68 | int try_prune_leaf(int nid, int np, float_type gamma, vector &leaf_child_count);
69 |
70 | void reorder_nid();
71 | };
72 |
73 | #endif //THUNDERGBM_TREE_H
74 |
--------------------------------------------------------------------------------
/include/thundergbm/util/device_lambda.cuh:
--------------------------------------------------------------------------------
1 | //
2 | // Created by jiashuai on 18-1-19.
3 | //
4 |
5 | #ifndef THUNDERGBM_DEVICE_LAMBDA_H
6 | #define THUNDERGBM_DEVICE_LAMBDA_H
7 |
8 | #include "thundergbm/common.h"
9 |
10 | template
11 | __global__ void lambda_kernel(size_t len, L lambda) {
12 | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += blockDim.x * gridDim.x) {
13 | lambda(i);
14 | }
15 | }
16 |
17 | template
18 | __global__ void anonymous_kernel_k(L lambda) {
19 | lambda();
20 | }
21 |
22 | template
23 | __global__ void lambda_2d_sparse_kernel(const int *len2, L lambda) {
24 | int i = blockIdx.x;
25 | int begin = len2[i];
26 | int end = len2[i + 1];
27 | for (int j = begin + blockIdx.y * blockDim.x + threadIdx.x; j < end; j += blockDim.x * gridDim.y) {
28 | lambda(i, j);
29 | }
30 | }
31 |
32 | template
33 | __global__ void lambda_2d_maximum_sparse_kernel(const int *len2, const int maximum, L lambda) {
34 | int i = blockIdx.x;
35 | int begin = len2[i];
36 | int end = len2[i + 1];
37 | int interval = (end - begin) / maximum;
38 | for (int j = begin + blockIdx.y * blockDim.x + threadIdx.x; j < end; j += blockDim.x * gridDim.y) {
39 | lambda(i, j, interval);
40 | }
41 | }
42 |
43 | ///p100 has 56 MPs, using 32*56 thread blocks, one MP supports 2048 threads
44 | //a6000 has 84 mps
45 | template
46 | inline void device_loop(size_t len, L lambda) {
47 | if (len > 0) {
48 | lambda_kernel << < NUM_BLOCK, BLOCK_SIZE >> > (len, lambda);
49 | cudaDeviceSynchronize();
50 | /*cudaError_t error = cudaPeekAtLastError();*/
51 | if(cudaPeekAtLastError() == cudaErrorInvalidResourceHandle){
52 | cudaGetLastError();
53 | LOG(INFO) << "warning: cuda invalid resource handle, potential issue of driver version and cuda version mismatch";
54 | } else {
55 | CUDA_CHECK(cudaPeekAtLastError());
56 | }
57 | }
58 | }
59 |
60 | template
61 | inline void anonymous_kernel(L lambda, size_t num_fv, size_t smem_size = 0, int NUM_BLOCK = 32 * 84, int BLOCK_SIZE = 256) {
62 | size_t tmp_num_block = num_fv / (BLOCK_SIZE * 8);
63 | NUM_BLOCK = std::min(NUM_BLOCK, (int)std::max(tmp_num_block, (size_t)32));
64 | anonymous_kernel_k<< < NUM_BLOCK, BLOCK_SIZE, smem_size >> > (lambda);
65 | cudaDeviceSynchronize();
66 | if(cudaPeekAtLastError() == cudaErrorInvalidResourceHandle){
67 | cudaGetLastError();
68 | LOG(INFO) << "warning: cuda invalid resource handle, potential issue of driver version and cuda version mismatch";
69 | } else {
70 | CUDA_CHECK(cudaPeekAtLastError());
71 | }
72 | }
73 |
74 | /**
75 | * @brief: (len1 x NUM_BLOCK) is the total number of blocks; len2 is an array of lengths.
76 | */
77 | template
78 | void device_loop_2d(int len1, const int *len2, L lambda, unsigned int NUM_BLOCK = 4 * 84,
79 | unsigned int BLOCK_SIZE = 256) {
80 | if (len1 > 0) {
81 | lambda_2d_sparse_kernel << < dim3(len1, NUM_BLOCK), BLOCK_SIZE >> > (len2, lambda);
82 | cudaDeviceSynchronize();
83 | CUDA_CHECK(cudaPeekAtLastError());
84 | }
85 | }
86 |
87 | /**
88 | * @brief: (len1 x NUM_BLOCK) is the total number of blocks; len2 is an array of lengths.
89 | */
90 | template
91 | void device_loop_2d_with_maximum(int len1, const int *len2, const int maximum, L lambda,
92 | unsigned int NUM_BLOCK = 4 * 84,
93 | unsigned int BLOCK_SIZE = 256) {
94 | if (len1 > 0) {
95 | lambda_2d_maximum_sparse_kernel << < dim3(len1, NUM_BLOCK), BLOCK_SIZE >> > (len2, maximum, lambda);
96 | cudaDeviceSynchronize();
97 | CUDA_CHECK(cudaPeekAtLastError());
98 | }
99 | }
100 |
101 | //for sparse formate bin id
102 | //func for new sparse loop
103 | template
104 | __global__ void lambda_hist_csr_root_kernel(const int *csr_row_ptr, L lambda) {
105 |
106 | int i = blockIdx.x;
107 | int begin = csr_row_ptr[i];
108 | int end = csr_row_ptr[i + 1];
109 | for (int j = begin + blockIdx.y * blockDim.x + threadIdx.x; j < end; j += blockDim.x * gridDim.y) {
110 | //i for instance
111 | //j for feature
112 | lambda(i, j );
113 | }
114 | }
115 |
116 |
117 |
118 | template
119 | void device_loop_hist_csr_root(int n_instances, const int *csr_row_ptr, L lambda , unsigned int NUM_BLOCK = 4 * 84,
120 | unsigned int BLOCK_SIZE = 256) {
121 | if (n_instances > 0) {
122 | lambda_hist_csr_root_kernel << < dim3(n_instances, NUM_BLOCK), BLOCK_SIZE >> > (csr_row_ptr, lambda);
123 | cudaDeviceSynchronize();
124 | CUDA_CHECK(cudaPeekAtLastError());
125 | }
126 | }
127 |
128 | template
129 | __global__ void lambda_hist_csr_node_kernel(L lambda) {
130 | int i = blockIdx.x;
131 | int current_pos = blockIdx.y * blockDim.x + threadIdx.x;
132 | int stride = blockDim.x * gridDim.y;
133 | lambda(i,current_pos,stride);
134 | }
135 |
136 | template
137 | void device_loop_hist_csr_node(int n_instances, const int *csr_row_ptr, L lambda , unsigned int NUM_BLOCK = 4 * 84,
138 | unsigned int BLOCK_SIZE = 256) {
139 | if (n_instances > 0) {
140 | lambda_hist_csr_node_kernel << < dim3(n_instances, NUM_BLOCK), BLOCK_SIZE >> > (lambda);
141 | cudaDeviceSynchronize();
142 | CUDA_CHECK(cudaPeekAtLastError());
143 | }
144 | }
145 |
146 |
147 | #endif //THUNDERGBM_DEVICE_LAMBDA_H
148 |
--------------------------------------------------------------------------------
/include/thundergbm/util/multi_device.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 18-6-18.
3 | //
4 |
5 | #ifndef THUNDERGBM_MULTI_DEVICE_H
6 | #define THUNDERGBM_MULTI_DEVICE_H
7 |
8 | #include "thundergbm/common.h"
9 |
10 | //switch to specific device and do something, then switch back to the original device
11 | //FIXME make this macro into a function?
12 | #define DO_ON_DEVICE(device_id, something) \
13 | do { \
14 | int org_device_id = 0; \
15 | CUDA_CHECK(cudaGetDevice(&org_device_id)); \
16 | CUDA_CHECK(cudaSetDevice(device_id)); \
17 | something; \
18 | CUDA_CHECK(cudaSetDevice(org_device_id)); \
19 | } while (false)
20 |
21 | /**
22 | * Do something on multiple devices, then switch back to the original device
23 | *
24 | *
25 | * example:
26 | *
27 | * DO_ON_MULTI_DEVICES(n_devices, [&](int device_id){
28 | * //do_something_on_device(device_id);
29 | * });
30 | */
31 |
32 | template
33 | void DO_ON_MULTI_DEVICES(int n_devices, L do_something) {
34 | int org_device_id = 0;
35 | CUDA_CHECK(cudaGetDevice(&org_device_id));
36 | #pragma omp parallel for num_threads(n_devices)
37 | for (int device_id = 0; device_id < n_devices; device_id++) {
38 | CUDA_CHECK(cudaSetDevice(device_id));
39 | do_something(device_id);
40 | }
41 | CUDA_CHECK(cudaSetDevice(org_device_id));
42 |
43 | }
44 |
45 | #endif //THUNDERGBM_MULTI_DEVICE_H
46 |
--------------------------------------------------------------------------------
/python/README.md:
--------------------------------------------------------------------------------
1 | We provide a scikit-learn wrapper interface. Before you use the Python interface, you must build ThunderGBM. Note that both dense and sparse matrices are supported in ThunderGBM. The dense matrices are converted into csr format which the the parallel training algorithm on GPUs is based on.
2 |
3 | ## Instructions for building ThunderGBM
4 | * Please refer to [Installation](http://thundergbm.readthedocs.io/en/latest/how-to.html) for building ThunderGBM.
5 |
6 | * Then, if you want to install the Python package from source, go to the project root directory and run:
7 | ```bash
8 | cd python && python setup.py install
9 | ```
10 | Or you can install via pip
11 | ```bash
12 | pip3 install thundergbm
13 | ```
14 | * After you have successfully installed ThunderGBM, you can import TGBMModel:
15 | ```python
16 | from thundergbm import TGBMClassifier
17 | from sklearn import datasets
18 | clf = TGBMClassifier()
19 | X, y = datasets.load_digits(return_X_y=True)
20 | clf.fit(X, y)
21 | clf.save_model(model_path)
22 | ```
23 | * Load model
24 | ```python
25 | from thundergbm import TGBMClassifier
26 | clf = TGBMClassifier(objective="your-objective")
27 | # You should specific objective here as in training stage
28 | clf.load_model(model_path)
29 | y_pred = clf.predict(X)
30 | ```
31 | ## Prerequisites
32 | * numpy | scipy | sklearn
33 |
34 | ## Example
35 |
36 | * Step 1: Create a file called ```tgbm_test.py``` which has the following content.
37 | ```python
38 | from thundergbm import *
39 | from sklearn.datasets import *
40 | from sklearn.metrics import mean_squared_error
41 | from math import sqrt
42 |
43 | x,y = load_svmlight_file("../dataset/test_dataset.txt")
44 | clf = TGBMRegressor()
45 | clf.fit(x,y)
46 |
47 | x2,y2=load_svmlight_file("../dataset/test_dataset.txt")
48 | y_predict=clf.predict(x2)
49 |
50 | rms = sqrt(mean_squared_error(y2, y_predict))
51 | print(rms)
52 |
53 | ```
54 | * Step 2: Run the python script.
55 | ```bash
56 | python tgbm_test.py
57 | ```
58 |
59 | ## ThunderGBM class
60 | *class TGBMModel(depth = 6, num_round = 40, n_device = 1, min_child_weight = 1.0, lambda_tgbm = 1.0, gamma = 1.0, max_num_bin = 255, verbose = 0, column_sampling_rate = 1.0, bagging = 0, n_parallel_trees = 1, learning_rate = 1.0, objective = "reg:linear", num_class = 1, path = "../dataset/test_dataset.txt"))*
61 |
62 | Please note that ``TGBMClassifier`` and ``TGBMRegressor`` are wrappers of ``TGBMModel``, and their constructors have the same parameters to ``TGBMModel``.
63 |
64 | ### Parametes
65 | Please refer to [Parameters](https://github.com/zeyiwen/thundergbm/blob/master/docs/parameters.md) in ThunderGBM documentations.
66 |
67 | ### Methods
68 | *fit(X, y)*:\
69 | Fit the GBM model according to the given training data.
70 |
71 | *predict(X)*:\
72 | Perform prediction on samples in X.
73 |
74 | *save_model(path)*:\
75 | Save the model to the file path.
76 |
77 | *load_model(path)*:\
78 | Load the model from the file path.
79 |
80 | *cv(X, y, folds=None, nfold=5, shuffle=True,seed=0)*:\
81 | * folds: A length *n* list of tuples. Each tuple is (in,out) where *in* is a list of indices to be used as the training samples for the *n* th fold and *out*
82 | is a list of indices to be used as the testing samples for the *n* th fold.
83 | * shuffle (bool): Whether to shuffle data before creating folds.
84 | * seed (int): Seed used to generate the folds.
85 |
86 | *get_shap_trees()*:\
87 | Return the model that used for the SHAP explainer.
88 |
--------------------------------------------------------------------------------
/python/benchmarks/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Introduction
3 |
4 | This is some python scripts that can help you reproduce our experimental results. We would like to thank the author of [GBM-Benchmarks](https://github.com/RAMitchell/GBM-Benchmarks) which our scripts are built on top of.
5 |
6 | ## Requirements
7 | _python3.x_ _numpy_ _sklearn_ _xgboost_ _lightgbm_ _catboost_ _thundergbm_
8 |
9 | ## How to run
10 |
11 | You can use command `python3 experiments.py [model_name] [device_type] [dataset_name]` to run the scripts. And the candidate values of each parameter are as follows:
12 |
13 | - model_name
14 | - xgboost
15 | - lightgbm
16 | - catboost
17 | - thundergbm
18 | - device_type
19 | - cpu
20 | - gpu
21 | - dataset_name
22 | - news20
23 | - higgs
24 | - log1p
25 | - cifar
26 |
27 | ## Files descriptions
28 | - model
29 | - base_model.py
30 | - datasets. py
31 | - catboost_model.py
32 | - lightgb_model.py
33 | - xgboost_model.py
34 | - thundergbm_model.py
35 | - utils
36 | - data_utils,py
37 | - file_utils.py
38 | - experiments. py
39 | - convert_dataset_plk.py
40 |
41 | Floder **_model_** contains the model file of each libraries which inherit from `BaseModel` in `base_model.py.` Floder **_utils_** contains a few tools including `data heleper` and `file I/O helper`. `convert_dataset_plk.py` is used for converting normal `libsvm` file to Python pickle file. This is because the datasets we used for experiments sometime have large size which lead to time-consuming data loading step. By using pickle, we can sharply reduce consuming time in data loading step. `experiments.py` is the main entrance of our scripts.
42 |
43 | ## How to add more datasets
44 |
45 | As the optional datasets of our scripts are limited, you can add the datasets you want. You can achieve this by modifying the file `utils/data_utils.py`. There are some dataset template in that script which may help you add your own dataset easily.
46 |
--------------------------------------------------------------------------------
/python/benchmarks/convert_dataset_plk.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from sklearn.datasets import load_svmlight_file
3 | import os
4 | from sklearn.datasets import dump_svmlight_file
5 |
6 |
7 | base_dir = './raw_dataset'
8 | plk_dir = './datasets'
9 | for dataset_name in os.listdir(base_dir):
10 | print(dataset_name)
11 | # if 'binary' not in dataset_name:
12 | # continue
13 | X, y = load_svmlight_file(os.path.join(base_dir, dataset_name))
14 | print(X.shape)
15 | print(y.shape)
16 | # y2 = [1 if x % 2 == 0 else 0 for x in y]
17 | # dump_svmlight_file(X, y2, open('SVNH.2classes', 'wb'))
18 | pickle.dump((X, y), open(os.path.join(plk_dir, dataset_name+'.plk'), 'wb'), protocol=4)
19 | # X, y = load_svmlight_file('datasets/YearPredictMSD.train')
--------------------------------------------------------------------------------
/python/benchmarks/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/python/benchmarks/model/__init__.py
--------------------------------------------------------------------------------
/python/benchmarks/model/base_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import mean_squared_error, accuracy_score
3 |
4 |
5 | class BaseModel(object):
6 | """
7 | Base model to run the test
8 | """
9 |
10 | def __init__(self):
11 | self.max_depth = 6
12 | self.learning_rate = 1
13 | self.min_split_loss = 1
14 | self.min_weight = 1
15 | self.L1_reg = 1
16 | self.L2_reg = 1
17 | self.num_rounds = 40
18 | self.max_bin = 255
19 | self.use_gpu = True
20 | self.params = {}
21 |
22 | self.model = None # self.model is different with different libraries
23 |
24 | def _config_model(self, data):
25 | """
26 | To config the model with different params
27 | """
28 | pass
29 |
30 | def _train_model(self, data):
31 | """
32 | To train model
33 | :param data:
34 | :return:
35 | """
36 | pass
37 |
38 | def _predict(self, data):
39 | pass
40 |
41 | def eval(self, data, pred):
42 | """
43 | To eval the predict results with specified metric
44 | :param data:
45 | :param pred:
46 | :return:
47 | """
48 | if data.metric == "RMSE":
49 | with open('pred', 'w') as f:
50 | for x in pred:
51 | f.write(str(x) + '\n')
52 | return np.sqrt(mean_squared_error(data.y_test, pred))
53 | elif data.metric == "Accuracy":
54 | # Threshold prediction if binary classification
55 | if data.task == "Classification":
56 | pred = pred > 0.5
57 | elif data.task == "Multiclass classification":
58 | if pred.ndim > 1:
59 | pred = np.argmax(pred, axis=1)
60 | return accuracy_score(data.y_test, pred)
61 | else:
62 | raise ValueError("Unknown metric: " + data.metric)
63 |
64 | def run_model(self, data):
65 | """
66 | To run model
67 | :param data:
68 | :return:
69 | """
70 | self._config_model(data)
71 | elapsed = self._train_model(data)
72 | # metric = 0
73 | metric = self._predict(data)
74 | print("##### Elapsed time: %.5f #####" % (elapsed))
75 | print("##### Predict %s: %.4f #####" % (data.metric, metric))
76 |
77 | return elapsed, metric
78 |
79 | def model_name(self):
80 | pass
81 |
82 |
--------------------------------------------------------------------------------
/python/benchmarks/model/catboost_model.py:
--------------------------------------------------------------------------------
1 | from model.base_model import BaseModel
2 | import numpy as np
3 | import catboost as cb
4 | import time
5 | import utils.data_utils as du
6 | from model.datasets import Dataset
7 |
8 |
9 | class CatboostModel(BaseModel):
10 |
11 | def __init__(self):
12 | BaseModel.__init__(self)
13 |
14 | def _config_model(self, data):
15 | self.params['eta'] = self.learning_rate
16 | self.params['depth'] = self.max_depth
17 | self.params['l2_leaf_reg'] = self.L2_reg
18 | self.params['devices'] = "0"
19 | self.params['max_bin'] = self.max_bin
20 | self.params['thread_count'] = 20
21 | self.params['iterations'] = self.num_rounds
22 |
23 | if self.use_gpu:
24 | self.params['task_type'] = 'GPU'
25 | else:
26 | self.params['task_type'] = 'CPU'
27 | if data.task == "Multiclass classification":
28 | self.params['loss_function'] = 'MultiClass'
29 | self.params["classes_count"] = int(np.max(data.y_test) + 1)
30 | self.params["eval_metric"] = 'MultiClass'
31 | if data.task == "Classification":
32 | self.params['loss_function'] = 'Logloss'
33 | if data.task == "Ranking":
34 | self.params['loss_function'] = 'YetiRank'
35 | self.params["eval_metric"] = 'NDCG'
36 |
37 | def _train_model(self, data):
38 | print(self.params)
39 | dtrain = cb.Pool(data.X_train, data.y_train)
40 | if data.task == 'Ranking':
41 | dtrain.set_group_id(data.groups)
42 |
43 | start = time.time()
44 | self.model = cb.train(pool=dtrain, params=self.params, )
45 | elapsed = time.time() - start
46 |
47 | return elapsed
48 |
49 |
50 | def _predict(self, data):
51 | # test dataset for catboost
52 | cb_test = cb.Pool(data.X_test, data.y_test)
53 | if data.task == 'Ranking':
54 | cb_test.set_group_id(data.groups)
55 | preds = self.model.predict(cb_test)
56 | metric = self.eval(data, preds)
57 |
58 | return metric
59 |
60 | def model_name(self):
61 | name = "catboost_"
62 | use_cpu = "gpu_" if self.use_gpu else "cpu_"
63 | nr = str(self.num_rounds) + "_"
64 | return name + use_cpu + nr + str(self.max_depth)
65 |
66 |
67 | if __name__ == "__main__":
68 | X, y, groups = du.get_yahoo()
69 | dataset = Dataset(name='yahoo', task='Ranking', metric='NDCG', get_func=du.get_yahoo)
70 | print(dataset.X_train.shape)
71 | print(dataset.y_test.shape)
72 |
73 | t_start = time.time()
74 | xgbModel = CatboostModel()
75 | xgbModel.use_gpu = False
76 | xgbModel.run_model(data=dataset)
77 |
78 | eplased = time.time() - t_start
79 | print("--------->> " + str(eplased))
--------------------------------------------------------------------------------
/python/benchmarks/model/datasets.py:
--------------------------------------------------------------------------------
1 | from sklearn.model_selection import train_test_split
2 | import utils.data_utils as du
3 |
4 | class Dataset(object):
5 |
6 | def __init__(self, name, task, metric, X=None, y=None, get_func=None):
7 | """
8 | Please notice that the training set and test set are the same here.
9 | """
10 | group = None
11 | self.name = name
12 | self.task = task
13 | self.metric = metric
14 | if task == 'Ranking':
15 | if get_func is not None:
16 | X, y, group = get_func()
17 | else:
18 | if get_func is not None:
19 | X, y = get_func()
20 | self.X_train = X
21 | self.X_test = X
22 | self.y_train = y
23 | self.y_test = y
24 | self.groups = group
25 |
26 | def split_dataset(self, test_size=0.1):
27 | """
28 | Split the dataset in a certain proportion
29 | :param test_size: the proportion of test set
30 | :return:
31 | """
32 | self.X_train, self.X_test, self.y_train, self.y_test = \
33 | train_test_split(self.X_train, self.X_test, test_size=test_size)
34 |
35 |
36 | if __name__ == "__main__":
37 | X, y = du.get_higgs()
38 | dataset = Dataset(name='higgs', task='Regression', metric='RMSE', get_func=du.get_higgs)
39 | print(dataset.X_train.shape)
40 | print(dataset.y_test.shape)
41 |
42 | dataset.split_dataset(test_size=0.5)
43 | print(dataset.X_train.shape)
44 | print(dataset.y_test.shape)
45 |
--------------------------------------------------------------------------------
/python/benchmarks/model/lightgbm_model.py:
--------------------------------------------------------------------------------
1 | from model.base_model import BaseModel
2 | import numpy as np
3 | import lightgbm as lgb
4 | import time
5 | import utils.data_utils as du
6 | from model.datasets import Dataset
7 |
8 |
9 | class LightGBMModel(BaseModel):
10 |
11 | def __init__(self):
12 | BaseModel.__init__(self)
13 |
14 | def _config_model(self, data):
15 | self.params['task'] = 'train'
16 | self.params['boosting_type'] = 'gbdt'
17 | self.params['max_depth'] = 6
18 | self.params['num_leaves'] = 2 ** self.params['max_depth'] # for max_depth is 6
19 | # self.params['min_sum_hessian+in_leaf'] = 1
20 | self.params['min_split_gain'] = self.min_split_loss
21 | self.params['min_child_weight'] = self.min_weight
22 | self.params['lambda_l1'] = self.L1_reg
23 | self.params['lambda_l2'] = self.L2_reg
24 | self.params['max_bin'] = self.max_bin
25 | self.params['num_threads'] = 20
26 |
27 | if self.use_gpu:
28 | self.params['device'] = 'gpu'
29 | else:
30 | self.params['device'] = 'cpu'
31 | if data.task == "Regression":
32 | self.params["objective"] = "regression"
33 | elif data.task == "Multiclass classification":
34 | self.params["objective"] = "multiclass"
35 | self.params["num_class"] = int(np.max(data.y_test) + 1)
36 | elif data.task == "Classification":
37 | self.params["objective"] = "binary"
38 | elif data.task == "Ranking":
39 | self.params["objective"] = "lambdarank"
40 | else:
41 | raise ValueError("Unknown task: " + data.task)
42 |
43 |
44 | def _train_model(self, data):
45 | print(self.params)
46 | lgb_train = lgb.Dataset(data.X_train, data.y_train)
47 | if data.task == 'Ranking':
48 | lgb_train.set_group(data.groups)
49 |
50 | start = time.time()
51 | self.model = lgb.train(self.params,
52 | lgb_train,
53 | num_boost_round=self.num_rounds)
54 | elapsed = time.time() - start
55 |
56 | return elapsed
57 |
58 | def _predict(self, data):
59 | pred = self.model.predict(data.X_test)
60 | metric = self.eval(data, pred)
61 |
62 | return metric
63 |
64 | def model_name(self):
65 | name = "lightgbm_"
66 | use_cpu = "gpu_" if self.use_gpu else "cpu_"
67 | nr = str(self.num_rounds) + "_"
68 | return name + use_cpu + nr + str(self.max_depth)
69 |
70 |
71 | if __name__ == "__main__":
72 | X, y, groups = du.get_yahoo()
73 | dataset = Dataset(name='yahoo', task='Ranking', metric='NDCG', get_func=du.get_yahoo)
74 | print(dataset.X_train.shape)
75 | print(dataset.y_test.shape)
76 |
77 | t_start = time.time()
78 | xgbModel = LightGBMModel()
79 | xgbModel.use_gpu = False
80 | xgbModel.run_model(data=dataset)
81 |
82 | eplased = time.time() - t_start
83 | print("--------->> " + str(eplased))
84 |
--------------------------------------------------------------------------------
/python/benchmarks/model/thundergbm_model.py:
--------------------------------------------------------------------------------
1 | from model.base_model import BaseModel
2 | import thundergbm as tgb
3 | import time
4 | import numpy as np
5 | import utils.data_utils as du
6 | from model.datasets import Dataset
7 |
8 |
9 | class ThunderGBMModel(BaseModel):
10 |
11 | def __init__(self, depth=6, n_device=1, n_parallel_trees=1,
12 | verbose=0, column_sampling_rate=1.0, bagging=0, tree_method='auto'):
13 | BaseModel.__init__(self)
14 | self.verbose = verbose
15 | self.n_device = n_device
16 | self.column_sampling_rate = column_sampling_rate
17 | self.bagging = bagging
18 | self.n_parallel_trees = n_parallel_trees
19 | self.tree_method = tree_method
20 | self.objective = ""
21 | self.num_class = 1
22 |
23 | def _config_model(self, data):
24 | if data.task == "Regression":
25 | self.objective = "reg:linear"
26 | elif data.task == "Multiclass classification":
27 | self.objective = "multi:softmax"
28 | self.num_class = int(np.max(data.y_test) + 1)
29 | elif data.task == "Classification":
30 | self.objective = "binary:logistic"
31 | elif data.task == "Ranking":
32 | self.objective = "rank:ndcg"
33 | else:
34 | raise ValueError("Unknown task: " + data.task)
35 |
36 |
37 | def _train_model(self, data):
38 | if data.task is 'Regression':
39 | self.model = tgb.TGBMRegressor(tree_method=self.tree_method, depth = self.max_depth, n_trees = 40, n_gpus = 1, \
40 | min_child_weight = 1.0, lambda_tgbm = 1.0, gamma = 1.0,\
41 | max_num_bin = 255, verbose = 0, column_sampling_rate = 1.0,\
42 | bagging = 0, n_parallel_trees = 1, learning_rate = 1.0, \
43 | objective = "reg:linear", num_class = 1)
44 | else:
45 | self.model = tgb.TGBMClassifier(bagging=1, lambda_tgbm=1, learning_rate=0.07, min_child_weight=1.2, n_gpus=1, verbose=0,
46 | n_parallel_trees=40, gamma=0.2, depth=self.max_depth, n_trees=40, tree_method=self.tree_method, objective='multi:softprob')
47 | start = time.time()
48 | self.model.fit(data.X_train, data.y_train)
49 | elapsed = time.time() - start
50 | print("##################!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! %.5f" % elapsed)
51 |
52 | return elapsed
53 |
54 | def _predict(self, data):
55 | pred = self.model.predict(data.X_test)
56 | metric = self.eval(data, pred)
57 | return metric
58 |
59 |
60 | def model_name(self):
61 | name = "thundergbm_"
62 | use_cpu = "gpu_" if self.use_gpu else "cpu_"
63 | nr = str(self.num_rounds) + "_"
64 | return name + use_cpu + nr + str(self.max_depth)
65 |
66 |
67 | if __name__ == "__main__":
68 | # X, y = du.get_higgs()
69 | dataset = Dataset(name='higgs', task='Regression', metric='RMSE', get_func=du.get_realsim)
70 | print(dataset.X_train.shape)
71 | print(dataset.y_test.shape)
72 |
73 | t_start = time.time()
74 |
75 | tgmModel = ThunderGBMModel()
76 | tgmModel.tree_method = 'hist'
77 | tgmModel.run_model(data=dataset)
78 |
79 | eplased = time.time() - t_start
80 |
81 | print("--------->> " + str(eplased))
--------------------------------------------------------------------------------
/python/benchmarks/model/xgboost_model.py:
--------------------------------------------------------------------------------
1 | from model.base_model import BaseModel
2 | import numpy as np
3 | import xgboost as xgb
4 | import time
5 | import utils.data_utils as du
6 | from model.datasets import Dataset
7 |
8 |
9 | class XGboostModel(BaseModel):
10 |
11 | def __init__(self, use_exact=False, debug_verose=1):
12 | BaseModel.__init__(self)
13 | self.use_exact = use_exact
14 | self.debug_verose = debug_verose
15 |
16 | def _config_model(self, data):
17 | self.params['max_depth'] = self.max_depth
18 | self.params['learning_rate'] = self.learning_rate
19 | self.params['min_split_loss'] = self.min_split_loss
20 | self.params['min_child_weight'] = self.min_weight
21 | self.params['alpha'] = self.L1_reg
22 | self.params['lambda'] = self.L2_reg
23 | self.params['debug_verbose'] = self.debug_verose
24 | self.params['max_bin'] = self.max_bin
25 |
26 | if self.use_gpu:
27 | self.params['tree_method'] = ('gpu_exact' if self.use_exact
28 | else 'gpu_hist')
29 | self.params['n_gpus'] = 1
30 | else:
31 | self.params['nthread'] = 20
32 | self.params['tree_method'] = ('exact' if self.use_exact else 'hist')
33 |
34 | self.params["predictor"] = "gpu_predictor"
35 | if data.task == "Regression":
36 | self.params["objective"] = "reg:squarederror"
37 | elif data.task == "Multiclass classification":
38 | self.params["objective"] = "multi:softprob"
39 | self.params["num_class"] = int(np.max(data.y_test) + 1)
40 | elif data.task == "Classification":
41 | self.params["objective"] = "binary:logistic"
42 | elif data.task == "Ranking":
43 | self.params["objective"] = "rank:ndcg"
44 | else:
45 | raise ValueError("Unknown task: " + data.task)
46 |
47 | def _train_model(self, data):
48 | print(self.params)
49 | dtrain = xgb.DMatrix(data.X_train, data.y_train)
50 | if data.task == 'Ranking':
51 | dtrain.set_group(data.groups)
52 | t_start = time.time()
53 | self.model = xgb.train(self.params, dtrain, self.num_rounds, [(dtrain, "train")])
54 | elapsed_time = time.time() - t_start
55 |
56 | return elapsed_time
57 |
58 | def _predict(self, data):
59 | dtest = xgb.DMatrix(data.X_test, data.y_test)
60 | if data.task == 'Ranking':
61 | dtest.set_group(data.groups)
62 | pred = self.model.predict(dtest)
63 | metric = self.eval(data=data, pred=pred)
64 |
65 | return metric
66 |
67 | def model_name(self):
68 | name = "xgboost_"
69 | use_cpu = "gpu_" if self.use_gpu else "cpu_"
70 | nr = str(self.num_rounds) + "_"
71 | return name + use_cpu + nr + str(self.max_depth)
72 |
73 |
74 |
75 |
76 | if __name__ == "__main__":
77 | X, y, groups = du.get_yahoo()
78 | dataset = Dataset(name='yahoo', task='Ranking', metric='NDCG', get_func=du.get_yahoo)
79 | print(dataset.X_train.shape)
80 | print(dataset.y_test.shape)
81 |
82 | t_start = time.time()
83 | xgbModel = XGboostModel()
84 | xgbModel.use_gpu = False
85 | xgbModel.run_model(data=dataset)
86 |
87 | eplased = time.time() - t_start
88 | print("--------->> " + str(eplased))
--------------------------------------------------------------------------------
/python/benchmarks/run_exp.sh:
--------------------------------------------------------------------------------
1 | # lgb
2 | #python3 experiments.py lightgbm cpu higgs
3 | #python3 experiments.py lightgbm cpu log1p
4 | #python3 experiments.py lightgbm cpu cifar
5 | python3 experiments.py lightgbm cpu news20
6 |
7 | # lgb gpu
8 | #python3 experiments.py lightgbm gpu higgs
9 | #python3 experiments.py lightgbm gpu log1p
10 | #python3 experiments.py lightgbm gpu cifar
11 | python3 experiments.py lightgbm gpu news20
12 |
13 | # xgboost
14 | #python3 experiments.py xgboost cpu higgs
15 | #python3 experiments.py xgboost cpu log1p
16 | #python3 experiments.py xgboost cpu cifar
17 | python3 experiments.py xgboost cpu news20
18 |
19 | # xgboost gpu
20 | #python3 experiments.py xgboost gpu higgs
21 | #python3 experiments.py xgboost gpu log1p
22 | #python3 experiments.py xgboost gpu cifar
23 | python3 experiments.py xgboost gpu news20
24 |
25 | # catboost
26 | #python3 experiments.py catboost cpu higgs
27 | #python3 experiments.py catboost cpu log1p
28 | #python3 experiments.py catboost cpu cifar
29 | python3 experiments.py catboost cpu news20
30 |
31 | # catboost gpu
32 | #python3 experiments.py catboost gpu higgs
33 | #python3 experiments.py catboost gpu log1p
34 | #python3 experiments.py catboost gpu cifar
35 | python3 experiments.py catboost gpu news20
36 |
--------------------------------------------------------------------------------
/python/benchmarks/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/python/benchmarks/utils/__init__.py
--------------------------------------------------------------------------------
/python/benchmarks/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import pandas as pd
4 | from sklearn import datasets
5 | import pickle
6 | import numpy as np
7 |
8 | if sys.version_info[0] >= 3:
9 | from urllib.request import urlretrieve # pylint: disable=import-error,no-name-in-module
10 | else:
11 | from urllib import urlretrieve # pylint: disable=import-error,no-name-in-module
12 |
13 | # to make a base directory of datasets
14 | BASE_DATASET_DIR = 'datasets'
15 | if not os.path.exists(BASE_DATASET_DIR):
16 | print("Made dir: %s" % BASE_DATASET_DIR)
17 | os.mkdir(BASE_DATASET_DIR)
18 |
19 |
20 | def convert_to_plk(filename, plk_filename, data_url):
21 | if not os.path.isfile(plk_filename):
22 | print('Downloading dataset, please wait...')
23 | urlretrieve(data_url, filename)
24 | X, y = datasets.load_svmlight_file(filename)
25 | pickle.dump((X, y), open(plk_filename, 'wb'))
26 | os.remove(filename)
27 | print("----------- loading dataset %s -----------" % filename)
28 | X, y = pickle.load(open(plk_filename, 'rb'))
29 | print("----------- Finished loading dataset %s -----------" % filename)
30 | return X, y
31 |
32 |
33 | # -------------------------------------------------------------------------------------------------
34 | get_higgs_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/HIGGS.bz2' # pylint: disable=line-too-long
35 | def get_higgs(num_rows=None):
36 | filename = os.path.join(BASE_DATASET_DIR, 'HIGGS.bz2')
37 | plk_filename = filename + '.plk'
38 | X, y = convert_to_plk(filename, plk_filename, data_url=get_higgs_url)
39 |
40 | return X.toarray(), y
41 |
42 |
43 | # -------------------------------------------------------------------------------------------------
44 | def get_covtype(num_rows=None):
45 | data = datasets.fetch_covtype()
46 | X = data.data
47 | y = data.target
48 | if num_rows is not None:
49 | X = X[0:num_rows]
50 | y = y[0:num_rows]
51 |
52 | return X, y
53 |
54 |
55 | # -------------------------------------------------------------------------------------------------
56 | get_e2006_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/E2006.train.bz2' # pylint: disable=line-too-long
57 | def get_e2006(num_rows=None):
58 | filename = os.path.join(BASE_DATASET_DIR, 'E2006.train')
59 | plk_filename = filename + '.plk'
60 | X, y = convert_to_plk(filename, plk_filename, data_url=get_e2006_url)
61 |
62 | return X.toarray(), y
63 |
64 |
65 | # -------------------------------------------------------------------------------------------------
66 | get_lop1p_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/log1p.E2006.train.bz2'
67 | def get_log1p(num_rows=None):
68 | filename = os.path.join(BASE_DATASET_DIR, 'log1p.train.bz2')
69 | plk_filename = filename + '.plk'
70 | X, y = convert_to_plk(filename, plk_filename, data_url=get_lop1p_url)
71 |
72 | return X, y
73 |
74 |
75 | # -------------------------------------------------------------------------------------------------
76 | get_news20_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/news20.binary.bz2'
77 | def get_news20(num_rows=None):
78 | filename = os.path.join(BASE_DATASET_DIR, 'news20.binary.bz2')
79 | plk_filename = filename + '.plk'
80 | X, y = convert_to_plk(filename, plk_filename, data_url=get_news20_url)
81 |
82 | return X.toarray(), y
83 |
84 |
85 | # -------------------------------------------------------------------------------------------------
86 | get_realsim_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/real-sim.bz2'
87 | def get_realsim(num_rows=None):
88 | filename = os.path.join(BASE_DATASET_DIR, 'real-sim.bz2')
89 | plk_filename = filename + '.plk'
90 | X, y = convert_to_plk(filename, plk_filename, data_url=get_realsim_url)
91 |
92 | return X.toarray(), y
93 |
94 |
95 | # -------------------------------------------------------------------------------------------------
96 | get_susy_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/SUSY.bz2'
97 | def get_susy(num_rows=None):
98 | filename = os.path.join(BASE_DATASET_DIR, 'susy.bz2')
99 | plk_filename = filename + '.plk'
100 | X, y = convert_to_plk(filename, plk_filename, data_url=get_susy_url)
101 |
102 | return X.toarray(), y
103 |
104 |
105 | # -------------------------------------------------------------------------------------------------
106 | get_epslion_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/epsilon_normalized.bz2'
107 | def get_epsilon(num_rows=None):
108 | filename = os.path.join(BASE_DATASET_DIR, 'epsilon_normalized')
109 | plk_filename = filename + '.plk'
110 | X, y = convert_to_plk(filename, plk_filename, data_url=get_epslion_url)
111 |
112 | return X.toarray(), y
113 |
114 |
115 | # -------------------------------------------------------------------------------------------------
116 | get_cifar_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/cifar10.bz2'
117 | def get_cifar(num_rows=None):
118 | filename = os.path.join(BASE_DATASET_DIR, 'cifar')
119 | plk_filename = filename + '.plk'
120 | X, y = convert_to_plk(filename, plk_filename, data_url=get_cifar_url)
121 |
122 | return X.toarray(), y
123 |
124 |
125 | # -------------------------------------------------------------------------------------------------
126 | get_ins_url = ''
127 | def get_ins(num_rows=None):
128 | filename = BASE_DATASET_DIR + '/' + 'ins'
129 | plk_filename = filename + '.plk'
130 | X, y = convert_to_plk(filename, plk_filename, data_url=get_ins_url)
131 |
132 | return X.toarray(), y
133 |
134 |
135 | get_cifar_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/cifar10.bz2'
136 | def get_cifar10(num_rows=None):
137 | filename = BASE_DATASET_DIR + '/' + 'cifar10'
138 | plk_filename = filename + '.plk'
139 | X, y = convert_to_plk(filename, plk_filename, data_url=get_cifar_url)
140 |
141 | return X.toarray(), y
142 |
143 |
144 |
145 | get_news20_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/news20.bz2'
146 | def get_news20(num_rows=None):
147 | filename = BASE_DATASET_DIR + '/' + 'news20.bz2'
148 | plk_filename = filename + '.plk'
149 | X, y = convert_to_plk(filename, plk_filename, data_url=get_news20_url)
150 |
151 | return X.toarray(), y
152 |
153 |
154 | def load_groups(filename):
155 | groups = []
156 | with open(filename) as f:
157 | for line in f:
158 | if line.strip() != '':
159 | groups.append(int(line.strip()))
160 |
161 | return np.asarray(groups)
162 |
163 |
164 | get_yahoo_url = ''
165 | def get_yahoo():
166 | filename = BASE_DATASET_DIR + "/" + 'yahoo-ltr-libsvm'
167 | group_filename = BASE_DATASET_DIR + "/" + 'yahoo-ltr-libsvm.group'
168 | plk_filename = filename + '.plk'
169 | X, y = convert_to_plk(filename, plk_filename, data_url=get_yahoo_url)
170 | groups = load_groups(group_filename)
171 |
172 | return X.toarray(), y, groups
173 |
174 | if __name__ == "__main__":
175 | X, y, groups = get_yahoo();
176 | print(X.shape)
177 | print(y.shape)
--------------------------------------------------------------------------------
/python/benchmarks/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 |
4 | def add_data(df, algorithm, data, elapsed, metric):
5 | time_col = (data.name, 'Time(s)')
6 | metric_col = (data.name, data.metric)
7 | try:
8 | df.insert(len(df.columns), time_col, '-')
9 | df.insert(len(df.columns), metric_col, '-')
10 | except:
11 | pass
12 |
13 | df.at[algorithm, time_col] = elapsed
14 | df.at[algorithm, metric_col] = metric
15 |
16 |
17 | def write_results(df, filename, format):
18 | if format == "latex":
19 | tmp_df = df.copy()
20 | tmp_df.columns = pd.MultiIndex.from_tuples(tmp_df.columns)
21 | with open(filename, "a") as file:
22 | file.write(tmp_df.to_latex())
23 | elif format == "csv":
24 | with open(filename, "a") as file:
25 | file.write(df.to_csv())
26 | else:
27 | raise ValueError("Unknown format: " + format)
28 |
29 | print(format + " results written to: " + filename)
--------------------------------------------------------------------------------
/python/dist/thundergbm-0.3.12-py2-none-win_amd64.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/python/dist/thundergbm-0.3.12-py2-none-win_amd64.whl
--------------------------------------------------------------------------------
/python/dist/thundergbm-0.3.12-py3-none-win_amd64.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/python/dist/thundergbm-0.3.12-py3-none-win_amd64.whl
--------------------------------------------------------------------------------
/python/dist/thundergbm-0.3.16-py3-none-any.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/python/dist/thundergbm-0.3.16-py3-none-any.whl
--------------------------------------------------------------------------------
/python/dist/thundergbm-0.3.4-py3-none-win_amd64.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/thundergbm/e3f824e4bdeba9230f4230b121f87119f63c425c/python/dist/thundergbm-0.3.4-py3-none-win_amd64.whl
--------------------------------------------------------------------------------
/python/examples/classification_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("../")
3 |
4 | from thundergbm import TGBMClassifier
5 | from sklearn.datasets import load_digits
6 | from sklearn.metrics import accuracy_score
7 |
8 | if __name__ == '__main__':
9 | x, y = load_digits(return_X_y=True)
10 | clf = TGBMClassifier()
11 | clf.fit(x, y)
12 | y_pred = clf.predict(x)
13 | accuracy = accuracy_score(y, y_pred)
14 | print(accuracy)
--------------------------------------------------------------------------------
/python/examples/ranking_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("../")
3 |
4 | import thundergbm
5 | from sklearn.datasets import load_svmlight_file
6 |
7 | def load_test_data(file_path):
8 | X, y = load_svmlight_file(file_path)
9 |
10 | return X, y
11 |
12 | def load_groups(file_path):
13 | groups = []
14 | with open(file_path, 'r') as f:
15 | for line in f.readlines():
16 | if line.strip() != '':
17 | groups.append(int(line.strip()))
18 |
19 | return groups
20 |
21 | if __name__ == "__main__":
22 | X, y = load_test_data('../../dataset/test_dataset.txt')
23 | groups = load_groups('../../dataset/test_dataset.txt.group')
24 | # tgbm_model = thundergbm.TGBMClassifier(depth=6, n_trees=40, objective='rank:ndcg')
25 | tgbm_model = thundergbm.TGBMRanker(depth=6, n_trees=40, objective='rank:ndcg')
26 | tgbm_model.fit(X, y, groups)
27 | pred_result = tgbm_model.predict(X, groups)
28 | print(pred_result)
29 |
30 |
--------------------------------------------------------------------------------
/python/examples/regression_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("../")
3 |
4 | from thundergbm import TGBMRegressor
5 | from sklearn.datasets import load_boston
6 | from sklearn.metrics import mean_squared_error
7 | from math import sqrt
8 |
9 | if __name__ == '__main__':
10 | x, y = load_boston(return_X_y=True)
11 | clf = TGBMRegressor()
12 | clf.fit(x, y)
13 | y_pred = clf.predict(x)
14 | rmse = sqrt(mean_squared_error(y, y_pred))
15 | print(rmse)
--------------------------------------------------------------------------------
/python/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy
2 | sklearn
3 | numpy
4 |
--------------------------------------------------------------------------------
/python/setup.py:
--------------------------------------------------------------------------------
1 | from os import path
2 | import setuptools
3 | from shutil import copyfile
4 | from sys import platform
5 |
6 | dirname = path.dirname(path.abspath(__file__))
7 |
8 | if platform == "linux" or platform == "linux2":
9 | lib_path = path.abspath(path.join(dirname, '../build/lib/libthundergbm.so'))
10 | elif platform == "win32":
11 | lib_path = path.abspath(path.join(dirname, '../build/bin/Debug/thundergbm.dll'))
12 | elif platform == "darwin":
13 | lib_path = path.abspath(path.join(dirname, '../build/lib/libthundergbm.dylib'))
14 | else:
15 | print("OS not supported!")
16 | exit()
17 | if not path.exists(path.join(dirname, "thundergbm", path.basename(lib_path))):
18 | copyfile(lib_path, path.join(dirname, "thundergbm", path.basename(lib_path)))
19 | setuptools.setup(name="thundergbm",
20 | version="0.3.16",
21 | packages=["thundergbm"],
22 | package_dir={"python": "thundergbm"},
23 | description="A Fast GBM Library on GPUs and CPUs",
24 | long_description="""The mission of ThunderGBM is to help users easily and efficiently apply GBDTs and Random Forests to solve problems. ThunderGBM exploits GPUs and multi-core CPUs to achieve high efficiency""",
25 | long_description_content_type="text/plain",
26 | license='Apache-2.0',
27 | author='Xtra Computing Group',
28 | maintainer='thundergbm contributors',
29 | maintainer_email='wenzy@comp.nus.edu.sg',
30 | url="https://github.com/zeyiwen/thundergbm",
31 | package_data={"thundergbm": [path.basename(lib_path)]},
32 | install_requires=['numpy', 'scipy', 'scikit-learn'],
33 | classifiers=[
34 | "Programming Language :: Python :: 3",
35 | "License :: OSI Approved :: Apache Software License",
36 | ],
37 | )
38 |
--------------------------------------------------------------------------------
/python/thundergbm/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*-coding:utf-8-*-
3 | """
4 | * Name : __init__.py
5 | * Author : Locke
6 | * Version : 0.0.1
7 | * Description :
8 | """
9 | name = "thundergbm"
10 | from .thundergbm import *
11 |
--------------------------------------------------------------------------------
/src/test/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | add_subdirectory(googletest)
2 |
3 | include_directories(googletest/googletest/include)
4 | include_directories(googletest/googlemock/include)
5 |
6 | file(GLOB TEST_SRC *)
7 |
8 | cuda_add_executable(${PROJECT_NAME}-test ${TEST_SRC} ${COMMON_INCLUDES})
9 | target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME} gtest)
10 |
11 |
--------------------------------------------------------------------------------
/src/test/test_cub_wrapper.cu:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "thundergbm/util/cub_wrapper.h"
3 | #include "thundergbm/syncarray.h"
4 |
5 | class CubWrapperTest: public::testing::Test {
6 | public:
7 | SyncArray values;
8 | SyncArray keys;
9 |
10 | protected:
11 | void SetUp() override {
12 | values.resize(4);
13 | keys.resize(4);
14 | auto values_data = values.host_data();
15 | values_data[0] = 1;
16 | values_data[1] = 4;
17 | values_data[2] = -1;
18 | values_data[3] = 0;
19 | auto keys_data = keys.host_data();
20 | keys_data[0] = 1;
21 | keys_data[1] = 3;
22 | keys_data[2] = 2;
23 | keys_data[3] = 4;
24 | }
25 |
26 | };
27 |
28 | TEST_F(CubWrapperTest, test_cub_sort_by_key) {
29 | // keys from [1, 3, 2, 4] to [1, 2, 3, 4]
30 | // values from [1, 4, -1, 0] to [1, -1, 4, 0]
31 | cub_sort_by_key(keys, values);
32 | auto keys_data = keys.host_data();
33 | auto values_data = values.host_data();
34 | EXPECT_EQ(keys_data[0], 1);
35 | EXPECT_EQ(keys_data[1], 2);
36 | EXPECT_EQ(keys_data[2], 3);
37 | EXPECT_EQ(keys_data[3], 4);
38 | EXPECT_EQ(values_data[0], 1);
39 | EXPECT_EQ(values_data[1], -1);
40 | EXPECT_EQ(values_data[2], 4);
41 | EXPECT_EQ(values_data[3], 0);
42 | }
43 |
44 | TEST_F(CubWrapperTest, test_cub_seg_sort_by_key) {
45 | SyncArray seg_ptr(3); // 2 segments
46 | auto seg_ptr_data = seg_ptr.host_data();
47 | seg_ptr_data[0] = 0;
48 | seg_ptr_data[1] = 2;
49 | seg_ptr_data[2] = 4;
50 | // keys from [1, 3, 2, 4] to [1, 3, 2, 4]
51 | // values from [1, 4, -1, 0] to [1, 4, -1, 0]
52 | /*cub_seg_sort_by_key(keys, values, seg_ptr);*/
53 | auto keys_data = keys.host_data();
54 | auto values_data = values.host_data();
55 | EXPECT_EQ(keys_data[0], 1);
56 | EXPECT_EQ(keys_data[1], 3);
57 | EXPECT_EQ(keys_data[2], 2);
58 | EXPECT_EQ(keys_data[3], 4);
59 | EXPECT_EQ(values_data[0], 1);
60 | EXPECT_EQ(values_data[1], 4);
61 | EXPECT_EQ(values_data[2], -1);
62 | EXPECT_EQ(values_data[3], 0);
63 | }
64 |
65 | TEST_F(CubWrapperTest, test_sort_array) {
66 | sort_array(values);
67 | auto values_data = values.host_data();
68 | EXPECT_EQ(values_data[0], -1);
69 | EXPECT_EQ(values_data[1], 0);
70 | EXPECT_EQ(values_data[2], 1);
71 | EXPECT_EQ(values_data[3], 4);
72 | }
73 |
74 | TEST_F(CubWrapperTest, test_max_elem) {
75 | int max_elem = max_elements(values);
76 | EXPECT_EQ(max_elem, 4);
77 | }
78 |
79 |
--------------------------------------------------------------------------------
/src/test/test_dataset.cpp:
--------------------------------------------------------------------------------
1 | //
2 | // Created by jiashuai on 17-9-17.
3 | //
4 | #include "gtest/gtest.h"
5 | #include "thundergbm/dataset.h"
6 | #include "omp.h"
7 |
8 | class DatasetTest : public ::testing::Test {
9 | public:
10 | GBMParam param;
11 | vector csr_val;
12 | vector csr_row_ptr;
13 | vector csr_col_idx;
14 | vector y;
15 | size_t n_features_;
16 | vector label;
17 | protected:
18 | void SetUp() override {
19 | param.depth = 6;
20 | param.n_trees = 40;
21 | param.n_device = 1;
22 | param.min_child_weight = 1;
23 | param.lambda = 1;
24 | param.gamma = 1;
25 | param.rt_eps = 1e-6;
26 | param.max_num_bin = 255;
27 | param.verbose = false;
28 | param.profiling = false;
29 | param.column_sampling_rate = 1;
30 | param.bagging = false;
31 | param.n_parallel_trees = 1;
32 | param.learning_rate = 1;
33 | param.objective = "reg:linear";
34 | param.num_class = 1;
35 | param.path = "../dataset/test_dataset.txt";
36 | param.tree_method = "auto";
37 | if (!param.verbose) {
38 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false");
39 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false");
40 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "True");
41 | }
42 | if (!param.profiling) {
43 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false");
44 | }
45 | }
46 |
47 | void load_from_file(string file_name, GBMParam ¶m) {
48 | LOG(INFO) << "loading LIBSVM dataset from file \"" << file_name << "\"";
49 | y.clear();
50 | csr_row_ptr.resize(1, 0);
51 | csr_col_idx.clear();
52 | csr_val.clear();
53 | n_features_ = 0;
54 |
55 | std::ifstream ifs(file_name, std::ifstream::binary);
56 | CHECK(ifs.is_open()) << "file " << file_name << " not found";
57 |
58 | int buffer_size = 4 << 20;
59 | char *buffer = (char *)malloc(buffer_size);
60 | //array may cause stack overflow in windows
61 | //std::array buffer{};
62 | const int nthread = omp_get_max_threads();
63 |
64 | auto find_last_line = [](char *ptr, const char *begin) {
65 | while (ptr != begin && *ptr != '\n' && *ptr != '\r' && *ptr != '\0') --ptr;
66 | return ptr;
67 | };
68 |
69 | while (ifs) {
70 | ifs.read(buffer, buffer_size);
71 | char *head = buffer;
72 | //ifs.read(buffer.data(), buffer.size());
73 | //char *head = buffer.data();
74 | size_t size = ifs.gcount();
75 | vector> y_(nthread);
76 | vector> col_idx_(nthread);
77 | vector> row_len_(nthread);
78 | vector> val_(nthread);
79 |
80 | vector max_feature(nthread, 0);
81 | bool is_zeor_base = false;
82 |
83 | #pragma omp parallel num_threads(nthread)
84 | {
85 | //get working area of this thread
86 | int tid = omp_get_thread_num();
87 | size_t nstep = (size + nthread - 1) / nthread;
88 | size_t sbegin = (std::min)(tid * nstep, size - 1);
89 | size_t send = (std::min)((tid + 1) * nstep, size - 1);
90 | char *pbegin = find_last_line(head + sbegin, head);
91 | char *pend = find_last_line(head + send, head);
92 |
93 | //move stream start position to the end of last line
94 | if (tid == nthread - 1) {
95 | if (ifs.eof())
96 | pend = head + send;
97 | else
98 | ifs.seekg(-(head + send - pend), std::ios_base::cur);
99 | }
100 |
101 | //read instances line by line
102 | //TODO optimize parse line
103 | char *lbegin = pbegin;
104 | char *lend = lbegin;
105 | while (lend != pend) {
106 | //get one line
107 | lend = lbegin + 1;
108 | while (lend != pend && *lend != '\n' && *lend != '\r' && *lend != '\0') {
109 | ++lend;
110 | }
111 | string line(lbegin, lend);
112 | if (line != "\n") {
113 | std::stringstream ss(line);
114 |
115 | //read label of an instance
116 | y_[tid].push_back(0);
117 | ss >> y_[tid].back();
118 |
119 | row_len_[tid].push_back(0);
120 | string tuple;
121 | while (ss >> tuple) {
122 | int i;
123 | float v;
124 | CHECK_EQ(sscanf(tuple.c_str(), "%d:%f", &i, &v), 2)
125 | << "read error, using [index]:[value] format";
126 | //TODO one-based and zero-based
127 | col_idx_[tid].push_back(i - 1);//one based
128 | if(i - 1 == -1){
129 | is_zeor_base = true;
130 | }
131 | CHECK_GE(i - 1, -1) << "dataset format error";
132 | val_[tid].push_back(v);
133 | if (i > max_feature[tid]) {
134 | max_feature[tid] = i;
135 | }
136 | row_len_[tid].back()++;
137 | }
138 | }
139 | //read next instance
140 | lbegin = lend;
141 |
142 | }
143 | }
144 | for (int i = 0; i < nthread; i++) {
145 | if (max_feature[i] > n_features_)
146 | n_features_ = max_feature[i];
147 | }
148 | for (int tid = 0; tid < nthread; tid++) {
149 | csr_val.insert(csr_val.end(), val_[tid].begin(), val_[tid].end());
150 | if(is_zeor_base){
151 | for (int i = 0; i < col_idx_[tid].size(); ++i) {
152 | col_idx_[tid][i]++;
153 | }
154 | }
155 | csr_col_idx.insert(csr_col_idx.end(), col_idx_[tid].begin(), col_idx_[tid].end());
156 | for (int row_len : row_len_[tid]) {
157 | csr_row_ptr.push_back(csr_row_ptr.back() + row_len);
158 | }
159 | }
160 | for (int i = 0; i < nthread; i++) {
161 | this->y.insert(y.end(), y_[i].begin(), y_[i].end());
162 | this->label.insert(label.end(), y_[i].begin(), y_[i].end());
163 | }
164 | }
165 | ifs.close();
166 | free(buffer);
167 | }
168 | };
169 |
170 | TEST_F(DatasetTest, load_dataset){
171 | DataSet dataset;
172 | load_from_file(param.path, param);
173 | dataset.load_from_file(param.path, param);
174 | printf("### Dataset: %s, num_instances: %d, num_features: %d. ###\n",
175 | param.path.c_str(),
176 | dataset.n_instances(),
177 | dataset.n_features());
178 | EXPECT_EQ(dataset.n_instances(), 1605);
179 | EXPECT_EQ(dataset.n_features_, 119);
180 | EXPECT_EQ(dataset.label[0], -1);
181 | EXPECT_EQ(dataset.csr_val[1], 1);
182 |
183 | for(int i = 0; i < csr_val.size(); i++)
184 | EXPECT_EQ(csr_val[i], dataset.csr_val[i]);
185 | for(int i = 0; i < csr_row_ptr.size(); i++)
186 | EXPECT_EQ(csr_row_ptr[i], dataset.csr_row_ptr[i]);
187 | for(int i = 0; i < csr_col_idx.size(); i++)
188 | EXPECT_EQ(csr_col_idx[i], dataset.csr_col_idx[i]);
189 | }
190 |
191 |
--------------------------------------------------------------------------------
/src/test/test_for_refactor.cpp:
--------------------------------------------------------------------------------
1 | //
2 | // Created by ss on 19-1-14.
3 | //
4 | #include
5 | #include
6 | #include "gtest/gtest.h"
7 | #include "thundergbm/common.h"
8 |
9 | class TrainTest : public ::testing::Test {
10 | public:
11 | GBMParam param;
12 | TreeTrainer trainer;
13 | protected:
14 | void SetUp() override {
15 | param.depth = 6;
16 | param.n_trees = 40;
17 | param.n_device = 1;
18 | param.min_child_weight = 1;
19 | param.lambda = 1;
20 | param.gamma = 1;
21 | param.rt_eps = 1e-6;
22 | param.max_num_bin = 255;
23 | param.verbose = false;
24 | param.profiling = false;
25 | param.column_sampling_rate = 1;
26 | param.bagging = false;
27 | param.n_parallel_trees = 1;
28 | param.learning_rate = 1;
29 | param.objective = "reg:linear";
30 | param.num_class = 1;
31 | param.path = "../dataset/test_dataset.txt";
32 | param.tree_method = "auto";
33 | if (!param.verbose) {
34 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false");
35 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false");
36 | }
37 | if (!param.profiling) {
38 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false");
39 | }
40 | }
41 |
42 | void TearDown() override {
43 | }
44 | };
45 |
46 | TEST_F(TrainTest, news20) {
47 | param.path = DATASET_DIR "news20.scale";
48 | // trainer.train(param, <#initializer#>);
49 | // float_type rmse = trainer.train(param);
50 | // EXPECT_NEAR(rmse, 2.55274, 1e-5);
51 | }
52 |
53 | TEST_F(TrainTest, covtype) {
54 | param.path = DATASET_DIR "covtype";
55 | // float_type rmse = trainer.train2(param);
56 | // trainer.train(param, <#initializer#>);
57 | param.bagging = true;
58 | // trainer.train(param, <#initializer#>);
59 | param.column_sampling_rate = 0.5;
60 | // trainer.train(param, <#initializer#>);
61 | // EXPECT_NEAR(rmse, 0.730795, 1e-5);
62 | }
63 |
64 | TEST_F(TrainTest, covtype_multiclass) {
65 | param.path = DATASET_DIR "covtype";
66 | param.num_class = 7;
67 | param.objective = "multi:softprob";
68 | // trainer.train(param, <#initializer#>);
69 | }
70 |
71 | TEST_F(TrainTest, mnist_multiclass) {
72 | param.path = DATASET_DIR "mnist.scale";
73 | param.objective = "multi:softprob";
74 | param.num_class = 10;
75 | // trainer.train(param, <#initializer#>);
76 | }
77 |
78 | TEST_F(TrainTest, cifar10_multiclass) {
79 | param.path = DATASET_DIR "cifar10";
80 | param.objective = "multi:softprob";
81 | param.num_class = 10;
82 | // trainer.train(param, <#initializer#>);
83 | }
84 |
85 | TEST_F(TrainTest, sector_multiclass) {
86 | param.path = DATASET_DIR "sector.scale";
87 | param.objective = "multi:softprob";
88 | param.tree_method = "hist";
89 | param.num_class = 105;
90 | // trainer.train(param, <#initializer#>);
91 | }
92 |
93 | TEST_F(TrainTest, news20_multiclass) {
94 | param.path = DATASET_DIR "news20.scale";
95 | param.objective = "multi:softprob";
96 | param.num_class = 20;
97 | // trainer.train(param, <#initializer#>);
98 | }
99 |
100 | TEST_F(TrainTest, rcv1_multiclass) {
101 | param.path = DATASET_DIR "rcv1_train.multiclass";
102 | param.objective = "multi:softprob";
103 | param.tree_method = "hist";
104 | param.num_class = 51;
105 | // trainer.train(param, <#initializer#>);
106 | }
107 |
108 | TEST_F(TrainTest, yahoo_ranking) {
109 | // param.path = "dataset/rank.train";
110 | param.path = DATASET_DIR "yahoo-ltr-libsvm";
111 | param.objective = "rank:ndcg";
112 | // trainer.train(param, <#initializer#>);
113 | // float_type rmse = trainer.train(param);
114 | }
115 |
--------------------------------------------------------------------------------
/src/test/test_gbdt.cpp:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "thundergbm/trainer.h"
3 | #include "thundergbm/predictor.h"
4 | #include "thundergbm/dataset.h"
5 | #include "thundergbm/tree.h"
6 |
7 | class GBDTTest: public ::testing::Test {
8 | public:
9 | GBMParam param;
10 | protected:
11 | void SetUp() override {
12 | param.depth = 3;
13 | param.n_trees = 5;
14 | param.n_device = 1;
15 | param.min_child_weight = 1;
16 | param.lambda = 1;
17 | param.gamma = 1;
18 | param.rt_eps = 1e-6;
19 | param.max_num_bin = 255;
20 | param.verbose = false;
21 | param.profiling = false;
22 | param.column_sampling_rate = 1;
23 | param.bagging = false;
24 | param.n_parallel_trees = 1;
25 | param.learning_rate = 1;
26 | param.objective = "reg:linear";
27 | param.num_class = 1;
28 | param.path = "../dataset/test_dataset.txt";
29 | param.tree_method = "hist";
30 | param.tree_per_rounds = 1;
31 | if (!param.verbose) {
32 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false");
33 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false");
34 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "True");
35 | }
36 | if (!param.profiling) {
37 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false");
38 | }
39 | }
40 | };
41 |
42 | // training GBDTs in hist manner
43 | TEST_F(GBDTTest, test_hist) {
44 | param.tree_method = "hist";
45 | DataSet dataset;
46 | dataset.load_from_file(param.path, param);
47 | TreeTrainer trainer;
48 | trainer.train(param, dataset);
49 | }
50 |
51 | // training GBDTs in exact manner
52 | TEST_F(GBDTTest, test_exact) {
53 | param.tree_method = "exact";
54 | DataSet dataset;
55 | dataset.load_from_file(param.path, param);
56 | TreeTrainer trainer;
57 | trainer.train(param, dataset);
58 | }
59 |
60 | // test bagging
61 | TEST_F(GBDTTest,test_bagging) {
62 | param.bagging = true;
63 | DataSet dataset;
64 | dataset.load_from_file(param.path, param);
65 | TreeTrainer trainer;
66 | trainer.train(param, dataset);
67 | }
68 |
69 | // test different number of parallel trees
70 | TEST_F(GBDTTest, test_n_parallel_trees) {
71 | param.n_parallel_trees = 2;
72 | DataSet dataset;
73 | dataset.load_from_file(param.path, param);
74 | TreeTrainer trainer;
75 | trainer.train(param, dataset);
76 | }
77 |
78 | // test predictor
79 | TEST_F(GBDTTest, test_predictor) {
80 | DataSet dataset;
81 | dataset.load_from_file(param.path, param);
82 | TreeTrainer trainer;
83 | vector> boosted_model;
84 | boosted_model = trainer.train(param, dataset);
85 | Predictor predictor;
86 | predictor.predict(param, boosted_model, dataset);
87 | }
88 |
--------------------------------------------------------------------------------
/src/test/test_get_cut_point.cpp:
--------------------------------------------------------------------------------
1 | //
2 | // Created by hanfeng on 6/11/19.
3 | //
4 |
5 | #include "gtest/gtest.h"
6 | #include "thundergbm/common.h"
7 | #include "thundergbm/sparse_columns.h"
8 | #include "thundergbm/hist_cut.h"
9 | #include "thrust/unique.h"
10 | #include "thrust/execution_policy.h"
11 | #include "thundergbm/builder/shard.h"
12 |
13 |
14 | class GetCutPointTest : public ::testing::Test {
15 | public:
16 | GBMParam param;
17 | HistCut cut;
18 |
19 | vector cut_points;
20 | vector row_ptr;
21 | //for gpu
22 | SyncArray cut_points_val;
23 | SyncArray cut_row_ptr;
24 | SyncArray cut_fid;
25 |
26 |
27 | protected:
28 | void SetUp() override {
29 | param.max_num_bin = 255;
30 | param.depth = 6;
31 | param.n_trees = 40;
32 | param.n_device = 1;
33 | param.min_child_weight = 1;
34 | param.lambda = 1;
35 | param.gamma = 1;
36 | param.rt_eps = 1e-6;
37 | param.verbose = false;
38 | param.profiling = false;
39 | param.column_sampling_rate = 1;
40 | param.bagging = false;
41 | param.n_parallel_trees = 1;
42 | param.learning_rate = 1;
43 | param.objective = "reg:linear";
44 | param.num_class = 1;
45 | param.path = "../dataset/test_dataset.txt";
46 | param.tree_method = "auto";
47 | if (!param.verbose) {
48 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false");
49 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false");
50 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "false");
51 | }
52 | if (!param.profiling) {
53 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false");
54 | }
55 | }
56 |
57 | void get_cut_points2(SparseColumns &columns, int max_num_bins, int n_instances){
58 | int n_column = columns.n_column;
59 | auto csc_val_data = columns.csc_val.host_data();
60 | auto csc_col_ptr_data = columns.csc_col_ptr.host_data();
61 | cut_points.clear();
62 | row_ptr.clear();
63 | row_ptr.resize(1, 0);
64 |
65 | //TODO do this on GPU
66 | for (int fid = 0; fid < n_column; ++fid) {
67 | int col_start = csc_col_ptr_data[fid];
68 | int col_len = csc_col_ptr_data[fid + 1] - col_start;
69 | auto val_data = csc_val_data + col_start;
70 | vector unique_val(col_len);
71 |
72 | int unique_len = thrust::unique_copy(thrust::host, val_data, val_data + col_len, unique_val.data()) - unique_val.data();
73 | if (unique_len <= max_num_bins) {
74 | row_ptr.push_back(unique_len + row_ptr.back());
75 | for (int i = 0; i < unique_len; ++i) {
76 | cut_points.push_back(unique_val[i]);
77 | }
78 | } else {
79 | row_ptr.push_back(max_num_bins + row_ptr.back());
80 | for (int i = 0; i < max_num_bins; ++i) {
81 | cut_points.push_back(unique_val[unique_len / max_num_bins * i]);
82 | }
83 | }
84 | }
85 |
86 | cut_points_val.resize(cut_points.size());
87 | cut_points_val.copy_from(cut_points.data(), cut_points.size());
88 | cut_row_ptr.resize(row_ptr.size());
89 | cut_row_ptr.copy_from(row_ptr.data(), row_ptr.size());
90 | cut_fid.resize(cut_points.size());
91 | }
92 | };
93 |
94 |
95 | TEST_F(GetCutPointTest, covtype) {
96 | //param.path = "../dataset/covtype";
97 | DataSet dataset;
98 | dataset.load_from_file(param.path, param);
99 | SparseColumns columns;
100 | //vector> v_columns(1);
101 | //columns.csr2csc_gpu(dataset, v_columns);
102 | //cut.get_cut_points2(columns, param.max_num_bin, dataset.n_instances());
103 |
104 | printf("### Dataset: %s, num_instances: %d, num_features: %d, get_cut_points finished. ###\n",
105 | param.path.c_str(),
106 | dataset.n_instances(),
107 | dataset.n_features());
108 |
109 | //this->get_cut_points2(columns, param.max_num_bin,dataset.n_instances());
110 |
111 | // --- test cut_points_val
112 | //auto gpu_cut_points_val = cut.cut_points_val.host_data();
113 | //auto cpu_cut_points_val = this->cut_points_val.host_data();
114 | //for(int i = 0; i < cut.cut_points_val.size(); i++)
115 | //EXPECT_EQ(gpu_cut_points_val[i], cpu_cut_points_val[i]);
116 |
117 | // --- test cut_row_ptr
118 | //auto gpu_cut_row_ptr = cut.cut_row_ptr.host_data();
119 | //auto cpu_cut_row_ptr = this->cut_row_ptr.host_data();
120 | //for(int i = 0; i < cut.cut_row_ptr.size(); i++)
121 | //EXPECT_EQ(gpu_cut_row_ptr[i], cpu_cut_row_ptr[i]);
122 |
123 | // --- test cut_fid
124 | //EXPECT_EQ(cut.cut_fid.size(), this->cut_fid.size());
125 | }
126 |
127 |
128 | TEST_F(GetCutPointTest, real_sim) {
129 | DataSet dataset;
130 | dataset.load_from_file(param.path, param);
131 | SparseColumns columns;
132 | //vector> v_columns(1);
133 | //columns.csr2csc_gpu(dataset, v_columns);
134 | //cut.get_cut_points2(columns, param.max_num_bin, dataset.n_instances());
135 |
136 | //printf("### Dataset: %s, num_instances: %d, num_features: %d, get_cut_points finished. ###\n",
137 | //param.path.c_str(),
138 | //dataset.n_instances(),
139 | //dataset.n_features());
140 |
141 | //this->get_cut_points2(columns, param.max_num_bin,dataset.n_instances());
142 |
143 | // --- test cut_points_val
144 | //auto gpu_cut_points_val = cut.cut_points_val.host_data();
145 | //auto cpu_cut_points_val = this->cut_points_val.host_data();
146 | //for(int i = 0; i < cut.cut_points_val.size(); i++)
147 | //EXPECT_EQ(gpu_cut_points_val[i], cpu_cut_points_val[i]);
148 |
149 | // --- test cut_row_ptr
150 | //auto gpu_cut_row_ptr = cut.cut_row_ptr.host_data();
151 | //auto cpu_cut_row_ptr = this->cut_row_ptr.host_data();
152 | //for(int i = 0; i < cut.cut_row_ptr.size(); i++)
153 | //EXPECT_EQ(gpu_cut_row_ptr[i], cpu_cut_row_ptr[i]);
154 |
155 | //// --- test cut_fid
156 | //EXPECT_EQ(cut.cut_fid.size(), this->cut_fid.size());
157 | }
158 |
159 | TEST_F(GetCutPointTest, susy) {
160 | //param.path = "../dataset/SUSY";
161 | DataSet dataset;
162 | dataset.load_from_file(param.path, param);
163 | SparseColumns columns;
164 | //vector> v_columns(1);
165 | //columns.csr2csc_gpu(dataset, v_columns);
166 | //cut.get_cut_points2(columns, param.max_num_bin, dataset.n_instances());
167 |
168 | //printf("### Dataset: %s, num_instances: %d, num_features: %d, get_cut_points finished. ###\n",
169 | //param.path.c_str(),
170 | //dataset.n_instances(),
171 | //dataset.n_features());
172 |
173 | //this->get_cut_points2(columns, param.max_num_bin,dataset.n_instances());
174 |
175 | // --- test cut_points_val
176 | //auto gpu_cut_points_val = cut.cut_points_val.host_data();
177 | //auto cpu_cut_points_val = this->cut_points_val.host_data();
178 | //for(int i = 0; i < cut.cut_points_val.size(); i++)
179 | //EXPECT_EQ(gpu_cut_points_val[i], cpu_cut_points_val[i]);
180 |
181 | // --- test cut_row_ptr
182 | //auto gpu_cut_row_ptr = cut.cut_row_ptr.host_data();
183 | //auto cpu_cut_row_ptr = this->cut_row_ptr.host_data();
184 | //for(int i = 0; i < cut.cut_row_ptr.size(); i++)
185 | //EXPECT_EQ(gpu_cut_row_ptr[i], cpu_cut_row_ptr[i]);
186 |
187 | // --- test cut_fid
188 | //EXPECT_EQ(cut.cut_fid.size(), this->cut_fid.size());
189 | }
190 |
--------------------------------------------------------------------------------
/src/test/test_gradient.cu:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "thundergbm/objective/multiclass_obj.h"
3 | #include "thundergbm/objective/regression_obj.h"
4 | #include "thundergbm/objective/ranking_obj.h"
5 | #include "thundergbm/dataset.h"
6 | #include "thundergbm/parser.h"
7 | #include "thundergbm/syncarray.h"
8 |
9 |
10 | class GradientTest: public ::testing::Test {
11 | public:
12 | GBMParam param;
13 | protected:
14 | void SetUp() override {
15 | param.depth = 6;
16 | param.n_trees = 40;
17 | param.n_device = 1;
18 | param.min_child_weight = 1;
19 | param.lambda = 1;
20 | param.gamma = 1;
21 | param.rt_eps = 1e-6;
22 | param.max_num_bin = 255;
23 | param.verbose = false;
24 | param.profiling = false;
25 | param.column_sampling_rate = 1;
26 | param.bagging = false;
27 | param.n_parallel_trees = 1;
28 | param.learning_rate = 1;
29 | param.objective = "reg:linear";
30 | param.num_class = 2;
31 | param.path = "../dataset/test_dataset.txt";
32 | param.tree_method = "hist";
33 | if (!param.verbose) {
34 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false");
35 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false");
36 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "True");
37 | }
38 | if (!param.profiling) {
39 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false");
40 | }
41 | }
42 | };
43 |
44 | TEST_F(GradientTest, test_softmax_obj) {
45 | DataSet dataset;
46 | dataset.load_from_file(param.path, param);
47 | param.num_class = 2;
48 | dataset.label.resize(2);
49 | dataset.label[0] = 0;
50 | dataset.label[1] = 1;
51 | Softmax softmax;
52 | softmax.configure(param, dataset);
53 |
54 | // check the metric name
55 | EXPECT_EQ(softmax.default_metric_name(), "macc");
56 | SyncArray y_true(4);
57 | SyncArray y_pred(8);
58 | auto y_pred_data = y_pred.host_data();
59 | for(int i = 0; i < 8; i++)
60 | y_pred_data[i] = -i;
61 | SyncArray gh_pair(8);
62 | softmax.get_gradient(y_true, y_pred, gh_pair);
63 |
64 | // test the transform function
65 | EXPECT_EQ(y_pred.size(), 8);
66 | softmax.predict_transform(y_pred);
67 | EXPECT_EQ(y_pred.size(), 4);
68 | }
69 |
70 | TEST_F(GradientTest, test_softmaxprob_obj) {
71 | DataSet dataset;
72 | dataset.load_from_file(param.path, param);
73 | param.num_class = 2;
74 | dataset.label.resize(2);
75 | dataset.label[0] = 0;
76 | dataset.label[1] = 1;
77 | SoftmaxProb smp;
78 | smp.configure(param, dataset);
79 |
80 | // check the metric name
81 | EXPECT_EQ(smp.default_metric_name(), "macc");
82 | SyncArray y_true(4);
83 | SyncArray y_pred(8);
84 | auto y_pred_data = y_pred.host_data();
85 | for(int i = 0; i < 8; i++)
86 | y_pred_data[i] = -i;
87 | SyncArray gh_pair(8);
88 | smp.get_gradient(y_true, y_pred, gh_pair);
89 |
90 | // test the transform function
91 | EXPECT_EQ(y_pred.size(), 8);
92 | smp.predict_transform(y_pred);
93 | EXPECT_EQ(y_pred.size(), 8);
94 | }
95 |
96 | TEST_F(GradientTest, test_regression_obj) {
97 | DataSet dataset;
98 | dataset.load_from_file(param.path, param);
99 | RegressionObj rmse;
100 | SyncArray y_true(4);
101 | SyncArray y_pred(4);
102 | auto y_pred_data = y_pred.host_data();
103 | for(int i = 0; i < 4; i++)
104 | y_pred_data[i] = -i;
105 | SyncArray gh_pair(4);
106 | EXPECT_EQ(rmse.default_metric_name(), "rmse");
107 | rmse.get_gradient(y_true, y_pred, gh_pair);
108 | }
109 |
110 | TEST_F(GradientTest, test_logcls_obj) {
111 | DataSet dataset;
112 | dataset.load_from_file(param.path, param);
113 | LogClsObj logcls;
114 | SyncArray y_true(4);
115 | SyncArray y_pred(4);
116 | auto y_pred_data = y_pred.host_data();
117 | for(int i = 0; i < 4; i++)
118 | y_pred_data[i] = -i;
119 | SyncArray gh_pair(4);
120 | EXPECT_EQ(logcls.default_metric_name(), "error");
121 | logcls.get_gradient(y_true, y_pred, gh_pair);
122 | }
123 |
124 | /*TEST_F(GradientTest, test_squareloss_obj) {*/
125 | /*DataSet dataset;*/
126 | /*dataset.load_from_file(param.path, param);*/
127 | /*}*/
128 |
--------------------------------------------------------------------------------
/src/test/test_main.cpp:
--------------------------------------------------------------------------------
1 | //
2 | // Created by jiashuai on 17-9-15.
3 | //
4 | #include
5 | #include "gtest/gtest.h"
6 | #ifdef _WIN32
7 | INITIALIZE_EASYLOGGINGPP
8 | #endif
9 | int main(int argc, char **argv) {
10 | ::testing::InitGoogleTest(&argc, argv);
11 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Format, "%datetime %level %fbase:%line : %msg");
12 | el::Loggers::addFlag(el::LoggingFlag::ColoredTerminalOutput);
13 | el::Loggers::addFlag(el::LoggingFlag::FixedTimeFormat);
14 | return RUN_ALL_TESTS();
15 | }
16 |
--------------------------------------------------------------------------------
/src/test/test_metrics.cpp:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "thundergbm/metric/metric.h"
3 | #include "thundergbm/metric/multiclass_metric.h"
4 | #include "thundergbm/metric/pointwise_metric.h"
5 | #include "thundergbm/metric/ranking_metric.h"
6 |
7 |
8 | class MetricTest : public ::testing::Test {
9 | public:
10 | GBMParam param;
11 | vector csr_val;
12 | vector csr_row_ptr;
13 | vector csr_col_idx;
14 | vector y;
15 | size_t n_features_;
16 | vector label;
17 | protected:
18 | void SetUp() override {
19 | param.depth = 6;
20 | param.n_trees = 40;
21 | param.n_device = 1;
22 | param.min_child_weight = 1;
23 | param.lambda = 1;
24 | param.gamma = 1;
25 | param.rt_eps = 1e-6;
26 | param.max_num_bin = 255;
27 | param.verbose = false;
28 | param.profiling = false;
29 | param.column_sampling_rate = 1;
30 | param.bagging = false;
31 | param.n_parallel_trees = 1;
32 | param.learning_rate = 1;
33 | param.objective = "reg:linear";
34 | param.num_class = 2;
35 | param.path = "../dataset/test_dataset.txt";
36 | param.tree_method = "hist";
37 | if (!param.verbose) {
38 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false");
39 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false");
40 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "True");
41 | }
42 | if (!param.profiling) {
43 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false");
44 | }
45 | }
46 | };
47 |
48 |
49 | TEST_F(MetricTest, test_multiclass_metric_config) {
50 | DataSet dataset;
51 | dataset.load_from_file(param.path, param);
52 | MulticlassAccuracy mmetric;
53 | dataset.y.resize(4);
54 | dataset.label.resize(2);
55 | mmetric.configure(param, dataset);
56 | EXPECT_EQ(mmetric.get_name(), "multi-class accuracy");
57 | }
58 |
59 | TEST_F(MetricTest, test_multiclass_metric_score) {
60 | DataSet dataset;
61 | dataset.load_from_file(param.path, param);
62 | MulticlassAccuracy mmetric;
63 | dataset.y.resize(4);
64 | dataset.label.resize(2);
65 | mmetric.configure(param, dataset);
66 |
67 | SyncArray y_pred(8);
68 | auto y_pred_data = y_pred.host_data();
69 | y_pred_data[0] = 1;
70 | EXPECT_EQ(mmetric.get_score(y_pred), 0) << mmetric.get_score(y_pred);
71 | }
72 |
73 | TEST_F(MetricTest, test_binaryclass_metric_score) {
74 | DataSet dataset;
75 | dataset.load_from_file(param.path, param);
76 | BinaryClassMetric mmetric;
77 | dataset.y.resize(4);
78 | dataset.label.resize(2);
79 | mmetric.configure(param, dataset);
80 |
81 | SyncArray y_pred(4);
82 | auto y_pred_data = y_pred.host_data();
83 | y_pred_data[0] = 1;
84 | EXPECT_EQ(mmetric.get_score(y_pred), 1) << mmetric.get_score(y_pred);
85 | }
86 |
87 |
88 | TEST_F(MetricTest, test_pointwise_metric_score) {
89 | DataSet dataset;
90 | dataset.load_from_file(param.path, param);
91 | RMSE mmetric;
92 | dataset.y.resize(4);
93 | mmetric.configure(param, dataset);
94 |
95 | SyncArray y_pred(4);
96 | EXPECT_EQ(mmetric.get_score(y_pred), 1) << mmetric.get_score(y_pred);
97 | }
98 |
--------------------------------------------------------------------------------
/src/test/test_parser.cpp:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "thundergbm/parser.h"
3 | #include "thundergbm/dataset.h"
4 | #include "thundergbm/tree.h"
5 |
6 | class ParserTest: public ::testing::Test {
7 | public:
8 | GBMParam param;
9 | vector