├── README.md ├── examples ├── Hyperopt_LightGBM_with_Focal_Loss.ipynb ├── Lightgbm_with_Focal_Loss.ipynb └── Lightgbm_with_Focal_Loss_multiclass.ipynb ├── lightgbm_with_focal_loss.py ├── prepare_data.py └── utils ├── feature_tools.py ├── metrics.py └── train_hyperopt.py /README.md: -------------------------------------------------------------------------------- 1 | # LightGBM with Focal Loss 2 | This is implementation of the [Focal 3 | Loss](https://arxiv.org/pdf/1708.02002.pdf)[1] to be used with LightGBM. 4 | 5 | The companion Medium post can be found [here](https://medium.com/@jrzaurin/lightgbm-with-the-focal-loss-for-imbalanced-datasets-9836a9ae00ca). 6 | 7 | The Focal Loss for 8 | [LightGBM](https://papers.nips.cc/paper/6907-lightgbm-a-highly-efficient-gradient-boosting-decision-tree.pdf)[2] 9 | can be simply coded as: 10 | 11 | ```python 12 | def focal_loss_lgb(y_pred, dtrain, alpha, gamma): 13 | a,g = alpha, gamma 14 | y_true = dtrain.label 15 | def fl(x,t): 16 | p = 1/(1+np.exp(-x)) 17 | return -( a*t + (1-a)*(1-t) ) * (( 1 - ( t*p + (1-t)*(1-p)) )**g) * ( t*np.log(p)+(1-t)*np.log(1-p) ) 18 | partial_fl = lambda x: fl(x, y_true) 19 | grad = derivative(partial_fl, y_pred, n=1, dx=1e-6) 20 | hess = derivative(partial_fl, y_pred, n=2, dx=1e-6) 21 | return grad, hess 22 | 23 | ``` 24 | 25 | to use it one would need the corresponding evaluation function: 26 | 27 | ```python 28 | def focal_loss_lgb_eval_error(y_pred, dtrain, alpha, gamma): 29 | a,g = alpha, gamma 30 | y_true = dtrain.label 31 | p = 1/(1+np.exp(-y_pred)) 32 | loss = -( a*y_true + (1-a)*(1-y_true) ) * (( 1 - ( y_true*p + (1-y_true)*(1-p)) )**g) * ( y_true*np.log(p)+(1-y_true)*np.log(1-p) ) 33 | return 'focal_loss', np.mean(loss), False 34 | ``` 35 | 36 | And to use it, simply: 37 | 38 | ```python 39 | focal_loss = lambda x,y: focal_loss_lgb(x, y, 0.25, 1.) 40 | eval_error = lambda x,y: focal_loss_lgb_eval_error(x, y, 0.25, 1.) 41 | lgbtrain = lgb.Dataset(X_tr, y_tr, free_raw_data=True) 42 | lgbeval = lgb.Dataset(X_val, y_val) 43 | params = {'learning_rate':0.1, 'num_boost_round':10} 44 | model = lgb.train(params, lgbtrain, valid_sets=[lgbeval], fobj=focal_loss, feval=eval_error ) 45 | ``` 46 | 47 | In the `examples` directory you will find more details, including how to use [Hyperopt](https://github.com/hyperopt/hyperopt) in combination with LightGBM and the Focal Loss, or how to adapt the Focal Loss to a multi-class classification problem. 48 | 49 | Any comment: jrzaurin@gmail.com 50 | 51 | ### References: 52 | [1] Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár. Focal Loss for Dense Object Detection 53 | 54 | [2] Guolin Ke, Qi Meng Thomas Finley, et al., 2017. LightGBM: A Highly Efficient Gradient Boosting 55 | Decision Tree 56 | -------------------------------------------------------------------------------- /examples/Hyperopt_LightGBM_with_Focal_Loss.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Hyperparameter Optimization of LightGBM with Focal Loss\n", 8 | "\n", 9 | "Here I will quicky show how to use [Hyperopt](https://github.com/hyperopt/hyperopt) to optimize all LightGBM's hyperparameters and $\\alpha$ and $\\gamma$ for the Focal Loss. \n", 10 | "\n", 11 | "I am going to assume that we want to optimise \"against\" a standard metric for imbalanced datasets such as the F1 score\n", 12 | "\n", 13 | "We first need to code that metric to be passed to LightGBM" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 7, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import numpy as np\n", 23 | "import lightgbm as lgb\n", 24 | "import pickle\n", 25 | "\n", 26 | "from pathlib import Path\n", 27 | "from sklearn.model_selection import train_test_split\n", 28 | "from sklearn.metrics import f1_score\n", 29 | "from scipy.misc import derivative\n", 30 | "from hyperopt import hp, tpe, fmin, Trials\n", 31 | "\n", 32 | "def sigmoid(x): return 1./(1. + np.exp(-x))\n", 33 | "\n", 34 | "def focal_loss_lgb(y_pred, dtrain, alpha, gamma):\n", 35 | " \"\"\"\n", 36 | " Focal Loss for lightgbm\n", 37 | "\n", 38 | " Parameters:\n", 39 | " -----------\n", 40 | " y_pred: numpy.ndarray\n", 41 | " array with the predictions\n", 42 | " dtrain: lightgbm.Dataset\n", 43 | " alpha, gamma: float\n", 44 | " See original paper https://arxiv.org/pdf/1708.02002.pdf\n", 45 | " \"\"\"\n", 46 | " a,g = alpha, gamma\n", 47 | " y_true = dtrain.label\n", 48 | " def fl(x,t):\n", 49 | " p = 1/(1+np.exp(-x))\n", 50 | " return -( a*t + (1-a)*(1-t) ) * (( 1 - ( t*p + (1-t)*(1-p)) )**g) * ( t*np.log(p)+(1-t)*np.log(1-p) )\n", 51 | " partial_fl = lambda x: fl(x, y_true)\n", 52 | " grad = derivative(partial_fl, y_pred, n=1, dx=1e-6)\n", 53 | " hess = derivative(partial_fl, y_pred, n=2, dx=1e-6)\n", 54 | " return grad, hess\n", 55 | "\n", 56 | "def lgb_focal_f1_score(preds, lgbDataset):\n", 57 | " \"\"\"\n", 58 | " When using custom losses the row prediction needs to passed through a\n", 59 | " sigmoid to represent a probability\n", 60 | "\n", 61 | " Parameters:\n", 62 | " -----------\n", 63 | " preds: numpy.ndarray\n", 64 | " array with the predictions\n", 65 | " lgbDataset: lightgbm.Dataset\n", 66 | " \"\"\"\n", 67 | " preds = sigmoid(preds)\n", 68 | " binary_preds = [int(p>0.5) for p in preds]\n", 69 | " y_true = lgbDataset.get_label()\n", 70 | " return 'f1', f1_score(y_true, binary_preds), True" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "Let's now define our objective function" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 15, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "def objective(params):\n", 87 | " \"\"\"\n", 88 | " objective function for lightgbm.\n", 89 | " \"\"\"\n", 90 | " # hyperopt casts as float\n", 91 | " params['num_boost_round'] = int(params['num_boost_round'])\n", 92 | " params['num_leaves'] = int(params['num_leaves'])\n", 93 | "\n", 94 | " # need to be passed as parameter\n", 95 | " params['verbose'] = -1\n", 96 | " params['seed'] = 1\n", 97 | "\n", 98 | " focal_loss = lambda x,y: focal_loss_lgb(x, y,\n", 99 | " params['alpha'], params['gamma'])\n", 100 | " # if you do not want an annoying warning related to the unrecognised param\n", 101 | " # 'alpha', simple pop them out from the dict params here and insert them\n", 102 | " # back before return. For this particular notebook I can live with it, so\n", 103 | " # I will leave it\n", 104 | " cv_result = lgb.cv(\n", 105 | " params,\n", 106 | " train,\n", 107 | " num_boost_round=params['num_boost_round'],\n", 108 | " fobj = focal_loss,\n", 109 | " feval = lgb_focal_f1_score,\n", 110 | " nfold=3,\n", 111 | " stratified=True,\n", 112 | " early_stopping_rounds=20)\n", 113 | " # I save the length or the results (i.e. the number of estimators) because\n", 114 | " # it might have stopped earlier and is always useful to have that\n", 115 | " # information \n", 116 | " early_stop_dict[objective.i] = len(cv_result['f1-mean'])\n", 117 | " score = round(cv_result['f1-mean'][-1], 4)\n", 118 | " objective.i+=1\n", 119 | " return -score" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "Now the parameter space that we are going to be exploring:" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 5, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "space = {\n", 136 | " 'learning_rate': hp.uniform('learning_rate', 0.01, 0.2),\n", 137 | " 'num_boost_round': hp.quniform('num_boost_round', 50, 500, 20),\n", 138 | " 'num_leaves': hp.quniform('num_leaves', 31, 255, 4),\n", 139 | " 'min_child_weight': hp.uniform('min_child_weight', 0.1, 10),\n", 140 | " 'colsample_bytree': hp.uniform('colsample_bytree', 0.5, 1.),\n", 141 | " 'subsample': hp.uniform('subsample', 0.5, 1.),\n", 142 | " 'reg_alpha': hp.uniform('reg_alpha', 0.01, 0.1),\n", 143 | " 'reg_lambda': hp.uniform('reg_lambda', 0.01, 0.1),\n", 144 | " 'alpha': hp.uniform('alpha', 0.1, 0.75),\n", 145 | " 'gamma': hp.uniform('gamma', 0.5, 5)\n", 146 | " }" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "And we are ready, let's just load some data and run the whole thing" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 9, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "data": { 163 | "text/html": [ 164 | "
\n", 165 | "\n", 178 | "\n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | "
ageworkclasseducationmarital_statusoccupationrelationshipracegendercapital_gaincapital_losshours_per_weeknative_countryeducation_occupationnative_country_occupation
119610.28767100000000.00.00.397959000
12300.09589011010010.00.00.397959011
160670.58904111111010.00.00.193878011
129140.45205511222000.00.00.479592022
63430.20547912232000.00.00.397959033
\n", 286 | "
" 287 | ], 288 | "text/plain": [ 289 | " age workclass education marital_status occupation \\\n", 290 | "11961 0.287671 0 0 0 0 \n", 291 | "1230 0.095890 1 1 0 1 \n", 292 | "16067 0.589041 1 1 1 1 \n", 293 | "12914 0.452055 1 1 2 2 \n", 294 | "6343 0.205479 1 2 2 3 \n", 295 | "\n", 296 | " relationship race gender capital_gain capital_loss hours_per_week \\\n", 297 | "11961 0 0 0 0.0 0.0 0.397959 \n", 298 | "1230 0 0 1 0.0 0.0 0.397959 \n", 299 | "16067 1 0 1 0.0 0.0 0.193878 \n", 300 | "12914 2 0 0 0.0 0.0 0.479592 \n", 301 | "6343 2 0 0 0.0 0.0 0.397959 \n", 302 | "\n", 303 | " native_country education_occupation native_country_occupation \n", 304 | "11961 0 0 0 \n", 305 | "1230 0 1 1 \n", 306 | "16067 0 1 1 \n", 307 | "12914 0 2 2 \n", 308 | "6343 0 3 3 " 309 | ] 310 | }, 311 | "execution_count": 9, 312 | "metadata": {}, 313 | "output_type": "execute_result" 314 | } 315 | ], 316 | "source": [ 317 | "PATH = Path(\"../data/\")\n", 318 | "databunch = pickle.load(open(PATH/'adult_databunch.p', 'rb'))\n", 319 | "colnames = databunch.colnames\n", 320 | "categorical_columns = databunch.categorical_columns + databunch.crossed_columns\n", 321 | "X = databunch.data\n", 322 | "y = databunch.target\n", 323 | "# you know, in real life, train, valid AND test, and you keep it somewhere safe...\n", 324 | "X_tr, X_val, y_tr, y_val = train_test_split(X, y, test_size=0.25,\n", 325 | " random_state=1, stratify=y)\n", 326 | "# let's have a look:\n", 327 | "X.head()" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 11, 333 | "metadata": {}, 334 | "outputs": [ 335 | { 336 | "name": "stdout", 337 | "output_type": "stream", 338 | "text": [ 339 | "[1 0 0 ... 0 0 1]\n" 340 | ] 341 | } 342 | ], 343 | "source": [ 344 | "print(y.values)" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 12, 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "train = lgb.Dataset(\n", 354 | " X_tr, y_tr,\n", 355 | " feature_name=colnames,\n", 356 | " categorical_feature = categorical_columns,\n", 357 | " free_raw_data=False)" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 16, 363 | "metadata": {}, 364 | "outputs": [ 365 | { 366 | "name": "stdout", 367 | "output_type": "stream", 368 | "text": [ 369 | "\r", 370 | " 0%| | 0/5 [00:00 0.5).astype('int')" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": 20, 522 | "metadata": {}, 523 | "outputs": [ 524 | { 525 | "name": "stdout", 526 | "output_type": "stream", 527 | "text": [ 528 | "0.7121898206846586\n" 529 | ] 530 | } 531 | ], 532 | "source": [ 533 | "print(f1_score(y_val, preds))" 534 | ] 535 | } 536 | ], 537 | "metadata": { 538 | "kernelspec": { 539 | "display_name": "Python 3", 540 | "language": "python", 541 | "name": "python3" 542 | }, 543 | "language_info": { 544 | "codemirror_mode": { 545 | "name": "ipython", 546 | "version": 3 547 | }, 548 | "file_extension": ".py", 549 | "mimetype": "text/x-python", 550 | "name": "python", 551 | "nbconvert_exporter": "python", 552 | "pygments_lexer": "ipython3", 553 | "version": "3.6.5" 554 | } 555 | }, 556 | "nbformat": 4, 557 | "nbformat_minor": 2 558 | } 559 | -------------------------------------------------------------------------------- /examples/Lightgbm_with_Focal_Loss.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## How to code and use the Focal Loss with LightGBM\n", 8 | "\n", 9 | "The [Focal Loss](https://arxiv.org/pdf/1708.02002.pdf) for LightGBM can be coded as:" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 67, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "import pickle\n", 20 | "import lightgbm as lgb\n", 21 | "\n", 22 | "from pathlib import Path\n", 23 | "from sklearn.model_selection import train_test_split\n", 24 | "from sklearn.metrics import f1_score\n", 25 | "from scipy.misc import derivative" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 68, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "def focal_loss_lgb(y_pred, dtrain, alpha, gamma):\n", 35 | " \"\"\"\n", 36 | " Focal Loss for lightgbm\n", 37 | "\n", 38 | " Parameters:\n", 39 | " -----------\n", 40 | " y_pred: numpy.ndarray\n", 41 | " array with the predictions\n", 42 | " dtrain: lightgbm.Dataset\n", 43 | " alpha, gamma: float\n", 44 | " See original paper https://arxiv.org/pdf/1708.02002.pdf\n", 45 | " \"\"\"\n", 46 | " a,g = alpha, gamma\n", 47 | " y_true = dtrain.label\n", 48 | " def fl(x,t):\n", 49 | " p = 1/(1+np.exp(-x))\n", 50 | " return -( a*t + (1-a)*(1-t) ) * (( 1 - ( t*p + (1-t)*(1-p)) )**g) * ( t*np.log(p)+(1-t)*np.log(1-p) )\n", 51 | " partial_fl = lambda x: fl(x, y_true)\n", 52 | " grad = derivative(partial_fl, y_pred, n=1, dx=1e-6)\n", 53 | " hess = derivative(partial_fl, y_pred, n=2, dx=1e-6)\n", 54 | " return grad, hess" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "If we are going to use it as our custom loss, we also need our custom evaluation function" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 69, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "def focal_loss_lgb_eval_error(y_pred, dtrain, alpha, gamma):\n", 71 | " \"\"\"\n", 72 | " Adapation of the Focal Loss for lightgbm to be used as evaluation loss\n", 73 | "\n", 74 | " Parameters:\n", 75 | " -----------\n", 76 | " y_pred: numpy.ndarray\n", 77 | " array with the predictions\n", 78 | " dtrain: lightgbm.Dataset\n", 79 | " alpha, gamma: float\n", 80 | " See original paper https://arxiv.org/pdf/1708.02002.pdf\n", 81 | " \"\"\"\n", 82 | " a,g = alpha, gamma\n", 83 | " y_true = dtrain.label\n", 84 | " p = 1/(1+np.exp(-y_pred))\n", 85 | " loss = -( a*y_true + (1-a)*(1-y_true) ) * (( 1 - ( y_true*p + (1-y_true)*(1-p)) )**g) * ( y_true*np.log(p)+(1-y_true)*np.log(1-p) )\n", 86 | " return 'focal_loss', np.mean(loss), False" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "To use them, first we need to make them partial functions of **only** `y_pred` and `dtrain`, since this is a structural requirement for LighGBM. Then, we simply pass them as parameters. \n", 94 | "\n", 95 | "Let me first load some processed data:" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 76, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "data": { 105 | "text/html": [ 106 | "
\n", 107 | "\n", 120 | "\n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | "
ageworkclasseducationmarital_statusoccupationrelationshipracegendercapital_gaincapital_losshours_per_weeknative_countryeducation_occupationnative_country_occupation
119610.28767100000000.00.00.397959000
12300.09589011010010.00.00.397959011
160670.58904111111010.00.00.193878011
129140.45205511222000.00.00.479592022
63430.20547912232000.00.00.397959033
\n", 228 | "
" 229 | ], 230 | "text/plain": [ 231 | " age workclass education marital_status occupation \\\n", 232 | "11961 0.287671 0 0 0 0 \n", 233 | "1230 0.095890 1 1 0 1 \n", 234 | "16067 0.589041 1 1 1 1 \n", 235 | "12914 0.452055 1 1 2 2 \n", 236 | "6343 0.205479 1 2 2 3 \n", 237 | "\n", 238 | " relationship race gender capital_gain capital_loss hours_per_week \\\n", 239 | "11961 0 0 0 0.0 0.0 0.397959 \n", 240 | "1230 0 0 1 0.0 0.0 0.397959 \n", 241 | "16067 1 0 1 0.0 0.0 0.193878 \n", 242 | "12914 2 0 0 0.0 0.0 0.479592 \n", 243 | "6343 2 0 0 0.0 0.0 0.397959 \n", 244 | "\n", 245 | " native_country education_occupation native_country_occupation \n", 246 | "11961 0 0 0 \n", 247 | "1230 0 1 1 \n", 248 | "16067 0 1 1 \n", 249 | "12914 0 2 2 \n", 250 | "6343 0 3 3 " 251 | ] 252 | }, 253 | "execution_count": 76, 254 | "metadata": {}, 255 | "output_type": "execute_result" 256 | } 257 | ], 258 | "source": [ 259 | "PATH = Path(\"../data/\")\n", 260 | "databunch = pickle.load(open(PATH/'adult_databunch.p', 'rb'))\n", 261 | "colnames = databunch.colnames\n", 262 | "categorical_columns = databunch.categorical_columns + databunch.crossed_columns\n", 263 | "X = databunch.data\n", 264 | "y = databunch.target\n", 265 | "# you know, in real life, train, valid AND test, and you keep it somewhere safe...\n", 266 | "X_tr, X_val, y_tr, y_val = train_test_split(X, y, test_size=0.25,\n", 267 | " random_state=1, stratify=y)\n", 268 | "# let's have a look:\n", 269 | "X.head()" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 78, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "name": "stdout", 279 | "output_type": "stream", 280 | "text": [ 281 | "[1 0 0 ... 0 0 1]\n" 282 | ] 283 | } 284 | ], 285 | "source": [ 286 | "print(y.values)" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": {}, 292 | "source": [ 293 | "LightGBM with Focal Loss" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 71, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "lgtrain = lgb.Dataset(\n", 303 | " X_tr, y_tr,\n", 304 | " feature_name=colnames,\n", 305 | " categorical_feature = categorical_columns,\n", 306 | " free_raw_data=False)\n", 307 | "lgvalid = lgtrain.create_valid(X_val, y_val)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 72, 313 | "metadata": {}, 314 | "outputs": [ 315 | { 316 | "name": "stderr", 317 | "output_type": "stream", 318 | "text": [ 319 | "/usr/local/lib/python3.6/site-packages/lightgbm/engine.py:118: UserWarning: Found `num_boost_round` in params. Will use it instead of argument\n", 320 | " warnings.warn(\"Found `{}` in params. Will use it instead of argument\".format(alias))\n", 321 | "/usr/local/lib/python3.6/site-packages/lightgbm/basic.py:1205: UserWarning: Using categorical_feature in Dataset.\n", 322 | " warnings.warn('Using categorical_feature in Dataset.')\n", 323 | "/usr/local/lib/python3.6/site-packages/lightgbm/basic.py:1209: UserWarning: categorical_feature in Dataset is overridden.\n", 324 | "New categorical_feature is ['education', 'education_occupation', 'gender', 'marital_status', 'native_country', 'native_country_occupation', 'occupation', 'race', 'relationship', 'workclass']\n", 325 | " 'New categorical_feature is {}'.format(sorted(list(categorical_feature))))\n", 326 | "/usr/local/lib/python3.6/site-packages/lightgbm/basic.py:762: UserWarning: categorical_feature in param dict is overridden.\n", 327 | " warnings.warn('categorical_feature in param dict is overridden.')\n" 328 | ] 329 | }, 330 | { 331 | "name": "stdout", 332 | "output_type": "stream", 333 | "text": [ 334 | "[1]\tvalid_0's focal_loss: 0.098781\n", 335 | "[2]\tvalid_0's focal_loss: 0.0898273\n", 336 | "[3]\tvalid_0's focal_loss: 0.0821333\n", 337 | "[4]\tvalid_0's focal_loss: 0.0755058\n", 338 | "[5]\tvalid_0's focal_loss: 0.0697994\n", 339 | "[6]\tvalid_0's focal_loss: 0.064839\n", 340 | "[7]\tvalid_0's focal_loss: 0.0605124\n", 341 | "[8]\tvalid_0's focal_loss: 0.0567805\n", 342 | "[9]\tvalid_0's focal_loss: 0.0534902\n", 343 | "[10]\tvalid_0's focal_loss: 0.0506304\n" 344 | ] 345 | } 346 | ], 347 | "source": [ 348 | "focal_loss = lambda x,y: focal_loss_lgb(x, y, 0.25, 2.)\n", 349 | "eval_error = lambda x,y: focal_loss_lgb_eval_error(x, y, 0.25, 2.)\n", 350 | "params = {'learning_rate':0.1, 'num_boost_round':10}\n", 351 | "model = lgb.train(\n", 352 | " params,\n", 353 | " lgtrain,\n", 354 | " valid_sets=[lgvalid],\n", 355 | " fobj=focal_loss,\n", 356 | " feval=eval_error\n", 357 | " )" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "metadata": {}, 363 | "source": [ 364 | "### Sklearn's API\n", 365 | "\n", 366 | "If you prefer to use LightGBM's sklearn API, simply replace `dtrain` with `y_true`, and swap the predictions and ground truth order, like:" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 73, 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "def focal_loss_lgb_sk(y_true, y_pred, alpha, gamma):\n", 376 | " \"\"\"\n", 377 | " Focal Loss for lightgbm\n", 378 | "\n", 379 | " Parameters:\n", 380 | " -----------\n", 381 | " y_pred: numpy.ndarray\n", 382 | " array with the predictions\n", 383 | " dtrain: lightgbm.Dataset\n", 384 | " alpha, gamma: float\n", 385 | " See original paper https://arxiv.org/pdf/1708.02002.pdf\n", 386 | " \"\"\"\n", 387 | " a,g = alpha, gamma\n", 388 | " def fl(x,t):\n", 389 | " p = 1/(1+np.exp(-x))\n", 390 | " return -( a*t + (1-a)*(1-t) ) * (( 1 - ( t*p + (1-t)*(1-p)) )**g) * ( t*np.log(p)+(1-t)*np.log(1-p) )\n", 391 | " partial_fl = lambda x: fl(x, y_true)\n", 392 | " grad = derivative(partial_fl, y_pred, n=1, dx=1e-6)\n", 393 | " hess = derivative(partial_fl, y_pred, n=2, dx=1e-6)\n", 394 | " return grad, hess" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 74, 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [ 403 | "def focal_loss_lgb_eval_error_sk(y_true, y_pred, alpha, gamma):\n", 404 | " \"\"\"\n", 405 | " Adapation of the Focal Loss for lightgbm to be used as evaluation loss\n", 406 | "\n", 407 | " Parameters:\n", 408 | " -----------\n", 409 | " y_pred: numpy.ndarray\n", 410 | " array with the predictions\n", 411 | " dtrain: lightgbm.Dataset\n", 412 | " alpha, gamma: float\n", 413 | " See original paper https://arxiv.org/pdf/1708.02002.pdf\n", 414 | " \"\"\"\n", 415 | " a,g = alpha, gamma\n", 416 | " p = 1/(1+np.exp(-y_pred))\n", 417 | " loss = -( a*y_true + (1-a)*(1-y_true) ) * (( 1 - ( y_true*p + (1-y_true)*(1-p)) )**g) * ( y_true*np.log(p)+(1-y_true)*np.log(1-p) )\n", 418 | " return 'focal_loss', np.mean(loss), False" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": 75, 424 | "metadata": {}, 425 | "outputs": [ 426 | { 427 | "name": "stdout", 428 | "output_type": "stream", 429 | "text": [ 430 | "[1]\tvalid_0's focal_loss: 0.0988352\n", 431 | "[2]\tvalid_0's focal_loss: 0.0899494\n", 432 | "[3]\tvalid_0's focal_loss: 0.0823239\n", 433 | "[4]\tvalid_0's focal_loss: 0.0757314\n", 434 | "[5]\tvalid_0's focal_loss: 0.0700502\n", 435 | "[6]\tvalid_0's focal_loss: 0.0651475\n", 436 | "[7]\tvalid_0's focal_loss: 0.0608702\n", 437 | "[8]\tvalid_0's focal_loss: 0.0571672\n", 438 | "[9]\tvalid_0's focal_loss: 0.0539455\n", 439 | "[10]\tvalid_0's focal_loss: 0.051152\n" 440 | ] 441 | }, 442 | { 443 | "data": { 444 | "text/plain": [ 445 | "LGBMClassifier(boosting_type='gbdt', class_weight=None, colsample_bytree=1.0,\n", 446 | " importance_type='split', learning_rate=0.1, max_depth=-1,\n", 447 | " min_child_samples=20, min_child_weight=0.001, min_split_gain=0.0,\n", 448 | " n_estimators=100, n_jobs=-1, num_boost_round=10, num_leaves=31,\n", 449 | " objective= at 0x115bc1950>, random_state=None,\n", 450 | " reg_alpha=0.0, reg_lambda=0.0, silent=True, subsample=1.0,\n", 451 | " subsample_for_bin=200000, subsample_freq=0)" 452 | ] 453 | }, 454 | "execution_count": 75, 455 | "metadata": {}, 456 | "output_type": "execute_result" 457 | } 458 | ], 459 | "source": [ 460 | "focal_loss = lambda x,y: focal_loss_lgb_sk(x, y, 0.25, 2.)\n", 461 | "eval_error = lambda x,y: focal_loss_lgb_eval_error_sk(x, y, 0.25, 2.)\n", 462 | "model = lgb.LGBMClassifier(objective=focal_loss, learning_rate=0.1, num_boost_round=10)\n", 463 | "model.fit(\n", 464 | " X_tr,\n", 465 | " y_tr,\n", 466 | " eval_set=[(X_val, y_val)],\n", 467 | " eval_metric=eval_error)" 468 | ] 469 | } 470 | ], 471 | "metadata": { 472 | "kernelspec": { 473 | "display_name": "Python 3", 474 | "language": "python", 475 | "name": "python3" 476 | }, 477 | "language_info": { 478 | "codemirror_mode": { 479 | "name": "ipython", 480 | "version": 3 481 | }, 482 | "file_extension": ".py", 483 | "mimetype": "text/x-python", 484 | "name": "python", 485 | "nbconvert_exporter": "python", 486 | "pygments_lexer": "ipython3", 487 | "version": "3.6.5" 488 | } 489 | }, 490 | "nbformat": 4, 491 | "nbformat_minor": 2 492 | } 493 | -------------------------------------------------------------------------------- /examples/Lightgbm_with_Focal_Loss_multiclass.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### LightGBM with Focal Loss for Multiclass classification problems\n", 8 | "\n", 9 | "Let me show how to adapt the Focal Loss implementation for binary classification to a multiclass classification problem.\n", 10 | "\n", 11 | "The idea is to face the problem using the Binary Cross Entropy With Logits (borrowing from `Pytorch` notation `BCEWithLogitsLoss`). \n", 12 | "\n", 13 | "$$\n", 14 | "loss = -[y_{\\text true} \\cdot log\\sigma(x) + (1-y_{\\text true}) \\cdot log(1-\\sigma(x))] \n", 15 | "$$\n", 16 | "\n", 17 | "Where $\\sigma$ is the sigmoid function\n", 18 | "\n", 19 | "For example, let's assume we have a problem with 10 classes and we have two samples/observations" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import numpy as np\n", 29 | "\n", 30 | "y_true = np.random.choice(10, (1,2))\n", 31 | "# from -2 to 2 to illustrate the fact the preds coming from lightGBM when using custom losses are NOT probs\n", 32 | "y_pred = np.random.uniform(low=-2, high=2, size=(2, 10))" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "array([[0, 0]])" 44 | ] 45 | }, 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "# labels\n", 53 | "y_true" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "array([[-0.62900913, 0.92265852, -1.33477174, -1.89011705, 1.85566209,\n", 65 | " 0.76361995, 0.1983925 , -0.5764042 , -0.84919259, 0.92979002],\n", 66 | " [ 1.99012283, 0.43470132, 0.35491818, 1.48850368, -0.20095172,\n", 67 | " 1.96445624, 1.25049923, -0.86754563, -1.6867512 , 1.098587 ]])" 68 | ] 69 | }, 70 | "execution_count": 3, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | } 74 | ], 75 | "source": [ 76 | "# predictions\n", 77 | "y_pred" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 4, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "def sigmoid(x): \n", 87 | " return 1./(1. + np.exp(-x))\n", 88 | "\n", 89 | "def softmax(x):\n", 90 | " exp_x = np.exp(x - np.max(x))\n", 91 | " return exp_x / (np.sum(exp_x, axis=1, keepdims=True) + 1e-6)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 5, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "# labels one-hot encoded\n", 101 | "y_true_oh = np.eye(10)[y_true][0]" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 6, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "data": { 111 | "text/plain": [ 112 | "array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 113 | " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])" 114 | ] 115 | }, 116 | "execution_count": 6, 117 | "metadata": {}, 118 | "output_type": "execute_result" 119 | } 120 | ], 121 | "source": [ 122 | "y_true_oh" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 7, 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "data": { 132 | "text/plain": [ 133 | "0.9219779528703436" 134 | ] 135 | }, 136 | "execution_count": 7, 137 | "metadata": {}, 138 | "output_type": "execute_result" 139 | } 140 | ], 141 | "source": [ 142 | "# BCEWithLogitsLoss\n", 143 | "( -( y_true_oh * np.log(sigmoid(y_pred)) + (1-y_true_oh) * np.log(1-sigmoid(y_pred)) ) ).mean()" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "### Multiclass Focal Loss\n", 151 | "\n", 152 | "Before we jump to the Focal Loss code, let's focus for one second in a sentence in the LightGBM [documentation](https://lightgbm.readthedocs.io/en/latest/index.html) site : *\"For multi-class task, the preds is group by class_id first, then group by row_id. If you want to get i-th row preds in j-th class, the access way is score[j $\\times$ num_data + i] and you should group grad and hess in this way as well.\"*\n", 153 | "\n", 154 | "Let's assume we have 100 rows and 4 classes" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 8, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "preds = np.random.rand(100*4)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": {}, 169 | "source": [ 170 | "To access to the prediction for class `1` for the 20-th row we need the index 1 $\\times$ 100 + 20 = 120\n", 171 | "\n", 172 | "We will compute the Focal Loss using the `BCEWithLogitsLoss` which requires that we have an array of predictions of shape (num_data, num_class). \n", 173 | "\n", 174 | "Therefore, to reshape the predictions (scores) coming from lightGBM to that format, we need to use 'Fortran' style." 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 10, 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "text/plain": [ 185 | "0.8213133926994343" 186 | ] 187 | }, 188 | "execution_count": 10, 189 | "metadata": {}, 190 | "output_type": "execute_result" 191 | } 192 | ], 193 | "source": [ 194 | "preds[120]" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 11, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "0.8213133926994343" 206 | ] 207 | }, 208 | "execution_count": 11, 209 | "metadata": {}, 210 | "output_type": "execute_result" 211 | } 212 | ], 213 | "source": [ 214 | "preds.reshape(-1 , 4, order='F')[20, 1]" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": {}, 220 | "source": [ 221 | "And in general" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 12, 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "data": { 231 | "text/plain": [ 232 | "True" 233 | ] 234 | }, 235 | "execution_count": 12, 236 | "metadata": {}, 237 | "output_type": "execute_result" 238 | } 239 | ], 240 | "source": [ 241 | "np.all(preds[:100] == preds.reshape(-1 , 4, order='F')[:100,0])" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": {}, 247 | "source": [ 248 | "So, without further ado:" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 13, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "def focal_loss_lgb(y_pred, dtrain, alpha, gamma, num_class):\n", 258 | " \"\"\"\n", 259 | " Focal Loss for lightgbm\n", 260 | "\n", 261 | " Parameters:\n", 262 | " -----------\n", 263 | " y_pred: numpy.ndarray\n", 264 | " array with the predictions\n", 265 | " dtrain: lightgbm.Dataset\n", 266 | " alpha, gamma: float\n", 267 | " See original paper https://arxiv.org/pdf/1708.02002.pdf\n", 268 | " num_class: int\n", 269 | " number of classes\n", 270 | " \"\"\"\n", 271 | " a,g = alpha, gamma\n", 272 | " y_true = dtrain.label\n", 273 | " # N observations x num_class arrays\n", 274 | " y_true = np.eye(num_class)[y_true.astype('int')]\n", 275 | " y_pred = y_pred.reshape(-1,num_class, order='F')\n", 276 | " # alpha and gamma multiplicative factors with BCEWithLogitsLoss\n", 277 | " def fl(x,t):\n", 278 | " p = 1/(1+np.exp(-x))\n", 279 | " return -( a*t + (1-a)*(1-t) ) * (( 1 - ( t*p + (1-t)*(1-p)) )**g) * ( t*np.log(p)+(1-t)*np.log(1-p) )\n", 280 | " partial_fl = lambda x: fl(x, y_true)\n", 281 | " grad = derivative(partial_fl, y_pred, n=1, dx=1e-6)\n", 282 | " hess = derivative(partial_fl, y_pred, n=2, dx=1e-6)\n", 283 | " # flatten in column-major (Fortran-style) order\n", 284 | " return grad.flatten('F'), hess.flatten('F')" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "metadata": {}, 290 | "source": [ 291 | "And that's it really. Now one would want/need the corresponding evalulation function." 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 14, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "def focal_loss_lgb_eval_error(y_pred, dtrain, alpha, gamma, num_class):\n", 301 | " \"\"\"\n", 302 | " Focal Loss for lightgbm\n", 303 | "\n", 304 | " Parameters:\n", 305 | " -----------\n", 306 | " y_pred: numpy.ndarray\n", 307 | " array with the predictions\n", 308 | " dtrain: lightgbm.Dataset\n", 309 | " alpha, gamma: float\n", 310 | " See original paper https://arxiv.org/pdf/1708.02002.pdf\n", 311 | " num_class: int\n", 312 | " number of classes\n", 313 | " \"\"\"\n", 314 | " a,g = alpha, gamma\n", 315 | " y_true = dtrain.label\n", 316 | " y_true = np.eye(num_class)[y_true.astype('int')]\n", 317 | " y_pred = y_pred.reshape(-1, num_class, order='F')\n", 318 | " p = 1/(1+np.exp(-y_pred))\n", 319 | " loss = -( a*y_true + (1-a)*(1-y_true) ) * (( 1 - ( y_true*p + (1-y_true)*(1-p)) )**g) * ( y_true*np.log(p)+(1-y_true)*np.log(1-p) )\n", 320 | " # a variant can be np.sum(loss)/num_class\n", 321 | " return 'focal_loss', np.mean(loss), False" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "metadata": {}, 327 | "source": [ 328 | "### EXAMPLE" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 15, 334 | "metadata": {}, 335 | "outputs": [ 336 | { 337 | "name": "stderr", 338 | "output_type": "stream", 339 | "text": [ 340 | "/usr/local/lib/python3.6/site-packages/lightgbm/__init__.py:46: UserWarning: Starting from version 2.2.1, the library file in distribution wheels for macOS is built by the Apple Clang (Xcode_8.3.3) compiler.\n", 341 | "This means that in case of installing LightGBM from PyPI via the ``pip install lightgbm`` command, you don't need to install the gcc compiler anymore.\n", 342 | "Instead of that, you need to install the OpenMP library, which is required for running LightGBM on the system with the Apple Clang compiler.\n", 343 | "You can install the OpenMP library by the following command: ``brew install libomp``.\n", 344 | " \"You can install the OpenMP library by the following command: ``brew install libomp``.\", UserWarning)\n" 345 | ] 346 | } 347 | ], 348 | "source": [ 349 | "import numpy as np\n", 350 | "import lightgbm as lgb\n", 351 | "\n", 352 | "from sklearn import datasets\n", 353 | "from sklearn.model_selection import train_test_split\n", 354 | "from sklearn.metrics import accuracy_score\n", 355 | "from scipy.misc import derivative\n", 356 | "\n", 357 | "# very inadequate dataset as is perfectly balanced, but just to illustrate\n", 358 | "digits = datasets.load_digits()\n", 359 | "X = digits.data\n", 360 | "y = digits.target" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 16, 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "X_tr, X_val, y_tr, y_val = train_test_split(X, y, test_size=0.2, random_state=1)\n", 370 | "lgbtrain = lgb.Dataset(X_tr, y_tr, free_raw_data=True)\n", 371 | "lgbeval = lgb.Dataset(X_val, y_val)" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 17, 377 | "metadata": {}, 378 | "outputs": [ 379 | { 380 | "name": "stdout", 381 | "output_type": "stream", 382 | "text": [ 383 | "[1]\tvalid_0's focal_loss: 0.107288\n", 384 | "[2]\tvalid_0's focal_loss: 0.0951971\n", 385 | "[3]\tvalid_0's focal_loss: 0.0846662\n", 386 | "[4]\tvalid_0's focal_loss: 0.0755319\n", 387 | "[5]\tvalid_0's focal_loss: 0.0675866\n", 388 | "[6]\tvalid_0's focal_loss: 0.0605897\n", 389 | "[7]\tvalid_0's focal_loss: 0.0544604\n", 390 | "[8]\tvalid_0's focal_loss: 0.0490753\n", 391 | "[9]\tvalid_0's focal_loss: 0.0442874\n", 392 | "[10]\tvalid_0's focal_loss: 0.0400507\n" 393 | ] 394 | }, 395 | { 396 | "name": "stderr", 397 | "output_type": "stream", 398 | "text": [ 399 | "/usr/local/lib/python3.6/site-packages/lightgbm/engine.py:118: UserWarning: Found `num_boost_round` in params. Will use it instead of argument\n", 400 | " warnings.warn(\"Found `{}` in params. Will use it instead of argument\".format(alias))\n" 401 | ] 402 | } 403 | ], 404 | "source": [ 405 | "focal_loss = lambda x,y: focal_loss_lgb(x, y, 0.25, 2., 10)\n", 406 | "eval_error = lambda x,y: focal_loss_lgb_eval_error(x, y, 0.25, 2., 10)\n", 407 | "params = {'learning_rate':0.1, 'num_boost_round':10, 'num_class':10}\n", 408 | "# model = lgb.train(params, lgbtrain, fobj=focal_loss)\n", 409 | "model = lgb.train(params, lgbtrain, valid_sets=[lgbeval], fobj=focal_loss, feval=eval_error)" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 18, 415 | "metadata": { 416 | "scrolled": true 417 | }, 418 | "outputs": [ 419 | { 420 | "data": { 421 | "text/plain": [ 422 | "0.9083333333333333" 423 | ] 424 | }, 425 | "execution_count": 18, 426 | "metadata": {}, 427 | "output_type": "execute_result" 428 | } 429 | ], 430 | "source": [ 431 | "accuracy_score(y_val, np.argmax(softmax(model.predict(X_val)), axis=1))" 432 | ] 433 | } 434 | ], 435 | "metadata": { 436 | "kernelspec": { 437 | "display_name": "Python 3", 438 | "language": "python", 439 | "name": "python3" 440 | }, 441 | "language_info": { 442 | "codemirror_mode": { 443 | "name": "ipython", 444 | "version": 3 445 | }, 446 | "file_extension": ".py", 447 | "mimetype": "text/x-python", 448 | "name": "python", 449 | "nbconvert_exporter": "python", 450 | "pygments_lexer": "ipython3", 451 | "version": "3.6.5" 452 | } 453 | }, 454 | "nbformat": 4, 455 | "nbformat_minor": 2 456 | } 457 | -------------------------------------------------------------------------------- /lightgbm_with_focal_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import lightgbm as lgb 3 | import argparse 4 | import pickle 5 | import warnings 6 | 7 | from pathlib import Path 8 | from sklearn.model_selection import train_test_split 9 | from sklearn.preprocessing import RobustScaler 10 | from utils.train_hyperopt import LGBOptimizer 11 | 12 | 13 | warnings.filterwarnings("ignore") 14 | 15 | if __name__ == '__main__': 16 | 17 | ap = argparse.ArgumentParser() 18 | ap.add_argument("--dataset", required=True) 19 | ap.add_argument("--save_experiment", action="store_true") 20 | ap.add_argument("--with_focal_loss", action="store_true") 21 | ap.add_argument("--is_unbalance", action="store_true") 22 | args = vars(ap.parse_args()) 23 | 24 | PATH = Path("data/") 25 | is_unbalance = args['is_unbalance'] 26 | with_focal_loss = args['with_focal_loss'] 27 | save_experiment = args['save_experiment'] 28 | 29 | if args['dataset'] == 'adult': 30 | 31 | databunch = pickle.load(open(PATH/'adult_databunch.p', 'rb')) 32 | colnames = databunch.colnames 33 | categorical_columns = databunch.categorical_columns + databunch.crossed_columns 34 | X = databunch.data 35 | y = databunch.target 36 | X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.25, 37 | random_state=1, stratify=y) 38 | 39 | lgbopt = LGBOptimizer( 40 | args['dataset'], 41 | train_set=[X_tr, y_tr], 42 | eval_set=[X_te, y_te], 43 | colnames=colnames, 44 | categorical_columns=categorical_columns, 45 | is_unbalance=is_unbalance, 46 | with_focal_loss=with_focal_loss, 47 | save=save_experiment) 48 | lgbopt.optimize(maxevals=100) 49 | 50 | if args['dataset'] == 'credit': 51 | 52 | databunch = pickle.load(open("data/credit_databunch.p", 'rb')) 53 | colnames = databunch.colnames 54 | X = databunch.data 55 | y = databunch.target 56 | X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.25, 57 | random_state=1, stratify=y) 58 | 59 | lgbopt = LGBOptimizer( 60 | args['dataset'], 61 | train_set=[X_tr, y_tr], 62 | eval_set=[X_te, y_te], 63 | colnames=colnames, 64 | is_unbalance=is_unbalance, 65 | with_focal_loss=with_focal_loss, 66 | save=save_experiment) 67 | lgbopt.optimize(maxevals=100) 68 | 69 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pickle 3 | import warnings 4 | 5 | from pathlib import Path 6 | from utils.feature_tools import FeatureTools 7 | from sklearn.preprocessing import MinMaxScaler, RobustScaler 8 | 9 | import pdb 10 | 11 | warnings.filterwarnings("ignore") 12 | 13 | if __name__ == '__main__': 14 | 15 | PATH = Path('data/') 16 | train_fname = 'adult.data' 17 | test_fname = 'adult.test' 18 | 19 | df_tr = pd.read_csv(PATH/train_fname) 20 | df_te = pd.read_csv(PATH/test_fname) 21 | 22 | adult_df = pd.concat([df_tr, df_te]).sample(frac=1) 23 | adult_df.drop('fnlwgt', axis=1, inplace=True) 24 | 25 | adult_df['target'] = (adult_df['income_bracket'].apply(lambda x: '>50K' in x)).astype(int) 26 | adult_df.drop('income_bracket', axis=1, inplace=True) 27 | 28 | categorical_cols = list(adult_df.select_dtypes(include=['object']).columns) 29 | scale_cols = [c for c in adult_df.columns if c not in categorical_cols+['target']] 30 | crossed_cols = (['education', 'occupation'], ['native_country', 'occupation']) 31 | 32 | preprocessor = FeatureTools() 33 | adult_databunch = preprocessor(adult_df, target_col='target', scale_cols=scale_cols, 34 | scaler=MinMaxScaler(), categorical_cols=categorical_cols, x_cols=crossed_cols) 35 | pickle.dump(adult_databunch, open(PATH/'adult_databunch.p', "wb")) 36 | 37 | credit_df = pd.read_csv(PATH/'creditcard.csv.zip') 38 | scale_cols = ['Time', 'Amount'] 39 | preprocessor = FeatureTools() 40 | credit_databunch = preprocessor(credit_df, target_col='Class', scale_cols=scale_cols, 41 | scaler=MinMaxScaler()) 42 | pickle.dump(credit_databunch, open(PATH/'credit_databunch.p', "wb")) 43 | 44 | -------------------------------------------------------------------------------- /utils/feature_tools.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code here taking from: https://github.com/jrzaurin/ml_pipelines/blob/master/utils/feature_tools.py 3 | ''' 4 | import pandas as pd 5 | import copy 6 | 7 | from sklearn.utils import Bunch 8 | 9 | class FeatureTools(object): 10 | """Collection of preprocessing methods""" 11 | 12 | @staticmethod 13 | def num_scaler(df_inp, cols, sc, trained=False): 14 | """ 15 | Method to scale numeric columns in a dataframe 16 | Parameters: 17 | ----------- 18 | df_inp: Pandas.DataFrame 19 | cols: List 20 | List of numeric columns to be scaled 21 | sc: Scaler object. From sklearn.preprocessing or similar structure 22 | trained: Boolean 23 | If True it will only be used to 'transform' 24 | Returns: 25 | -------- 26 | df: Pandas.DataFrame 27 | transformed/normalised dataframe 28 | sc: trained scaler 29 | """ 30 | df = df_inp.copy() 31 | if not trained: 32 | df[cols] = sc.fit_transform(df[cols]) 33 | else: 34 | df[cols] = sc.transform(df[cols]) 35 | return df, sc 36 | 37 | @staticmethod 38 | def cross_columns(df_inp, x_cols): 39 | """ 40 | Method to build crossed columns. These are new columns that are the 41 | cartesian product of the parent columns. 42 | Parameters: 43 | ----------- 44 | df_inp: Pandas.DataFrame 45 | x_cols: List. 46 | List of tuples with the columns to cross 47 | e.g. [('colA', 'colB'),('colC', 'colD')] 48 | Returns: 49 | -------- 50 | df: Pandas.DataFrame 51 | pandas dataframe with the new crossed columns 52 | colnames: List 53 | list the new column names 54 | """ 55 | df = df_inp.copy() 56 | colnames = ['_'.join(x_c) for x_c in x_cols] 57 | crossed_columns = {k:v for k,v in zip(colnames, x_cols)} 58 | 59 | for k, v in crossed_columns.items(): 60 | df[k] = df[v].apply(lambda x: '-'.join(x), axis=1) 61 | 62 | return df, colnames 63 | 64 | @staticmethod 65 | def val2idx(df_inp, cols, val_to_idx=None): 66 | """ 67 | This is basically a LabelEncoder that returns a dictionary with the 68 | mapping of the labels. 69 | Parameters: 70 | ----------- 71 | df_inp: Pandas.DataFrame 72 | cols: List 73 | List of categorical columns to encode 74 | val_to_idx: Dict 75 | LabelEncoding dictionary if already exists 76 | Returns: 77 | -------- 78 | df: Pandas.DataFrame 79 | pandas dataframe with the categorical columns encoded 80 | val_to_idx: Dict 81 | dictionary with the encoding mappings 82 | """ 83 | df = df_inp.copy() 84 | if not val_to_idx: 85 | 86 | val_types = dict() 87 | for c in cols: 88 | val_types[c] = df[c].unique() 89 | 90 | val_to_idx = dict() 91 | for k, v in val_types.items(): 92 | val_to_idx[k] = {o: i for i, o in enumerate(val_types[k])} 93 | 94 | for k, v in val_to_idx.items(): 95 | df[k] = df[k].apply(lambda x: v[x]) 96 | 97 | return df, val_to_idx 98 | 99 | def __call__(self, df_inp, target_col, scale_cols=None, scaler=None, 100 | categorical_cols=None, x_cols=None): 101 | """ 102 | Parameters: 103 | ----------- 104 | df_inp: Pandas.DataFrame 105 | target_col: Str 106 | scale_cols: List 107 | List of numerical columns to be scaled 108 | scaler: Scaler. From sklearn.preprocessing or object with the same 109 | structure 110 | categorical_cols: List 111 | List with the categorical columns 112 | x_cols: List 113 | List of tuples with the columns to cross 114 | """ 115 | 116 | df = df_inp.copy() 117 | databunch = Bunch() 118 | 119 | if scale_cols: 120 | assert scaler is not None, 'scaler argument is missing' 121 | databunch.scale_cols = scale_cols 122 | df, sc = self.num_scaler(df, scale_cols, scaler) 123 | databunch.scaler = sc 124 | else: 125 | databunch.scale_cols = None 126 | 127 | if categorical_cols: 128 | databunch.categorical_cols = categorical_cols 129 | if x_cols: 130 | df, crossed_cols = self.cross_columns(df, x_cols) 131 | df, encoding_d = self.val2idx(df, categorical_cols+crossed_cols) 132 | databunch.crossed_cols = crossed_cols 133 | databunch.encoding_dict = encoding_d 134 | else: 135 | df, encoding_d = self.val2idx(df, categorical_cols) 136 | databunch.crossed_cols = None 137 | databunch.encoding_dict = encoding_d 138 | else: 139 | databunch.encoding_dict = None 140 | databunch.categorical_cols = None 141 | 142 | databunch.target = df[target_col] 143 | df.drop(target_col, axis=1, inplace=True) 144 | databunch.data = df 145 | databunch.colnames = df.columns.tolist() 146 | 147 | return databunch 148 | 149 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import lightgbm as lgb 3 | 4 | from sklearn.metrics import f1_score 5 | from scipy.misc import derivative 6 | 7 | 8 | def sigmoid(x): return 1./(1. + np.exp(-x)) 9 | 10 | 11 | def best_threshold(y_true, pred_proba, proba_range, verbose=False): 12 | """ 13 | Function to find the probability threshold that optimises the f1_score 14 | 15 | Comment: this function is not used in this repo, but I include it in case the 16 | it useful 17 | 18 | Parameters: 19 | ----------- 20 | y_true: numpy.ndarray 21 | array with the true labels 22 | pred_proba: numpy.ndarray 23 | array with the predicted probability 24 | proba_range: numpy.ndarray 25 | range of probabilities to explore. 26 | e.g. np.arange(0.1,0.9,0.01) 27 | 28 | Return: 29 | ----------- 30 | tuple with the optimal threshold and the corresponding f1_score 31 | """ 32 | scores = [] 33 | for prob in proba_range: 34 | pred = [int(p>prob) for p in pred_proba] 35 | score = f1_score(y_true,pred) 36 | scores.append(score) 37 | if verbose: 38 | print("INFO: prob threshold: {}. score :{}".format(round(prob,3), round(score,5))) 39 | best_score = scores[np.argmax(scores)] 40 | optimal_threshold = proba_range[np.argmax(scores)] 41 | return (optimal_threshold, best_score) 42 | 43 | 44 | def focal_loss_lgb(y_pred, dtrain, alpha, gamma): 45 | """ 46 | Focal Loss for lightgbm 47 | 48 | Parameters: 49 | ----------- 50 | y_pred: numpy.ndarray 51 | array with the predictions 52 | dtrain: lightgbm.Dataset 53 | alpha, gamma: float 54 | See original paper https://arxiv.org/pdf/1708.02002.pdf 55 | """ 56 | a,g = alpha, gamma 57 | y_true = dtrain.label 58 | def fl(x,t): 59 | p = 1/(1+np.exp(-x)) 60 | return -( a*t + (1-a)*(1-t) ) * (( 1 - ( t*p + (1-t)*(1-p)) )**g) * ( t*np.log(p) + (1-t)*np.log(1-p) ) 61 | partial_fl = lambda x: fl(x, y_true) 62 | grad = derivative(partial_fl, y_pred, n=1, dx=1e-6) 63 | hess = derivative(partial_fl, y_pred, n=2, dx=1e-6) 64 | return grad, hess 65 | 66 | 67 | def focal_loss_lgb_eval_error(y_pred, dtrain, alpha, gamma): 68 | """ 69 | Adapation of the Focal Loss for lightgbm to be used as evaluation loss 70 | 71 | Parameters: 72 | ----------- 73 | y_pred: numpy.ndarray 74 | array with the predictions 75 | dtrain: lightgbm.Dataset 76 | alpha, gamma: float 77 | See original paper https://arxiv.org/pdf/1708.02002.pdf 78 | """ 79 | a,g = alpha, gamma 80 | y_true = dtrain.label 81 | p = 1/(1+np.exp(-y_pred)) 82 | loss = -( a*y_true + (1-a)*(1-y_true) ) * (( 1 - ( y_true*p + (1-y_true)*(1-p)) )**g) * ( y_true*np.log(p)+(1-y_true)*np.log(1-p) ) 83 | return 'focal_loss', np.mean(loss), False 84 | 85 | 86 | def lgb_f1_score(preds, lgbDataset): 87 | """ 88 | Implementation of the f1 score to be used as evaluation score for lightgbm 89 | 90 | Parameters: 91 | ----------- 92 | preds: numpy.ndarray 93 | array with the predictions 94 | lgbDataset: lightgbm.Dataset 95 | """ 96 | binary_preds = [int(p>0.5) for p in preds] 97 | y_true = lgbDataset.get_label() 98 | return 'f1', f1_score(y_true, binary_preds), True 99 | 100 | 101 | def lgb_focal_f1_score(preds, lgbDataset): 102 | """ 103 | Adaptation of the implementation of the f1 score to be used as evaluation 104 | score for lightgbm. The adaptation is required since when using custom losses 105 | the row prediction needs to passed through a sigmoid to represent a 106 | probability 107 | 108 | Parameters: 109 | ----------- 110 | preds: numpy.ndarray 111 | array with the predictions 112 | lgbDataset: lightgbm.Dataset 113 | """ 114 | preds = sigmoid(preds) 115 | binary_preds = [int(p>0.5) for p in preds] 116 | y_true = lgbDataset.get_label() 117 | return 'f1', f1_score(y_true, binary_preds), True -------------------------------------------------------------------------------- /utils/train_hyperopt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pickle 4 | import lightgbm as lgb 5 | import warnings 6 | 7 | from pathlib import Path 8 | from .metrics import (focal_loss_lgb, focal_loss_lgb_eval_error, lgb_f1_score, 9 | lgb_focal_f1_score, sigmoid) 10 | from sklearn.metrics import (accuracy_score, f1_score, precision_score, 11 | recall_score, confusion_matrix) 12 | from sklearn.utils import Bunch 13 | from hyperopt import hp, tpe, fmin, Trials 14 | 15 | import pdb 16 | 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | class LGBOptimizer(object): 21 | """Use Hyperopt to optimize LightGBM 22 | 23 | # Arguments (details only when args are not self-explanatory) 24 | train_set: List 25 | List with the training dataset e.g. [X_tr, y_tr] 26 | eval_set: List 27 | List with the training dataset 28 | """ 29 | def __init__(self, dataname, train_set, eval_set, colnames, 30 | categorical_columns=None, out_dir=Path('data'), is_unbalance=False, 31 | with_focal_loss=False, save=False): 32 | 33 | self.PATH = out_dir 34 | self.dataname = dataname 35 | self.is_unbalance = is_unbalance 36 | self.with_focal_loss = with_focal_loss 37 | self.save = save 38 | 39 | self.early_stop_dict = {} 40 | 41 | self.colnames = colnames 42 | self.categorical_columns = categorical_columns 43 | 44 | self.X_tr, self.y_tr = train_set[0], train_set[1] 45 | self.X_val, self.y_val = eval_set[0], eval_set[1] 46 | self.lgtrain = lgb.Dataset( 47 | self.X_tr, self.y_tr, 48 | feature_name=self.colnames, 49 | categorical_feature = self.categorical_columns, 50 | free_raw_data=False) 51 | self.lgvalid = self.lgtrain.create_valid( 52 | self.X_val, self.y_val) 53 | 54 | def optimize(self, maxevals=200): 55 | 56 | param_space = self.hyperparameter_space() 57 | objective = self.get_objective(self.lgtrain) 58 | objective.i=0 59 | trials = Trials() 60 | best = fmin(fn=objective, 61 | space=param_space, 62 | algo=tpe.suggest, 63 | max_evals=maxevals, 64 | trials=trials) 65 | best['num_boost_round'] = self.early_stop_dict[trials.best_trial['tid']] 66 | best['num_leaves'] = int(best['num_leaves']) 67 | best['verbose'] = -1 68 | 69 | if self.with_focal_loss: 70 | focal_loss = lambda x,y: focal_loss_lgb(x, y, best['alpha'], best['gamma']) 71 | model = lgb.train(best, self.lgtrain, fobj=focal_loss) 72 | preds = model.predict(self.X_val) 73 | preds = sigmoid(preds) 74 | preds = (preds > 0.5).astype('int') 75 | else: 76 | model = lgb.train(best, self.lgtrain) 77 | preds = model.predict(self.lgvalid.data) 78 | preds = (preds > 0.5).astype('int') 79 | 80 | acc = accuracy_score(self.y_val, preds) 81 | f1 = f1_score(self.y_val, preds) 82 | prec = precision_score(self.y_val, preds) 83 | rec = recall_score(self.y_val, preds) 84 | cm = confusion_matrix(self.y_val, preds) 85 | 86 | print('acc: {:.4f}, f1 score: {:.4f}, precision: {:.4f}, recall: {:.4f}'.format( 87 | acc, f1, prec, rec)) 88 | print('confusion_matrix') 89 | print(cm) 90 | 91 | if self.save: 92 | results = Bunch(acc=acc, f1=f1, prec=prec, rec=rec, cm=cm) 93 | out_fname = 'results_'+self.dataname 94 | if self.is_unbalance: 95 | out_fname += '_unb' 96 | if self.with_focal_loss: 97 | out_fname += '_fl' 98 | out_fname += '.p' 99 | results.model = model 100 | results.best_params = best 101 | pickle.dump(results, open(self.PATH/out_fname, 'wb')) 102 | 103 | self.best = best 104 | self.model = model 105 | 106 | def get_objective(self, train): 107 | 108 | def objective(params): 109 | """ 110 | objective function for lightgbm. 111 | """ 112 | # hyperopt casts as float 113 | params['num_boost_round'] = int(params['num_boost_round']) 114 | params['num_leaves'] = int(params['num_leaves']) 115 | 116 | # need to be passed as parameter 117 | if self.is_unbalance: 118 | params['is_unbalance'] = True 119 | params['verbose'] = -1 120 | params['seed'] = 1 121 | 122 | if self.with_focal_loss: 123 | focal_loss = lambda x,y: focal_loss_lgb(x, y, 124 | params['alpha'], params['gamma']) 125 | cv_result = lgb.cv( 126 | params, 127 | train, 128 | num_boost_round=params['num_boost_round'], 129 | fobj = focal_loss, 130 | feval = lgb_focal_f1_score, 131 | nfold=3, 132 | stratified=True, 133 | early_stopping_rounds=20) 134 | else: 135 | cv_result = lgb.cv( 136 | params, 137 | train, 138 | num_boost_round=params['num_boost_round'], 139 | metrics='binary_logloss', 140 | feval = lgb_f1_score, 141 | nfold=3, 142 | stratified=True, 143 | early_stopping_rounds=20) 144 | self.early_stop_dict[objective.i] = len(cv_result['f1-mean']) 145 | score = round(cv_result['f1-mean'][-1], 4) 146 | objective.i+=1 147 | return -score 148 | 149 | return objective 150 | 151 | def hyperparameter_space(self, param_space=None): 152 | 153 | space = { 154 | 'learning_rate': hp.uniform('learning_rate', 0.01, 0.2), 155 | 'num_boost_round': hp.quniform('num_boost_round', 50, 500, 20), 156 | 'num_leaves': hp.quniform('num_leaves', 31, 255, 4), 157 | 'min_child_weight': hp.uniform('min_child_weight', 0.1, 10), 158 | 'colsample_bytree': hp.uniform('colsample_bytree', 0.5, 1.), 159 | 'subsample': hp.uniform('subsample', 0.5, 1.), 160 | 'reg_alpha': hp.uniform('reg_alpha', 0.01, 0.1), 161 | 'reg_lambda': hp.uniform('reg_lambda', 0.01, 0.1), 162 | } 163 | if self.with_focal_loss: 164 | space['alpha'] = hp.uniform('alpha', 0.1, 0.75) 165 | space['gamma'] = hp.uniform('gamma', 0.5, 5) 166 | if param_space: 167 | return param_space 168 | else: 169 | return space --------------------------------------------------------------------------------