├── .gitignore ├── README.md ├── bench.py ├── environment.yml ├── images ├── binned-gradient-hess.png ├── binning101.png ├── gradient-split-points.png ├── histogram_subtraction.png ├── hyperp_default.png ├── hyperp_learning_rate_05.png ├── hyperp_max_depth_3.png ├── hyperp_max_leaf_nodes_200.png ├── hyperp_min_samples_leaf_100.png ├── hyperp_n_iter_no_change_10.png ├── hyperp_tol_1e-2.png ├── images_generator.ipynb ├── tree-growing-1.png └── tree-growing-2.png ├── presentation.md └── presentation.pdf /.gitignore: -------------------------------------------------------------------------------- 1 | cache/ 2 | HIGGS.csv.gz 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | 131 | catboost_info 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Dive into scikit-learn's HistGradientBoosting Classifier and Regressor 2 | 3 | - [Link to slides](https://github.com/thomasjpfan/pydata-2019-histgradientboosting/blob/master/presentation.pdf) 4 | - [Link to youtube](https://youtu.be/J9QQ6l_HToU) 5 | 6 | ## Running benchmarks 7 | 8 | 0. Install anaconda 9 | 10 | 1. Setup environment 11 | 12 | ```bash 13 | conda env create -f environment.yml 14 | conda activate 2019-pydata-nyc-hist 15 | ``` 16 | 17 | 2. Run benchmarks for each library 18 | 19 | First run will download the HIGGS dataset which is 2.6 GB! 20 | 21 | ```bash 22 | # This is the number of cores (no hyperthreading) 23 | export OMP_NUM_THREADS=12 24 | python bench.py sklearn 25 | python bench.py catboost 26 | python bench.py lightgbm 27 | python bench.py xgboost 28 | ``` 29 | -------------------------------------------------------------------------------- /bench.py: -------------------------------------------------------------------------------- 1 | from urllib.request import urlretrieve 2 | import os 3 | from gzip import GzipFile 4 | from time import time 5 | import argparse 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from joblib import Memory 10 | from sklearn.model_selection import train_test_split 11 | from sklearn.metrics import accuracy_score, roc_auc_score 12 | # To use this experimental feature, we need to explicitly ask for it: 13 | from sklearn.experimental import enable_hist_gradient_boosting # noqa 14 | from sklearn.ensemble import HistGradientBoostingClassifier 15 | from sklearn.ensemble._hist_gradient_boosting.utils import ( 16 | get_equivalent_estimator) 17 | 18 | 19 | HERE = os.path.dirname(__file__) 20 | URL = ("https://archive.ics.uci.edu/ml/machine-learning-databases/00280/" 21 | "HIGGS.csv.gz") 22 | m = Memory(location='./cache', mmap_mode='r') 23 | 24 | 25 | @m.cache 26 | def load_data(): 27 | filename = os.path.join(HERE, URL.rsplit('/', 1)[-1]) 28 | if not os.path.exists(filename): 29 | print(f"Downloading {URL} to {filename} (2.6 GB)...") 30 | urlretrieve(URL, filename) 31 | print("done.") 32 | 33 | print(f"Parsing {filename}...") 34 | tic = time() 35 | with GzipFile(filename) as f: 36 | df = pd.read_csv(f, header=None, dtype=np.float32) 37 | toc = time() 38 | print(f"Loaded {df.values.nbytes / 1e9:0.3f} GB in {toc - tic:0.3f}s") 39 | return df 40 | 41 | 42 | if __name__ == '__main__': 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('library', choices=['sklearn', 'lightgbm', 45 | 'xgboost', 'catboost']) 46 | parser.add_argument('--n-trees', type=int, default=100) 47 | 48 | args = parser.parse_args() 49 | 50 | n_trees = args.n_trees 51 | 52 | df = load_data() 53 | target = df.values[:, 0] 54 | data = np.ascontiguousarray(df.values[:, 1:]) 55 | data_train, data_test, target_train, target_test = train_test_split( 56 | data, target, test_size=.2, random_state=0) 57 | 58 | n_samples, n_features = data_train.shape 59 | print(f"Training set with {n_samples} records with {n_features} features.") 60 | 61 | est = HistGradientBoostingClassifier(loss='binary_crossentropy', 62 | max_iter=n_trees, 63 | n_iter_no_change=None, 64 | random_state=0, 65 | verbose=1) 66 | 67 | if args.library == 'sklearn': 68 | print("Fitting a sklearn model...") 69 | tic = time() 70 | est.fit(data_train, target_train) 71 | toc = time() 72 | predicted_test = est.predict(data_test) 73 | predicted_proba_test = est.predict_proba(data_test) 74 | roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1]) 75 | acc = accuracy_score(target_test, predicted_test) 76 | print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}, " 77 | f"ACC: {acc :.4f}") 78 | 79 | elif args.library == 'lightgbm': 80 | print("Fitting a LightGBM model...") 81 | tic = time() 82 | lightgbm_est = get_equivalent_estimator(est, lib='lightgbm') 83 | lightgbm_est.fit(data_train, target_train) 84 | toc = time() 85 | predicted_test = lightgbm_est.predict(data_test) 86 | predicted_proba_test = lightgbm_est.predict_proba(data_test) 87 | roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1]) 88 | acc = accuracy_score(target_test, predicted_test) 89 | print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}, " 90 | f"ACC: {acc :.4f}") 91 | 92 | elif args.library == 'xgboost': 93 | print("Fitting an XGBoost model...") 94 | tic = time() 95 | xgboost_est = get_equivalent_estimator(est, lib='xgboost') 96 | xgboost_est.fit(data_train, target_train) 97 | toc = time() 98 | predicted_test = xgboost_est.predict(data_test) 99 | predicted_proba_test = xgboost_est.predict_proba(data_test) 100 | roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1]) 101 | acc = accuracy_score(target_test, predicted_test) 102 | print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}, " 103 | f"ACC: {acc :.4f}") 104 | 105 | else: # catboost 106 | print("Fitting a Catboost model...") 107 | tic = time() 108 | catboost_est = get_equivalent_estimator(est, lib='catboost') 109 | catboost_est.fit(data_train, target_train) 110 | toc = time() 111 | predicted_test = catboost_est.predict(data_test) 112 | predicted_proba_test = catboost_est.predict_proba(data_test) 113 | roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1]) 114 | acc = accuracy_score(target_test, predicted_test) 115 | print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}, " 116 | f"ACC: {acc :.4f}") 117 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: 2019-pydata-nyc-hist 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - catboost=0.18 7 | - lightgbm=2.2 8 | - python=3.7 9 | - scikit-learn=0.21 10 | - xgboost=0.90 11 | - matplotlib 12 | - jupyterlab 13 | -------------------------------------------------------------------------------- /images/binned-gradient-hess.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/images/binned-gradient-hess.png -------------------------------------------------------------------------------- /images/binning101.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/images/binning101.png -------------------------------------------------------------------------------- /images/gradient-split-points.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/images/gradient-split-points.png -------------------------------------------------------------------------------- /images/histogram_subtraction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/images/histogram_subtraction.png -------------------------------------------------------------------------------- /images/hyperp_default.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/images/hyperp_default.png -------------------------------------------------------------------------------- /images/hyperp_learning_rate_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/images/hyperp_learning_rate_05.png -------------------------------------------------------------------------------- /images/hyperp_max_depth_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/images/hyperp_max_depth_3.png -------------------------------------------------------------------------------- /images/hyperp_max_leaf_nodes_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/images/hyperp_max_leaf_nodes_200.png -------------------------------------------------------------------------------- /images/hyperp_min_samples_leaf_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/images/hyperp_min_samples_leaf_100.png -------------------------------------------------------------------------------- /images/hyperp_n_iter_no_change_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/images/hyperp_n_iter_no_change_10.png -------------------------------------------------------------------------------- /images/hyperp_tol_1e-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/images/hyperp_tol_1e-2.png -------------------------------------------------------------------------------- /images/tree-growing-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/images/tree-growing-1.png -------------------------------------------------------------------------------- /images/tree-growing-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/images/tree-growing-2.png -------------------------------------------------------------------------------- /presentation.md: -------------------------------------------------------------------------------- 1 | slide-dividers: # 2 | slidenumbers: true 3 | 4 | # Deep Dive into scikit-learn's **HistGradientBoosting** Classifier and Regressor 5 | [.header: alignment(center), text-scale(1.7)] 6 | [.text: alignment(left), text-scale(1)] 7 | [.slidenumbers: false] 8 | 9 | Thomas J Fan 10 | Scikit-learn Core Developer 11 | @thomasjpfan 12 | 13 | # Scikit-learn API 🛠 14 | 15 | ```py 16 | from sklearn.experimental import enable_hist_gradient_boosting 17 | from sklearn.ensemble import HistGradientBoostingClassifier 18 | 19 | clf = HistGradientBoostingClassifier() 20 | 21 | clf.fit(X, y) 22 | 23 | clf.predict(X) 24 | ``` 25 | 26 | # Supervised learning 📖 27 | 28 | $$ 29 | y = f(X) 30 | $$ 31 | 32 | - X of shape `(n_samples, n_features)` 33 | - y of shape `(n_samples,)` 34 | 35 | # HistGradient**Boosting** 36 | [.header: alignment(center)] 37 | 38 | # Boosting 🏂 39 | 40 | $$ 41 | f(X) = h_0(X) + h_1(X) + h_2(X) + ... 42 | $$ 43 | 44 | $$ 45 | f(X) = \sum_i h_i(X) 46 | $$ 47 | 48 | # Hist**Gradient**Boosting 49 | [.header: alignment(center)] 50 | 51 | # Gradient 🏔 (Loss Function) 52 | 53 | - **Regression** 54 | 1. `least_squares` 55 | 1. `least_absolute_deviation` 56 | 57 | - **Classificaiton** 58 | 1. `binary_crossentropy` 59 | 1. `categorical_crossentropy` 60 | 61 | # Gradient 🏔 (Regression Loss Function) 62 | 63 | - `least_squares` 64 | 65 | $$ 66 | L(y, f(X)) = \frac{1}{2}||y - f(X)||^2 67 | $$ 68 | 69 | # Gradient 🏔 - `least_squares` 70 | 71 | - **Gradient** 72 | 73 | $$ 74 | \nabla L(y, f(X)) = -(y - f(X)) 75 | $$ 76 | 77 | - **Hessian** 78 | 79 | $$ 80 | \nabla^2 L(y, f(X)) = 1 81 | $$ 82 | 83 | # Gradient Boosting 🏂 84 | 85 | - Initial Condition 86 | 87 | $$ 88 | f_0(X) = C 89 | $$ 90 | 91 | - Recursive Condition 92 | 93 | $$ 94 | f_{m+1}(X) = f_{m}(X) - \eta \nabla L(y, f_{m}(X)) 95 | $$ 96 | 97 | where $$\eta$$ is the learning rate 98 | 99 | # Gradient Boosting 🏂 - `least_squares` 100 | 101 | $$ 102 | f_{m+1}(X) = f_{m}(X) + \eta(y - f_{m}(X)) 103 | $$ 104 | 105 | - Let $$h_{m}(X)=(y - f_{m}(X))$$ 106 | 107 | $$ 108 | f_{m+1}(X) = f_{m}(X) + \eta h_{m}(X) 109 | $$ 110 | 111 | - We need to learn $$h_{m}(X)$$! 112 | - For the next example, let $$\eta=1$$ 113 | 114 | # Gradient Boosting 🏂 (Example, part 1) 115 | 116 | $$ 117 | f_0(X) = C 118 | $$ 119 | 120 | | $$X$$ | $$y$$ | $$f_0(X)$$ | $$y - f_0(X)$$ | $$h_0(X)$$ | 121 | | --- | --- | --- | --- | --- | 122 | | 35 | 70 | 78 | -8 | -7 | 123 | | 45 | 90 | 78 | 12 | 10 | 124 | | 25 | 80 | 78 | 2 | 5 | 125 | | 15 | 50 | 78 | -28 | -20 | 126 | | 55 | 100 | 78 | 22 | 25 | 127 | 128 | # Gradient Boosting 🏂 (Example, part 2) 129 | 130 | $$ 131 | f_{m+1}(X) = f_{m}(X) + h_{m}(X) 132 | $$ 133 | 134 | | $$f_0(X)$$ | $$h_0(X)$$ | $$f_1(X)$$ | $$y - f_1(X)$$ | $$h_1(X)$$ | $$f_2(X)$$ | 135 | | --- | --- | --- | --- | --- | --- | 136 | | 78 | -7 | 71 | -1 | -1 | 70 | 137 | | 78 | 10 | 88 | 2 | 1 | 89 | 138 | | 78 | 5 | 83 | -3 | -4 | 79 | 139 | | 78 | -20 | 58 | -8 | -6 | 52 | 140 | | 78 | 25 | 103 | -3 | -2 | 101 | 141 | 142 | # Gradient Boosting 🏂 (Example, part 3) 143 | 144 | With two iterations in boosting: 145 | 146 | $$ 147 | f(X) = C + h_0(X) + h_1(X) 148 | $$ 149 | 150 | - **predict**: With X = 40 151 | 152 | $$ 153 | f(40) = 78 + h_0(40) + h_1(40) 154 | $$ 155 | 156 | # How to learn $$h_m(X)$$? 157 | 158 | [.header: alignment(center), text-scale(2)] 159 | [.text: alignment(center), text-scale(2)] 160 | 🌲! 161 | 162 | # Tree Growing 🌲 (part 1) 163 | 164 | ![right fit](images/tree-growing-1.png) 165 | 166 | 1. For every feature 167 | 1. Sort feature 168 | 1. For every split point 169 | 1. Evaluate split 170 | 1. Pick **best** split 171 | 172 | # Tree Growing 🌲 (part 2) 173 | 174 | - Recall Loss, Gradient, Hessian 175 | 176 | $$ 177 | L(y, f(X)) = \frac{1}{2}||y - f(X)||^2 178 | $$ 179 | 180 | $$ 181 | G = \nabla L(y, f(X)) = -(y - f(X)) 182 | $$ 183 | 184 | $$ 185 | H = \nabla^2 L(y, f(X)) = 1 186 | $$ 187 | 188 | # Tree Growing 🌲 (part 3) 189 | 190 | - How to evaluate split? 191 | 192 | $$ 193 | Gain = \dfrac{1}{2}\left[\dfrac{G_L^2}{H_L+\lambda} 194 | + \dfrac{G_R^2}{H_R + \lambda} - \dfrac{(G_L+G_R)^2}{H_L+H_R+\lambda}\right] 195 | $$ 196 | 197 | - $$\lambda$$: `l2_regularization=0` 198 | 199 | # Tree Growing 🌲 (part 4) 200 | 201 | ![right fit](images/tree-growing-2.png) 202 | 203 | 1. For every feature 204 | 1. Sort feature 205 | 1. For every split point 206 | 1. Evaluate split 207 | 1. Pick **best** split 208 | 209 | - Done? 210 | 211 | # Tree Growing 🌲 (part 5) 212 | 213 | ![right fit](images/tree-growing-2.png) 214 | 215 | 1. For every feature 216 | 1. Sort feature - _**O(nlog(n))**_ 217 | 1. For every split point - _**O(n)**_ 218 | 1. Evaluate split 219 | 1. Pick **best** split 220 | 221 | # **Hist**GradientBoosting 222 | [.header: alignment(center)] 223 | 224 | # Binning! 🗑 (part 1) 225 | 226 | ![inline fit](images/binning101.png) 227 | 228 | # Binning! 🗑 (part 2) 229 | 230 | ```py 231 | # Original data 232 | [-0.752, 2.7042, 1.3919, 0.5091, -2.0636, 233 | -2.064, -2.6514, 2.1977, 0.6007, 1.2487, ...] 234 | 235 | # Binned data 236 | [4, 9, 7, 6, 2, 1, 0, 8, 6, 7, ...] 237 | ``` 238 | 239 | # Histograms! 📊 (part 1) 240 | 241 | ![inline fit](images/binned-gradient-hess.png) 242 | 243 | # Histograms! 📊 (part 2) 244 | 245 | 1. For every feature 246 | 1. Build histogram _**O(n)**_ 247 | 1. For every split point - _**O(n\_bins)**_ 248 | 1. Evaluate split 249 | 1. Pick **best** split 250 | 251 | ![right fit](images/gradient-split-points.png) 252 | 253 | # Histograms! 📊 (part 3) 254 | 255 | ![inline fit](images/histogram_subtraction.png) 256 | 257 | # Trees = $$h_{m}(X)$$ 🌲 258 | 259 | $$ 260 | f(X) = C + \eta\sum h_{m}(X) 261 | $$ 262 | 263 | # Overview of Algorithm 👀 264 | 265 | 1. Bin data 266 | 1. Make initial predictions (constant) 267 | 1. Calculate gradients and hessians 268 | 1. Grow Trees For Boosting 269 | 1. Find best splits 270 | 1. Add tree to predictors 271 | 1. Update gradients and hessians 272 | 273 | # Implementation? 🤔 274 | 275 | - Pure Python? 276 | - Numpy? 277 | - Cython? 278 | - Cython + OpenMP! 279 | 280 | # OpenMP! (Bin data 🗑, part 1) 281 | 282 | 1. _**Bin data**_ 283 | 1. Make initial predictions (constant) 284 | 1. Calculate gradients and hessians 285 | 1. Grow Trees For Boosting 286 | 1. Find best splits by building histograms 287 | 1. Add tree to predictors 288 | 1. Update gradients and hessians 289 | 290 | # OpenMP! (Bin data 🗑, part 2) 291 | 292 | [.code-highlight: 1] 293 | 294 | ```py 295 | for i in range(data.shape[0]): 296 | left, right = 0, binning_thresholds.shape[0] 297 | while left < right: 298 | middle = (right + left - 1) // 2 299 | if data[i] <= binning_thresholds[middle]: 300 | right = middle 301 | else: 302 | left = middle + 1 303 | binned[i] = left 304 | ``` 305 | 306 | # OpenMP! (Bin data 🗑, part 3) 307 | 308 | [.code-highlight: 1-4] 309 | 310 | ```py 311 | # sklearn/ensemble/_hist_gradient_boosting/_binning.pyx 312 | for i in prange(data.shape[0], 313 | schedule='static', 314 | nogil=True): 315 | left, right = 0, binning_thresholds.shape[0] 316 | while left < right: 317 | middle = (right + left - 1) // 2 318 | if data[i] <= binning_thresholds[middle]: 319 | right = middle 320 | else: 321 | left = middle + 1 322 | binned[i] = left 323 | ``` 324 | 325 | # OpenMP! (building histograms 🌋, part 1) 326 | 327 | 1. Bin data 328 | 1. Make initial predictions (constant) 329 | 1. Calculate gradients and hessians 330 | 1. Grow Trees For Boosting 331 | 1. Find best splits by _**building histograms**_ 332 | 1. Add tree to predictors 333 | 1. Update gradients and hessians 334 | 335 | # OpenMP! (building histograms 🌋, part 2) 336 | 337 | [.code-highlight: all] 338 | [.code-highlight: 1-4] 339 | [.code-highlight: 6-8] 340 | 341 | ```py 342 | # sklearn/ensemble/_hist_gradient_boosting/histogram.pyx 343 | with nogil: 344 | for feature_idx in prange(n_features, schedule='static'): 345 | self._compute_histogram_brute_single_feature(...) 346 | 347 | for feature_idx in prange(n_features, schedule='static', 348 | nogil=True): 349 | _subtract_histograms(feature_idx, ...) 350 | ``` 351 | 352 | # OpenMP! (Find best splits ✂️, part 1) 353 | 354 | 1. Bin data 355 | 1. Make initial predictions (constant) 356 | 1. Calculate gradients and hessians 357 | 1. Grow Trees For Boosting 358 | 1. _**Find best splits**_ by building histograms 359 | 1. Add tree to predictors 360 | 1. Update gradients and hessians 361 | 362 | # OpenMP! (Find best splits ✂️, part 2) 363 | 364 | ```py 365 | # sklearn/ensemble/_hist_gradient_boosting/splitting.pyx 366 | for feature_idx in prange(n_features, schedule='static'): 367 | # For each feature, find best bin to split on 368 | ``` 369 | 370 | # OpenMP! (Splitting ✂️, part 3) 371 | 372 | ```py 373 | # sklearn/ensemble/_hist_gradient_boosting/splitting.pyx 374 | for thread_idx in prange(n_threads, schedule='static', 375 | chunksize=1): 376 | # splits a partition of node 377 | ``` 378 | 379 | # OpenMP! (Update gradients and hessians 🏔, part 1) 380 | 381 | 1. Bin data 382 | 1. Make initial predictions (constant) 383 | 1. Calculate gradients and hessians 384 | 1. Grow Trees For Boosting 385 | 1. Find best splits by building histograms 386 | 1. Add tree to predictors 387 | 2. _**Update gradients and hessians**_ 388 | 389 | # OpenMP! (Update gradients and hessians 🏔, part 2) 390 | 391 | - `least_squares` 392 | 393 | ```py 394 | # sklearn/ensemble/_hist_gradient_boosting/_loss.pyx 395 | for i in prange(n_samples, schedule='static', nogil=True): 396 | gradients[i] = raw_predictions[i] - y_true[i] 397 | ``` 398 | 399 | # Hyperparameters (Bin data 🗑, part 1) 400 | 401 | 1. _**Bin data**_ 402 | 1. Make initial predictions (constant) 403 | 1. Calculate gradients and hessians 404 | 1. Grow Trees For Boosting 405 | 1. Find best splits by building _**histograms**_ 406 | 1. Add tree to predictors 407 | 1. Update gradients and hessians 408 | 409 | # Hyperparameters (Bin data 🗑, part 2) 410 | 411 | - `max_bins=255` 412 | 413 | ![inline fit](images/binning101.png) 414 | 415 | # Hyperparameters (Loss 📉, part 1) 416 | 417 | 1. Bin data 418 | 1. _**Make initial predictions (constant)**_ 419 | 1. Calculate _**gradients and hessians**_ 420 | 1. Grow Trees For Boosting 421 | 1. Find best splits by building histograms 422 | 1. Add tree to predictors 423 | 2. _**Update gradients and hessians**_ 424 | 425 | # Hyperparameters (Loss 📉, part 2) 426 | 427 | - `HistGradientBoostingRegressor` 428 | 1. `loss=least_squares` (default) 429 | 1. `least_absolute_deviation` 430 | 431 | - `HistGradientBoostingClassifier` 432 | 1. `loss=auto` (default) 433 | 1. `binary_crossentropy` 434 | 1. `categorical_crossentropy` 435 | 436 | - `l2_regularization=0` 437 | 438 | # Hyperparameters (Boosting 🏂, part 1) 439 | 440 | 1. Bin data 441 | 1. Make initial predictions (constant) 442 | 1. Calculate gradients and hessians 443 | 1. Grow Trees For _**Boosting**_ 444 | 1. Find best splits by building histograms 445 | 1. Add tree to predictors 446 | 1. Update gradients and hessians 447 | 448 | # Hyperparameters (Boosting 🏂, part 2) 449 | 450 | - `learning_rate=0.1` ($$\eta$$) 451 | - `max_iter=100` 452 | 453 | $$ 454 | f(X) = C + \eta\sum_{m}^{max\_iter}h_{m}(X) 455 | $$ 456 | 457 | # Hyperparameters (Boosting 🏂, part 3) 458 | 459 | ![inline fit](images/hyperp_default.png) 460 | 461 | # Hyperparameters (Boosting 🏂, part 4) 462 | 463 | ![inline fit](images/hyperp_learning_rate_05.png) 464 | 465 | # Hyperparameters (Grow Trees 🎄, part 1) 466 | 467 | 1. Bin data 468 | 1. Make initial predictions (constant) 469 | 1. Calculate gradients and hessians 470 | 1. _**Grow Trees**_ For Boosting 471 | 1. Find best splits by building histograms 472 | 1. Add tree to predictors 473 | 1. Update gradients and hessians 474 | 475 | # Hyperparameters (Grow Trees 🎄, part 2) 476 | 477 | - `max_leaf_nodes=31` 478 | - `max_depth=None` 479 | - `min_samples_leaf=20` 480 | 481 | # Hyperparameters (Grow Trees 🎄, part 3) 482 | 483 | ![inline fit](images/hyperp_max_leaf_nodes_200.png) 484 | 485 | # Hyperparameters (Grow Trees 🎄, part 4) 486 | 487 | ![inline fit](images/hyperp_max_depth_3.png) 488 | 489 | # Hyperparameters (Early Stopping 🛑, part 1) 490 | 491 | 1. Bin data 492 | 1. _**Split into a validation dataset**_ 493 | 1. Make initial predictions (constant) 494 | 1. Calculate gradients and hessians 495 | 1. Grow Trees For Boosting 496 | 1. ... 497 | 1. _**Stop if early stop condition is true**_ 498 | 499 | # Hyperparameters (Early Stopping 🛑, part 2) 500 | 501 | - `scoring=None` (could be 'loss') 502 | - `validation_fraction=0.1` 503 | - `n_iter_no_change=None` 504 | - `tol=1e-7` 505 | 506 | # Hyperparameters (Early Stopping 🛑, part 3) 507 | 508 | ![inline fit](images/hyperp_n_iter_no_change_10.png) 509 | 510 | # Hyperparameters (Misc 🎁) 511 | 512 | - `verbose=0` 513 | - `random_state=None` 514 | - `export OMP_NUM_THREADS=12` 515 | 516 | # Benchmarks 🚀 (HIGGS Part 1) 517 | 518 | - 8800000 records 519 | - 28 features 520 | - binary classification (1 for signal, 0 for background) 521 | 522 | # Benchmarks 🚀 (HIGGS Part 2) 523 | 524 | - `max_iter=100`, `learning_rate=0.1`, `export OMP_NUM_THREADS=12` 525 | 526 | | library | time | roc auc | accuracy | 527 | |----------|------|---------|----------| 528 | | sklearn | 38s | 0.8125 | 0.7324 | 529 | | lightgbm | 39s | 0.8124 | 0.7322 | 530 | | xgboost | 48s | 0.8126 | 0.7326 | 531 | | catboost | 100s | 0.8004 | 0.7222 | 532 | 533 | # Benchmarks 🚀 (HIGGS Part 3) 534 | 535 | - `max_iter=500` 536 | 537 | | library | time | roc auc | accuracy | 538 | |----------|------|---------|----------| 539 | | sklearn | 129s | 0.8281 | 0.7461 | 540 | | lightgbm | 125s | 0.8283 | 0.7462 | 541 | | xgboost | 149s | 0.8285 | 0.7465 | 542 | | catboost | 427s | 0.8225 | 0.7412 | 543 | 544 | 545 | # Benchmarks 🚀 (HIGGS Part 4) 546 | 547 | `export OMP_NUM_THREADS=4` `max_iter=100` (on my laptop) 548 | 549 | | library | time (12 cores) | time (4 cores) | 550 | |----------|------|---------| 551 | | sklearn | 38s | 85s | 552 | | lightgbm | 39s | 86s | 553 | | xgboost | 48s | 115s | 554 | | catboost | 100s | 164s | 555 | 556 | # Benchmarks 🚀 (HIGGS Part 5) 557 | 558 | [.header: alignment(center), text-scale(1.8)] 559 | [.text: alignment(center), text-scale(1.8)] 560 | DEMO! 561 | 562 | # Roadmap 🛣 (In upcoming 0.22) 563 | 564 | - ~~Missing Values~~ 565 | 566 | ```py 567 | from sklearn.experimental import enable_hist_gradient_boosting 568 | from sklearn.ensemble import HistGradientBoostingClassifier 569 | from sklearn.datasets import make_classification 570 | 571 | X, y = make_classification(random_state=42) 572 | X[:10, 0] = np.nan 573 | 574 | gbdt = HistGradientBoostingClassifier().fit(X, y) 575 | print(gbdt.predict(X[:20])) 576 | # [0 0 1 1 0 0 0 1 0 1 1 0 0 0 1 1 1 0 0 1] 577 | ``` 578 | 579 | # Roadmap 🛣 (After 0.22) 580 | 581 | - Discrete (Categorical) Feature support 582 | - Sparse Data 583 | - Sample Weights 584 | 585 | # Thank you Working on This 🎉 586 | 587 | - @hug_nicolas - Associate Research Scientist @ Columbia University 588 | - All the core developers for reviewing! 589 | 590 | # Conclusion 🎉 591 | 592 | ```py 593 | from sklearn.experimental import enable_hist_gradient_boosting 594 | from sklearn.ensemble import HistGradientBoostingClassifier 595 | from sklearn.ensemble import HistGradientBoostingRegressor 596 | ``` 597 | 598 | - Try out the dev build (for missing values): 599 | 600 | ```bash 601 | pip install --pre -f https://sklearn-nightly.scdn8.secure.raxcdn.com scikit-learn 602 | ``` 603 | 604 | - [github.com/thomasjpfan/pydata-2019-histgradientboosting](https://github.com/thomasjpfan/pydata-2019-histgradientboosting) 605 | 606 | - Twitter: @thomasjpfan 607 | 608 | # Appendix 609 | 610 | - Loss function with l2 regularization 611 | 612 | $$ 613 | L(y, f(X)) = \frac{1}{2}||y - f(X)||^2 + \lambda \sum_i w_i^2 614 | $$ 615 | 616 | where $$w_i$$ score of the leaves. 617 | -------------------------------------------------------------------------------- /presentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/pydata-2019-histgradientboosting/2907a2d556dd5ce94d63eaa1a257631f0b5d9861/presentation.pdf --------------------------------------------------------------------------------