├── GRANDE ├── __init__.py └── tf_legacy │ └── GRANDE.py ├── figures ├── grande.jpg └── tabarena_leaderboard.jpg ├── requirements.txt ├── setup.py ├── LICENSE ├── README.md ├── GRANDE_minimal_example_with_comparison_REG.ipynb └── GRANDE_minimal_example_with_comparison_BINARY.ipynb /GRANDE/__init__.py: -------------------------------------------------------------------------------- 1 | from .GRANDE import GRANDE -------------------------------------------------------------------------------- /figures/grande.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-marton/GRANDE/HEAD/figures/grande.jpg -------------------------------------------------------------------------------- /figures/tabarena_leaderboard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-marton/GRANDE/HEAD/figures/tabarena_leaderboard.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core runtime requirements (GRANDE/GRANDE.py) 2 | numpy>=1.25.0,<2.4.0 3 | pandas>=2.0.0,<2.4.0 4 | torch>=2.6,<2.10 5 | scikit-learn>=1.4.0,<1.8.0 6 | category-encoders>=2.6.4,<=2.9.0 7 | tqdm>=4.38,<5 8 | autogluon>=1.3.0,<=1.4.0 9 | 10 | # Notebook / example-only requirements (GRANDE_minimal_example_with_comparison_BINARY.ipynb) 11 | openml>=0.15.0,<=0.15.1 12 | xgboost>=2.0,<3.1 13 | catboost>=1.2,<1.3 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="GRANDE", 5 | license="MIT", 6 | version="0.2.0", 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'numpy>=1.25.0,<2.4.0', 10 | 'pandas>=2.0.0,<2.4.0', 11 | 'torch>=2.6,<2.10', 12 | 'scikit-learn>=1.4.0,<1.8.0', 13 | 'category-encoders>=2.6.4,<=2.9.0', 14 | 'tqdm>=4.38,<5', 15 | 'autogluon>=1.3.0,<=1.4.0', 16 | ], 17 | author="Sascha Marton", 18 | author_email="sascha.marton@gmail.com", 19 | description="A novel ensemble method for hard, axis-aligned decision trees learned end-to-end with gradient descent.", 20 | long_description=open('README.md').read(), 21 | long_description_content_type="text/markdown", 22 | url="https://github.com/s-marton/GRANDE", 23 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sascha Marton 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🌳 GRANDE: Gradient-Based Decision Tree Ensembles 🌳 2 | 3 | [![PyPI version](https://img.shields.io/pypi/v/GRANDE)](https://pypi.org/project/GRANDE/) [![OpenReview](https://img.shields.io/badge/OpenReview-XEFWBxi075-blue)](https://openreview.net/forum?id=XEFWBxi075) [![arXiv](https://img.shields.io/badge/arXiv-2309.17130-b31b1b.svg)](https://arxiv.org/abs/2309.17130) 4 | 5 | 6 |
7 | 8 | TabArena Leaderboard 9 | 10 |

TabArena. The updated PyTorch GRANDE has been evaluated on TabArena and achieved strong results.

11 | 12 |
13 | 14 | 🔍 What's new? 15 | - PyTorch-native implementation for seamless integration; TensorFlow is maintained as a legacy version. 16 | - Strong results on TabArena (specifically for binary classification and regression; multi-class results are less strong dragging down the overall performance which can hopefully be fixed in a future release) 17 | - Method updates for improved performance, including optional categorical and numerical embeddings. 18 | - Training improvements (optimizers, schedulers, early stopping, optional SWA). 19 | - Enhanced preprocessing pipeline with optional frequency encoding and robust normalization. 20 | 21 |
22 | 23 | GRANDE Overview 24 | 25 |

Figure 1: Overview GRANDE. GRANDE learns hard, axis-aligned trees end-to-end via gradient descent, and uses dynamic instance-wise leaf weighting to combine estimators into a strong ensemble.

26 |
27 | 28 | 29 | 🌳 GRANDE is a gradient-based decision tree ensemble for tabular data. 30 | GRANDE trains ensembles of hard, axis-aligned decision trees end-to-end with gradient descent. Each estimator contributes via instance-wise leaf weights that are learned jointly with split locations and leaf values. This combines the strong inductive bias of trees with the flexibility of neural optimization. The PyTorch version optionally augments inputs with learnable categorical and numerical embeddings, improving representation capacity while preserving interpretability of splits. 31 | 32 | 📝 More details in the paper: https://openreview.net/forum?id=XEFWBxi075 33 | 34 | 35 | ## Cite us 36 | ```text 37 | @inproceedings{ 38 | marton2024grande, 39 | title={{GRANDE}: Gradient-Based Decision Tree Ensembles}, 40 | author={Sascha Marton and Stefan L{\"u}dtke and Christian Bartelt and Heiner Stuckenschmidt}, 41 | booktitle={The Twelfth International Conference on Learning Representations}, 42 | year={2024}, 43 | url={https://openreview.net/forum?id=XEFWBxi075} 44 | } 45 | ``` 46 | 47 | ## Installation 48 | To install the latest release: 49 | ```bash 50 | pip install git+https://github.com/s-marton/GRANDE.git 51 | ``` 52 | 53 | ## Dependencies 54 | Install core runtime requirements (and optional notebook/example dependencies) via: 55 | 56 | ```bash 57 | pip install -r requirements.txt 58 | ``` 59 | 60 | Notes: 61 | - The file contains a **core** section (library runtime deps) and a **notebook/example-only** section (OpenML/XGBoost/CatBoost). 62 | 63 | ## Usage (PyTorch) 64 | Example aligned with the attached notebook (binary classification, OpenML dataset 46915). GPU is recommended. 65 | 66 | ```python 67 | # Enable GPU (optional) 68 | import os 69 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 70 | 71 | # Load data 72 | from sklearn.model_selection import train_test_split 73 | import openml 74 | import numpy as np 75 | import sklearn 76 | 77 | dataset = openml.datasets.get_dataset(46915, download_data=True, download_qualities=True, download_features_meta_data=True) 78 | X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute) 79 | 80 | X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42) 81 | X_train, X_valid, y_train, y_valid = train_test_split(X_temp, y_temp, test_size=0.2, random_state=42) 82 | 83 | # GRANDE (PyTorch) 84 | from GRANDE import GRANDE 85 | 86 | params = { 87 | 'depth': 5, 88 | 'n_estimators': 1024, 89 | 90 | 'learning_rate_weights': 0.001, 91 | 'learning_rate_index': 0.01, 92 | 'learning_rate_values': 0.05, 93 | 'learning_rate_leaf': 0.05, 94 | 'learning_rate_embedding': 0.02, # used if embeddings are enabled 95 | 96 | # Embeddings (set True to enable) 97 | 'use_category_embeddings': False, # True to enable 98 | 'embedding_dim_cat': 8, 99 | 'use_numeric_embeddings': False, # True to enable 100 | 'embedding_dim_num': 8, 101 | 'embedding_threshold': 1, # low-cardinality split for categorical embeddings 102 | 'loo_cardinality': 10, # high-cardinality split for encoders 103 | 104 | 'dropout': 0.2, 105 | 'selected_variables': 0.8, 106 | 'data_subset_fraction': 1.0, 107 | 'bootstrap': False, 108 | 'missing_values': False, 109 | 110 | 'optimizer': 'adam', # options: nadam, radam, adamw, adam 111 | 'cosine_decay_restarts': False, 112 | 'reduce_on_plateau_scheduler': True, 113 | 'label_smoothing': 0.0, 114 | 'use_class_weights': False, 115 | 'focal_loss': False, 116 | 'swa': False, 117 | 'es_metric': True, # AUC for binary, MSE for regression, val_loss for multiclass 118 | 119 | 'epochs': 250, 120 | 'batch_size': 256, 121 | 'early_stopping_epochs': 50, 122 | 123 | 'use_freq_enc': False, 124 | 'use_robust_scale_smoothing': False, 125 | 126 | # Important: use problem_type, not objective 127 | 'problem_type': 'binary', # {'binary', 'multiclass', 'regression'} 128 | 129 | 'random_seed': 42, 130 | 'verbose': 2, 131 | } 132 | 133 | model_grande = GRANDE(params=params) 134 | model_grande.fit(X=X_train, y=y_train, X_val=X_valid, y_val=y_valid) 135 | 136 | # Predict 137 | preds_grande = model_grande.predict_proba(X_test) 138 | 139 | # Evaluate (binary) 140 | accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_grande[:, 1])) 141 | f1 = sklearn.metrics.f1_score(y_test, np.round(preds_grande[:, 1]), average='macro') 142 | roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_grande[:, 1], average='macro') 143 | 144 | print('Accuracy GRANDE:', accuracy) 145 | print('F1 Score GRANDE:', f1) 146 | print('ROC AUC GRANDE:', roc_auc) 147 | ``` 148 | 149 | Notes: 150 | - Set use_category_embeddings/use_numeric_embeddings to True to enable embeddings. 151 | - For multiclass, use problem_type='multiclass'. For regression, use 'regression'. 152 | - TensorFlow is supported as a legacy version; the PyTorch path is the recommended/default. 153 | 154 | ## More 155 | This is an experimental implementation. If you encounter issues, please open an issue or report unexpected behavior. 156 | -------------------------------------------------------------------------------- /GRANDE/tf_legacy/GRANDE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | #import tensorflow_addons as tfa 4 | 5 | import sklearn 6 | from copy import deepcopy 7 | import category_encoders as ce 8 | import pandas as pd 9 | import math 10 | from focal_loss import SparseCategoricalFocalLoss 11 | 12 | import pickle 13 | import zipfile 14 | import os 15 | 16 | class GRANDE(tf.keras.Model): 17 | def __init__(self, 18 | params, 19 | args): 20 | 21 | params.update(args) 22 | self.config = None 23 | 24 | super(GRANDE, self).__init__() 25 | self.set_params(**params) 26 | 27 | self.is_fitted = False 28 | 29 | self.internal_node_num_ = 2 ** self.depth - 1 30 | self.leaf_node_num_ = 2 ** self.depth 31 | 32 | tf.keras.utils.set_random_seed(self.random_seed) 33 | 34 | def call(self, inputs, training): 35 | # einsum syntax: 36 | # - b is the batch size 37 | # - e is the number of estimators 38 | # - l the number of leaf nodes (i.e. the number of paths) 39 | # - i is the number of internal nodes 40 | # - d is the depth (i.e. the length of each path) 41 | # - n is the number of variables (one value is stored for each variable) 42 | 43 | 44 | # Adjust data: For each estimator select a subset of the features (output shape: (b, e, n)) 45 | if self.data_subset_fraction < 1.0 and training: 46 | #select subset of samples for each estimator during training if hyperparameter set accordingly 47 | X_estimator = tf.nn.embedding_lookup(tf.transpose(inputs), self.features_by_estimator) 48 | X_estimator = tf.transpose(tf.gather_nd(tf.transpose(X_estimator, [0,2,1]), tf.expand_dims(self.data_select,2), batch_dims=1), [1,0,2]) 49 | else: 50 | #use complete data 51 | X_estimator = tf.gather(inputs, self.features_by_estimator, axis=1) 52 | 53 | #entmax transformaton 54 | split_index_array = entmax15TF(self.split_index_array) 55 | 56 | #use ST-Operator to get one-hot encoded vector for feature index 57 | split_index_array = split_index_array - tf.stop_gradient(split_index_array - tf.one_hot(tf.argmax(split_index_array, axis=-1), depth=split_index_array.shape[-1])) 58 | 59 | # as split_index_array_selected is one-hot-encoded, taking the sum over the last axis after multiplication results in selecting the desired value at the index 60 | s1_sum = tf.einsum("ein,ein->ei", self.split_values, split_index_array) 61 | s2_sum = tf.einsum("ben,ein->bei", X_estimator, split_index_array) 62 | 63 | # calculate the split (output shape: (b, e, i)) 64 | node_result = (tf.nn.softsign(s1_sum-s2_sum) + 1) / 2 65 | 66 | #use round operation with ST operator to get hard decision for each node 67 | node_result_corrected = node_result - tf.stop_gradient(node_result - tf.round(node_result)) 68 | 69 | #generate tensor for further calculation: 70 | # - internal_node_index_list holds the indices for the internal nodes traversed for each path (there are l paths) in the tree 71 | # - for each estimator and for each path in each estimator, the tensors hold the information for all internal nodex traversed 72 | # - the resulting shape of the tensors is (b, e, l, d): 73 | node_result_extended = tf.gather(node_result_corrected, self.internal_node_index_list, axis=2) 74 | 75 | 76 | #reduce the path via multiplication to get result for each path (in each estimator) based on the results of the corresponding internal nodes (output shape: (b, e, l)) 77 | p = tf.reduce_prod(((1-self.path_identifier_list)*node_result_extended + self.path_identifier_list*(1-node_result_extended)), axis=3) 78 | 79 | #calculate instance-wise leaf weights for each estimator by selecting the weight of the selected path for each estimator 80 | estimator_weights_leaf = tf.einsum("el,bel->be", self.estimator_weights, p) 81 | 82 | #use softmax over weights for each instance 83 | estimator_weights_leaf_softmax = tf.keras.activations.softmax(estimator_weights_leaf) 84 | 85 | #optional dropout (deactivating random estimators) 86 | estimator_weights_leaf_softmax = self.apply_dropout_leaf(estimator_weights_leaf_softmax, training=training) 87 | 88 | #get raw prediction for each estimator 89 | #optionally transform to probability distribution before weighting 90 | if self.objective == 'regression': 91 | layer_output = tf.einsum('el,bel->be', self.leaf_classes_array, p) 92 | layer_output = tf.einsum('be,be->be', estimator_weights_leaf_softmax, layer_output) 93 | elif self.objective == 'binary': 94 | if self.from_logits: 95 | layer_output = tf.einsum('el,bel->be', self.leaf_classes_array, p) 96 | else: 97 | layer_output = tf.math.sigmoid(tf.einsum('el,bel->be', self.leaf_classes_array, p)) 98 | layer_output = tf.einsum('be,be->be', estimator_weights_leaf_softmax, layer_output) 99 | elif self.objective == 'classification': 100 | if self.from_logits: 101 | layer_output = tf.einsum('elc,bel->bec', self.leaf_classes_array, p) 102 | layer_output = tf.einsum('be,bec->bec', estimator_weights_leaf_softmax, layer_output) 103 | else: 104 | layer_output = tf.keras.activations.softmax(tf.einsum('elc,bel->bec', self.leaf_classes_array, p)) 105 | layer_output = tf.einsum('be,bec->bec', estimator_weights_leaf_softmax, layer_output) 106 | 107 | if self.data_subset_fraction < 1.0 and training: 108 | result = tf.scatter_nd(indices=tf.expand_dims(self.data_select, 2), updates=tf.transpose(layer_output), shape=[tf.shape(inputs)[0]]) 109 | result = (result / self.counts) * self.n_estimators 110 | else: 111 | if self.objective == 'regression' or self.objective == 'binary': 112 | result = tf.einsum('be->b', layer_output) 113 | else: 114 | result = tf.einsum('bec->bc', layer_output) 115 | 116 | if self.objective == 'regression' or self.objective == 'binary': 117 | result = tf.expand_dims(result, 1) 118 | 119 | return result 120 | 121 | def apply_preprocessing(self, X): 122 | 123 | if not isinstance(X, pd.DataFrame): 124 | X = pd.DataFrame(X) 125 | 126 | if len(self.num_columns) > 0: 127 | X[self.num_columns] = X[self.num_columns].fillna(self.mean_train_num) 128 | if len(self.cat_columns) > 0: 129 | X[self.cat_columns] = X[self.cat_columns].fillna(self.mode_train_cat) 130 | 131 | X = self.encoder_ordinal.transform(X) 132 | X = self.encoder_loo.transform(X) 133 | X = self.encoder_ohe.transform(X) 134 | 135 | X = self.normalizer.transform(X.values.astype(np.float64)) 136 | 137 | return X 138 | 139 | def perform_preprocessing(self, 140 | X_train, 141 | y_train, 142 | X_val, 143 | y_val): 144 | 145 | if isinstance(y_train, pd.Series): 146 | try: 147 | y_train = y_train.values.codes.astype(np.float64) 148 | except: 149 | pass 150 | if isinstance(y_val, pd.Series): 151 | try: 152 | y_val = y_val.values.codes.astype(np.float64) 153 | except: 154 | pass 155 | 156 | if not isinstance(X_train, pd.DataFrame): 157 | X_train = pd.DataFrame(X_train) 158 | if not isinstance(X_val, pd.DataFrame): 159 | X_val = pd.DataFrame(X_val) 160 | 161 | self.mean = np.mean(y_train) 162 | self.std = np.std(y_train) 163 | 164 | binary_indices = [] 165 | low_cardinality_indices = [] 166 | high_cardinality_indices = [] 167 | num_columns = [] 168 | for column_index, column in enumerate(X_train.columns): 169 | if column_index in self.cat_idx: 170 | if len(X_train.iloc[:,column_index].unique()) <= 2: 171 | binary_indices.append(column) 172 | elif len(X_train.iloc[:,column_index].unique()) < 5: 173 | low_cardinality_indices.append(column) 174 | else: 175 | high_cardinality_indices.append(column) 176 | else: 177 | num_columns.append(column) 178 | 179 | cat_columns = [col for col in X_train.columns if col not in num_columns] 180 | 181 | if len(num_columns) > 0: 182 | self.mean_train_num = X_train[num_columns].mean(axis=0).iloc[0] 183 | X_train[num_columns] = X_train[num_columns].fillna(self.mean_train_num) 184 | X_val[num_columns] = X_val[num_columns].fillna(self.mean_train_num) 185 | if len(cat_columns) > 0: 186 | self.mode_train_cat = X_train[cat_columns].mode(axis=0).iloc[0] 187 | X_train[cat_columns] = X_train[cat_columns].fillna(self.mode_train_cat) 188 | X_val[cat_columns] = X_val[cat_columns].fillna(self.mode_train_cat) 189 | 190 | self.cat_columns = cat_columns 191 | self.num_columns = num_columns 192 | 193 | self.encoder_ordinal = ce.OrdinalEncoder(cols=binary_indices) 194 | self.encoder_ordinal.fit(X_train) 195 | X_train = self.encoder_ordinal.transform(X_train) 196 | X_val = self.encoder_ordinal.transform(X_val) 197 | 198 | self.encoder_loo = ce.LeaveOneOutEncoder(cols=high_cardinality_indices) 199 | if self.objective == 'regression': 200 | self.encoder_loo.fit(X_train, (y_train-self.mean)/self.std) 201 | else: 202 | self.encoder_loo.fit(X_train, y_train) 203 | X_train = self.encoder_loo.transform(X_train) 204 | X_val = self.encoder_loo.transform(X_val) 205 | 206 | self.encoder_ohe = ce.OneHotEncoder(cols=low_cardinality_indices) 207 | self.encoder_ohe.fit(X_train) 208 | X_train = self.encoder_ohe.transform(X_train) 209 | X_val = self.encoder_ohe.transform(X_val) 210 | 211 | X_train = X_train.astype(np.float32) 212 | X_val = X_val.astype(np.float32) 213 | 214 | quantile_noise = 1e-4 215 | quantile_train = np.copy(X_train.values).astype(np.float64) 216 | np.random.seed(42) 217 | stds = np.std(quantile_train, axis=0, keepdims=True) 218 | noise_std = quantile_noise / np.maximum(stds, quantile_noise) 219 | quantile_train += noise_std * np.random.randn(*quantile_train.shape) 220 | 221 | quantile_train = pd.DataFrame(quantile_train, columns=X_train.columns, index=X_train.index) 222 | 223 | self.normalizer = sklearn.preprocessing.QuantileTransformer( 224 | n_quantiles=min(quantile_train.shape[0], 1000), 225 | output_distribution='normal', 226 | ) 227 | 228 | self.normalizer.fit(quantile_train.values.astype(np.float64)) 229 | X_train = self.normalizer.transform(X_train.values.astype(np.float64)) 230 | X_val = self.normalizer.transform(X_val.values.astype(np.float64)) 231 | 232 | return X_train, y_train, X_val, y_val 233 | 234 | def convert_to_numpy(self,data): 235 | """ 236 | Converts input data (Pandas DataFrame, TensorFlow tensor, PyTorch tensor, list, or similar iterable) to a NumPy array. 237 | 238 | Args: 239 | data: Input data to be converted. Can be a Pandas DataFrame, TensorFlow tensor, list, or similar iterable. 240 | 241 | Returns: 242 | numpy_array: A NumPy array representation of the input data. 243 | """ 244 | # Check if the data is a Pandas DataFrame 245 | if isinstance(data, (pd.DataFrame, pd.Series)): 246 | return data.values 247 | 248 | # Check if the data is a TensorFlow tensor 249 | elif isinstance(data, tf.Tensor): 250 | return data.numpy() 251 | 252 | # Check if the data is a list or similar iterable (not including strings) 253 | elif isinstance(data, (list, tuple, np.ndarray)): 254 | return np.array(data) 255 | 256 | else: 257 | raise TypeError("The input data type is not supported for conversion.") 258 | 259 | def fit(self, 260 | X_train, 261 | y_train, 262 | X_val=None, 263 | y_val=None, 264 | **kwargs): 265 | 266 | if self.preprocess_data: 267 | X_train, y_train, X_val, y_val = self.perform_preprocessing(X_train, y_train, X_val, y_val) 268 | else: 269 | X_train = self.convert_to_numpy(X_train) 270 | y_train = self.convert_to_numpy(y_train) 271 | X_val = self.convert_to_numpy(X_val) 272 | y_val = self.convert_to_numpy(y_val) 273 | 274 | jit_compile = True #X_train.shape[0] < 10_000 275 | 276 | self.number_of_variables = X_train.shape[1] 277 | if self.use_class_weights: 278 | if self.objective == 'classification' or self.objective == 'binary': 279 | self.number_of_classes = len(np.unique(y_train)) 280 | self.class_weights = sklearn.utils.class_weight.compute_class_weight(class_weight = 'balanced', classes = np.unique(y_train), y = y_train) 281 | 282 | self.class_weight_dict = {} 283 | for i in range(self.number_of_classes): 284 | self.class_weight_dict[i] = self.class_weights[i] 285 | 286 | else: 287 | self.number_of_classes = 1 288 | self.class_weights = np.ones_like(np.unique(y_train)) 289 | self.class_weight_dict = None 290 | else: 291 | if self.objective == 'classification' or self.objective == 'binary': 292 | self.number_of_classes = len(np.unique(y_train)) 293 | else: 294 | self.number_of_classes = 1 295 | self.class_weights = np.ones_like(np.unique(y_train)) 296 | self.class_weight_dict = None 297 | 298 | self.build_model() 299 | 300 | self.compile(loss=self.loss_name, metrics=[], jit_compile=jit_compile, mean=self.mean, std=self.std, class_weight=self.class_weights) 301 | 302 | 303 | train_data = tf.data.Dataset.from_tensor_slices((tf.dtypes.cast(tf.convert_to_tensor(X_train), tf.float32), 304 | tf.dtypes.cast(tf.convert_to_tensor(y_train), tf.float32))) 305 | 306 | if self.data_subset_fraction < 1.0: 307 | train_data = (train_data 308 | #.shuffle(32_768) 309 | .cache() 310 | .batch(batch_size=self.batch_size, drop_remainder=True) 311 | .prefetch(tf.data.AUTOTUNE) 312 | ) 313 | else: 314 | train_data = (train_data 315 | .shuffle(32_768) 316 | .cache() 317 | .batch(batch_size=self.batch_size, drop_remainder=False) 318 | .prefetch(tf.data.AUTOTUNE) 319 | ) 320 | 321 | if X_val is not None and y_val is not None: 322 | validation_data = (X_val, y_val) 323 | validation_data = tf.data.Dataset.from_tensor_slices((tf.dtypes.cast(tf.convert_to_tensor(validation_data[0]), tf.float32), 324 | tf.dtypes.cast(tf.convert_to_tensor(validation_data[1]), tf.float32))) 325 | 326 | 327 | validation_data = (validation_data 328 | .cache() 329 | .batch(batch_size=self.batch_size, drop_remainder=False) 330 | .prefetch(tf.data.AUTOTUNE) 331 | ) 332 | 333 | monitor = 'val_loss' 334 | else: 335 | monitor = 'loss' 336 | 337 | 338 | if 'callbacks' not in kwargs.keys(): 339 | callbacks = [] 340 | 341 | early_stopping = tf.keras.callbacks.EarlyStopping(monitor=monitor, 342 | patience=self.early_stopping_epochs, 343 | min_delta=1e-3, 344 | restore_best_weights=True) 345 | callbacks.append(early_stopping) 346 | 347 | if 'reduce_lr' in kwargs.keys() and kwargs['reduce_lr']: 348 | reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(factor=0.2, patience=self.early_stopping_epochs//3) 349 | callbacks.append(reduce_lr) 350 | 351 | super(GRANDE, self).fit(train_data, 352 | validation_data = validation_data, 353 | epochs = self.epochs, 354 | callbacks = callbacks, 355 | class_weight = self.class_weight_dict, 356 | verbose=self.verbose, 357 | **kwargs) 358 | 359 | def build_model(self): 360 | 361 | tf.keras.utils.set_random_seed(self.random_seed) 362 | 363 | if self.selected_variables > 1: 364 | self.selected_variables = min(self.selected_variables, self.number_of_variables) 365 | else: 366 | self.selected_variables = int(self.number_of_variables * self.selected_variables) 367 | self.selected_variables = min(self.selected_variables, 50) 368 | self.selected_variables = max(self.selected_variables, 10) 369 | self.selected_variables = min(self.selected_variables, self.number_of_variables) 370 | if self.objective != 'binary': 371 | self.data_subset_fraction = 1.0 372 | if self.data_subset_fraction < 1.0: 373 | if self.batch_size * self.data_subset_fraction > 4: 374 | self.subset_size = tf.cast(self.batch_size * self.data_subset_fraction, tf.int32) 375 | else: 376 | self.subset_size = tf.cast(4, tf.int32) 377 | 378 | if self.bootstrap: 379 | self.data_select = tf.random.uniform(shape=(self.n_estimators, self.subset_size), minval=0, maxval=self.batch_size, dtype=tf.int32) 380 | else: 381 | 382 | indices = [np.random.choice(self.batch_size, size=self.subset_size, replace=False) for _ in range(self.n_estimators)] 383 | self.data_select = tf.stack(indices) 384 | 385 | items, self.counts = np.unique(self.data_select, return_counts=True) 386 | self.counts = tf.constant(self.counts, dtype=tf.float32) 387 | 388 | self.features_by_estimator = tf.stack([np.random.choice(self.number_of_variables, size=(self.selected_variables), replace=False, p=None) for _ in range(self.n_estimators)]) 389 | 390 | self.path_identifier_list = [] 391 | self.internal_node_index_list = [] 392 | for leaf_index in tf.unstack(tf.constant([i for i in range(self.leaf_node_num_)])): 393 | for current_depth in tf.unstack(tf.constant([i for i in range(1, self.depth+1)])): 394 | path_identifier = tf.cast(tf.math.floormod(tf.math.floor(leaf_index/(tf.math.pow(2, (self.depth-current_depth)))), 2), tf.float32) 395 | internal_node_index = tf.cast(tf.cast(tf.math.pow(2, (current_depth-1)), tf.float32) + tf.cast(tf.math.floor(leaf_index/(tf.math.pow(2, (self.depth-(current_depth-1))))), tf.float32) - 1.0, tf.int64) 396 | self.path_identifier_list.append(path_identifier) 397 | self.internal_node_index_list.append(internal_node_index) 398 | self.path_identifier_list = tf.reshape(tf.stack(self.path_identifier_list), (-1,self.depth)) 399 | self.internal_node_index_list = tf.reshape(tf.cast(tf.stack(self.internal_node_index_list), tf.int64), (-1,self.depth)) 400 | 401 | leaf_classes_array_shape = [self.n_estimators,self.leaf_node_num_,] if self.objective == 'binary' or self.objective == 'regression' else [self.n_estimators, self.leaf_node_num_, self.number_of_classes] 402 | 403 | weight_shape = [self.n_estimators,self.leaf_node_num_] 404 | 405 | self.estimator_weights = self.add_weight(shape=weight_shape, 406 | initializer={'class_name': self.initializer, 'config': {'seed': self.random_seed + 1}}, 407 | trainable=True, 408 | name="estimator_weights",) 409 | 410 | self.split_values = self.add_weight(shape=(self.n_estimators, self.internal_node_num_, self.selected_variables), 411 | initializer={'class_name': self.initializer, 'config': {'seed': self.random_seed + 2}}, 412 | trainable=True, 413 | name="split_values",) 414 | 415 | 416 | self.split_index_array = self.add_weight(shape=(self.n_estimators, self.internal_node_num_, self.selected_variables), 417 | initializer={'class_name': self.initializer, 'config': {'seed': self.random_seed + 3}}, 418 | trainable=True, 419 | name="split_index_array",) 420 | 421 | self.leaf_classes_array = self.add_weight(shape=leaf_classes_array_shape, 422 | initializer={'class_name': self.initializer, 'config': {'seed': self.random_seed + 4}}, 423 | trainable=True, 424 | name="leaf_classes_array",) 425 | 426 | 427 | def compile(self, 428 | loss, 429 | metrics, 430 | jit_compile, 431 | **kwargs): 432 | 433 | if self.objective == 'classification': 434 | 435 | if loss == 'crossentropy': 436 | if not self.focal_loss: 437 | loss_function = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=self.from_logits) #tf.keras.losses.get('categorical_crossentropy') 438 | else: 439 | loss_function = SparseCategoricalFocalLoss(gamma=2, class_weight=self.class_weights, from_logits=from_logits) 440 | else: 441 | loss_function = tf.keras.losses.get(loss) 442 | try: 443 | loss_function.from_logits = self.from_logits 444 | except: 445 | pass 446 | elif self.objective == 'binary': 447 | if loss == 'crossentropy': 448 | if not self.focal_loss: 449 | loss_function = tf.keras.losses.BinaryCrossentropy(from_logits=self.from_logits) #tf.keras.losses.get('binary_crossentropy') 450 | else: 451 | loss_function = tf.keras.losses.BinaryFocalCrossentropy(alpha=0.5, gamma=2, from_logits=self.from_logits) 452 | else: 453 | loss_function = tf.keras.losses.get(loss) 454 | try: 455 | loss_function.from_logits = self.from_logits 456 | except: 457 | pass 458 | 459 | elif self.objective == 'regression': 460 | loss_function = loss_function_regression(loss_name=loss, mean=kwargs['mean'], std=kwargs['std']) 461 | 462 | loss_function = loss_function_weighting(loss_function, temp=self.temperature) 463 | self.weights_optimizer = get_optimizer_by_name(optimizer_name=self.optimizer_name, learning_rate=self.learning_rate_weights, warmup_steps=0, cosine_decay_steps=self.cosine_decay_steps) 464 | self.index_optimizer = get_optimizer_by_name(optimizer_name=self.optimizer_name, learning_rate=self.learning_rate_index, warmup_steps=0, cosine_decay_steps=self.cosine_decay_steps) 465 | self.values_optimizer = get_optimizer_by_name(optimizer_name=self.optimizer_name, learning_rate=self.learning_rate_values, warmup_steps=0, cosine_decay_steps=self.cosine_decay_steps) 466 | self.leaf_optimizer = get_optimizer_by_name(optimizer_name=self.optimizer_name, learning_rate=self.learning_rate_leaf, warmup_steps=0, cosine_decay_steps=self.cosine_decay_steps) 467 | 468 | super(GRANDE, self).compile(loss=loss_function, metrics=metrics, jit_compile=jit_compile) 469 | 470 | def train_step(self, data): 471 | 472 | if len(data) == 3: 473 | x, y, sample_weight = data 474 | else: 475 | x, y = data 476 | sample_weight = None 477 | 478 | if not self.built: 479 | _ = self(x, training=True) 480 | 481 | with tf.GradientTape(persistent=True) as tape: 482 | y_pred = self(x, training=True) 483 | loss = self.compute_loss(x=None, y=y, y_pred=y_pred, sample_weight=sample_weight) 484 | 485 | weights_gradients = tape.gradient(loss, [self.estimator_weights]) 486 | self.weights_optimizer.apply_gradients(zip(weights_gradients, [self.estimator_weights])) 487 | 488 | index_gradients = tape.gradient(loss, [self.split_index_array]) 489 | self.index_optimizer.apply_gradients(zip(index_gradients, [self.split_index_array])) 490 | 491 | values_gradients = tape.gradient(loss, [self.split_values]) 492 | self.values_optimizer.apply_gradients(zip(values_gradients, [self.split_values])) 493 | 494 | leaf_gradients = tape.gradient(loss, [self.leaf_classes_array]) 495 | self.leaf_optimizer.apply_gradients(zip(leaf_gradients, [self.leaf_classes_array])) 496 | 497 | 498 | for metric in self.metrics: 499 | if metric.name == "loss": 500 | metric.update_state(loss) 501 | else: 502 | metric.update_state(y, y_pred) 503 | 504 | return {m.name: m.result() for m in self.metrics} 505 | 506 | 507 | def predict(self, X): 508 | 509 | if self.preprocess_data: 510 | X = self.apply_preprocessing(X) 511 | else: 512 | X = self.convert_to_numpy(X) 513 | 514 | preds = super(GRANDE, self).predict(X, self.batch_size, verbose=0) 515 | preds = tf.convert_to_tensor(preds) 516 | if self.objective == 'regression': 517 | preds = preds * self.std + self.mean 518 | else: 519 | if self.from_logits: 520 | if self.objective == 'binary': 521 | preds = tf.math.sigmoid(preds) 522 | elif self.objective == 'classification': 523 | preds = tf.keras.activations.softmax(preds) 524 | 525 | if self.objective == 'binary': 526 | preds = tf.stack([1-tf.squeeze(preds), tf.squeeze(preds)], axis=-1) 527 | 528 | return preds.numpy() 529 | 530 | def set_params(self, **kwargs): 531 | 532 | if self.config is None: 533 | self.config = self.default_parameters() 534 | 535 | self.config.update(kwargs) 536 | 537 | if self.config['n_estimators'] == 1: 538 | self.config['selected_variables'] = 1.0 539 | self.config['data_subset_fraction'] = 1.0 540 | self.config['bootstrap'] = False 541 | self.config['dropout'] = 0.0 542 | 543 | if 'loss' not in self.config.keys(): 544 | if self.config['objective'] == 'classification' or self.config['objective'] == 'binary': 545 | self.config['loss'] = 'crossentropy' 546 | self.config['focal_loss'] = False 547 | elif self.config['objective'] == 'regression': 548 | self.config['loss'] = 'mse' 549 | self.config['focal_loss'] = False 550 | 551 | self.config['optimizer_name'] = self.config.pop('optimizer') 552 | self.config['loss_name'] = self.config.pop('loss') 553 | 554 | for arg_key, arg_value in self.config.items(): 555 | setattr(self, arg_key, arg_value) 556 | 557 | tf.keras.utils.set_random_seed(self.random_seed) 558 | 559 | 560 | def get_params(self): 561 | return self.config 562 | 563 | def define_trial_parameters(self, trial, args): 564 | params = { 565 | 'depth': trial.suggest_int("depth", 3, 7), 566 | 'n_estimators': trial.suggest_int("n_estimators", 512, 2048), 567 | 568 | 'learning_rate_weights': trial.suggest_float("learning_rate_weights", 0.0001, 0.25), 569 | 'learning_rate_index': trial.suggest_float("learning_rate_index", 0.0001, 0.25), 570 | 'learning_rate_values': trial.suggest_float("learning_rate_values", 0.0001, 0.25), 571 | 'learning_rate_leaf': trial.suggest_float("learning_rate_leaf", 0.0001, 0.25), 572 | 573 | 'cosine_decay_steps': trial.suggest_categorical("cosine_decay_steps", [0, 100, 1000]), 574 | 575 | 'dropout': trial.suggest_categorical("dropout", [0, 0.25, 0.5]), 576 | 577 | 'selected_variables': trial.suggest_categorical("selected_variables", [1.0, 0.75, 0.5]), 578 | 'data_subset_fraction': trial.suggest_categorical("data_subset_fraction", [1.0, 0.8]), 579 | } 580 | 581 | try: 582 | if args['objective'] != 'regression': 583 | params['focal_loss'] = trial.suggest_categorical("focal_loss", [True, False]) 584 | params['temperature'] = trial.suggest_categorical("temperature", [0, 0.25]) 585 | except: 586 | if self.objective != 'regression': 587 | params['focal_loss'] = trial.suggest_categorical("focal_loss", [True, False]) 588 | params['temperature'] = trial.suggest_categorical("temperature", [0, 0.25]) 589 | return params 590 | 591 | def get_random_parameters(self, seed): 592 | rs = np.random.RandomState(seed) 593 | params = { 594 | 'depth': rs.randint(3, 7), 595 | 'n_estimators': rs.randint(512, 2048), 596 | 597 | 'learning_rate_weights': rs.uniform(0.0001, 0.25), 598 | 'learning_rate_index': rs.uniform(0.0001, 0.25), 599 | 'learning_rate_values': rs.uniform(0.0001, 0.25), 600 | 'learning_rate_leaf': rs.uniform(0.0001, 0.25), 601 | 602 | 'cosine_decay_steps': rs.choice([0, 100, 1000], p=[0.5, 0.25, 0.25]), 603 | 'dropout': rs.choice([0, 0.25]), 604 | 605 | 'selected_variables': rs.choice([1.0, 0.75, 0.5]), 606 | 'data_subset_fraction': rs.choice([1.0, 0.8]), 607 | 608 | } 609 | 610 | if self.objective != 'regression': 611 | params['focal_loss'] = rs.choice([True, False]) 612 | params['temperature'] = rs.choice([1, 1/3, 1/5, 1/7, 1/9, 0], p=[0.1, 0.1, 0.1, 0.1, 0.1,0.5]) 613 | 614 | return params 615 | 616 | def default_parameters(self): 617 | params = { 618 | 'depth': 5, 619 | 'n_estimators': 2048, 620 | 621 | 'learning_rate_weights': 0.005, 622 | 'learning_rate_index': 0.01, 623 | 'learning_rate_values': 0.01, 624 | 'learning_rate_leaf': 0.01, 625 | 626 | 'optimizer': 'adam', 627 | 'cosine_decay_steps': 0, 628 | 'temperature': 0.0, 629 | 630 | 'initializer': 'RandomNormal', 631 | 632 | 'loss': 'crossentropy', 633 | 'focal_loss': False, 634 | 635 | 'from_logits': True, 636 | 'use_class_weights': True, 637 | 'preprocess_data': True, 638 | 639 | 'dropout': 0.0, 640 | 641 | 'selected_variables': 0.8, 642 | 'data_subset_fraction': 1.0, 643 | 'bootstrap': False, 644 | } 645 | 646 | return params 647 | 648 | def save_model(self, save_path='model_gande'): 649 | config = {'params': { 650 | 'depth': self.depth, 651 | 'n_estimators': self.n_estimators, 652 | 653 | 'std': self.std, 654 | 'mean': self.mean, 655 | 'mode_train_cat': self.mode_train_cat, 656 | 'mean_train_num': self.mean_train_num, 657 | 658 | 'encoder_ordinal': self.encoder_ordinal, 659 | 'encoder_loo': self.encoder_loo, 660 | 'encoder_ohe': self.encoder_ohe, 661 | 'normalizer': self.normalizer, 662 | 663 | 'number_of_classes': self.number_of_classes, 664 | 'number_of_variables': self.number_of_variables, 665 | 666 | 'selected_variables': self.selected_variables, 667 | 'data_subset_fraction': self.data_subset_fraction, 668 | }, 669 | 'args': { 670 | 'objective': self.objective, 671 | 672 | 'random_seed': self.random_seed, 673 | 'verbose': self.verbose, 674 | 675 | } 676 | } 677 | 678 | temp_dir = 'temp_model' 679 | os.makedirs(temp_dir, exist_ok=True) 680 | 681 | np.savez(os.path.join(temp_dir, 'variables.npz'), 682 | estimator_weights=self.estimator_weights.numpy(), 683 | split_values=self.split_values.numpy(), 684 | split_index_array=self.split_index_array.numpy(), 685 | leaf_classes_array=self.leaf_classes_array.numpy()) 686 | 687 | config_path = os.path.join(temp_dir, 'config.pkl') 688 | with open(config_path, 'wb') as config_file: 689 | pickle.dump(config, config_file) 690 | 691 | with zipfile.ZipFile(save_path, 'w') as model_zip: 692 | for dirname, _, filenames in os.walk(temp_dir): 693 | for filename in filenames: 694 | file_path = os.path.join(dirname, filename) 695 | model_zip.write(file_path, arcname=os.path.relpath(file_path, temp_dir)) 696 | 697 | for dirname, _, filenames in os.walk(temp_dir): 698 | for filename in filenames: 699 | os.remove(os.path.join(dirname, filename)) 700 | os.rmdir(temp_dir) 701 | 702 | print(f"Model and config saved to {save_path}") 703 | 704 | 705 | def load_model(load_path='model_gande'): 706 | temp_dir = 'temp_model' 707 | with zipfile.ZipFile(load_path, 'r') as model_zip: 708 | model_zip.extractall(temp_dir) 709 | 710 | config_path = os.path.join(temp_dir, 'config.pkl') 711 | with open(config_path, 'rb') as config_file: 712 | config = pickle.load(config_file) 713 | 714 | 715 | model = GRANDE(params=config['params'], args=config['args']) 716 | 717 | model.build_model() 718 | model.predict(np.random.uniform(0,1, (1,config['params']['number_of_variables']))) 719 | 720 | variables = np.load(os.path.join(temp_dir, 'variables.npz')) 721 | 722 | model.estimator_weights.assign(variables['estimator_weights']) 723 | model.split_values.assign(variables['split_values']) 724 | model.split_index_array.assign(variables['split_index_array']) 725 | model.leaf_classes_array.assign(variables['leaf_classes_array']) 726 | 727 | for dirname, _, filenames in os.walk(temp_dir): 728 | for filename in filenames: 729 | os.remove(os.path.join(dirname, filename)) 730 | os.rmdir(temp_dir) 731 | 732 | print("Model and config loaded successfully.") 733 | 734 | return model 735 | 736 | def apply_dropout_leaf(self, 737 | index_array: tf.Tensor, 738 | training: bool): 739 | 740 | if training and self.dropout > 0.0: 741 | index_array = tf.nn.dropout(index_array, rate=self.dropout)*(1-self.dropout) 742 | index_array = index_array/tf.expand_dims(tf.reduce_sum(index_array, axis=1), 1) 743 | else: 744 | index_array = index_array 745 | 746 | return index_array 747 | 748 | 749 | def entmax15TF(inputs, axis=-1): 750 | 751 | # Implementation taken from: https://github.com/deep-spin/entmax/tree/master/entmax 752 | 753 | """ 754 | Entmax 1.5 implementation, heavily inspired by 755 | * paper: https://arxiv.org/pdf/1905.05702.pdf 756 | * pytorch code: https://github.com/deep-spin/entmax 757 | :param inputs: similar to softmax logits, but for entmax1.5 758 | :param axis: entmax1.5 outputs will sum to 1 over this axis 759 | :return: entmax activations of same shape as inputs 760 | """ 761 | @tf.custom_gradient 762 | def _entmax_inner(inputs): 763 | with tf.name_scope('entmax'): 764 | inputs = inputs / 2 # divide by 2 so as to solve actual entmax 765 | inputs -= tf.reduce_max(inputs, axis, keepdims=True) # subtract max for stability 766 | 767 | threshold, _ = entmax_threshold_and_supportTF(inputs, axis) 768 | outputs_sqrt = tf.nn.relu(inputs - threshold) 769 | outputs = tf.square(outputs_sqrt) 770 | 771 | def grad_fn(d_outputs): 772 | with tf.name_scope('entmax_grad'): 773 | d_inputs = d_outputs * outputs_sqrt 774 | q = tf.reduce_sum(d_inputs, axis=axis, keepdims=True) 775 | q = q / tf.reduce_sum(outputs_sqrt, axis=axis, keepdims=True) 776 | d_inputs -= q * outputs_sqrt 777 | return d_inputs 778 | 779 | return outputs, grad_fn 780 | 781 | return _entmax_inner(inputs) 782 | 783 | 784 | def top_k_over_axisTF(inputs, k, axis=-1, **kwargs): 785 | """ performs tf.nn.top_k over any chosen axis """ 786 | with tf.name_scope('top_k_along_axis'): 787 | if axis == -1: 788 | return tf.nn.top_k(inputs, k, **kwargs) 789 | 790 | perm_order = list(range(inputs.shape.ndims)) 791 | perm_order.append(perm_order.pop(axis)) 792 | inv_order = [perm_order.index(i) for i in range(len(perm_order))] 793 | 794 | input_perm = tf.transpose(inputs, perm_order) 795 | input_perm_sorted, sort_indices_perm = tf.nn.top_k( 796 | input_perm, k=k, **kwargs) 797 | 798 | input_sorted = tf.transpose(input_perm_sorted, inv_order) 799 | sort_indices = tf.transpose(sort_indices_perm, inv_order) 800 | return input_sorted, sort_indices 801 | 802 | 803 | def _make_ix_likeTF(inputs, axis=-1): 804 | """ creates indices 0, ... , input[axis] unsqueezed to input dimensios """ 805 | assert inputs.shape.ndims is not None 806 | rho = tf.cast(tf.range(1, tf.shape(inputs)[axis] + 1), dtype=inputs.dtype) 807 | view = [1] * inputs.shape.ndims 808 | view[axis] = -1 809 | return tf.reshape(rho, view) 810 | 811 | 812 | def gather_over_axisTF(values, indices, gather_axis): 813 | """ 814 | replicates the behavior of torch.gather for tf<=1.8; 815 | for newer versions use tf.gather with batch_dims 816 | :param values: tensor [d0, ..., dn] 817 | :param indices: int64 tensor of same shape as values except for gather_axis 818 | :param gather_axis: performs gather along this axis 819 | :returns: gathered values, same shape as values except for gather_axis 820 | If gather_axis == 2 821 | gathered_values[i, j, k, ...] = values[i, j, indices[i, j, k, ...], ...] 822 | see torch.gather for more detils 823 | """ 824 | assert indices.shape.ndims is not None 825 | assert indices.shape.ndims == values.shape.ndims 826 | 827 | ndims = indices.shape.ndims 828 | gather_axis = gather_axis % ndims 829 | shape = tf.shape(indices) 830 | 831 | selectors = [] 832 | for axis_i in range(ndims): 833 | if axis_i == gather_axis: 834 | selectors.append(indices) 835 | else: 836 | index_i = tf.range(tf.cast(shape[axis_i], dtype=indices.dtype), dtype=indices.dtype) 837 | index_i = tf.reshape(index_i, [-1 if i == axis_i else 1 for i in range(ndims)]) 838 | index_i = tf.tile(index_i, [shape[i] if i != axis_i else 1 for i in range(ndims)]) 839 | selectors.append(index_i) 840 | 841 | return tf.gather_nd(values, tf.stack(selectors, axis=-1)) 842 | 843 | 844 | def entmax_threshold_and_supportTF(inputs, axis=-1): 845 | """ 846 | Computes clipping threshold for entmax1.5 over specified axis 847 | NOTE this implementation uses the same heuristic as 848 | the original code: https://tinyurl.com/pytorch-entmax-line-203 849 | :param inputs: (entmax1.5 inputs - max) / 2 850 | :param axis: entmax1.5 outputs will sum to 1 over this axis 851 | """ 852 | 853 | with tf.name_scope('entmax_threshold_and_supportTF'): 854 | num_outcomes = tf.shape(inputs)[axis] 855 | inputs_sorted, _ = top_k_over_axisTF(inputs, k=num_outcomes, axis=axis, sorted=True) 856 | 857 | rho = _make_ix_likeTF(inputs, axis=axis) 858 | 859 | mean = tf.cumsum(inputs_sorted, axis=axis) / rho 860 | 861 | mean_sq = tf.cumsum(tf.square(inputs_sorted), axis=axis) / rho 862 | delta = (1 - rho * (mean_sq - tf.square(mean))) / rho 863 | 864 | delta_nz = tf.nn.relu(delta) 865 | tau = mean - tf.sqrt(delta_nz) 866 | 867 | support_size = tf.reduce_sum(tf.cast(tf.less_equal(tau, inputs_sorted), tf.int64), axis=axis, keepdims=True) 868 | 869 | tau_star = gather_over_axisTF(tau, support_size - 1, axis) 870 | return tau_star, support_size 871 | 872 | 873 | def loss_function_weighting(loss_function, temp=0.25): 874 | 875 | # Implementation of "Stochastic Re-weighted Gradient Descent via Distributionally Robust Optimization" from https://arxiv.org/abs/2306.09222 876 | 877 | loss_function.reduction = tf.keras.losses.Reduction.NONE 878 | def _loss_function_weighting(y_true, y_pred): 879 | loss = loss_function(y_true, y_pred) 880 | 881 | if temp > 0: 882 | clamped_loss = tf.clip_by_value(loss, clip_value_min=float('-inf'), clip_value_max=temp) 883 | 884 | out = loss * tf.stop_gradient(tf.exp(clamped_loss / (temp + 1))) 885 | else: 886 | out = loss 887 | 888 | return tf.reduce_mean(out) 889 | return _loss_function_weighting 890 | 891 | 892 | def loss_function_regression(loss_name, mean, std): #mean, log, 893 | loss_function = tf.keras.losses.get(loss_name) 894 | def _loss_function_regression(y_true, y_pred): 895 | #if tf.keras.backend.learning_phase(): 896 | y_true = (y_true - mean) / std 897 | 898 | loss = loss_function(y_true, y_pred) 899 | 900 | return loss 901 | return _loss_function_regression 902 | 903 | def _threshold_and_supportTF(input, dim=-1): 904 | Xsrt = tf.sort(input, axis=dim, direction='DESCENDING') 905 | 906 | rho = tf.range(1, tf.shape(input)[dim] + 1, dtype=input.dtype) 907 | mean = tf.math.cumsum(Xsrt, axis=dim) / rho 908 | mean_sq = tf.math.cumsum(tf.square(Xsrt), axis=dim) / rho 909 | ss = rho * (mean_sq - tf.square(mean)) 910 | delta = (1 - ss) / rho 911 | 912 | delta_nz = tf.maximum(delta, 0) 913 | tau = mean - tf.sqrt(delta_nz) 914 | 915 | support_size = tf.reduce_sum(tf.cast(tau <= Xsrt, tf.int32), axis=dim) 916 | tau_star = tf.gather(tau, support_size - 1, batch_dims=-1) 917 | return tau_star, support_size 918 | 919 | def get_optimizer_by_name(optimizer_name, learning_rate, warmup_steps, cosine_decay_steps): 920 | 921 | 922 | if cosine_decay_steps > 0: 923 | learning_rate = tf.keras.optimizers.schedules.CosineDecayRestarts( 924 | initial_learning_rate=learning_rate, 925 | first_decay_steps=cosine_decay_steps, 926 | #first_decay_steps=steps_per_epoch, 927 | ) 928 | 929 | if optimizer_name== 'SWA' or optimizer_name== 'EMA': 930 | #optimizer = tfa.optimizers.SWA(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=learning_rate), average_period=5) 931 | frequency = 10 932 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, 933 | use_ema = True, 934 | ema_momentum = 1/frequency, 935 | ema_overwrite_frequency = 1 936 | ) 937 | else: 938 | optimizer = tf.keras.optimizers.get(optimizer_name) 939 | optimizer.learning_rate = learning_rate 940 | 941 | return optimizer 942 | 943 | 944 | -------------------------------------------------------------------------------- /GRANDE_minimal_example_with_comparison_REG.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "bf584ce8-8849-4a53-8ce9-2de8f9dd752d", 7 | "metadata": { 8 | "execution": { 9 | "iopub.execute_input": "2024-03-15T11:56:56.253578Z", 10 | "iopub.status.busy": "2024-03-15T11:56:56.253464Z", 11 | "iopub.status.idle": "2024-03-15T11:56:56.265905Z", 12 | "shell.execute_reply": "2024-03-15T11:56:56.265512Z", 13 | "shell.execute_reply.started": "2024-03-15T11:56:56.253566Z" 14 | } 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "#specify GPU to use\n", 19 | "import os\n", 20 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "52b07238-f972-45df-86ce-a11cc11d72e3", 27 | "metadata": { 28 | "execution": { 29 | "iopub.execute_input": "2024-03-15T11:56:56.266511Z", 30 | "iopub.status.busy": "2024-03-15T11:56:56.266392Z", 31 | "iopub.status.idle": "2024-03-15T11:56:57.956110Z", 32 | "shell.execute_reply": "2024-03-15T11:56:57.955562Z", 33 | "shell.execute_reply.started": "2024-03-15T11:56:56.266499Z" 34 | } 35 | }, 36 | "outputs": [ 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "Training set size: 856\n", 42 | "Validation set size: 214\n", 43 | "Test set size: 268\n" 44 | ] 45 | }, 46 | { 47 | "data": { 48 | "text/html": [ 49 | "
\n", 50 | "\n", 63 | "\n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | "
agesexbmichildrensmokerregion
9433male35.7501yessoutheast
81446male19.8550nonorthwest
24628female27.5002nosouthwest
79436female25.9001nosouthwest
123942female36.1951nonorthwest
\n", 123 | "
" 124 | ], 125 | "text/plain": [ 126 | " age sex bmi children smoker region\n", 127 | "94 33 male 35.750 1 yes southeast\n", 128 | "814 46 male 19.855 0 no northwest\n", 129 | "246 28 female 27.500 2 no southwest\n", 130 | "794 36 female 25.900 1 no southwest\n", 131 | "1239 42 female 36.195 1 no northwest" 132 | ] 133 | }, 134 | "execution_count": 2, 135 | "metadata": {}, 136 | "output_type": "execute_result" 137 | } 138 | ], 139 | "source": [ 140 | "from sklearn.model_selection import train_test_split\n", 141 | "import openml\n", 142 | "import category_encoders as ce\n", 143 | "import numpy as np\n", 144 | "import sklearn\n", 145 | "\n", 146 | "# Load healthcare_insurance_expenses dataset\n", 147 | "dataset = openml.datasets.get_dataset(46931, download_data=True, download_qualities=True, download_features_meta_data=True)\n", 148 | "X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n", 149 | "categorical_feature_indices = [idx for idx, idx_bool in enumerate(categorical_indicator) if idx_bool]\n", 150 | "\n", 151 | "X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", 152 | "\n", 153 | "X_train, X_valid, y_train, y_valid = train_test_split(X_temp, y_temp, test_size=0.2, random_state=42)\n", 154 | "\n", 155 | "print(\"Training set size:\", len(X_train))\n", 156 | "print(\"Validation set size:\", len(X_valid))\n", 157 | "print(\"Test set size:\", len(X_test))\n", 158 | "\n", 159 | "X_train.head()" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 3, 165 | "id": "add65bc2", 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "data": { 170 | "text/plain": [ 171 | "[1, 4, 5]" 172 | ] 173 | }, 174 | "execution_count": 3, 175 | "metadata": {}, 176 | "output_type": "execute_result" 177 | } 178 | ], 179 | "source": [ 180 | "categorical_feature_indices" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "id": "bf0a93c2-e9a4-4fa5-a7f6-44689dd07d37", 187 | "metadata": { 188 | "execution": { 189 | "iopub.execute_input": "2024-03-15T11:56:57.957002Z", 190 | "iopub.status.busy": "2024-03-15T11:56:57.956779Z", 191 | "iopub.status.idle": "2024-03-15T11:57:30.600318Z", 192 | "shell.execute_reply": "2024-03-15T11:57:30.599632Z", 193 | "shell.execute_reply.started": "2024-03-15T11:56:57.956987Z" 194 | }, 195 | "scrolled": true 196 | }, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "self.params {'depth': 5, 'n_estimators': 1024, 'learning_rate_weights': 0.001, 'learning_rate_index': 0.01, 'learning_rate_values': 0.05, 'learning_rate_leaf': 0.15, 'learning_rate_embedding': 0.01, 'use_category_embeddings': True, 'embedding_dim_cat': 8, 'use_numeric_embeddings': False, 'embedding_dim_num': 8, 'embedding_threshold': 1, 'loo_cardinality': 10, 'dropout': 0.2, 'selected_variables': 0.8, 'data_subset_fraction': 1.0, 'bootstrap': False, 'missing_values': False, 'optimizer': 'adam', 'cosine_decay_restarts': False, 'reduce_on_plateau_scheduler': True, 'label_smoothing': 0.0, 'use_class_weights': False, 'focal_loss': False, 'swa': False, 'es_metric': True, 'epochs': 250, 'batch_size': 256, 'early_stopping_epochs': 50, 'use_freq_enc': False, 'use_robust_scale_smoothing': False, 'problem_type': 'regression', 'random_seed': 42, 'verbose': 2, 'device': 'cuda:0', 'objective': 'regression'}\n", 203 | "X_train shape before embeddings (856, 6)\n", 204 | "categories [5]\n", 205 | "number_of_variables 1 6\n", 206 | "number_of_variables 2 13\n", 207 | "{'self.column_names_dataframe': ['age',\n", 208 | " 'sex',\n", 209 | " 'bmi',\n", 210 | " 'children',\n", 211 | " 'smoker',\n", 212 | " 'region'],\n", 213 | " 'encoded_columns': ['sex', 'smoker'],\n", 214 | " 'encoded_columns_indices': [1, 4],\n", 215 | " 'num_columns': ['age', 'bmi', 'children'],\n", 216 | " 'num_columns_indices': [0, 2, 3],\n", 217 | " 'categorical_features_raw_indices': [5],\n", 218 | " 'not_encoded_columns': ['age', 'bmi', 'children', 'region']}\n", 219 | "Use Category Embeddings: Embedding(5, 8)\n" 220 | ] 221 | }, 222 | { 223 | "name": "stderr", 224 | "output_type": "stream", 225 | "text": [ 226 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n", 227 | "Online softmax is disabled on the fly since Inductor decides to\n", 228 | "split the reduction. Cut an issue to PyTorch if this is an\n", 229 | "important use case and you want to speed it up with online\n", 230 | "softmax.\n", 231 | "\n", 232 | " warnings.warn(\n", 233 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n", 234 | "Online softmax is disabled on the fly since Inductor decides to\n", 235 | "split the reduction. Cut an issue to PyTorch if this is an\n", 236 | "important use case and you want to speed it up with online\n", 237 | "softmax.\n", 238 | "\n", 239 | " warnings.warn(\n" 240 | ] 241 | }, 242 | { 243 | "name": "stdout", 244 | "output_type": "stream", 245 | "text": [ 246 | "Epoch 001 | TrainLoss: 0.9442 | ValLoss: 0.6983 | MSE: 110805304.0000 | MAE: 7907.7832 | R2: 0.2430 | Time: 6.61s\n", 247 | "Epoch 002 | TrainLoss: 0.6745 | ValLoss: 0.4770 | MSE: 75694136.0000 | MAE: 6047.0703 | R2: 0.4829 | Time: 0.02s\n", 248 | "Epoch 003 | TrainLoss: 0.4664 | ValLoss: 0.3108 | MSE: 49315516.0000 | MAE: 4471.3096 | R2: 0.6631 | Time: 0.02s\n", 249 | "Epoch 004 | TrainLoss: 0.3181 | ValLoss: 0.2167 | MSE: 34391280.0000 | MAE: 4079.7856 | R2: 0.7651 | Time: 0.02s\n", 250 | "Epoch 005 | TrainLoss: 0.2304 | ValLoss: 0.1863 | MSE: 29557700.0000 | MAE: 3613.9580 | R2: 0.7981 | Time: 0.02s\n", 251 | "Epoch 006 | TrainLoss: 0.2101 | ValLoss: 0.1661 | MSE: 26352892.0000 | MAE: 3145.3518 | R2: 0.8200 | Time: 0.02s\n", 252 | "Epoch 007 | TrainLoss: 0.1697 | ValLoss: 0.1398 | MSE: 22176866.0000 | MAE: 2775.7600 | R2: 0.8485 | Time: 0.02s\n", 253 | "Epoch 008 | TrainLoss: 0.1484 | ValLoss: 0.1316 | MSE: 20874780.0000 | MAE: 2580.6177 | R2: 0.8574 | Time: 0.02s\n", 254 | "Epoch 009 | TrainLoss: 0.1400 | ValLoss: 0.1328 | MSE: 21068986.0000 | MAE: 2652.9009 | R2: 0.8561 | Time: 0.02s\n", 255 | "[EarlyStop TRAINING] no improve (1/50). best=20874780.000000, curr=21068986.000000\n", 256 | "Epoch 010 | TrainLoss: 0.1354 | ValLoss: 0.1315 | MSE: 20857652.0000 | MAE: 2709.4941 | R2: 0.8575 | Time: 0.02s\n", 257 | "Epoch 011 | TrainLoss: 0.1323 | ValLoss: 0.1329 | MSE: 21088880.0000 | MAE: 2658.9639 | R2: 0.8559 | Time: 0.02s\n", 258 | "[EarlyStop TRAINING] no improve (1/50). best=20857652.000000, curr=21088880.000000\n", 259 | "Epoch 012 | TrainLoss: 0.1407 | ValLoss: 0.1376 | MSE: 21836154.0000 | MAE: 2449.4062 | R2: 0.8508 | Time: 0.02s\n", 260 | "[EarlyStop TRAINING] no improve (2/50). best=20857652.000000, curr=21836154.000000\n", 261 | "Epoch 013 | TrainLoss: 0.1340 | ValLoss: 0.1359 | MSE: 21563870.0000 | MAE: 2985.1030 | R2: 0.8527 | Time: 0.02s\n", 262 | "[EarlyStop TRAINING] no improve (3/50). best=20857652.000000, curr=21563870.000000\n", 263 | "Epoch 014 | TrainLoss: 0.1363 | ValLoss: 0.1375 | MSE: 21812660.0000 | MAE: 2578.1575 | R2: 0.8510 | Time: 0.02s\n", 264 | "[EarlyStop TRAINING] no improve (4/50). best=20857652.000000, curr=21812660.000000\n", 265 | "Epoch 015 | TrainLoss: 0.1294 | ValLoss: 0.1369 | MSE: 21724846.0000 | MAE: 2606.5918 | R2: 0.8516 | Time: 0.02s\n", 266 | "[EarlyStop TRAINING] no improve (5/50). best=20857652.000000, curr=21724846.000000\n", 267 | "Epoch 016 | TrainLoss: 0.1150 | ValLoss: 0.1372 | MSE: 21774892.0000 | MAE: 2736.3572 | R2: 0.8512 | Time: 0.02s\n", 268 | "[EarlyStop TRAINING] no improve (6/50). best=20857652.000000, curr=21774892.000000\n", 269 | "Epoch 017 | TrainLoss: 0.1263 | ValLoss: 0.1347 | MSE: 21375200.0000 | MAE: 2670.4187 | R2: 0.8540 | Time: 0.02s\n", 270 | "[EarlyStop TRAINING] no improve (7/50). best=20857652.000000, curr=21375200.000000\n", 271 | "Epoch 018 | TrainLoss: 0.1145 | ValLoss: 0.1361 | MSE: 21592756.0000 | MAE: 2709.2681 | R2: 0.8525 | Time: 0.02s\n", 272 | "[EarlyStop TRAINING] no improve (8/50). best=20857652.000000, curr=21592756.000000\n", 273 | "Epoch 019 | TrainLoss: 0.1116 | ValLoss: 0.1443 | MSE: 22900512.0000 | MAE: 2851.0134 | R2: 0.8436 | Time: 0.02s\n", 274 | "[EarlyStop TRAINING] no improve (9/50). best=20857652.000000, curr=22900512.000000\n", 275 | "Epoch 020 | TrainLoss: 0.1201 | ValLoss: 0.1437 | MSE: 22796944.0000 | MAE: 2674.2185 | R2: 0.8443 | Time: 0.02s\n", 276 | "[EarlyStop TRAINING] no improve (10/50). best=20857652.000000, curr=22796944.000000\n", 277 | "Epoch 021 | TrainLoss: 0.1125 | ValLoss: 0.1405 | MSE: 22288538.0000 | MAE: 2823.5095 | R2: 0.8477 | Time: 0.02s\n", 278 | "[EarlyStop TRAINING] no improve (11/50). best=20857652.000000, curr=22288538.000000\n", 279 | "Epoch 022 | TrainLoss: 0.1150 | ValLoss: 0.1425 | MSE: 22610858.0000 | MAE: 2622.5259 | R2: 0.8455 | Time: 0.02s\n", 280 | "[EarlyStop TRAINING] no improve (12/50). best=20857652.000000, curr=22610858.000000\n", 281 | "Epoch 023 | TrainLoss: 0.1120 | ValLoss: 0.1439 | MSE: 22834784.0000 | MAE: 3120.4143 | R2: 0.8440 | Time: 0.02s\n", 282 | "[EarlyStop TRAINING] no improve (13/50). best=20857652.000000, curr=22834784.000000\n", 283 | "Epoch 024 | TrainLoss: 0.1106 | ValLoss: 0.1451 | MSE: 23015768.0000 | MAE: 2632.4998 | R2: 0.8428 | Time: 0.02s\n", 284 | "[EarlyStop TRAINING] no improve (14/50). best=20857652.000000, curr=23015768.000000\n", 285 | "Epoch 025 | TrainLoss: 0.1116 | ValLoss: 0.1432 | MSE: 22729254.0000 | MAE: 3018.4402 | R2: 0.8447 | Time: 0.02s\n", 286 | "[EarlyStop TRAINING] no improve (15/50). best=20857652.000000, curr=22729254.000000\n", 287 | "Epoch 026 | TrainLoss: 0.1055 | ValLoss: 0.1451 | MSE: 23025022.0000 | MAE: 2757.2036 | R2: 0.8427 | Time: 0.02s\n", 288 | "[EarlyStop TRAINING] no improve (16/50). best=20857652.000000, curr=23025022.000000\n", 289 | "Epoch 027 | TrainLoss: 0.1120 | ValLoss: 0.1428 | MSE: 22663312.0000 | MAE: 2849.4290 | R2: 0.8452 | Time: 0.02s\n", 290 | "[EarlyStop TRAINING] no improve (17/50). best=20857652.000000, curr=22663312.000000\n", 291 | "Epoch 028 | TrainLoss: 0.1024 | ValLoss: 0.1442 | MSE: 22888228.0000 | MAE: 2710.1841 | R2: 0.8436 | Time: 0.02s\n", 292 | "[EarlyStop TRAINING] no improve (18/50). best=20857652.000000, curr=22888228.000000\n", 293 | "Epoch 029 | TrainLoss: 0.1029 | ValLoss: 0.1486 | MSE: 23580132.0000 | MAE: 2878.5706 | R2: 0.8389 | Time: 0.02s\n", 294 | "[EarlyStop TRAINING] no improve (19/50). best=20857652.000000, curr=23580132.000000\n", 295 | "Epoch 030 | TrainLoss: 0.0962 | ValLoss: 0.1455 | MSE: 23092736.0000 | MAE: 2784.8010 | R2: 0.8422 | Time: 0.02s\n", 296 | "[EarlyStop TRAINING] no improve (20/50). best=20857652.000000, curr=23092736.000000\n", 297 | "Epoch 031 | TrainLoss: 0.1075 | ValLoss: 0.1473 | MSE: 23373388.0000 | MAE: 2717.4727 | R2: 0.8403 | Time: 0.02s\n", 298 | "[EarlyStop TRAINING] no improve (21/50). best=20857652.000000, curr=23373388.000000\n", 299 | "Epoch 032 | TrainLoss: 0.0999 | ValLoss: 0.1550 | MSE: 24594842.0000 | MAE: 2826.3821 | R2: 0.8320 | Time: 0.02s\n", 300 | "[EarlyStop TRAINING] no improve (22/50). best=20857652.000000, curr=24594842.000000\n", 301 | "Epoch 033 | TrainLoss: 0.1044 | ValLoss: 0.1545 | MSE: 24521254.0000 | MAE: 2981.0029 | R2: 0.8325 | Time: 0.02s\n", 302 | "[EarlyStop TRAINING] no improve (23/50). best=20857652.000000, curr=24521254.000000\n", 303 | "Epoch 034 | TrainLoss: 0.1000 | ValLoss: 0.1488 | MSE: 23607160.0000 | MAE: 2794.5957 | R2: 0.8387 | Time: 0.02s\n", 304 | "[EarlyStop TRAINING] no improve (24/50). best=20857652.000000, curr=23607160.000000\n", 305 | "Epoch 035 | TrainLoss: 0.1003 | ValLoss: 0.1524 | MSE: 24181636.0000 | MAE: 2961.8938 | R2: 0.8348 | Time: 0.02s\n", 306 | "[EarlyStop TRAINING] no improve (25/50). best=20857652.000000, curr=24181636.000000\n", 307 | "Epoch 036 | TrainLoss: 0.0970 | ValLoss: 0.1529 | MSE: 24256992.0000 | MAE: 2983.7617 | R2: 0.8343 | Time: 0.02s\n", 308 | "[EarlyStop TRAINING] no improve (26/50). best=20857652.000000, curr=24256992.000000\n", 309 | "Epoch 037 | TrainLoss: 0.0917 | ValLoss: 0.1549 | MSE: 24575086.0000 | MAE: 2947.2693 | R2: 0.8321 | Time: 0.02s\n", 310 | "[EarlyStop TRAINING] no improve (27/50). best=20857652.000000, curr=24575086.000000\n", 311 | "Epoch 038 | TrainLoss: 0.0932 | ValLoss: 0.1538 | MSE: 24408730.0000 | MAE: 2851.0256 | R2: 0.8333 | Time: 0.02s\n", 312 | "[EarlyStop TRAINING] no improve (28/50). best=20857652.000000, curr=24408730.000000\n", 313 | "Epoch 039 | TrainLoss: 0.0938 | ValLoss: 0.1576 | MSE: 25012860.0000 | MAE: 2975.3767 | R2: 0.8291 | Time: 0.02s\n", 314 | "[EarlyStop TRAINING] no improve (29/50). best=20857652.000000, curr=25012860.000000\n", 315 | "Epoch 040 | TrainLoss: 0.0978 | ValLoss: 0.1578 | MSE: 25035894.0000 | MAE: 3047.7266 | R2: 0.8290 | Time: 0.02s\n", 316 | "[EarlyStop TRAINING] no improve (30/50). best=20857652.000000, curr=25035894.000000\n", 317 | "Epoch 041 | TrainLoss: 0.0941 | ValLoss: 0.1551 | MSE: 24604590.0000 | MAE: 2932.2500 | R2: 0.8319 | Time: 0.02s\n", 318 | "[EarlyStop TRAINING] no improve (31/50). best=20857652.000000, curr=24604590.000000\n", 319 | "Epoch 042 | TrainLoss: 0.0862 | ValLoss: 0.1587 | MSE: 25179846.0000 | MAE: 3095.2544 | R2: 0.8280 | Time: 0.02s\n", 320 | "[EarlyStop TRAINING] no improve (32/50). best=20857652.000000, curr=25179846.000000\n", 321 | "Epoch 043 | TrainLoss: 0.0967 | ValLoss: 0.1597 | MSE: 25341562.0000 | MAE: 3101.8044 | R2: 0.8269 | Time: 0.02s\n", 322 | "[EarlyStop TRAINING] no improve (33/50). best=20857652.000000, curr=25341562.000000\n", 323 | "Epoch 044 | TrainLoss: 0.0841 | ValLoss: 0.1600 | MSE: 25389414.0000 | MAE: 2996.4243 | R2: 0.8266 | Time: 0.02s\n", 324 | "[EarlyStop TRAINING] no improve (34/50). best=20857652.000000, curr=25389414.000000\n", 325 | "Epoch 045 | TrainLoss: 0.0937 | ValLoss: 0.1590 | MSE: 25234082.0000 | MAE: 3053.0994 | R2: 0.8276 | Time: 0.02s\n", 326 | "[EarlyStop TRAINING] no improve (35/50). best=20857652.000000, curr=25234082.000000\n", 327 | "Epoch 046 | TrainLoss: 0.0850 | ValLoss: 0.1643 | MSE: 26070064.0000 | MAE: 3206.8984 | R2: 0.8219 | Time: 0.02s\n", 328 | "[EarlyStop TRAINING] no improve (36/50). best=20857652.000000, curr=26070064.000000\n", 329 | "Epoch 047 | TrainLoss: 0.0836 | ValLoss: 0.1669 | MSE: 26475984.0000 | MAE: 3009.0151 | R2: 0.8191 | Time: 0.02s\n", 330 | "[EarlyStop TRAINING] no improve (37/50). best=20857652.000000, curr=26475984.000000\n", 331 | "Epoch 048 | TrainLoss: 0.0814 | ValLoss: 0.1633 | MSE: 25911006.0000 | MAE: 3283.3823 | R2: 0.8230 | Time: 0.02s\n", 332 | "[EarlyStop TRAINING] no improve (38/50). best=20857652.000000, curr=25911006.000000\n", 333 | "Epoch 049 | TrainLoss: 0.0842 | ValLoss: 0.1626 | MSE: 25801604.0000 | MAE: 2987.0801 | R2: 0.8237 | Time: 0.02s\n", 334 | "[EarlyStop TRAINING] no improve (39/50). best=20857652.000000, curr=25801604.000000\n", 335 | "Epoch 050 | TrainLoss: 0.0828 | ValLoss: 0.1612 | MSE: 25575234.0000 | MAE: 3190.6045 | R2: 0.8253 | Time: 0.02s\n", 336 | "[EarlyStop TRAINING] no improve (40/50). best=20857652.000000, curr=25575234.000000\n", 337 | "Epoch 051 | TrainLoss: 0.0822 | ValLoss: 0.1709 | MSE: 27122086.0000 | MAE: 3196.4832 | R2: 0.8147 | Time: 0.02s\n", 338 | "[EarlyStop TRAINING] no improve (41/50). best=20857652.000000, curr=27122086.000000\n", 339 | "Epoch 052 | TrainLoss: 0.0838 | ValLoss: 0.1670 | MSE: 26501512.0000 | MAE: 3267.1182 | R2: 0.8190 | Time: 0.02s\n", 340 | "[EarlyStop TRAINING] no improve (42/50). best=20857652.000000, curr=26501512.000000\n", 341 | "Epoch 053 | TrainLoss: 0.0850 | ValLoss: 0.1626 | MSE: 25795368.0000 | MAE: 3034.9268 | R2: 0.8238 | Time: 0.02s\n", 342 | "[EarlyStop TRAINING] no improve (43/50). best=20857652.000000, curr=25795368.000000\n", 343 | "Epoch 054 | TrainLoss: 0.0825 | ValLoss: 0.1644 | MSE: 26090838.0000 | MAE: 3215.0332 | R2: 0.8218 | Time: 0.02s\n", 344 | "[EarlyStop TRAINING] no improve (44/50). best=20857652.000000, curr=26090838.000000\n", 345 | "Epoch 055 | TrainLoss: 0.0807 | ValLoss: 0.1652 | MSE: 26218004.0000 | MAE: 3167.4287 | R2: 0.8209 | Time: 0.02s\n", 346 | "[EarlyStop TRAINING] no improve (45/50). best=20857652.000000, curr=26218004.000000\n", 347 | "Epoch 056 | TrainLoss: 0.0816 | ValLoss: 0.1737 | MSE: 27564982.0000 | MAE: 3368.5759 | R2: 0.8117 | Time: 0.02s\n", 348 | "[EarlyStop TRAINING] no improve (46/50). best=20857652.000000, curr=27564982.000000\n", 349 | "Epoch 057 | TrainLoss: 0.0711 | ValLoss: 0.1764 | MSE: 27982072.0000 | MAE: 3213.2036 | R2: 0.8088 | Time: 0.02s\n", 350 | "[EarlyStop TRAINING] no improve (47/50). best=20857652.000000, curr=27982072.000000\n", 351 | "Epoch 058 | TrainLoss: 0.0840 | ValLoss: 0.1687 | MSE: 26770818.0000 | MAE: 3231.4912 | R2: 0.8171 | Time: 0.02s\n", 352 | "[EarlyStop TRAINING] no improve (48/50). best=20857652.000000, curr=26770818.000000\n", 353 | "Epoch 059 | TrainLoss: 0.0809 | ValLoss: 0.1716 | MSE: 27229828.0000 | MAE: 3382.3977 | R2: 0.8140 | Time: 0.02s\n", 354 | "[EarlyStop TRAINING] no improve (49/50). best=20857652.000000, curr=27229828.000000\n", 355 | "Epoch 060 | TrainLoss: 0.0767 | ValLoss: 0.1727 | MSE: 27396762.0000 | MAE: 3288.3879 | R2: 0.8128 | Time: 0.02s\n", 356 | "[EarlyStop TRAINING] no improve (50/50). best=20857652.000000, curr=27396762.000000\n", 357 | "[EarlyStop TRAINING] restoring best weights and stopping.\n", 358 | "Restoring best model from epoch with val score: 20857652.0\n" 359 | ] 360 | }, 361 | { 362 | "name": "stderr", 363 | "output_type": "stream", 364 | "text": [ 365 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n", 366 | "Online softmax is disabled on the fly since Inductor decides to\n", 367 | "split the reduction. Cut an issue to PyTorch if this is an\n", 368 | "important use case and you want to speed it up with online\n", 369 | "softmax.\n", 370 | "\n", 371 | " warnings.warn(\n", 372 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n", 373 | "Online softmax is disabled on the fly since Inductor decides to\n", 374 | "split the reduction. Cut an issue to PyTorch if this is an\n", 375 | "important use case and you want to speed it up with online\n", 376 | "softmax.\n", 377 | "\n", 378 | " warnings.warn(\n" 379 | ] 380 | } 381 | ], 382 | "source": [ 383 | "from GRANDE import GRANDE\n", 384 | "\n", 385 | "params = {\n", 386 | " 'depth': 5,\n", 387 | " 'n_estimators': 1024,\n", 388 | " \n", 389 | " 'learning_rate_weights': 0.001,\n", 390 | " 'learning_rate_index': 0.01,\n", 391 | " 'learning_rate_values': 0.05,\n", 392 | " 'learning_rate_leaf': 0.15,\n", 393 | " 'learning_rate_embedding': 0.01,\n", 394 | "\n", 395 | " 'use_category_embeddings': True,\n", 396 | " 'embedding_dim_cat': 8,\n", 397 | " 'use_numeric_embeddings': False,\n", 398 | " 'embedding_dim_num': 8,\n", 399 | " 'embedding_threshold': 1,\n", 400 | " 'loo_cardinality': 10,\n", 401 | "\n", 402 | "\n", 403 | " 'dropout': 0.2,\n", 404 | " 'selected_variables': 0.8,\n", 405 | " 'data_subset_fraction': 1.0,\n", 406 | " 'bootstrap': False,\n", 407 | " 'missing_values': False,\n", 408 | "\n", 409 | " 'optimizer': 'adam', #nadam, radam, adamw, adam \n", 410 | " 'cosine_decay_restarts': False,\n", 411 | " 'reduce_on_plateau_scheduler': True,\n", 412 | " 'label_smoothing': 0.0,\n", 413 | " 'use_class_weights': False,\n", 414 | " 'focal_loss': False,\n", 415 | " 'swa': False,\n", 416 | " 'es_metric': True, # if True use AUC for binary, MSE for regression, val_loss for multiclass\n", 417 | "\n", 418 | "\n", 419 | " 'epochs': 250,\n", 420 | " 'batch_size': 256,\n", 421 | " 'early_stopping_epochs': 50,\n", 422 | "\n", 423 | " 'use_freq_enc': False,\n", 424 | " 'use_robust_scale_smoothing': False,\n", 425 | " 'problem_type': 'regression',\n", 426 | " \n", 427 | " 'random_seed': 42,\n", 428 | " 'verbose': 2,\n", 429 | "}\n", 430 | "\n", 431 | "model_grande = GRANDE(params=params)\n", 432 | "\n", 433 | "model_grande.fit(X=X_train,\n", 434 | " y=y_train,\n", 435 | " X_val=X_valid,\n", 436 | " y_val=y_valid)\n", 437 | "\n", 438 | "preds_grande = model_grande.predict_proba(X_test)" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": 5, 444 | "id": "b251df9d-67f0-4dd1-a0cc-379621e33fad", 445 | "metadata": { 446 | "execution": { 447 | "iopub.execute_input": "2024-03-15T11:57:30.601353Z", 448 | "iopub.status.busy": "2024-03-15T11:57:30.601004Z", 449 | "iopub.status.idle": "2024-03-15T11:57:30.604147Z", 450 | "shell.execute_reply": "2024-03-15T11:57:30.603752Z", 451 | "shell.execute_reply.started": "2024-03-15T11:57:30.601338Z" 452 | } 453 | }, 454 | "outputs": [], 455 | "source": [ 456 | "def calculate_sample_weights(y_data):\n", 457 | " class_weights = sklearn.utils.class_weight.compute_class_weight(class_weight = 'balanced', classes = np.unique(y_data), y = y_data)\n", 458 | " sample_weights = sklearn.utils.class_weight.compute_sample_weight(class_weight = 'balanced', y =y_data)\n", 459 | " return sample_weights" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": 6, 465 | "id": "c17e805c-0110-4e39-97d3-34573804296b", 466 | "metadata": { 467 | "execution": { 468 | "iopub.execute_input": "2024-03-15T11:57:30.605592Z", 469 | "iopub.status.busy": "2024-03-15T11:57:30.605338Z", 470 | "iopub.status.idle": "2024-03-15T11:57:30.611465Z", 471 | "shell.execute_reply": "2024-03-15T11:57:30.611074Z", 472 | "shell.execute_reply.started": "2024-03-15T11:57:30.605578Z" 473 | } 474 | }, 475 | "outputs": [], 476 | "source": [ 477 | "try:\n", 478 | " y_train = y_train.values.codes.astype(np.float64)\n", 479 | " y_valid = y_valid.values.codes.astype(np.float64)\n", 480 | " y_test = y_test.values.codes.astype(np.float64)\n", 481 | "except:\n", 482 | " y_train = y_train.values.astype(np.float64)\n", 483 | " y_valid = y_valid.values.astype(np.float64)\n", 484 | " y_test = y_test.values.astype(np.float64)" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 7, 490 | "id": "cfc121e3-2d7b-4fb1-8616-0318fcc20ac1", 491 | "metadata": { 492 | "execution": { 493 | "iopub.execute_input": "2024-03-15T11:57:30.612164Z", 494 | "iopub.status.busy": "2024-03-15T11:57:30.611960Z", 495 | "iopub.status.idle": "2024-03-15T11:57:30.627515Z", 496 | "shell.execute_reply": "2024-03-15T11:57:30.627108Z", 497 | "shell.execute_reply.started": "2024-03-15T11:57:30.612151Z" 498 | } 499 | }, 500 | "outputs": [], 501 | "source": [ 502 | "binary_indices = []\n", 503 | "low_cardinality_indices = []\n", 504 | "high_cardinality_indices = []\n", 505 | "num_columns = []\n", 506 | "for column_index, column in enumerate(X_train.columns):\n", 507 | " if column_index in categorical_feature_indices:\n", 508 | " if len(X_train.iloc[:,column_index].unique()) <= 2:\n", 509 | " binary_indices.append(column)\n", 510 | " if len(X_train.iloc[:,column_index].unique()) < 5:\n", 511 | " low_cardinality_indices.append(column)\n", 512 | " else:\n", 513 | " high_cardinality_indices.append(column)\n", 514 | " else:\n", 515 | " num_columns.append(column) \n", 516 | "cat_columns = [col for col in X_train.columns if col not in num_columns]" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": 8, 522 | "id": "f7e55785-bf56-4aad-9825-7c5737be5ab7", 523 | "metadata": { 524 | "execution": { 525 | "iopub.execute_input": "2024-03-15T11:57:30.628204Z", 526 | "iopub.status.busy": "2024-03-15T11:57:30.627992Z", 527 | "iopub.status.idle": "2024-03-15T11:57:30.691371Z", 528 | "shell.execute_reply": "2024-03-15T11:57:30.690745Z", 529 | "shell.execute_reply.started": "2024-03-15T11:57:30.628191Z" 530 | } 531 | }, 532 | "outputs": [], 533 | "source": [ 534 | "if len(num_columns) > 0:\n", 535 | " mean_train_num = X_train[num_columns].mean(axis=0).iloc[0]\n", 536 | " X_train[num_columns] = X_train[num_columns].fillna(mean_train_num)\n", 537 | " X_valid[num_columns] = X_valid[num_columns].fillna(mean_train_num)\n", 538 | " X_test[num_columns] = X_test[num_columns].fillna(mean_train_num)\n", 539 | "if len(cat_columns) > 0:\n", 540 | " mode_train_cat = X_train[cat_columns].mode(axis=0).iloc[0]\n", 541 | " X_train[cat_columns] = X_train[cat_columns].fillna(mode_train_cat)\n", 542 | " X_valid[cat_columns] = X_valid[cat_columns].fillna(mode_train_cat)\n", 543 | " X_test[cat_columns] = X_test[cat_columns].fillna(mode_train_cat)\n", 544 | "\n", 545 | "X_train_raw = X_train.copy()\n", 546 | "X_valid_raw = X_valid.copy()\n", 547 | "X_test_raw = X_test.copy()" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": 9, 553 | "id": "b5d5f00f", 554 | "metadata": { 555 | "execution": { 556 | "iopub.execute_input": "2024-03-15T11:57:30.692101Z", 557 | "iopub.status.busy": "2024-03-15T11:57:30.691967Z", 558 | "iopub.status.idle": "2024-03-15T11:57:31.276758Z", 559 | "shell.execute_reply": "2024-03-15T11:57:31.276086Z", 560 | "shell.execute_reply.started": "2024-03-15T11:57:30.692089Z" 561 | } 562 | }, 563 | "outputs": [], 564 | "source": [ 565 | "encoder_ordinal = ce.OrdinalEncoder(cols=binary_indices)\n", 566 | "encoder_ordinal.fit(X_train)\n", 567 | "X_train = encoder_ordinal.transform(X_train)\n", 568 | "X_valid = encoder_ordinal.transform(X_valid) \n", 569 | "X_test = encoder_ordinal.transform(X_test) \n", 570 | "\n", 571 | "encoder = ce.LeaveOneOutEncoder(cols=high_cardinality_indices)\n", 572 | "encoder.fit(X_train, y_train)\n", 573 | "X_train = encoder.transform(X_train)\n", 574 | "X_valid = encoder.transform(X_valid)\n", 575 | "X_test = encoder.transform(X_test)\n", 576 | "\n", 577 | "encoder = ce.OneHotEncoder(cols=low_cardinality_indices)\n", 578 | "encoder.fit(X_train)\n", 579 | "X_train = encoder.transform(X_train)\n", 580 | "X_valid = encoder.transform(X_valid)\n", 581 | "X_test = encoder.transform(X_test)\n", 582 | "\n", 583 | "X_train = X_train.astype(np.float32)\n", 584 | "X_valid = X_valid.astype(np.float32)\n", 585 | "X_test = X_test.astype(np.float32)" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": 10, 591 | "id": "361f7d9f-7e58-4bc1-8855-153b9e972742", 592 | "metadata": { 593 | "execution": { 594 | "iopub.execute_input": "2024-03-15T11:57:31.277594Z", 595 | "iopub.status.busy": "2024-03-15T11:57:31.277439Z", 596 | "iopub.status.idle": "2024-03-15T11:57:31.575860Z", 597 | "shell.execute_reply": "2024-03-15T11:57:31.575224Z", 598 | "shell.execute_reply.started": "2024-03-15T11:57:31.277581Z" 599 | } 600 | }, 601 | "outputs": [ 602 | { 603 | "name": "stdout", 604 | "output_type": "stream", 605 | "text": [ 606 | "[0]\tvalidation_0-rmse:9117.37519\n", 607 | "[1]\tvalidation_0-rmse:7182.91726\n", 608 | "[2]\tvalidation_0-rmse:6008.67928\n", 609 | "[3]\tvalidation_0-rmse:5384.74130\n", 610 | "[4]\tvalidation_0-rmse:5041.28295\n", 611 | "[5]\tvalidation_0-rmse:4821.53449\n", 612 | "[6]\tvalidation_0-rmse:4714.24431\n", 613 | "[7]\tvalidation_0-rmse:4660.87726\n", 614 | "[8]\tvalidation_0-rmse:4676.10315\n", 615 | "[9]\tvalidation_0-rmse:4700.87577\n", 616 | "[10]\tvalidation_0-rmse:4677.24837\n", 617 | "[11]\tvalidation_0-rmse:4667.92836\n", 618 | "[12]\tvalidation_0-rmse:4681.76173\n", 619 | "[13]\tvalidation_0-rmse:4687.39543\n", 620 | "[14]\tvalidation_0-rmse:4754.37928\n", 621 | "[15]\tvalidation_0-rmse:4738.09888\n", 622 | "[16]\tvalidation_0-rmse:4735.64695\n", 623 | "[17]\tvalidation_0-rmse:4751.13172\n", 624 | "[18]\tvalidation_0-rmse:4773.50335\n", 625 | "[19]\tvalidation_0-rmse:4790.19284\n", 626 | "[20]\tvalidation_0-rmse:4773.47564\n", 627 | "[21]\tvalidation_0-rmse:4774.69551\n", 628 | "[22]\tvalidation_0-rmse:4778.65167\n", 629 | "[23]\tvalidation_0-rmse:4775.53956\n", 630 | "[24]\tvalidation_0-rmse:4781.68505\n", 631 | "[25]\tvalidation_0-rmse:4795.20388\n", 632 | "[26]\tvalidation_0-rmse:4829.21788\n", 633 | "[27]\tvalidation_0-rmse:4837.16893\n" 634 | ] 635 | } 636 | ], 637 | "source": [ 638 | "if params['problem_type'] == 'regression':\n", 639 | " from xgboost import XGBRegressor\n", 640 | " model_xgb = XGBRegressor(n_estimators=1000, early_stopping_rounds=20)\n", 641 | " model_xgb.fit(X_train, \n", 642 | " y_train, \n", 643 | " eval_set=[(X_valid, y_valid)], \n", 644 | " )\n", 645 | " preds_xgb = model_xgb.predict(X_test)\n", 646 | "else:\n", 647 | " from xgboost import XGBClassifier\n", 648 | " model_xgb = XGBClassifier(n_estimators=1000, early_stopping_rounds=20)\n", 649 | " model_xgb.fit(X_train, \n", 650 | " y_train, \n", 651 | " #sample_weight=calculate_sample_weights(y_train), \n", 652 | " eval_set=[(X_valid, y_valid)], \n", 653 | " #sample_weight_eval_set=[calculate_sample_weights(y_valid)]\n", 654 | " )\n", 655 | "\n", 656 | "\n", 657 | " preds_xgb = model_xgb.predict_proba(X_test)" 658 | ] 659 | }, 660 | { 661 | "cell_type": "code", 662 | "execution_count": 11, 663 | "id": "6b23c67e-cdde-40c2-842a-61383e1a62d0", 664 | "metadata": { 665 | "execution": { 666 | "iopub.execute_input": "2024-03-15T11:57:31.576940Z", 667 | "iopub.status.busy": "2024-03-15T11:57:31.576558Z", 668 | "iopub.status.idle": "2024-03-15T11:57:52.427379Z", 669 | "shell.execute_reply": "2024-03-15T11:57:52.426427Z", 670 | "shell.execute_reply.started": "2024-03-15T11:57:31.576923Z" 671 | } 672 | }, 673 | "outputs": [ 674 | { 675 | "name": "stdout", 676 | "output_type": "stream", 677 | "text": [ 678 | "Learning rate set to 0.049693\n", 679 | "0:\tlearn: 12135.8956932\ttest: 11637.9908345\tbest: 11637.9908345 (0)\ttotal: 48.8ms\tremaining: 48.8s\n", 680 | "1:\tlearn: 11663.8838516\ttest: 11175.5245901\tbest: 11175.5245901 (1)\ttotal: 50.8ms\tremaining: 25.4s\n", 681 | "2:\tlearn: 11273.2960222\ttest: 10835.8030762\tbest: 10835.8030762 (2)\ttotal: 52.5ms\tremaining: 17.5s\n", 682 | "3:\tlearn: 10889.3829166\ttest: 10474.9396901\tbest: 10474.9396901 (3)\ttotal: 53.4ms\tremaining: 13.3s\n", 683 | "4:\tlearn: 10532.8002095\ttest: 10132.1552031\tbest: 10132.1552031 (4)\ttotal: 55ms\tremaining: 10.9s\n", 684 | "5:\tlearn: 10196.7503986\ttest: 9816.3979358\tbest: 9816.3979358 (5)\ttotal: 56.6ms\tremaining: 9.38s\n", 685 | "6:\tlearn: 9836.9135106\ttest: 9465.9962973\tbest: 9465.9962973 (6)\ttotal: 58.7ms\tremaining: 8.32s\n", 686 | "7:\tlearn: 9523.0102491\ttest: 9176.2224294\tbest: 9176.2224294 (7)\ttotal: 59.7ms\tremaining: 7.4s\n", 687 | "8:\tlearn: 9234.9423371\ttest: 8905.7040385\tbest: 8905.7040385 (8)\ttotal: 60.2ms\tremaining: 6.63s\n", 688 | "9:\tlearn: 8946.3797800\ttest: 8636.1173694\tbest: 8636.1173694 (9)\ttotal: 61ms\tremaining: 6.04s\n", 689 | "10:\tlearn: 8693.7650289\ttest: 8391.1433126\tbest: 8391.1433126 (10)\ttotal: 61.8ms\tremaining: 5.56s\n", 690 | "11:\tlearn: 8430.3329885\ttest: 8151.9293896\tbest: 8151.9293896 (11)\ttotal: 62.5ms\tremaining: 5.14s\n", 691 | "12:\tlearn: 8197.7535505\ttest: 7937.7069961\tbest: 7937.7069961 (12)\ttotal: 63.8ms\tremaining: 4.84s\n", 692 | "13:\tlearn: 7985.5741708\ttest: 7728.4296985\tbest: 7728.4296985 (13)\ttotal: 64.8ms\tremaining: 4.57s\n", 693 | "14:\tlearn: 7806.9133166\ttest: 7562.0935256\tbest: 7562.0935256 (14)\ttotal: 65.3ms\tremaining: 4.29s\n", 694 | "15:\tlearn: 7651.9441909\ttest: 7419.6808464\tbest: 7419.6808464 (15)\ttotal: 65.5ms\tremaining: 4.03s\n", 695 | "16:\tlearn: 7446.1069448\ttest: 7236.7268464\tbest: 7236.7268464 (16)\ttotal: 66.2ms\tremaining: 3.83s\n", 696 | "17:\tlearn: 7274.6704815\ttest: 7064.1042181\tbest: 7064.1042181 (17)\ttotal: 66.9ms\tremaining: 3.65s\n", 697 | "18:\tlearn: 7090.3414827\ttest: 6880.2568747\tbest: 6880.2568747 (18)\ttotal: 67.5ms\tremaining: 3.48s\n", 698 | "19:\tlearn: 6935.8494326\ttest: 6736.2596571\tbest: 6736.2596571 (19)\ttotal: 68.7ms\tremaining: 3.37s\n", 699 | "20:\tlearn: 6799.9531942\ttest: 6603.7128076\tbest: 6603.7128076 (20)\ttotal: 69ms\tremaining: 3.22s\n", 700 | "21:\tlearn: 6682.0661648\ttest: 6498.8882793\tbest: 6498.8882793 (21)\ttotal: 69.5ms\tremaining: 3.09s\n", 701 | "22:\tlearn: 6538.5362251\ttest: 6367.2327244\tbest: 6367.2327244 (22)\ttotal: 70.2ms\tremaining: 2.98s\n", 702 | "23:\tlearn: 6413.3233042\ttest: 6249.6403121\tbest: 6249.6403121 (23)\ttotal: 70.8ms\tremaining: 2.88s\n", 703 | "24:\tlearn: 6289.2337797\ttest: 6136.6353585\tbest: 6136.6353585 (24)\ttotal: 71.7ms\tremaining: 2.79s\n", 704 | "25:\tlearn: 6166.4909063\ttest: 6026.3136729\tbest: 6026.3136729 (25)\ttotal: 72.4ms\tremaining: 2.71s\n", 705 | "26:\tlearn: 6056.6980556\ttest: 5923.1889379\tbest: 5923.1889379 (26)\ttotal: 72.9ms\tremaining: 2.63s\n", 706 | "27:\tlearn: 5955.0594430\ttest: 5826.9385445\tbest: 5826.9385445 (27)\ttotal: 73.5ms\tremaining: 2.55s\n", 707 | "28:\tlearn: 5860.7535599\ttest: 5738.5303225\tbest: 5738.5303225 (28)\ttotal: 74.2ms\tremaining: 2.48s\n", 708 | "29:\tlearn: 5774.8792282\ttest: 5658.1048594\tbest: 5658.1048594 (29)\ttotal: 75.5ms\tremaining: 2.44s\n", 709 | "30:\tlearn: 5700.9796461\ttest: 5591.5205624\tbest: 5591.5205624 (30)\ttotal: 76.1ms\tremaining: 2.38s\n", 710 | "31:\tlearn: 5622.6454958\ttest: 5518.4609745\tbest: 5518.4609745 (31)\ttotal: 76.8ms\tremaining: 2.32s\n", 711 | "32:\tlearn: 5546.7784980\ttest: 5456.5329749\tbest: 5456.5329749 (32)\ttotal: 78.2ms\tremaining: 2.29s\n", 712 | "33:\tlearn: 5480.4837742\ttest: 5393.4714671\tbest: 5393.4714671 (33)\ttotal: 78.8ms\tremaining: 2.24s\n", 713 | "34:\tlearn: 5426.0583102\ttest: 5346.8814621\tbest: 5346.8814621 (34)\ttotal: 80ms\tremaining: 2.21s\n", 714 | "35:\tlearn: 5364.8177436\ttest: 5285.3885735\tbest: 5285.3885735 (35)\ttotal: 81.6ms\tremaining: 2.18s\n", 715 | "36:\tlearn: 5317.6277425\ttest: 5245.2405633\tbest: 5245.2405633 (36)\ttotal: 81.9ms\tremaining: 2.13s\n", 716 | "37:\tlearn: 5273.9836551\ttest: 5199.5748939\tbest: 5199.5748939 (37)\ttotal: 82.2ms\tremaining: 2.08s\n", 717 | "38:\tlearn: 5226.4476525\ttest: 5152.4112228\tbest: 5152.4112228 (38)\ttotal: 83.5ms\tremaining: 2.06s\n", 718 | "39:\tlearn: 5187.3186203\ttest: 5121.6387926\tbest: 5121.6387926 (39)\ttotal: 84.3ms\tremaining: 2.02s\n", 719 | "40:\tlearn: 5148.1007653\ttest: 5091.1530033\tbest: 5091.1530033 (40)\ttotal: 84.9ms\tremaining: 1.99s\n", 720 | "41:\tlearn: 5118.4574272\ttest: 5065.1781067\tbest: 5065.1781067 (41)\ttotal: 85.1ms\tremaining: 1.94s\n", 721 | "42:\tlearn: 5082.0201202\ttest: 5034.1461523\tbest: 5034.1461523 (42)\ttotal: 86ms\tremaining: 1.91s\n", 722 | "43:\tlearn: 5046.9010507\ttest: 5001.7382661\tbest: 5001.7382661 (43)\ttotal: 86.6ms\tremaining: 1.88s\n", 723 | "44:\tlearn: 5014.4682607\ttest: 4967.7026153\tbest: 4967.7026153 (44)\ttotal: 87.4ms\tremaining: 1.85s\n", 724 | "45:\tlearn: 4984.8106676\ttest: 4937.7427381\tbest: 4937.7427381 (45)\ttotal: 88.1ms\tremaining: 1.83s\n", 725 | "46:\tlearn: 4956.7529140\ttest: 4911.5675003\tbest: 4911.5675003 (46)\ttotal: 88.6ms\tremaining: 1.8s\n", 726 | "47:\tlearn: 4925.9004687\ttest: 4885.6560194\tbest: 4885.6560194 (47)\ttotal: 89.4ms\tremaining: 1.77s\n", 727 | "48:\tlearn: 4897.7500534\ttest: 4862.7024286\tbest: 4862.7024286 (48)\ttotal: 90.2ms\tremaining: 1.75s\n", 728 | "49:\tlearn: 4876.4231527\ttest: 4844.1445645\tbest: 4844.1445645 (49)\ttotal: 90.9ms\tremaining: 1.73s\n", 729 | "50:\tlearn: 4850.3335001\ttest: 4820.7833775\tbest: 4820.7833775 (50)\ttotal: 91.8ms\tremaining: 1.71s\n", 730 | "51:\tlearn: 4830.3468948\ttest: 4801.6432589\tbest: 4801.6432589 (51)\ttotal: 92.7ms\tremaining: 1.69s\n", 731 | "52:\tlearn: 4804.6721278\ttest: 4781.7853353\tbest: 4781.7853353 (52)\ttotal: 93.4ms\tremaining: 1.67s\n", 732 | "53:\tlearn: 4785.7884488\ttest: 4765.1396970\tbest: 4765.1396970 (53)\ttotal: 94.1ms\tremaining: 1.65s\n", 733 | "54:\tlearn: 4764.6931172\ttest: 4745.1402240\tbest: 4745.1402240 (54)\ttotal: 95ms\tremaining: 1.63s\n", 734 | "55:\tlearn: 4745.0014477\ttest: 4731.9202365\tbest: 4731.9202365 (55)\ttotal: 95.8ms\tremaining: 1.61s\n", 735 | "56:\tlearn: 4733.6924249\ttest: 4721.4346631\tbest: 4721.4346631 (56)\ttotal: 96.7ms\tremaining: 1.6s\n", 736 | "57:\tlearn: 4718.4285051\ttest: 4706.8977435\tbest: 4706.8977435 (57)\ttotal: 97.5ms\tremaining: 1.58s\n", 737 | "58:\tlearn: 4708.3924650\ttest: 4700.5360220\tbest: 4700.5360220 (58)\ttotal: 98ms\tremaining: 1.56s\n", 738 | "59:\tlearn: 4697.4051340\ttest: 4688.5274928\tbest: 4688.5274928 (59)\ttotal: 98.6ms\tremaining: 1.54s\n", 739 | "60:\tlearn: 4687.6765171\ttest: 4675.1671119\tbest: 4675.1671119 (60)\ttotal: 99.3ms\tremaining: 1.53s\n", 740 | "61:\tlearn: 4673.5371214\ttest: 4660.3489143\tbest: 4660.3489143 (61)\ttotal: 100ms\tremaining: 1.51s\n", 741 | "62:\tlearn: 4664.5705858\ttest: 4650.6639993\tbest: 4650.6639993 (62)\ttotal: 101ms\tremaining: 1.5s\n", 742 | "63:\tlearn: 4655.5121100\ttest: 4642.6593428\tbest: 4642.6593428 (63)\ttotal: 101ms\tremaining: 1.48s\n", 743 | "64:\tlearn: 4640.6477105\ttest: 4632.1934679\tbest: 4632.1934679 (64)\ttotal: 103ms\tremaining: 1.47s\n", 744 | "65:\tlearn: 4631.2775709\ttest: 4628.6864938\tbest: 4628.6864938 (65)\ttotal: 103ms\tremaining: 1.46s\n", 745 | "66:\tlearn: 4617.5614730\ttest: 4620.7160823\tbest: 4620.7160823 (66)\ttotal: 104ms\tremaining: 1.45s\n", 746 | "67:\tlearn: 4609.6661970\ttest: 4615.9897070\tbest: 4615.9897070 (67)\ttotal: 105ms\tremaining: 1.44s\n", 747 | "68:\tlearn: 4594.9269326\ttest: 4607.7160626\tbest: 4607.7160626 (68)\ttotal: 107ms\tremaining: 1.44s\n", 748 | "69:\tlearn: 4586.4000778\ttest: 4599.2989263\tbest: 4599.2989263 (69)\ttotal: 107ms\tremaining: 1.42s\n", 749 | "70:\tlearn: 4584.2924168\ttest: 4595.2372445\tbest: 4595.2372445 (70)\ttotal: 108ms\tremaining: 1.41s\n", 750 | "71:\tlearn: 4580.6310982\ttest: 4590.8254827\tbest: 4590.8254827 (71)\ttotal: 108ms\tremaining: 1.4s\n", 751 | "72:\tlearn: 4573.0512085\ttest: 4584.5655906\tbest: 4584.5655906 (72)\ttotal: 109ms\tremaining: 1.39s\n", 752 | "73:\tlearn: 4564.9022767\ttest: 4582.2014024\tbest: 4582.2014024 (73)\ttotal: 110ms\tremaining: 1.38s\n", 753 | "74:\tlearn: 4554.9704480\ttest: 4575.4565012\tbest: 4575.4565012 (74)\ttotal: 111ms\tremaining: 1.37s\n", 754 | "75:\tlearn: 4550.6381327\ttest: 4575.0104562\tbest: 4575.0104562 (75)\ttotal: 111ms\tremaining: 1.35s\n", 755 | "76:\tlearn: 4543.7447385\ttest: 4574.9770248\tbest: 4574.9770248 (76)\ttotal: 112ms\tremaining: 1.34s\n", 756 | "77:\tlearn: 4541.5607817\ttest: 4571.2443289\tbest: 4571.2443289 (77)\ttotal: 113ms\tremaining: 1.33s\n", 757 | "78:\tlearn: 4535.4912581\ttest: 4564.2435780\tbest: 4564.2435780 (78)\ttotal: 113ms\tremaining: 1.32s\n", 758 | "79:\tlearn: 4531.0415122\ttest: 4559.8685708\tbest: 4559.8685708 (79)\ttotal: 114ms\tremaining: 1.31s\n", 759 | "80:\tlearn: 4527.2248414\ttest: 4555.9682281\tbest: 4555.9682281 (80)\ttotal: 115ms\tremaining: 1.3s\n", 760 | "81:\tlearn: 4519.2937423\ttest: 4549.6245735\tbest: 4549.6245735 (81)\ttotal: 116ms\tremaining: 1.3s\n", 761 | "82:\tlearn: 4515.3573293\ttest: 4545.7439628\tbest: 4545.7439628 (82)\ttotal: 117ms\tremaining: 1.29s\n", 762 | "83:\tlearn: 4510.5792299\ttest: 4540.8327696\tbest: 4540.8327696 (83)\ttotal: 117ms\tremaining: 1.28s\n", 763 | "84:\tlearn: 4506.0538270\ttest: 4541.5568665\tbest: 4540.8327696 (83)\ttotal: 118ms\tremaining: 1.27s\n", 764 | "85:\tlearn: 4503.1630183\ttest: 4537.6270336\tbest: 4537.6270336 (85)\ttotal: 119ms\tremaining: 1.26s\n", 765 | "86:\tlearn: 4496.0911875\ttest: 4531.5680559\tbest: 4531.5680559 (86)\ttotal: 120ms\tremaining: 1.25s\n", 766 | "87:\tlearn: 4492.2828701\ttest: 4531.0761587\tbest: 4531.0761587 (87)\ttotal: 120ms\tremaining: 1.24s\n", 767 | "88:\tlearn: 4488.9621285\ttest: 4528.3046536\tbest: 4528.3046536 (88)\ttotal: 121ms\tremaining: 1.24s\n", 768 | "89:\tlearn: 4488.8701209\ttest: 4528.2794489\tbest: 4528.2794489 (89)\ttotal: 121ms\tremaining: 1.22s\n", 769 | "90:\tlearn: 4485.2112888\ttest: 4524.9720945\tbest: 4524.9720945 (90)\ttotal: 122ms\tremaining: 1.22s\n", 770 | "91:\tlearn: 4481.2354323\ttest: 4521.7897667\tbest: 4521.7897667 (91)\ttotal: 123ms\tremaining: 1.21s\n", 771 | "92:\tlearn: 4474.3205123\ttest: 4517.9420374\tbest: 4517.9420374 (92)\ttotal: 124ms\tremaining: 1.21s\n", 772 | "93:\tlearn: 4466.7718387\ttest: 4514.6783108\tbest: 4514.6783108 (93)\ttotal: 124ms\tremaining: 1.2s\n", 773 | "94:\tlearn: 4466.1435996\ttest: 4514.5008349\tbest: 4514.5008349 (94)\ttotal: 125ms\tremaining: 1.19s\n", 774 | "95:\tlearn: 4464.3661156\ttest: 4514.0741901\tbest: 4514.0741901 (95)\ttotal: 125ms\tremaining: 1.18s\n", 775 | "96:\tlearn: 4461.7752866\ttest: 4510.9960879\tbest: 4510.9960879 (96)\ttotal: 126ms\tremaining: 1.17s\n", 776 | "97:\tlearn: 4461.5146567\ttest: 4510.9796991\tbest: 4510.9796991 (97)\ttotal: 126ms\tremaining: 1.16s\n", 777 | "98:\tlearn: 4456.6558845\ttest: 4508.3112522\tbest: 4508.3112522 (98)\ttotal: 126ms\tremaining: 1.15s\n", 778 | "99:\tlearn: 4454.3871627\ttest: 4505.7860654\tbest: 4505.7860654 (99)\ttotal: 127ms\tremaining: 1.14s\n", 779 | "100:\tlearn: 4448.9852680\ttest: 4503.1081417\tbest: 4503.1081417 (100)\ttotal: 128ms\tremaining: 1.14s\n", 780 | "101:\tlearn: 4440.5592360\ttest: 4502.4170018\tbest: 4502.4170018 (101)\ttotal: 129ms\tremaining: 1.13s\n", 781 | "102:\tlearn: 4431.6324937\ttest: 4500.9596119\tbest: 4500.9596119 (102)\ttotal: 129ms\tremaining: 1.13s\n", 782 | "103:\tlearn: 4429.3687673\ttest: 4501.4204720\tbest: 4500.9596119 (102)\ttotal: 130ms\tremaining: 1.12s\n", 783 | "104:\tlearn: 4429.3629093\ttest: 4501.3898254\tbest: 4500.9596119 (102)\ttotal: 131ms\tremaining: 1.11s\n", 784 | "105:\tlearn: 4428.8235974\ttest: 4501.0482453\tbest: 4500.9596119 (102)\ttotal: 131ms\tremaining: 1.11s\n", 785 | "106:\tlearn: 4425.9788709\ttest: 4500.6305192\tbest: 4500.6305192 (106)\ttotal: 132ms\tremaining: 1.1s\n", 786 | "107:\tlearn: 4424.3855718\ttest: 4499.2275541\tbest: 4499.2275541 (107)\ttotal: 133ms\tremaining: 1.09s\n", 787 | "108:\tlearn: 4422.0120302\ttest: 4499.8376373\tbest: 4499.2275541 (107)\ttotal: 133ms\tremaining: 1.09s\n", 788 | "109:\tlearn: 4418.3186434\ttest: 4498.5182120\tbest: 4498.5182120 (109)\ttotal: 134ms\tremaining: 1.08s\n", 789 | "110:\tlearn: 4417.2920763\ttest: 4497.8476563\tbest: 4497.8476563 (110)\ttotal: 134ms\tremaining: 1.07s\n", 790 | "111:\tlearn: 4416.7227745\ttest: 4498.9462544\tbest: 4497.8476563 (110)\ttotal: 135ms\tremaining: 1.07s\n", 791 | "112:\tlearn: 4412.7025996\ttest: 4498.7626385\tbest: 4497.8476563 (110)\ttotal: 136ms\tremaining: 1.06s\n", 792 | "113:\tlearn: 4407.8196920\ttest: 4498.6152396\tbest: 4497.8476563 (110)\ttotal: 136ms\tremaining: 1.06s\n", 793 | "114:\tlearn: 4404.0040730\ttest: 4496.1760525\tbest: 4496.1760525 (114)\ttotal: 137ms\tremaining: 1.05s\n", 794 | "115:\tlearn: 4402.6820993\ttest: 4493.5948619\tbest: 4493.5948619 (115)\ttotal: 138ms\tremaining: 1.05s\n", 795 | "116:\tlearn: 4400.7635771\ttest: 4491.6903454\tbest: 4491.6903454 (116)\ttotal: 139ms\tremaining: 1.05s\n", 796 | "117:\tlearn: 4399.7669773\ttest: 4489.9523133\tbest: 4489.9523133 (117)\ttotal: 140ms\tremaining: 1.04s\n", 797 | "118:\tlearn: 4393.0292150\ttest: 4486.7754864\tbest: 4486.7754864 (118)\ttotal: 141ms\tremaining: 1.04s\n", 798 | "119:\tlearn: 4385.1056174\ttest: 4482.0289943\tbest: 4482.0289943 (119)\ttotal: 142ms\tremaining: 1.04s\n", 799 | "120:\tlearn: 4382.9640733\ttest: 4482.4460160\tbest: 4482.0289943 (119)\ttotal: 142ms\tremaining: 1.03s\n", 800 | "121:\tlearn: 4381.6342302\ttest: 4479.7860514\tbest: 4479.7860514 (121)\ttotal: 143ms\tremaining: 1.03s\n", 801 | "122:\tlearn: 4381.1578021\ttest: 4479.6616481\tbest: 4479.6616481 (122)\ttotal: 144ms\tremaining: 1.02s\n", 802 | "123:\tlearn: 4375.5387552\ttest: 4478.0570824\tbest: 4478.0570824 (123)\ttotal: 144ms\tremaining: 1.02s\n", 803 | "124:\tlearn: 4375.1315479\ttest: 4477.1129288\tbest: 4477.1129288 (124)\ttotal: 145ms\tremaining: 1.01s\n", 804 | "125:\tlearn: 4373.8160483\ttest: 4475.2068758\tbest: 4475.2068758 (125)\ttotal: 145ms\tremaining: 1.01s\n", 805 | "126:\tlearn: 4370.8924010\ttest: 4477.4286380\tbest: 4475.2068758 (125)\ttotal: 146ms\tremaining: 1s\n", 806 | "127:\tlearn: 4369.3506544\ttest: 4476.3365496\tbest: 4475.2068758 (125)\ttotal: 147ms\tremaining: 1s\n", 807 | "128:\tlearn: 4367.1210753\ttest: 4474.4884111\tbest: 4474.4884111 (128)\ttotal: 149ms\tremaining: 1s\n", 808 | "129:\tlearn: 4366.7213657\ttest: 4474.6765012\tbest: 4474.4884111 (128)\ttotal: 149ms\tremaining: 997ms\n", 809 | "130:\tlearn: 4361.4266661\ttest: 4472.3525687\tbest: 4472.3525687 (130)\ttotal: 150ms\tremaining: 993ms\n", 810 | "131:\tlearn: 4359.1617519\ttest: 4468.8500676\tbest: 4468.8500676 (131)\ttotal: 150ms\tremaining: 989ms\n", 811 | "132:\tlearn: 4357.3157065\ttest: 4467.6042087\tbest: 4467.6042087 (132)\ttotal: 151ms\tremaining: 985ms\n", 812 | "133:\tlearn: 4355.4754571\ttest: 4466.7978797\tbest: 4466.7978797 (133)\ttotal: 152ms\tremaining: 982ms\n", 813 | "134:\tlearn: 4353.0920681\ttest: 4467.2216987\tbest: 4466.7978797 (133)\ttotal: 153ms\tremaining: 980ms\n", 814 | "135:\tlearn: 4349.3034898\ttest: 4462.7715216\tbest: 4462.7715216 (135)\ttotal: 154ms\tremaining: 976ms\n", 815 | "136:\tlearn: 4348.0399740\ttest: 4461.4525517\tbest: 4461.4525517 (136)\ttotal: 154ms\tremaining: 973ms\n", 816 | "137:\tlearn: 4342.2369116\ttest: 4463.1921563\tbest: 4461.4525517 (136)\ttotal: 155ms\tremaining: 969ms\n", 817 | "138:\tlearn: 4337.4500163\ttest: 4462.3466978\tbest: 4461.4525517 (136)\ttotal: 156ms\tremaining: 968ms\n", 818 | "139:\tlearn: 4336.7592406\ttest: 4462.1731148\tbest: 4461.4525517 (136)\ttotal: 157ms\tremaining: 964ms\n", 819 | "140:\tlearn: 4336.5533039\ttest: 4462.1081990\tbest: 4461.4525517 (136)\ttotal: 157ms\tremaining: 959ms\n", 820 | "141:\tlearn: 4335.4981537\ttest: 4461.7103284\tbest: 4461.4525517 (136)\ttotal: 158ms\tremaining: 957ms\n", 821 | "142:\tlearn: 4335.1834055\ttest: 4461.6170483\tbest: 4461.4525517 (136)\ttotal: 159ms\tremaining: 951ms\n", 822 | "143:\tlearn: 4333.4594557\ttest: 4461.5227009\tbest: 4461.4525517 (136)\ttotal: 160ms\tremaining: 950ms\n", 823 | "144:\tlearn: 4331.8637410\ttest: 4462.5664485\tbest: 4461.4525517 (136)\ttotal: 161ms\tremaining: 947ms\n", 824 | "145:\tlearn: 4330.8941104\ttest: 4462.2997026\tbest: 4461.4525517 (136)\ttotal: 161ms\tremaining: 943ms\n", 825 | "146:\tlearn: 4327.6644517\ttest: 4462.2039381\tbest: 4461.4525517 (136)\ttotal: 162ms\tremaining: 940ms\n", 826 | "147:\tlearn: 4326.5420613\ttest: 4462.8728516\tbest: 4461.4525517 (136)\ttotal: 163ms\tremaining: 936ms\n", 827 | "148:\tlearn: 4322.3927087\ttest: 4462.6585701\tbest: 4461.4525517 (136)\ttotal: 164ms\tremaining: 934ms\n", 828 | "149:\tlearn: 4321.0124809\ttest: 4463.2566088\tbest: 4461.4525517 (136)\ttotal: 164ms\tremaining: 931ms\n", 829 | "150:\tlearn: 4319.9787726\ttest: 4464.1263409\tbest: 4461.4525517 (136)\ttotal: 165ms\tremaining: 927ms\n", 830 | "151:\tlearn: 4318.2804397\ttest: 4463.9719555\tbest: 4461.4525517 (136)\ttotal: 165ms\tremaining: 923ms\n", 831 | "152:\tlearn: 4317.1621473\ttest: 4464.3190037\tbest: 4461.4525517 (136)\ttotal: 167ms\tremaining: 922ms\n", 832 | "153:\tlearn: 4315.3082889\ttest: 4462.3371173\tbest: 4461.4525517 (136)\ttotal: 168ms\tremaining: 922ms\n", 833 | "154:\tlearn: 4313.6240588\ttest: 4460.9090316\tbest: 4460.9090316 (154)\ttotal: 169ms\tremaining: 919ms\n", 834 | "155:\tlearn: 4308.7408461\ttest: 4461.9994794\tbest: 4460.9090316 (154)\ttotal: 170ms\tremaining: 917ms\n", 835 | "156:\tlearn: 4304.3605530\ttest: 4460.8561739\tbest: 4460.8561739 (156)\ttotal: 170ms\tremaining: 914ms\n", 836 | "157:\tlearn: 4296.5265000\ttest: 4458.2121464\tbest: 4458.2121464 (157)\ttotal: 171ms\tremaining: 912ms\n", 837 | "158:\tlearn: 4292.9219052\ttest: 4457.3077812\tbest: 4457.3077812 (158)\ttotal: 172ms\tremaining: 911ms\n", 838 | "159:\tlearn: 4291.2528338\ttest: 4457.1570880\tbest: 4457.1570880 (159)\ttotal: 173ms\tremaining: 908ms\n", 839 | "160:\tlearn: 4289.4613153\ttest: 4456.5441936\tbest: 4456.5441936 (160)\ttotal: 174ms\tremaining: 906ms\n", 840 | "161:\tlearn: 4285.0691200\ttest: 4456.8069797\tbest: 4456.5441936 (160)\ttotal: 175ms\tremaining: 903ms\n", 841 | "162:\tlearn: 4283.4649878\ttest: 4455.2279840\tbest: 4455.2279840 (162)\ttotal: 175ms\tremaining: 900ms\n", 842 | "163:\tlearn: 4281.2809661\ttest: 4455.3887740\tbest: 4455.2279840 (162)\ttotal: 176ms\tremaining: 897ms\n", 843 | "164:\tlearn: 4281.1947633\ttest: 4455.3803900\tbest: 4455.2279840 (162)\ttotal: 176ms\tremaining: 892ms\n", 844 | "165:\tlearn: 4279.2680603\ttest: 4456.9433372\tbest: 4455.2279840 (162)\ttotal: 177ms\tremaining: 890ms\n", 845 | "166:\tlearn: 4274.1633733\ttest: 4459.1915051\tbest: 4455.2279840 (162)\ttotal: 178ms\tremaining: 888ms\n", 846 | "167:\tlearn: 4271.5861198\ttest: 4457.6296101\tbest: 4455.2279840 (162)\ttotal: 179ms\tremaining: 885ms\n", 847 | "168:\tlearn: 4271.3799469\ttest: 4457.6197179\tbest: 4455.2279840 (162)\ttotal: 179ms\tremaining: 881ms\n", 848 | "169:\tlearn: 4269.6933288\ttest: 4456.6673339\tbest: 4455.2279840 (162)\ttotal: 180ms\tremaining: 878ms\n", 849 | "170:\tlearn: 4266.6707750\ttest: 4456.0106319\tbest: 4455.2279840 (162)\ttotal: 181ms\tremaining: 876ms\n", 850 | "171:\tlearn: 4265.4996249\ttest: 4456.0314402\tbest: 4455.2279840 (162)\ttotal: 181ms\tremaining: 873ms\n", 851 | "172:\tlearn: 4264.2493406\ttest: 4456.6100371\tbest: 4455.2279840 (162)\ttotal: 182ms\tremaining: 870ms\n", 852 | "173:\tlearn: 4259.6207748\ttest: 4457.9821824\tbest: 4455.2279840 (162)\ttotal: 183ms\tremaining: 867ms\n", 853 | "174:\tlearn: 4253.2039726\ttest: 4458.0419803\tbest: 4455.2279840 (162)\ttotal: 183ms\tremaining: 864ms\n", 854 | "175:\tlearn: 4252.1777721\ttest: 4458.1837295\tbest: 4455.2279840 (162)\ttotal: 184ms\tremaining: 861ms\n", 855 | "176:\tlearn: 4251.4820019\ttest: 4458.6473260\tbest: 4455.2279840 (162)\ttotal: 185ms\tremaining: 859ms\n", 856 | "177:\tlearn: 4248.0365555\ttest: 4457.1442368\tbest: 4455.2279840 (162)\ttotal: 185ms\tremaining: 856ms\n", 857 | "178:\tlearn: 4246.5116860\ttest: 4459.3102022\tbest: 4455.2279840 (162)\ttotal: 186ms\tremaining: 855ms\n", 858 | "179:\tlearn: 4244.3109488\ttest: 4459.7738963\tbest: 4455.2279840 (162)\ttotal: 187ms\tremaining: 852ms\n", 859 | "180:\tlearn: 4243.5314996\ttest: 4460.0109834\tbest: 4455.2279840 (162)\ttotal: 188ms\tremaining: 851ms\n", 860 | "181:\tlearn: 4243.3602516\ttest: 4460.2987130\tbest: 4455.2279840 (162)\ttotal: 189ms\tremaining: 848ms\n", 861 | "182:\tlearn: 4240.7892696\ttest: 4459.7722494\tbest: 4455.2279840 (162)\ttotal: 190ms\tremaining: 847ms\n", 862 | "Stopped by overfitting detector (20 iterations wait)\n", 863 | "\n", 864 | "bestTest = 4455.227984\n", 865 | "bestIteration = 162\n", 866 | "\n", 867 | "Shrink model to first 163 iterations.\n" 868 | ] 869 | } 870 | ], 871 | "source": [ 872 | "if params['problem_type'] == 'regression':\n", 873 | " from catboost import CatBoostRegressor, Pool\n", 874 | "\n", 875 | " model_catboost = CatBoostRegressor(n_estimators=1000, \n", 876 | " early_stopping_rounds=20)\n", 877 | " train_data = Pool(\n", 878 | " data=X_train_raw,\n", 879 | " label=y_train,\n", 880 | " cat_features=categorical_feature_indices,\n", 881 | " )\n", 882 | "\n", 883 | " eval_data = Pool(\n", 884 | " data=X_valid_raw,\n", 885 | " label=y_valid,\n", 886 | " cat_features=categorical_feature_indices,\n", 887 | " )\n", 888 | "\n", 889 | " model_catboost.fit(X=train_data, \n", 890 | " eval_set=eval_data)\n", 891 | "\n", 892 | "\n", 893 | "\n", 894 | "\n", 895 | " preds_catboost = model_catboost.predict(X_test_raw)\n", 896 | "else:\n", 897 | " from catboost import CatBoostClassifier, Pool\n", 898 | "\n", 899 | " model_catboost = CatBoostClassifier(n_estimators=1000, \n", 900 | " early_stopping_rounds=20)\n", 901 | " train_data = Pool(\n", 902 | " data=X_train_raw,\n", 903 | " label=y_train,\n", 904 | " cat_features=categorical_feature_indices,\n", 905 | " #weight=calculate_sample_weights(y_train)\n", 906 | " )\n", 907 | "\n", 908 | " eval_data = Pool(\n", 909 | " data=X_valid_raw,\n", 910 | " label=y_valid,\n", 911 | " cat_features=categorical_feature_indices,\n", 912 | " #weight=calculate_sample_weights(y_valid),\n", 913 | " )\n", 914 | "\n", 915 | " model_catboost.fit(X=train_data, \n", 916 | " eval_set=eval_data)\n", 917 | "\n", 918 | "\n", 919 | "\n", 920 | " preds_catboost = model_catboost.predict_proba(X_test_raw)\n" 921 | ] 922 | }, 923 | { 924 | "cell_type": "code", 925 | "execution_count": 12, 926 | "id": "21b9be01-ba30-4300-b3e0-f4fd43b55c87", 927 | "metadata": { 928 | "execution": { 929 | "iopub.execute_input": "2024-03-15T11:57:52.429019Z", 930 | "iopub.status.busy": "2024-03-15T11:57:52.428555Z", 931 | "iopub.status.idle": "2024-03-15T11:57:52.454151Z", 932 | "shell.execute_reply": "2024-03-15T11:57:52.453620Z", 933 | "shell.execute_reply.started": "2024-03-15T11:57:52.428990Z" 934 | } 935 | }, 936 | "outputs": [ 937 | { 938 | "name": "stdout", 939 | "output_type": "stream", 940 | "text": [ 941 | "MAE GRANDE: 2481.6191651649956\n", 942 | "RMSE GRANDE: 3761.773606283763\n", 943 | "R2 Score GRANDE: 0.8653776732830805\n", 944 | "\n", 945 | "\n", 946 | "MAE XGB: 2437.6996424431554\n", 947 | "RMSE XGB: 3930.3302577846944\n", 948 | "R2 Score XGB: 0.8530431372729947\n", 949 | "\n", 950 | "\n", 951 | "MAE CatBoost: 2194.5464048630265\n", 952 | "RMSE CatBoost: 3637.838113587122\n", 953 | "R2 Score CatBoost: 0.8741020902072107\n", 954 | "\n", 955 | "\n" 956 | ] 957 | } 958 | ], 959 | "source": [ 960 | "if params['problem_type'] == 'binary':\n", 961 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_grande[:,1]))\n", 962 | " f1_score = sklearn.metrics.f1_score(y_test, np.round(preds_grande[:,1]), average='macro')\n", 963 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_grande[:,1], average='macro', multi_class='ovo')\n", 964 | "\n", 965 | " print('Accuracy GRANDE:', accuracy)\n", 966 | " print('F1 Score GRANDE:', f1_score)\n", 967 | " print('ROC AUC GRANDE:', roc_auc)\n", 968 | " print('\\n')\n", 969 | "\n", 970 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_xgb[:,1]))\n", 971 | " f1_score = sklearn.metrics.f1_score(y_test, np.round(preds_xgb[:,1]), average='macro')\n", 972 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_xgb[:,1], average='macro', multi_class='ovo')\n", 973 | "\n", 974 | " print('Accuracy XGB:', accuracy)\n", 975 | " print('F1 Score XGB:', f1_score)\n", 976 | " print('ROC AUC XGB:', roc_auc)\n", 977 | " print('\\n')\n", 978 | "\n", 979 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_catboost[:,1]))\n", 980 | " f1_score = sklearn.metrics.f1_score(y_test, np.round(preds_catboost[:,1]), average='macro')\n", 981 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_catboost[:,1], average='macro', multi_class='ovo')\n", 982 | "\n", 983 | " print('Accuracy CatBoost:', accuracy)\n", 984 | " print('F1 Score CatBoost:', f1_score)\n", 985 | " print('ROC AUC CatBoost:', roc_auc)\n", 986 | " print('\\n')\n", 987 | "elif params['problem_type'] == 'multiclass':\n", 988 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.argmax(preds_grande, axis=1))\n", 989 | " f1_score = sklearn.metrics.f1_score(y_test, np.argmax(preds_grande, axis=1), average='macro')\n", 990 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_grande, average='macro', multi_class='ovo', labels=[i for i in range(preds_grande.shape[1])])\n", 991 | "\n", 992 | " print('Accuracy GRANDE:', accuracy)\n", 993 | " print('F1 Score GRANDE:', f1_score)\n", 994 | " print('ROC AUC GRANDE:', roc_auc)\n", 995 | " print('\\n')\n", 996 | "\n", 997 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.argmax(preds_xgb, axis=1))\n", 998 | " f1_score = sklearn.metrics.f1_score(y_test, np.argmax(preds_xgb, axis=1), average='macro')\n", 999 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_xgb, average='macro', multi_class='ovo', labels=[i for i in range(preds_grande.shape[1])])\n", 1000 | "\n", 1001 | " print('Accuracy XGB:', accuracy)\n", 1002 | " print('F1 Score XGB:', f1_score)\n", 1003 | " print('ROC AUC XGB:', roc_auc)\n", 1004 | " print('\\n')\n", 1005 | "\n", 1006 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.argmax(preds_catboost, axis=1))\n", 1007 | " f1_score = sklearn.metrics.f1_score(y_test, np.argmax(preds_catboost, axis=1), average='macro')\n", 1008 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_catboost, average='macro', multi_class='ovo', labels=[i for i in range(preds_grande.shape[1])])\n", 1009 | "\n", 1010 | " print('Accuracy CatBoost:', accuracy)\n", 1011 | " print('F1 Score CatBoost:', f1_score)\n", 1012 | " print('ROC AUC CatBoost:', roc_auc)\n", 1013 | " print('\\n')\n", 1014 | "else:\n", 1015 | " mean_absolute_error = sklearn.metrics.mean_absolute_error(y_test, preds_grande)\n", 1016 | " root_mean_squared_error = np.sqrt(((y_test - preds_grande) ** 2).mean())\n", 1017 | " r2_score = sklearn.metrics.r2_score(y_test, preds_grande)\n", 1018 | "\n", 1019 | " print('MAE GRANDE:', mean_absolute_error)\n", 1020 | " print('RMSE GRANDE:', root_mean_squared_error)\n", 1021 | " print('R2 Score GRANDE:', r2_score)\n", 1022 | " print('\\n')\n", 1023 | "\n", 1024 | " mean_absolute_error = sklearn.metrics.mean_absolute_error(y_test, preds_xgb)\n", 1025 | " root_mean_squared_error = np.sqrt(((y_test - preds_xgb) ** 2).mean())\n", 1026 | " r2_score = sklearn.metrics.r2_score(y_test, preds_xgb)\n", 1027 | "\n", 1028 | " print('MAE XGB:', mean_absolute_error)\n", 1029 | " print('RMSE XGB:', root_mean_squared_error)\n", 1030 | " print('R2 Score XGB:', r2_score)\n", 1031 | " print('\\n')\n", 1032 | "\n", 1033 | " mean_absolute_error = sklearn.metrics.mean_absolute_error(y_test, preds_catboost)\n", 1034 | " root_mean_squared_error = np.sqrt(((y_test - preds_catboost) ** 2).mean())\n", 1035 | " r2_score = sklearn.metrics.r2_score(y_test, preds_catboost)\n", 1036 | "\n", 1037 | " print('MAE CatBoost:', mean_absolute_error)\n", 1038 | " print('RMSE CatBoost:', root_mean_squared_error)\n", 1039 | " print('R2 Score CatBoost:', r2_score)\n", 1040 | " print('\\n')" 1041 | ] 1042 | }, 1043 | { 1044 | "cell_type": "code", 1045 | "execution_count": null, 1046 | "id": "c7ef884c-1269-4e0b-8ca7-7ea74c6e380c", 1047 | "metadata": {}, 1048 | "outputs": [], 1049 | "source": [] 1050 | }, 1051 | { 1052 | "cell_type": "code", 1053 | "execution_count": null, 1054 | "id": "2fc02cd3", 1055 | "metadata": {}, 1056 | "outputs": [], 1057 | "source": [] 1058 | } 1059 | ], 1060 | "metadata": { 1061 | "kernelspec": { 1062 | "display_name": "ReMeDe", 1063 | "language": "python", 1064 | "name": "python3" 1065 | }, 1066 | "language_info": { 1067 | "codemirror_mode": { 1068 | "name": "ipython", 1069 | "version": 3 1070 | }, 1071 | "file_extension": ".py", 1072 | "mimetype": "text/x-python", 1073 | "name": "python", 1074 | "nbconvert_exporter": "python", 1075 | "pygments_lexer": "ipython3", 1076 | "version": "3.12.9" 1077 | } 1078 | }, 1079 | "nbformat": 4, 1080 | "nbformat_minor": 5 1081 | } 1082 | -------------------------------------------------------------------------------- /GRANDE_minimal_example_with_comparison_BINARY.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "bf584ce8-8849-4a53-8ce9-2de8f9dd752d", 7 | "metadata": { 8 | "execution": { 9 | "iopub.execute_input": "2024-03-15T11:56:56.253578Z", 10 | "iopub.status.busy": "2024-03-15T11:56:56.253464Z", 11 | "iopub.status.idle": "2024-03-15T11:56:56.265905Z", 12 | "shell.execute_reply": "2024-03-15T11:56:56.265512Z", 13 | "shell.execute_reply.started": "2024-03-15T11:56:56.253566Z" 14 | } 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "#specify GPU to use\n", 19 | "import os\n", 20 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "52b07238-f972-45df-86ce-a11cc11d72e3", 27 | "metadata": { 28 | "execution": { 29 | "iopub.execute_input": "2024-03-15T11:56:56.266511Z", 30 | "iopub.status.busy": "2024-03-15T11:56:56.266392Z", 31 | "iopub.status.idle": "2024-03-15T11:56:57.956110Z", 32 | "shell.execute_reply": "2024-03-15T11:56:57.955562Z", 33 | "shell.execute_reply.started": "2024-03-15T11:56:56.266499Z" 34 | } 35 | }, 36 | "outputs": [ 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "Training set size: 3200\n", 42 | "Validation set size: 800\n", 43 | "Test set size: 1000\n" 44 | ] 45 | }, 46 | { 47 | "data": { 48 | "text/html": [ 49 | "
\n", 50 | "\n", 63 | "\n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | "
stateaccount_lengtharea_codeinternational_planvoice_mail_plannumber_vmail_messagestotal_day_minutestotal_day_callstotal_day_chargetotal_eve_minutestotal_eve_callstotal_eve_chargetotal_night_minutestotal_night_callstotal_night_chargetotal_intl_minutestotal_intl_callstotal_intl_chargenumber_customer_service_calls
34441120415.00127179.614230.53262.810322.34239.98110.8011.123.001
206315130408.0000115.612919.65167.810414.26141.81246.3812.693.401
37142666510.0000201.39534.22152.86612.99233.210110.497.542.031
26711360510.0000221.110637.59178.64815.18202.7909.127.432.001
215435134408.0000202.710534.46224.99019.12253.910811.4312.173.270
\n", 201 | "
" 202 | ], 203 | "text/plain": [ 204 | " state account_length area_code international_plan voice_mail_plan \\\n", 205 | "3444 1 120 415.0 0 1 \n", 206 | "2063 15 130 408.0 0 0 \n", 207 | "3714 26 66 510.0 0 0 \n", 208 | "2671 13 60 510.0 0 0 \n", 209 | "2154 35 134 408.0 0 0 \n", 210 | "\n", 211 | " number_vmail_messages total_day_minutes total_day_calls \\\n", 212 | "3444 27 179.6 142 \n", 213 | "2063 0 115.6 129 \n", 214 | "3714 0 201.3 95 \n", 215 | "2671 0 221.1 106 \n", 216 | "2154 0 202.7 105 \n", 217 | "\n", 218 | " total_day_charge total_eve_minutes total_eve_calls total_eve_charge \\\n", 219 | "3444 30.53 262.8 103 22.34 \n", 220 | "2063 19.65 167.8 104 14.26 \n", 221 | "3714 34.22 152.8 66 12.99 \n", 222 | "2671 37.59 178.6 48 15.18 \n", 223 | "2154 34.46 224.9 90 19.12 \n", 224 | "\n", 225 | " total_night_minutes total_night_calls total_night_charge \\\n", 226 | "3444 239.9 81 10.80 \n", 227 | "2063 141.8 124 6.38 \n", 228 | "3714 233.2 101 10.49 \n", 229 | "2671 202.7 90 9.12 \n", 230 | "2154 253.9 108 11.43 \n", 231 | "\n", 232 | " total_intl_minutes total_intl_calls total_intl_charge \\\n", 233 | "3444 11.1 2 3.00 \n", 234 | "2063 12.6 9 3.40 \n", 235 | "3714 7.5 4 2.03 \n", 236 | "2671 7.4 3 2.00 \n", 237 | "2154 12.1 7 3.27 \n", 238 | "\n", 239 | " number_customer_service_calls \n", 240 | "3444 1 \n", 241 | "2063 1 \n", 242 | "3714 1 \n", 243 | "2671 1 \n", 244 | "2154 0 " 245 | ] 246 | }, 247 | "execution_count": 2, 248 | "metadata": {}, 249 | "output_type": "execute_result" 250 | } 251 | ], 252 | "source": [ 253 | "from sklearn.model_selection import train_test_split\n", 254 | "import openml\n", 255 | "import category_encoders as ce\n", 256 | "import numpy as np\n", 257 | "import sklearn\n", 258 | "\n", 259 | "# Load churn dataset\n", 260 | "dataset = openml.datasets.get_dataset(46915, download_data=True, download_qualities=True, download_features_meta_data=True)\n", 261 | "X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n", 262 | "categorical_feature_indices = [idx for idx, idx_bool in enumerate(categorical_indicator) if idx_bool]\n", 263 | "\n", 264 | "X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", 265 | "\n", 266 | "X_train, X_valid, y_train, y_valid = train_test_split(X_temp, y_temp, test_size=0.2, random_state=42)\n", 267 | "\n", 268 | "print(\"Training set size:\", len(X_train))\n", 269 | "print(\"Validation set size:\", len(X_valid))\n", 270 | "print(\"Test set size:\", len(X_test))\n", 271 | "\n", 272 | "X_train.head()" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 3, 278 | "id": "d8364d07", 279 | "metadata": {}, 280 | "outputs": [ 281 | { 282 | "data": { 283 | "text/plain": [ 284 | "state category\n", 285 | "account_length uint8\n", 286 | "area_code category\n", 287 | "international_plan category\n", 288 | "voice_mail_plan category\n", 289 | "number_vmail_messages uint8\n", 290 | "total_day_minutes float64\n", 291 | "total_day_calls uint8\n", 292 | "total_day_charge float64\n", 293 | "total_eve_minutes float64\n", 294 | "total_eve_calls uint8\n", 295 | "total_eve_charge float64\n", 296 | "total_night_minutes float64\n", 297 | "total_night_calls uint8\n", 298 | "total_night_charge float64\n", 299 | "total_intl_minutes float64\n", 300 | "total_intl_calls uint8\n", 301 | "total_intl_charge float64\n", 302 | "number_customer_service_calls uint8\n", 303 | "dtype: object" 304 | ] 305 | }, 306 | "execution_count": 3, 307 | "metadata": {}, 308 | "output_type": "execute_result" 309 | } 310 | ], 311 | "source": [ 312 | "X_train.dtypes" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "id": "bf0a93c2-e9a4-4fa5-a7f6-44689dd07d37", 319 | "metadata": { 320 | "execution": { 321 | "iopub.execute_input": "2024-03-15T11:56:57.957002Z", 322 | "iopub.status.busy": "2024-03-15T11:56:57.956779Z", 323 | "iopub.status.idle": "2024-03-15T11:57:30.600318Z", 324 | "shell.execute_reply": "2024-03-15T11:57:30.599632Z", 325 | "shell.execute_reply.started": "2024-03-15T11:56:57.956987Z" 326 | }, 327 | "scrolled": true 328 | }, 329 | "outputs": [ 330 | { 331 | "name": "stdout", 332 | "output_type": "stream", 333 | "text": [ 334 | "self.params {'depth': 5, 'n_estimators': 1024, 'learning_rate_weights': 0.001, 'learning_rate_index': 0.01, 'learning_rate_values': 0.05, 'learning_rate_leaf': 0.05, 'learning_rate_embedding': 0.02, 'use_category_embeddings': False, 'embedding_dim_cat': 8, 'use_numeric_embeddings': False, 'embedding_dim_num': 8, 'embedding_threshold': 1, 'loo_cardinality': 10, 'dropout': 0.2, 'selected_variables': 0.8, 'data_subset_fraction': 1.0, 'bootstrap': False, 'missing_values': False, 'optimizer': 'adam', 'cosine_decay_restarts': False, 'reduce_on_plateau_scheduler': True, 'label_smoothing': 0.0, 'use_class_weights': False, 'focal_loss': False, 'swa': False, 'es_metric': True, 'epochs': 250, 'batch_size': 256, 'early_stopping_epochs': 50, 'use_freq_enc': False, 'use_robust_scale_smoothing': False, 'problem_type': 'binary', 'random_seed': 42, 'verbose': 2, 'device': 'cuda:0', 'objective': 'binary'}\n", 335 | "{'self.column_names_dataframe': ['state',\n", 336 | " 'account_length',\n", 337 | " 'area_code_1',\n", 338 | " 'area_code_2',\n", 339 | " 'area_code_3',\n", 340 | " 'international_plan',\n", 341 | " 'voice_mail_plan',\n", 342 | " 'number_vmail_messages',\n", 343 | " 'total_day_minutes',\n", 344 | " 'total_day_calls',\n", 345 | " 'total_day_charge',\n", 346 | " 'total_eve_minutes',\n", 347 | " 'total_eve_calls',\n", 348 | " 'total_eve_charge',\n", 349 | " 'total_night_minutes',\n", 350 | " 'total_night_calls',\n", 351 | " 'total_night_charge',\n", 352 | " 'total_intl_minutes',\n", 353 | " 'total_intl_calls',\n", 354 | " 'total_intl_charge',\n", 355 | " 'number_customer_service_calls'],\n", 356 | " 'encoded_columns': ['state',\n", 357 | " 'area_code_1',\n", 358 | " 'area_code_2',\n", 359 | " 'area_code_3',\n", 360 | " 'international_plan',\n", 361 | " 'voice_mail_plan'],\n", 362 | " 'encoded_columns_indices': [0, 2, 3, 4, 5, 6],\n", 363 | " 'num_columns': ['account_length',\n", 364 | " 'number_vmail_messages',\n", 365 | " 'total_day_minutes',\n", 366 | " 'total_day_calls',\n", 367 | " 'total_day_charge',\n", 368 | " 'total_eve_minutes',\n", 369 | " 'total_eve_calls',\n", 370 | " 'total_eve_charge',\n", 371 | " 'total_night_minutes',\n", 372 | " 'total_night_calls',\n", 373 | " 'total_night_charge',\n", 374 | " 'total_intl_minutes',\n", 375 | " 'total_intl_calls',\n", 376 | " 'total_intl_charge',\n", 377 | " 'number_customer_service_calls'],\n", 378 | " 'num_columns_indices': [1,\n", 379 | " 7,\n", 380 | " 8,\n", 381 | " 9,\n", 382 | " 10,\n", 383 | " 11,\n", 384 | " 12,\n", 385 | " 13,\n", 386 | " 14,\n", 387 | " 15,\n", 388 | " 16,\n", 389 | " 17,\n", 390 | " 18,\n", 391 | " 19,\n", 392 | " 20],\n", 393 | " 'categorical_features_raw_indices': [],\n", 394 | " 'not_encoded_columns': []}\n" 395 | ] 396 | }, 397 | { 398 | "name": "stderr", 399 | "output_type": "stream", 400 | "text": [ 401 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n", 402 | "Online softmax is disabled on the fly since Inductor decides to\n", 403 | "split the reduction. Cut an issue to PyTorch if this is an\n", 404 | "important use case and you want to speed it up with online\n", 405 | "softmax.\n", 406 | "\n", 407 | " warnings.warn(\n", 408 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n", 409 | "Online softmax is disabled on the fly since Inductor decides to\n", 410 | "split the reduction. Cut an issue to PyTorch if this is an\n", 411 | "important use case and you want to speed it up with online\n", 412 | "softmax.\n", 413 | "\n", 414 | " warnings.warn(\n", 415 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n", 416 | "Online softmax is disabled on the fly since Inductor decides to\n", 417 | "split the reduction. Cut an issue to PyTorch if this is an\n", 418 | "important use case and you want to speed it up with online\n", 419 | "softmax.\n", 420 | "\n", 421 | " warnings.warn(\n" 422 | ] 423 | }, 424 | { 425 | "name": "stdout", 426 | "output_type": "stream", 427 | "text": [ 428 | "Epoch 001 | TrainLoss: 0.5474 | ValLoss: 0.4453 | ValAcc: 0.8462 | ValAUC: 0.7365 | ValF1: 0.4584 | Time: 7.30s\n", 429 | "Epoch 002 | TrainLoss: 0.3948 | ValLoss: 0.3909 | ValAcc: 0.8462 | ValAUC: 0.7867 | ValF1: 0.4584 | Time: 0.07s\n", 430 | "Epoch 003 | TrainLoss: 0.3378 | ValLoss: 0.3544 | ValAcc: 0.8512 | ValAUC: 0.8047 | ValF1: 0.4911 | Time: 0.07s\n", 431 | "Epoch 004 | TrainLoss: 0.2971 | ValLoss: 0.3175 | ValAcc: 0.8688 | ValAUC: 0.8550 | ValF1: 0.6018 | Time: 0.07s\n", 432 | "Epoch 005 | TrainLoss: 0.2589 | ValLoss: 0.2886 | ValAcc: 0.8712 | ValAUC: 0.8808 | ValF1: 0.6189 | Time: 0.07s\n", 433 | "Epoch 006 | TrainLoss: 0.2375 | ValLoss: 0.2643 | ValAcc: 0.8775 | ValAUC: 0.8945 | ValF1: 0.6479 | Time: 0.07s\n", 434 | "Epoch 007 | TrainLoss: 0.2188 | ValLoss: 0.2462 | ValAcc: 0.8912 | ValAUC: 0.9017 | ValF1: 0.7092 | Time: 0.07s\n", 435 | "Epoch 008 | TrainLoss: 0.2062 | ValLoss: 0.2279 | ValAcc: 0.9125 | ValAUC: 0.9067 | ValF1: 0.7929 | Time: 0.07s\n", 436 | "Epoch 009 | TrainLoss: 0.1965 | ValLoss: 0.2168 | ValAcc: 0.9200 | ValAUC: 0.9133 | ValF1: 0.8123 | Time: 0.07s\n", 437 | "Epoch 010 | TrainLoss: 0.1835 | ValLoss: 0.2026 | ValAcc: 0.9413 | ValAUC: 0.9144 | ValF1: 0.8717 | Time: 0.07s\n", 438 | "Epoch 011 | TrainLoss: 0.1750 | ValLoss: 0.1998 | ValAcc: 0.9463 | ValAUC: 0.9169 | ValF1: 0.8807 | Time: 0.07s\n", 439 | "Epoch 012 | TrainLoss: 0.1680 | ValLoss: 0.1945 | ValAcc: 0.9375 | ValAUC: 0.9203 | ValF1: 0.8584 | Time: 0.07s\n", 440 | "Epoch 013 | TrainLoss: 0.1612 | ValLoss: 0.1865 | ValAcc: 0.9550 | ValAUC: 0.9199 | ValF1: 0.9044 | Time: 0.07s\n", 441 | "[EarlyStop TRAINING] no improve (1/50). best=-0.920284, curr=-0.919936\n", 442 | "Epoch 014 | TrainLoss: 0.1590 | ValLoss: 0.1786 | ValAcc: 0.9575 | ValAUC: 0.9222 | ValF1: 0.9083 | Time: 0.07s\n", 443 | "Epoch 015 | TrainLoss: 0.1521 | ValLoss: 0.1729 | ValAcc: 0.9650 | ValAUC: 0.9227 | ValF1: 0.9268 | Time: 0.07s\n", 444 | "[EarlyStop TRAINING] no improve (1/50). best=-0.922158, curr=-0.922722\n", 445 | "Epoch 016 | TrainLoss: 0.1500 | ValLoss: 0.1707 | ValAcc: 0.9600 | ValAUC: 0.9223 | ValF1: 0.9163 | Time: 0.07s\n", 446 | "[EarlyStop TRAINING] no improve (2/50). best=-0.922158, curr=-0.922326\n", 447 | "Epoch 017 | TrainLoss: 0.1441 | ValLoss: 0.1696 | ValAcc: 0.9613 | ValAUC: 0.9226 | ValF1: 0.9180 | Time: 0.07s\n", 448 | "[EarlyStop TRAINING] no improve (3/50). best=-0.922158, curr=-0.922578\n", 449 | "Epoch 018 | TrainLoss: 0.1402 | ValLoss: 0.1669 | ValAcc: 0.9600 | ValAUC: 0.9283 | ValF1: 0.9150 | Time: 0.07s\n", 450 | "Epoch 019 | TrainLoss: 0.1379 | ValLoss: 0.1611 | ValAcc: 0.9613 | ValAUC: 0.9261 | ValF1: 0.9174 | Time: 0.07s\n", 451 | "[EarlyStop TRAINING] no improve (1/50). best=-0.928258, curr=-0.926085\n", 452 | "Epoch 020 | TrainLoss: 0.1307 | ValLoss: 0.1609 | ValAcc: 0.9663 | ValAUC: 0.9257 | ValF1: 0.9286 | Time: 0.07s\n", 453 | "[EarlyStop TRAINING] no improve (2/50). best=-0.928258, curr=-0.925724\n", 454 | "Epoch 021 | TrainLoss: 0.1305 | ValLoss: 0.1588 | ValAcc: 0.9650 | ValAUC: 0.9254 | ValF1: 0.9268 | Time: 0.07s\n", 455 | "[EarlyStop TRAINING] no improve (3/50). best=-0.928258, curr=-0.925436\n", 456 | "Epoch 022 | TrainLoss: 0.1302 | ValLoss: 0.1556 | ValAcc: 0.9637 | ValAUC: 0.9291 | ValF1: 0.9233 | Time: 0.07s\n", 457 | "[EarlyStop TRAINING] no improve (4/50). best=-0.928258, curr=-0.929087\n", 458 | "Epoch 023 | TrainLoss: 0.1240 | ValLoss: 0.1555 | ValAcc: 0.9650 | ValAUC: 0.9262 | ValF1: 0.9268 | Time: 0.07s\n", 459 | "[EarlyStop TRAINING] no improve (5/50). best=-0.928258, curr=-0.926193\n", 460 | "Epoch 024 | TrainLoss: 0.1186 | ValLoss: 0.1511 | ValAcc: 0.9637 | ValAUC: 0.9299 | ValF1: 0.9239 | Time: 0.07s\n", 461 | "Epoch 025 | TrainLoss: 0.1205 | ValLoss: 0.1556 | ValAcc: 0.9600 | ValAUC: 0.9273 | ValF1: 0.9150 | Time: 0.07s\n", 462 | "[EarlyStop TRAINING] no improve (1/50). best=-0.929868, curr=-0.927334\n", 463 | "Epoch 026 | TrainLoss: 0.1175 | ValLoss: 0.1524 | ValAcc: 0.9675 | ValAUC: 0.9271 | ValF1: 0.9330 | Time: 0.07s\n", 464 | "[EarlyStop TRAINING] no improve (2/50). best=-0.929868, curr=-0.927093\n", 465 | "Epoch 027 | TrainLoss: 0.1147 | ValLoss: 0.1469 | ValAcc: 0.9675 | ValAUC: 0.9283 | ValF1: 0.9330 | Time: 0.07s\n", 466 | "[EarlyStop TRAINING] no improve (3/50). best=-0.929868, curr=-0.928270\n", 467 | "Epoch 028 | TrainLoss: 0.1148 | ValLoss: 0.1520 | ValAcc: 0.9637 | ValAUC: 0.9266 | ValF1: 0.9250 | Time: 0.07s\n", 468 | "[EarlyStop TRAINING] no improve (4/50). best=-0.929868, curr=-0.926577\n", 469 | "Epoch 029 | TrainLoss: 0.1127 | ValLoss: 0.1524 | ValAcc: 0.9663 | ValAUC: 0.9304 | ValF1: 0.9297 | Time: 0.07s\n", 470 | "[EarlyStop TRAINING] no improve (5/50). best=-0.929868, curr=-0.930408\n", 471 | "Epoch 030 | TrainLoss: 0.1154 | ValLoss: 0.1515 | ValAcc: 0.9650 | ValAUC: 0.9277 | ValF1: 0.9268 | Time: 0.07s\n", 472 | "[EarlyStop TRAINING] no improve (6/50). best=-0.929868, curr=-0.927658\n", 473 | "Epoch 031 | TrainLoss: 0.1106 | ValLoss: 0.1499 | ValAcc: 0.9650 | ValAUC: 0.9312 | ValF1: 0.9268 | Time: 0.07s\n", 474 | "Epoch 032 | TrainLoss: 0.1100 | ValLoss: 0.1485 | ValAcc: 0.9663 | ValAUC: 0.9305 | ValF1: 0.9302 | Time: 0.07s\n", 475 | "[EarlyStop TRAINING] no improve (1/50). best=-0.931189, curr=-0.930480\n", 476 | "Epoch 033 | TrainLoss: 0.1067 | ValLoss: 0.1470 | ValAcc: 0.9688 | ValAUC: 0.9312 | ValF1: 0.9354 | Time: 0.07s\n", 477 | "[EarlyStop TRAINING] no improve (2/50). best=-0.931189, curr=-0.931237\n", 478 | "Epoch 034 | TrainLoss: 0.1061 | ValLoss: 0.1527 | ValAcc: 0.9650 | ValAUC: 0.9299 | ValF1: 0.9273 | Time: 0.07s\n", 479 | "[EarlyStop TRAINING] no improve (3/50). best=-0.931189, curr=-0.929856\n", 480 | "Epoch 035 | TrainLoss: 0.1047 | ValLoss: 0.1492 | ValAcc: 0.9625 | ValAUC: 0.9280 | ValF1: 0.9215 | Time: 0.07s\n", 481 | "[EarlyStop TRAINING] no improve (4/50). best=-0.931189, curr=-0.927970\n", 482 | "Epoch 036 | TrainLoss: 0.1035 | ValLoss: 0.1483 | ValAcc: 0.9663 | ValAUC: 0.9278 | ValF1: 0.9307 | Time: 0.07s\n", 483 | "[EarlyStop TRAINING] no improve (5/50). best=-0.931189, curr=-0.927838\n", 484 | "Epoch 037 | TrainLoss: 0.1021 | ValLoss: 0.1567 | ValAcc: 0.9650 | ValAUC: 0.9291 | ValF1: 0.9262 | Time: 0.07s\n", 485 | "[EarlyStop TRAINING] no improve (6/50). best=-0.931189, curr=-0.929075\n", 486 | "Epoch 038 | TrainLoss: 0.0999 | ValLoss: 0.1498 | ValAcc: 0.9675 | ValAUC: 0.9279 | ValF1: 0.9325 | Time: 0.07s\n", 487 | "[EarlyStop TRAINING] no improve (7/50). best=-0.931189, curr=-0.927898\n", 488 | "Epoch 039 | TrainLoss: 0.1008 | ValLoss: 0.1490 | ValAcc: 0.9663 | ValAUC: 0.9317 | ValF1: 0.9297 | Time: 0.07s\n", 489 | "[EarlyStop TRAINING] no improve (8/50). best=-0.931189, curr=-0.931681\n", 490 | "Epoch 040 | TrainLoss: 0.0934 | ValLoss: 0.1425 | ValAcc: 0.9650 | ValAUC: 0.9309 | ValF1: 0.9284 | Time: 0.07s\n", 491 | "[EarlyStop TRAINING] no improve (9/50). best=-0.931189, curr=-0.930924\n", 492 | "Epoch 041 | TrainLoss: 0.0974 | ValLoss: 0.1554 | ValAcc: 0.9600 | ValAUC: 0.9309 | ValF1: 0.9163 | Time: 0.07s\n", 493 | "[EarlyStop TRAINING] no improve (10/50). best=-0.931189, curr=-0.930924\n", 494 | "Epoch 042 | TrainLoss: 0.0959 | ValLoss: 0.1494 | ValAcc: 0.9637 | ValAUC: 0.9328 | ValF1: 0.9244 | Time: 0.07s\n", 495 | "Epoch 043 | TrainLoss: 0.0944 | ValLoss: 0.1482 | ValAcc: 0.9625 | ValAUC: 0.9292 | ValF1: 0.9244 | Time: 0.07s\n", 496 | "[EarlyStop TRAINING] no improve (1/50). best=-0.932774, curr=-0.929207\n", 497 | "Epoch 044 | TrainLoss: 0.0944 | ValLoss: 0.1477 | ValAcc: 0.9675 | ValAUC: 0.9328 | ValF1: 0.9330 | Time: 0.07s\n", 498 | "[EarlyStop TRAINING] no improve (2/50). best=-0.932774, curr=-0.932798\n", 499 | "Epoch 045 | TrainLoss: 0.0895 | ValLoss: 0.1522 | ValAcc: 0.9663 | ValAUC: 0.9321 | ValF1: 0.9297 | Time: 0.07s\n", 500 | "[EarlyStop TRAINING] no improve (3/50). best=-0.932774, curr=-0.932149\n", 501 | "Epoch 046 | TrainLoss: 0.0899 | ValLoss: 0.1515 | ValAcc: 0.9637 | ValAUC: 0.9296 | ValF1: 0.9250 | Time: 0.07s\n", 502 | "[EarlyStop TRAINING] no improve (4/50). best=-0.932774, curr=-0.929603\n", 503 | "Epoch 047 | TrainLoss: 0.0879 | ValLoss: 0.1504 | ValAcc: 0.9650 | ValAUC: 0.9282 | ValF1: 0.9279 | Time: 0.07s\n", 504 | "[EarlyStop TRAINING] no improve (5/50). best=-0.932774, curr=-0.928198\n", 505 | "Epoch 048 | TrainLoss: 0.0885 | ValLoss: 0.1496 | ValAcc: 0.9625 | ValAUC: 0.9324 | ValF1: 0.9209 | Time: 0.07s\n", 506 | "[EarlyStop TRAINING] no improve (6/50). best=-0.932774, curr=-0.932365\n", 507 | "Epoch 049 | TrainLoss: 0.0881 | ValLoss: 0.1506 | ValAcc: 0.9625 | ValAUC: 0.9311 | ValF1: 0.9227 | Time: 0.07s\n", 508 | "[EarlyStop TRAINING] no improve (7/50). best=-0.932774, curr=-0.931080\n", 509 | "Epoch 050 | TrainLoss: 0.0889 | ValLoss: 0.1470 | ValAcc: 0.9663 | ValAUC: 0.9321 | ValF1: 0.9307 | Time: 0.07s\n", 510 | "[EarlyStop TRAINING] no improve (8/50). best=-0.932774, curr=-0.932089\n", 511 | "Epoch 051 | TrainLoss: 0.0854 | ValLoss: 0.1552 | ValAcc: 0.9613 | ValAUC: 0.9301 | ValF1: 0.9192 | Time: 0.07s\n", 512 | "[EarlyStop TRAINING] no improve (9/50). best=-0.932774, curr=-0.930120\n", 513 | "Epoch 052 | TrainLoss: 0.0853 | ValLoss: 0.1479 | ValAcc: 0.9650 | ValAUC: 0.9310 | ValF1: 0.9279 | Time: 0.07s\n", 514 | "[EarlyStop TRAINING] no improve (10/50). best=-0.932774, curr=-0.930960\n", 515 | "Epoch 053 | TrainLoss: 0.0836 | ValLoss: 0.1521 | ValAcc: 0.9600 | ValAUC: 0.9294 | ValF1: 0.9188 | Time: 0.07s\n", 516 | "[EarlyStop TRAINING] no improve (11/50). best=-0.932774, curr=-0.929375\n", 517 | "Epoch 054 | TrainLoss: 0.0848 | ValLoss: 0.1556 | ValAcc: 0.9637 | ValAUC: 0.9287 | ValF1: 0.9250 | Time: 0.07s\n", 518 | "[EarlyStop TRAINING] no improve (12/50). best=-0.932774, curr=-0.928679\n", 519 | "Epoch 055 | TrainLoss: 0.0813 | ValLoss: 0.1457 | ValAcc: 0.9688 | ValAUC: 0.9304 | ValF1: 0.9358 | Time: 0.07s\n", 520 | "[EarlyStop TRAINING] no improve (13/50). best=-0.932774, curr=-0.930432\n", 521 | "Epoch 056 | TrainLoss: 0.0790 | ValLoss: 0.1501 | ValAcc: 0.9663 | ValAUC: 0.9313 | ValF1: 0.9302 | Time: 0.07s\n", 522 | "[EarlyStop TRAINING] no improve (14/50). best=-0.932774, curr=-0.931273\n", 523 | "Epoch 057 | TrainLoss: 0.0786 | ValLoss: 0.1516 | ValAcc: 0.9650 | ValAUC: 0.9286 | ValF1: 0.9273 | Time: 0.07s\n", 524 | "[EarlyStop TRAINING] no improve (15/50). best=-0.932774, curr=-0.928607\n", 525 | "Epoch 058 | TrainLoss: 0.0790 | ValLoss: 0.1556 | ValAcc: 0.9625 | ValAUC: 0.9300 | ValF1: 0.9227 | Time: 0.07s\n", 526 | "[EarlyStop TRAINING] no improve (16/50). best=-0.932774, curr=-0.930000\n", 527 | "Epoch 059 | TrainLoss: 0.0801 | ValLoss: 0.1562 | ValAcc: 0.9637 | ValAUC: 0.9296 | ValF1: 0.9244 | Time: 0.07s\n", 528 | "[EarlyStop TRAINING] no improve (17/50). best=-0.932774, curr=-0.929579\n", 529 | "Epoch 060 | TrainLoss: 0.0751 | ValLoss: 0.1561 | ValAcc: 0.9625 | ValAUC: 0.9308 | ValF1: 0.9221 | Time: 0.07s\n", 530 | "[EarlyStop TRAINING] no improve (18/50). best=-0.932774, curr=-0.930792\n", 531 | "Epoch 061 | TrainLoss: 0.0762 | ValLoss: 0.1549 | ValAcc: 0.9675 | ValAUC: 0.9313 | ValF1: 0.9325 | Time: 0.07s\n", 532 | "[EarlyStop TRAINING] no improve (19/50). best=-0.932774, curr=-0.931321\n", 533 | "Epoch 062 | TrainLoss: 0.0745 | ValLoss: 0.1514 | ValAcc: 0.9625 | ValAUC: 0.9285 | ValF1: 0.9244 | Time: 0.07s\n", 534 | "[EarlyStop TRAINING] no improve (20/50). best=-0.932774, curr=-0.928487\n", 535 | "Epoch 063 | TrainLoss: 0.0741 | ValLoss: 0.1551 | ValAcc: 0.9625 | ValAUC: 0.9263 | ValF1: 0.9238 | Time: 0.07s\n", 536 | "[EarlyStop TRAINING] no improve (21/50). best=-0.932774, curr=-0.926337\n", 537 | "Epoch 064 | TrainLoss: 0.0734 | ValLoss: 0.1568 | ValAcc: 0.9613 | ValAUC: 0.9262 | ValF1: 0.9216 | Time: 0.07s\n", 538 | "[EarlyStop TRAINING] no improve (22/50). best=-0.932774, curr=-0.926181\n", 539 | "Epoch 065 | TrainLoss: 0.0744 | ValLoss: 0.1549 | ValAcc: 0.9650 | ValAUC: 0.9260 | ValF1: 0.9289 | Time: 0.07s\n", 540 | "[EarlyStop TRAINING] no improve (23/50). best=-0.932774, curr=-0.925965\n", 541 | "Epoch 066 | TrainLoss: 0.0713 | ValLoss: 0.1528 | ValAcc: 0.9637 | ValAUC: 0.9269 | ValF1: 0.9261 | Time: 0.07s\n", 542 | "[EarlyStop TRAINING] no improve (24/50). best=-0.932774, curr=-0.926913\n", 543 | "Epoch 067 | TrainLoss: 0.0724 | ValLoss: 0.1587 | ValAcc: 0.9587 | ValAUC: 0.9246 | ValF1: 0.9159 | Time: 0.07s\n", 544 | "[EarlyStop TRAINING] no improve (25/50). best=-0.932774, curr=-0.924644\n", 545 | "Epoch 068 | TrainLoss: 0.0670 | ValLoss: 0.1577 | ValAcc: 0.9637 | ValAUC: 0.9261 | ValF1: 0.9256 | Time: 0.07s\n", 546 | "[EarlyStop TRAINING] no improve (26/50). best=-0.932774, curr=-0.926109\n", 547 | "Epoch 069 | TrainLoss: 0.0703 | ValLoss: 0.1552 | ValAcc: 0.9613 | ValAUC: 0.9290 | ValF1: 0.9210 | Time: 0.07s\n", 548 | "[EarlyStop TRAINING] no improve (27/50). best=-0.932774, curr=-0.928967\n", 549 | "Epoch 070 | TrainLoss: 0.0673 | ValLoss: 0.1601 | ValAcc: 0.9650 | ValAUC: 0.9306 | ValF1: 0.9273 | Time: 0.07s\n", 550 | "[EarlyStop TRAINING] no improve (28/50). best=-0.932774, curr=-0.930600\n", 551 | "Epoch 071 | TrainLoss: 0.0698 | ValLoss: 0.1594 | ValAcc: 0.9600 | ValAUC: 0.9305 | ValF1: 0.9182 | Time: 0.07s\n", 552 | "[EarlyStop TRAINING] no improve (29/50). best=-0.932774, curr=-0.930468\n", 553 | "Epoch 072 | TrainLoss: 0.0683 | ValLoss: 0.1585 | ValAcc: 0.9625 | ValAUC: 0.9278 | ValF1: 0.9227 | Time: 0.07s\n", 554 | "[EarlyStop TRAINING] no improve (30/50). best=-0.932774, curr=-0.927838\n", 555 | "Epoch 073 | TrainLoss: 0.0656 | ValLoss: 0.1548 | ValAcc: 0.9625 | ValAUC: 0.9283 | ValF1: 0.9244 | Time: 0.07s\n", 556 | "[EarlyStop TRAINING] no improve (31/50). best=-0.932774, curr=-0.928258\n", 557 | "Epoch 074 | TrainLoss: 0.0647 | ValLoss: 0.1610 | ValAcc: 0.9625 | ValAUC: 0.9265 | ValF1: 0.9227 | Time: 0.07s\n", 558 | "[EarlyStop TRAINING] no improve (32/50). best=-0.932774, curr=-0.926457\n", 559 | "Epoch 075 | TrainLoss: 0.0618 | ValLoss: 0.1574 | ValAcc: 0.9625 | ValAUC: 0.9267 | ValF1: 0.9238 | Time: 0.07s\n", 560 | "[EarlyStop TRAINING] no improve (33/50). best=-0.932774, curr=-0.926685\n", 561 | "Epoch 076 | TrainLoss: 0.0649 | ValLoss: 0.1592 | ValAcc: 0.9625 | ValAUC: 0.9297 | ValF1: 0.9227 | Time: 0.07s\n", 562 | "[EarlyStop TRAINING] no improve (34/50). best=-0.932774, curr=-0.929663\n", 563 | "Epoch 077 | TrainLoss: 0.0645 | ValLoss: 0.1579 | ValAcc: 0.9625 | ValAUC: 0.9285 | ValF1: 0.9244 | Time: 0.07s\n", 564 | "[EarlyStop TRAINING] no improve (35/50). best=-0.932774, curr=-0.928462\n", 565 | "Epoch 078 | TrainLoss: 0.0631 | ValLoss: 0.1680 | ValAcc: 0.9587 | ValAUC: 0.9268 | ValF1: 0.9153 | Time: 0.07s\n", 566 | "[EarlyStop TRAINING] no improve (36/50). best=-0.932774, curr=-0.926829\n", 567 | "Epoch 079 | TrainLoss: 0.0616 | ValLoss: 0.1701 | ValAcc: 0.9575 | ValAUC: 0.9266 | ValF1: 0.9111 | Time: 0.08s\n", 568 | "[EarlyStop TRAINING] no improve (37/50). best=-0.932774, curr=-0.926625\n", 569 | "Epoch 080 | TrainLoss: 0.0631 | ValLoss: 0.1588 | ValAcc: 0.9587 | ValAUC: 0.9266 | ValF1: 0.9165 | Time: 0.07s\n", 570 | "[EarlyStop TRAINING] no improve (38/50). best=-0.932774, curr=-0.926649\n", 571 | "Epoch 081 | TrainLoss: 0.0612 | ValLoss: 0.1628 | ValAcc: 0.9587 | ValAUC: 0.9277 | ValF1: 0.9171 | Time: 0.07s\n", 572 | "[EarlyStop TRAINING] no improve (39/50). best=-0.932774, curr=-0.927694\n", 573 | "Epoch 082 | TrainLoss: 0.0611 | ValLoss: 0.1618 | ValAcc: 0.9625 | ValAUC: 0.9280 | ValF1: 0.9238 | Time: 0.07s\n", 574 | "[EarlyStop TRAINING] no improve (40/50). best=-0.932774, curr=-0.928042\n", 575 | "Epoch 083 | TrainLoss: 0.0601 | ValLoss: 0.1600 | ValAcc: 0.9575 | ValAUC: 0.9290 | ValF1: 0.9149 | Time: 0.07s\n", 576 | "[EarlyStop TRAINING] no improve (41/50). best=-0.932774, curr=-0.929027\n", 577 | "Epoch 084 | TrainLoss: 0.0641 | ValLoss: 0.1614 | ValAcc: 0.9550 | ValAUC: 0.9310 | ValF1: 0.9112 | Time: 0.07s\n", 578 | "[EarlyStop TRAINING] no improve (42/50). best=-0.932774, curr=-0.930960\n", 579 | "Epoch 085 | TrainLoss: 0.0615 | ValLoss: 0.1635 | ValAcc: 0.9537 | ValAUC: 0.9291 | ValF1: 0.9077 | Time: 0.07s\n", 580 | "[EarlyStop TRAINING] no improve (43/50). best=-0.932774, curr=-0.929087\n", 581 | "Epoch 086 | TrainLoss: 0.0595 | ValLoss: 0.1597 | ValAcc: 0.9625 | ValAUC: 0.9305 | ValF1: 0.9238 | Time: 0.07s\n", 582 | "[EarlyStop TRAINING] no improve (44/50). best=-0.932774, curr=-0.930492\n", 583 | "Epoch 087 | TrainLoss: 0.0599 | ValLoss: 0.1655 | ValAcc: 0.9613 | ValAUC: 0.9287 | ValF1: 0.9204 | Time: 0.07s\n", 584 | "[EarlyStop TRAINING] no improve (45/50). best=-0.932774, curr=-0.928727\n", 585 | "Epoch 088 | TrainLoss: 0.0586 | ValLoss: 0.1715 | ValAcc: 0.9587 | ValAUC: 0.9277 | ValF1: 0.9159 | Time: 0.07s\n", 586 | "[EarlyStop TRAINING] no improve (46/50). best=-0.932774, curr=-0.927682\n", 587 | "Epoch 089 | TrainLoss: 0.0570 | ValLoss: 0.1698 | ValAcc: 0.9625 | ValAUC: 0.9266 | ValF1: 0.9227 | Time: 0.07s\n", 588 | "[EarlyStop TRAINING] no improve (47/50). best=-0.932774, curr=-0.926589\n", 589 | "Epoch 090 | TrainLoss: 0.0592 | ValLoss: 0.1741 | ValAcc: 0.9563 | ValAUC: 0.9266 | ValF1: 0.9088 | Time: 0.07s\n", 590 | "[EarlyStop TRAINING] no improve (48/50). best=-0.932774, curr=-0.926577\n", 591 | "Epoch 091 | TrainLoss: 0.0547 | ValLoss: 0.1732 | ValAcc: 0.9625 | ValAUC: 0.9256 | ValF1: 0.9227 | Time: 0.07s\n", 592 | "[EarlyStop TRAINING] no improve (49/50). best=-0.932774, curr=-0.925592\n", 593 | "Epoch 092 | TrainLoss: 0.0566 | ValLoss: 0.1760 | ValAcc: 0.9587 | ValAUC: 0.9273 | ValF1: 0.9134 | Time: 0.07s\n", 594 | "[EarlyStop TRAINING] no improve (50/50). best=-0.932774, curr=-0.927274\n", 595 | "[EarlyStop TRAINING] restoring best weights and stopping.\n", 596 | "Restoring best model from epoch with val score: -0.9327977327040627\n" 597 | ] 598 | }, 599 | { 600 | "name": "stderr", 601 | "output_type": "stream", 602 | "text": [ 603 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n", 604 | "Online softmax is disabled on the fly since Inductor decides to\n", 605 | "split the reduction. Cut an issue to PyTorch if this is an\n", 606 | "important use case and you want to speed it up with online\n", 607 | "softmax.\n", 608 | "\n", 609 | " warnings.warn(\n" 610 | ] 611 | } 612 | ], 613 | "source": [ 614 | "from GRANDE import GRANDE\n", 615 | "\n", 616 | "params = {\n", 617 | " 'depth': 5,\n", 618 | " 'n_estimators': 1024,\n", 619 | "\n", 620 | " 'learning_rate_weights': 0.001,\n", 621 | " 'learning_rate_index': 0.01,\n", 622 | " 'learning_rate_values': 0.05,\n", 623 | " 'learning_rate_leaf': 0.05,\n", 624 | " 'learning_rate_embedding': 0.02,\n", 625 | "\n", 626 | " 'use_category_embeddings': False,\n", 627 | " 'embedding_dim_cat': 8,\n", 628 | " 'use_numeric_embeddings': False,\n", 629 | " 'embedding_dim_num': 8,\n", 630 | " 'embedding_threshold': 1,\n", 631 | " 'loo_cardinality': 10,\n", 632 | "\n", 633 | "\n", 634 | " 'dropout': 0.2,\n", 635 | " 'selected_variables': 0.8,\n", 636 | " 'data_subset_fraction': 1.0,\n", 637 | " 'bootstrap': False,\n", 638 | " 'missing_values': False,\n", 639 | "\n", 640 | " 'optimizer': 'adam', #nadam, radam, adamw, adam \n", 641 | " 'cosine_decay_restarts': False,\n", 642 | " 'reduce_on_plateau_scheduler': True,\n", 643 | " 'label_smoothing': 0.0,\n", 644 | " 'use_class_weights': False,\n", 645 | " 'focal_loss': False,\n", 646 | " 'swa': False,\n", 647 | " 'es_metric': True, # if True use AUC for binary, MSE for regression, val_loss for multiclass\n", 648 | "\n", 649 | "\n", 650 | " 'epochs': 250,\n", 651 | " 'batch_size': 256,\n", 652 | " 'early_stopping_epochs': 50,\n", 653 | "\n", 654 | " 'use_freq_enc': False,\n", 655 | " 'use_robust_scale_smoothing': False,\n", 656 | " 'problem_type': 'binary',\n", 657 | " \n", 658 | " 'random_seed': 42,\n", 659 | " 'verbose': 2,\n", 660 | "}\n", 661 | "\n", 662 | "model_grande = GRANDE(params=params)\n", 663 | "\n", 664 | "model_grande.fit(X=X_train,\n", 665 | " y=y_train,\n", 666 | " X_val=X_valid,\n", 667 | " y_val=y_valid)\n", 668 | "\n", 669 | "preds_grande = model_grande.predict_proba(X_test)" 670 | ] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "execution_count": 5, 675 | "id": "b251df9d-67f0-4dd1-a0cc-379621e33fad", 676 | "metadata": { 677 | "execution": { 678 | "iopub.execute_input": "2024-03-15T11:57:30.601353Z", 679 | "iopub.status.busy": "2024-03-15T11:57:30.601004Z", 680 | "iopub.status.idle": "2024-03-15T11:57:30.604147Z", 681 | "shell.execute_reply": "2024-03-15T11:57:30.603752Z", 682 | "shell.execute_reply.started": "2024-03-15T11:57:30.601338Z" 683 | } 684 | }, 685 | "outputs": [], 686 | "source": [ 687 | "def calculate_sample_weights(y_data):\n", 688 | " class_weights = sklearn.utils.class_weight.compute_class_weight(class_weight = 'balanced', classes = np.unique(y_data), y = y_data)\n", 689 | " sample_weights = sklearn.utils.class_weight.compute_sample_weight(class_weight = 'balanced', y =y_data)\n", 690 | " return sample_weights" 691 | ] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "execution_count": 6, 696 | "id": "c17e805c-0110-4e39-97d3-34573804296b", 697 | "metadata": { 698 | "execution": { 699 | "iopub.execute_input": "2024-03-15T11:57:30.605592Z", 700 | "iopub.status.busy": "2024-03-15T11:57:30.605338Z", 701 | "iopub.status.idle": "2024-03-15T11:57:30.611465Z", 702 | "shell.execute_reply": "2024-03-15T11:57:30.611074Z", 703 | "shell.execute_reply.started": "2024-03-15T11:57:30.605578Z" 704 | } 705 | }, 706 | "outputs": [], 707 | "source": [ 708 | "try:\n", 709 | " y_train = y_train.values.codes.astype(np.float64)\n", 710 | " y_valid = y_valid.values.codes.astype(np.float64)\n", 711 | " y_test = y_test.values.codes.astype(np.float64)\n", 712 | "except:\n", 713 | " y_train = y_train.values.astype(np.float64)\n", 714 | " y_valid = y_valid.values.astype(np.float64)\n", 715 | " y_test = y_test.values.astype(np.float64)" 716 | ] 717 | }, 718 | { 719 | "cell_type": "code", 720 | "execution_count": 7, 721 | "id": "cfc121e3-2d7b-4fb1-8616-0318fcc20ac1", 722 | "metadata": { 723 | "execution": { 724 | "iopub.execute_input": "2024-03-15T11:57:30.612164Z", 725 | "iopub.status.busy": "2024-03-15T11:57:30.611960Z", 726 | "iopub.status.idle": "2024-03-15T11:57:30.627515Z", 727 | "shell.execute_reply": "2024-03-15T11:57:30.627108Z", 728 | "shell.execute_reply.started": "2024-03-15T11:57:30.612151Z" 729 | } 730 | }, 731 | "outputs": [], 732 | "source": [ 733 | "binary_indices = []\n", 734 | "low_cardinality_indices = []\n", 735 | "high_cardinality_indices = []\n", 736 | "num_columns = []\n", 737 | "for column_index, column in enumerate(X_train.columns):\n", 738 | " if column_index in categorical_feature_indices or X_train.iloc[:,column_index].dtype.name == 'category' or X_train.iloc[:,column_index].dtype.name == 'object':\n", 739 | " if len(X_train.iloc[:,column_index].unique()) <= 2:\n", 740 | " binary_indices.append(column)\n", 741 | " if len(X_train.iloc[:,column_index].unique()) < 5:\n", 742 | " low_cardinality_indices.append(column)\n", 743 | " else:\n", 744 | " high_cardinality_indices.append(column)\n", 745 | " else:\n", 746 | " num_columns.append(column) \n", 747 | "cat_columns = [col for col in X_train.columns if col not in num_columns]\n", 748 | "categorical_feature_indices = [X_train.columns.get_loc(col) for col in cat_columns]\n" 749 | ] 750 | }, 751 | { 752 | "cell_type": "code", 753 | "execution_count": 8, 754 | "id": "f7e55785-bf56-4aad-9825-7c5737be5ab7", 755 | "metadata": { 756 | "execution": { 757 | "iopub.execute_input": "2024-03-15T11:57:30.628204Z", 758 | "iopub.status.busy": "2024-03-15T11:57:30.627992Z", 759 | "iopub.status.idle": "2024-03-15T11:57:30.691371Z", 760 | "shell.execute_reply": "2024-03-15T11:57:30.690745Z", 761 | "shell.execute_reply.started": "2024-03-15T11:57:30.628191Z" 762 | } 763 | }, 764 | "outputs": [], 765 | "source": [ 766 | "if len(num_columns) > 0:\n", 767 | " mean_train_num = X_train[num_columns].mean(axis=0).iloc[0]\n", 768 | " X_train[num_columns] = X_train[num_columns].fillna(mean_train_num)\n", 769 | " X_valid[num_columns] = X_valid[num_columns].fillna(mean_train_num)\n", 770 | " X_test[num_columns] = X_test[num_columns].fillna(mean_train_num)\n", 771 | "if len(cat_columns) > 0:\n", 772 | " mode_train_cat = X_train[cat_columns].mode(axis=0).iloc[0]\n", 773 | " X_train[cat_columns] = X_train[cat_columns].fillna(mode_train_cat)\n", 774 | " X_valid[cat_columns] = X_valid[cat_columns].fillna(mode_train_cat)\n", 775 | " X_test[cat_columns] = X_test[cat_columns].fillna(mode_train_cat)\n", 776 | "\n", 777 | "X_train_raw = X_train.copy()\n", 778 | "X_valid_raw = X_valid.copy()\n", 779 | "X_test_raw = X_test.copy()" 780 | ] 781 | }, 782 | { 783 | "cell_type": "code", 784 | "execution_count": 9, 785 | "id": "b5d5f00f", 786 | "metadata": { 787 | "execution": { 788 | "iopub.execute_input": "2024-03-15T11:57:30.692101Z", 789 | "iopub.status.busy": "2024-03-15T11:57:30.691967Z", 790 | "iopub.status.idle": "2024-03-15T11:57:31.276758Z", 791 | "shell.execute_reply": "2024-03-15T11:57:31.276086Z", 792 | "shell.execute_reply.started": "2024-03-15T11:57:30.692089Z" 793 | } 794 | }, 795 | "outputs": [], 796 | "source": [ 797 | "encoder_ordinal = ce.OrdinalEncoder(cols=binary_indices)\n", 798 | "encoder_ordinal.fit(X_train)\n", 799 | "X_train = encoder_ordinal.transform(X_train)\n", 800 | "X_valid = encoder_ordinal.transform(X_valid) \n", 801 | "X_test = encoder_ordinal.transform(X_test) \n", 802 | "\n", 803 | "encoder = ce.LeaveOneOutEncoder(cols=high_cardinality_indices)\n", 804 | "encoder.fit(X_train, y_train)\n", 805 | "X_train = encoder.transform(X_train)\n", 806 | "X_valid = encoder.transform(X_valid)\n", 807 | "X_test = encoder.transform(X_test)\n", 808 | "\n", 809 | "encoder = ce.OneHotEncoder(cols=low_cardinality_indices)\n", 810 | "encoder.fit(X_train)\n", 811 | "X_train = encoder.transform(X_train)\n", 812 | "X_valid = encoder.transform(X_valid)\n", 813 | "X_test = encoder.transform(X_test)\n", 814 | "\n", 815 | "X_train = X_train.astype(np.float32)\n", 816 | "X_valid = X_valid.astype(np.float32)\n", 817 | "X_test = X_test.astype(np.float32)" 818 | ] 819 | }, 820 | { 821 | "cell_type": "code", 822 | "execution_count": 10, 823 | "id": "361f7d9f-7e58-4bc1-8855-153b9e972742", 824 | "metadata": { 825 | "execution": { 826 | "iopub.execute_input": "2024-03-15T11:57:31.277594Z", 827 | "iopub.status.busy": "2024-03-15T11:57:31.277439Z", 828 | "iopub.status.idle": "2024-03-15T11:57:31.575860Z", 829 | "shell.execute_reply": "2024-03-15T11:57:31.575224Z", 830 | "shell.execute_reply.started": "2024-03-15T11:57:31.277581Z" 831 | } 832 | }, 833 | "outputs": [ 834 | { 835 | "name": "stdout", 836 | "output_type": "stream", 837 | "text": [ 838 | "[0]\tvalidation_0-logloss:0.28315\n", 839 | "[1]\tvalidation_0-logloss:0.23763\n", 840 | "[2]\tvalidation_0-logloss:0.20618\n", 841 | "[3]\tvalidation_0-logloss:0.18811\n", 842 | "[4]\tvalidation_0-logloss:0.17719\n", 843 | "[5]\tvalidation_0-logloss:0.16986\n", 844 | "[6]\tvalidation_0-logloss:0.16414\n", 845 | "[7]\tvalidation_0-logloss:0.16136\n", 846 | "[8]\tvalidation_0-logloss:0.15800\n", 847 | "[9]\tvalidation_0-logloss:0.15785\n", 848 | "[10]\tvalidation_0-logloss:0.15713\n", 849 | "[11]\tvalidation_0-logloss:0.15692\n", 850 | "[12]\tvalidation_0-logloss:0.15799\n", 851 | "[13]\tvalidation_0-logloss:0.15559\n", 852 | "[14]\tvalidation_0-logloss:0.15664\n", 853 | "[15]\tvalidation_0-logloss:0.15613\n", 854 | "[16]\tvalidation_0-logloss:0.15517\n", 855 | "[17]\tvalidation_0-logloss:0.15363\n", 856 | "[18]\tvalidation_0-logloss:0.15438\n", 857 | "[19]\tvalidation_0-logloss:0.15495\n", 858 | "[20]\tvalidation_0-logloss:0.15732\n", 859 | "[21]\tvalidation_0-logloss:0.15925\n", 860 | "[22]\tvalidation_0-logloss:0.15958\n", 861 | "[23]\tvalidation_0-logloss:0.15969\n", 862 | "[24]\tvalidation_0-logloss:0.16180\n", 863 | "[25]\tvalidation_0-logloss:0.16142\n", 864 | "[26]\tvalidation_0-logloss:0.16153\n", 865 | "[27]\tvalidation_0-logloss:0.16176\n", 866 | "[28]\tvalidation_0-logloss:0.16241\n", 867 | "[29]\tvalidation_0-logloss:0.16312\n", 868 | "[30]\tvalidation_0-logloss:0.16496\n", 869 | "[31]\tvalidation_0-logloss:0.16497\n", 870 | "[32]\tvalidation_0-logloss:0.16543\n", 871 | "[33]\tvalidation_0-logloss:0.16679\n", 872 | "[34]\tvalidation_0-logloss:0.16724\n", 873 | "[35]\tvalidation_0-logloss:0.16847\n", 874 | "[36]\tvalidation_0-logloss:0.17016\n" 875 | ] 876 | } 877 | ], 878 | "source": [ 879 | "if params['problem_type'] == 'regression':\n", 880 | " from xgboost import XGBRegressor\n", 881 | " model_xgb = XGBRegressor(n_estimators=1000, early_stopping_rounds=20)\n", 882 | " model_xgb.fit(X_train, \n", 883 | " y_train, \n", 884 | " eval_set=[(X_valid, y_valid)], \n", 885 | " )\n", 886 | " preds_xgb = model_xgb.predict(X_test)\n", 887 | "else:\n", 888 | " from xgboost import XGBClassifier\n", 889 | " model_xgb = XGBClassifier(n_estimators=1000, early_stopping_rounds=20)\n", 890 | " model_xgb.fit(X_train, \n", 891 | " y_train, \n", 892 | " #sample_weight=calculate_sample_weights(y_train), \n", 893 | " eval_set=[(X_valid, y_valid)], \n", 894 | " #sample_weight_eval_set=[calculate_sample_weights(y_valid)]\n", 895 | " )\n", 896 | "\n", 897 | "\n", 898 | " preds_xgb = model_xgb.predict_proba(X_test)" 899 | ] 900 | }, 901 | { 902 | "cell_type": "code", 903 | "execution_count": 11, 904 | "id": "6b23c67e-cdde-40c2-842a-61383e1a62d0", 905 | "metadata": { 906 | "execution": { 907 | "iopub.execute_input": "2024-03-15T11:57:31.576940Z", 908 | "iopub.status.busy": "2024-03-15T11:57:31.576558Z", 909 | "iopub.status.idle": "2024-03-15T11:57:52.427379Z", 910 | "shell.execute_reply": "2024-03-15T11:57:52.426427Z", 911 | "shell.execute_reply.started": "2024-03-15T11:57:31.576923Z" 912 | } 913 | }, 914 | "outputs": [ 915 | { 916 | "name": "stdout", 917 | "output_type": "stream", 918 | "text": [ 919 | "Learning rate set to 0.042236\n", 920 | "0:\tlearn: 0.6403699\ttest: 0.6408011\tbest: 0.6408011 (0)\ttotal: 53.5ms\tremaining: 53.4s\n", 921 | "1:\tlearn: 0.6121572\ttest: 0.6148520\tbest: 0.6148520 (1)\ttotal: 56.1ms\tremaining: 28s\n", 922 | "2:\tlearn: 0.5688593\ttest: 0.5737767\tbest: 0.5737767 (2)\ttotal: 62ms\tremaining: 20.6s\n", 923 | "3:\tlearn: 0.5353572\ttest: 0.5418521\tbest: 0.5418521 (3)\ttotal: 65.5ms\tremaining: 16.3s\n", 924 | "4:\tlearn: 0.5080383\ttest: 0.5151234\tbest: 0.5151234 (4)\ttotal: 69.9ms\tremaining: 13.9s\n", 925 | "5:\tlearn: 0.4885596\ttest: 0.4971972\tbest: 0.4971972 (5)\ttotal: 74.7ms\tremaining: 12.4s\n", 926 | "6:\tlearn: 0.4654316\ttest: 0.4759393\tbest: 0.4759393 (6)\ttotal: 79.2ms\tremaining: 11.2s\n", 927 | "7:\tlearn: 0.4445447\ttest: 0.4561330\tbest: 0.4561330 (7)\ttotal: 82.5ms\tremaining: 10.2s\n", 928 | "8:\tlearn: 0.4288701\ttest: 0.4418229\tbest: 0.4418229 (8)\ttotal: 85.3ms\tremaining: 9.39s\n", 929 | "9:\tlearn: 0.4052309\ttest: 0.4181304\tbest: 0.4181304 (9)\ttotal: 88.3ms\tremaining: 8.74s\n", 930 | "10:\tlearn: 0.3831187\ttest: 0.3959788\tbest: 0.3959788 (10)\ttotal: 94.1ms\tremaining: 8.46s\n", 931 | "11:\tlearn: 0.3703703\ttest: 0.3838890\tbest: 0.3838890 (11)\ttotal: 97ms\tremaining: 7.99s\n", 932 | "12:\tlearn: 0.3566437\ttest: 0.3710660\tbest: 0.3710660 (12)\ttotal: 102ms\tremaining: 7.74s\n", 933 | "13:\tlearn: 0.3436633\ttest: 0.3582671\tbest: 0.3582671 (13)\ttotal: 105ms\tremaining: 7.42s\n", 934 | "14:\tlearn: 0.3299345\ttest: 0.3452403\tbest: 0.3452403 (14)\ttotal: 109ms\tremaining: 7.17s\n", 935 | "15:\tlearn: 0.3192912\ttest: 0.3344290\tbest: 0.3344290 (15)\ttotal: 112ms\tremaining: 6.91s\n", 936 | "16:\tlearn: 0.3104381\ttest: 0.3260456\tbest: 0.3260456 (16)\ttotal: 117ms\tremaining: 6.74s\n", 937 | "17:\tlearn: 0.3004882\ttest: 0.3165658\tbest: 0.3165658 (17)\ttotal: 122ms\tremaining: 6.67s\n", 938 | "18:\tlearn: 0.2912218\ttest: 0.3075181\tbest: 0.3075181 (18)\ttotal: 127ms\tremaining: 6.54s\n", 939 | "19:\tlearn: 0.2848872\ttest: 0.3029670\tbest: 0.3029670 (19)\ttotal: 131ms\tremaining: 6.42s\n", 940 | "20:\tlearn: 0.2819334\ttest: 0.3005996\tbest: 0.3005996 (20)\ttotal: 135ms\tremaining: 6.3s\n", 941 | "21:\tlearn: 0.2745152\ttest: 0.2935901\tbest: 0.2935901 (21)\ttotal: 140ms\tremaining: 6.21s\n", 942 | "22:\tlearn: 0.2719255\ttest: 0.2916018\tbest: 0.2916018 (22)\ttotal: 142ms\tremaining: 6.05s\n", 943 | "23:\tlearn: 0.2654353\ttest: 0.2855682\tbest: 0.2855682 (23)\ttotal: 146ms\tremaining: 5.93s\n", 944 | "24:\tlearn: 0.2590030\ttest: 0.2792839\tbest: 0.2792839 (24)\ttotal: 150ms\tremaining: 5.86s\n", 945 | "25:\tlearn: 0.2536191\ttest: 0.2739589\tbest: 0.2739589 (25)\ttotal: 154ms\tremaining: 5.75s\n", 946 | "26:\tlearn: 0.2493147\ttest: 0.2699481\tbest: 0.2699481 (26)\ttotal: 157ms\tremaining: 5.65s\n", 947 | "27:\tlearn: 0.2445817\ttest: 0.2658077\tbest: 0.2658077 (27)\ttotal: 160ms\tremaining: 5.56s\n", 948 | "28:\tlearn: 0.2389872\ttest: 0.2601796\tbest: 0.2601796 (28)\ttotal: 163ms\tremaining: 5.46s\n", 949 | "29:\tlearn: 0.2343432\ttest: 0.2557029\tbest: 0.2557029 (29)\ttotal: 166ms\tremaining: 5.38s\n", 950 | "30:\tlearn: 0.2306654\ttest: 0.2521669\tbest: 0.2521669 (30)\ttotal: 169ms\tremaining: 5.29s\n", 951 | "31:\tlearn: 0.2266281\ttest: 0.2486322\tbest: 0.2486322 (31)\ttotal: 172ms\tremaining: 5.21s\n", 952 | "32:\tlearn: 0.2238766\ttest: 0.2457850\tbest: 0.2457850 (32)\ttotal: 175ms\tremaining: 5.12s\n", 953 | "33:\tlearn: 0.2197826\ttest: 0.2419484\tbest: 0.2419484 (33)\ttotal: 178ms\tremaining: 5.05s\n", 954 | "34:\tlearn: 0.2164592\ttest: 0.2386139\tbest: 0.2386139 (34)\ttotal: 181ms\tremaining: 4.99s\n", 955 | "35:\tlearn: 0.2126253\ttest: 0.2345415\tbest: 0.2345415 (35)\ttotal: 184ms\tremaining: 4.93s\n", 956 | "36:\tlearn: 0.2091321\ttest: 0.2315528\tbest: 0.2315528 (36)\ttotal: 187ms\tremaining: 4.87s\n", 957 | "37:\tlearn: 0.2062252\ttest: 0.2289329\tbest: 0.2289329 (37)\ttotal: 190ms\tremaining: 4.8s\n", 958 | "38:\tlearn: 0.2037497\ttest: 0.2266702\tbest: 0.2266702 (38)\ttotal: 193ms\tremaining: 4.75s\n", 959 | "39:\tlearn: 0.2009984\ttest: 0.2239214\tbest: 0.2239214 (39)\ttotal: 196ms\tremaining: 4.69s\n", 960 | "40:\tlearn: 0.1989320\ttest: 0.2221738\tbest: 0.2221738 (40)\ttotal: 198ms\tremaining: 4.64s\n", 961 | "41:\tlearn: 0.1964169\ttest: 0.2201016\tbest: 0.2201016 (41)\ttotal: 201ms\tremaining: 4.59s\n", 962 | "42:\tlearn: 0.1946265\ttest: 0.2182654\tbest: 0.2182654 (42)\ttotal: 204ms\tremaining: 4.54s\n", 963 | "43:\tlearn: 0.1925528\ttest: 0.2164779\tbest: 0.2164779 (43)\ttotal: 207ms\tremaining: 4.5s\n", 964 | "44:\tlearn: 0.1906966\ttest: 0.2145020\tbest: 0.2145020 (44)\ttotal: 210ms\tremaining: 4.46s\n", 965 | "45:\tlearn: 0.1884299\ttest: 0.2120531\tbest: 0.2120531 (45)\ttotal: 213ms\tremaining: 4.42s\n", 966 | "46:\tlearn: 0.1871897\ttest: 0.2111221\tbest: 0.2111221 (46)\ttotal: 216ms\tremaining: 4.38s\n", 967 | "47:\tlearn: 0.1859624\ttest: 0.2102293\tbest: 0.2102293 (47)\ttotal: 219ms\tremaining: 4.35s\n", 968 | "48:\tlearn: 0.1844340\ttest: 0.2088910\tbest: 0.2088910 (48)\ttotal: 223ms\tremaining: 4.33s\n", 969 | "49:\tlearn: 0.1823357\ttest: 0.2068230\tbest: 0.2068230 (49)\ttotal: 226ms\tremaining: 4.29s\n", 970 | "50:\tlearn: 0.1806636\ttest: 0.2050706\tbest: 0.2050706 (50)\ttotal: 229ms\tremaining: 4.26s\n", 971 | "51:\tlearn: 0.1787616\ttest: 0.2029668\tbest: 0.2029668 (51)\ttotal: 231ms\tremaining: 4.22s\n", 972 | "52:\tlearn: 0.1773778\ttest: 0.2014285\tbest: 0.2014285 (52)\ttotal: 236ms\tremaining: 4.21s\n", 973 | "53:\tlearn: 0.1755366\ttest: 0.1996299\tbest: 0.1996299 (53)\ttotal: 239ms\tremaining: 4.18s\n", 974 | "54:\tlearn: 0.1741232\ttest: 0.1984643\tbest: 0.1984643 (54)\ttotal: 242ms\tremaining: 4.15s\n", 975 | "55:\tlearn: 0.1725883\ttest: 0.1966571\tbest: 0.1966571 (55)\ttotal: 245ms\tremaining: 4.12s\n", 976 | "56:\tlearn: 0.1710051\ttest: 0.1955276\tbest: 0.1955276 (56)\ttotal: 248ms\tremaining: 4.1s\n", 977 | "57:\tlearn: 0.1698981\ttest: 0.1942014\tbest: 0.1942014 (57)\ttotal: 251ms\tremaining: 4.08s\n", 978 | "58:\tlearn: 0.1684351\ttest: 0.1929899\tbest: 0.1929899 (58)\ttotal: 254ms\tremaining: 4.06s\n", 979 | "59:\tlearn: 0.1673545\ttest: 0.1921834\tbest: 0.1921834 (59)\ttotal: 258ms\tremaining: 4.04s\n", 980 | "60:\tlearn: 0.1661590\ttest: 0.1911501\tbest: 0.1911501 (60)\ttotal: 261ms\tremaining: 4.02s\n", 981 | "61:\tlearn: 0.1650260\ttest: 0.1899028\tbest: 0.1899028 (61)\ttotal: 264ms\tremaining: 4s\n", 982 | "62:\tlearn: 0.1638037\ttest: 0.1885839\tbest: 0.1885839 (62)\ttotal: 268ms\tremaining: 3.99s\n", 983 | "63:\tlearn: 0.1624676\ttest: 0.1877544\tbest: 0.1877544 (63)\ttotal: 272ms\tremaining: 3.98s\n", 984 | "64:\tlearn: 0.1610981\ttest: 0.1865590\tbest: 0.1865590 (64)\ttotal: 275ms\tremaining: 3.96s\n", 985 | "65:\tlearn: 0.1604604\ttest: 0.1865760\tbest: 0.1865590 (64)\ttotal: 278ms\tremaining: 3.94s\n", 986 | "66:\tlearn: 0.1589032\ttest: 0.1855309\tbest: 0.1855309 (66)\ttotal: 282ms\tremaining: 3.93s\n", 987 | "67:\tlearn: 0.1581581\ttest: 0.1847642\tbest: 0.1847642 (67)\ttotal: 285ms\tremaining: 3.91s\n", 988 | "68:\tlearn: 0.1572876\ttest: 0.1842998\tbest: 0.1842998 (68)\ttotal: 288ms\tremaining: 3.89s\n", 989 | "69:\tlearn: 0.1561982\ttest: 0.1833719\tbest: 0.1833719 (69)\ttotal: 291ms\tremaining: 3.87s\n", 990 | "70:\tlearn: 0.1550831\ttest: 0.1825030\tbest: 0.1825030 (70)\ttotal: 296ms\tremaining: 3.87s\n", 991 | "71:\tlearn: 0.1542308\ttest: 0.1817230\tbest: 0.1817230 (71)\ttotal: 298ms\tremaining: 3.85s\n", 992 | "72:\tlearn: 0.1535367\ttest: 0.1812467\tbest: 0.1812467 (72)\ttotal: 302ms\tremaining: 3.83s\n", 993 | "73:\tlearn: 0.1529628\ttest: 0.1810243\tbest: 0.1810243 (73)\ttotal: 305ms\tremaining: 3.81s\n", 994 | "74:\tlearn: 0.1520269\ttest: 0.1802059\tbest: 0.1802059 (74)\ttotal: 308ms\tremaining: 3.8s\n", 995 | "75:\tlearn: 0.1508695\ttest: 0.1787649\tbest: 0.1787649 (75)\ttotal: 311ms\tremaining: 3.78s\n", 996 | "76:\tlearn: 0.1498385\ttest: 0.1776925\tbest: 0.1776925 (76)\ttotal: 315ms\tremaining: 3.77s\n", 997 | "77:\tlearn: 0.1490603\ttest: 0.1768350\tbest: 0.1768350 (77)\ttotal: 318ms\tremaining: 3.76s\n", 998 | "78:\tlearn: 0.1482503\ttest: 0.1761571\tbest: 0.1761571 (78)\ttotal: 321ms\tremaining: 3.75s\n", 999 | "79:\tlearn: 0.1473683\ttest: 0.1753244\tbest: 0.1753244 (79)\ttotal: 324ms\tremaining: 3.73s\n", 1000 | "80:\tlearn: 0.1466100\ttest: 0.1745125\tbest: 0.1745125 (80)\ttotal: 327ms\tremaining: 3.71s\n", 1001 | "81:\tlearn: 0.1462307\ttest: 0.1742048\tbest: 0.1742048 (81)\ttotal: 330ms\tremaining: 3.69s\n", 1002 | "82:\tlearn: 0.1460834\ttest: 0.1743457\tbest: 0.1742048 (81)\ttotal: 333ms\tremaining: 3.67s\n", 1003 | "83:\tlearn: 0.1453986\ttest: 0.1739246\tbest: 0.1739246 (83)\ttotal: 335ms\tremaining: 3.66s\n", 1004 | "84:\tlearn: 0.1445232\ttest: 0.1731785\tbest: 0.1731785 (84)\ttotal: 338ms\tremaining: 3.64s\n", 1005 | "85:\tlearn: 0.1439600\ttest: 0.1732023\tbest: 0.1731785 (84)\ttotal: 342ms\tremaining: 3.63s\n", 1006 | "86:\tlearn: 0.1433421\ttest: 0.1728107\tbest: 0.1728107 (86)\ttotal: 345ms\tremaining: 3.62s\n", 1007 | "87:\tlearn: 0.1427270\ttest: 0.1723026\tbest: 0.1723026 (87)\ttotal: 348ms\tremaining: 3.6s\n", 1008 | "88:\tlearn: 0.1420955\ttest: 0.1717022\tbest: 0.1717022 (88)\ttotal: 351ms\tremaining: 3.59s\n", 1009 | "89:\tlearn: 0.1415378\ttest: 0.1715440\tbest: 0.1715440 (89)\ttotal: 354ms\tremaining: 3.58s\n", 1010 | "90:\tlearn: 0.1412110\ttest: 0.1715418\tbest: 0.1715418 (90)\ttotal: 357ms\tremaining: 3.56s\n", 1011 | "91:\tlearn: 0.1404274\ttest: 0.1708362\tbest: 0.1708362 (91)\ttotal: 360ms\tremaining: 3.55s\n", 1012 | "92:\tlearn: 0.1397125\ttest: 0.1707287\tbest: 0.1707287 (92)\ttotal: 364ms\tremaining: 3.55s\n", 1013 | "93:\tlearn: 0.1393044\ttest: 0.1705506\tbest: 0.1705506 (93)\ttotal: 367ms\tremaining: 3.53s\n", 1014 | "94:\tlearn: 0.1387496\ttest: 0.1699821\tbest: 0.1699821 (94)\ttotal: 369ms\tremaining: 3.51s\n", 1015 | "95:\tlearn: 0.1381773\ttest: 0.1697968\tbest: 0.1697968 (95)\ttotal: 372ms\tremaining: 3.5s\n", 1016 | "96:\tlearn: 0.1376471\ttest: 0.1693315\tbest: 0.1693315 (96)\ttotal: 375ms\tremaining: 3.49s\n", 1017 | "97:\tlearn: 0.1372260\ttest: 0.1694167\tbest: 0.1693315 (96)\ttotal: 378ms\tremaining: 3.48s\n", 1018 | "98:\tlearn: 0.1369108\ttest: 0.1691435\tbest: 0.1691435 (98)\ttotal: 380ms\tremaining: 3.46s\n", 1019 | "99:\tlearn: 0.1365988\ttest: 0.1690779\tbest: 0.1690779 (99)\ttotal: 383ms\tremaining: 3.45s\n", 1020 | "100:\tlearn: 0.1361507\ttest: 0.1685021\tbest: 0.1685021 (100)\ttotal: 386ms\tremaining: 3.44s\n", 1021 | "101:\tlearn: 0.1354671\ttest: 0.1681311\tbest: 0.1681311 (101)\ttotal: 389ms\tremaining: 3.42s\n", 1022 | "102:\tlearn: 0.1348790\ttest: 0.1677541\tbest: 0.1677541 (102)\ttotal: 393ms\tremaining: 3.42s\n", 1023 | "103:\tlearn: 0.1342635\ttest: 0.1669889\tbest: 0.1669889 (103)\ttotal: 395ms\tremaining: 3.41s\n", 1024 | "104:\tlearn: 0.1339653\ttest: 0.1668001\tbest: 0.1668001 (104)\ttotal: 400ms\tremaining: 3.41s\n", 1025 | "105:\tlearn: 0.1333219\ttest: 0.1665222\tbest: 0.1665222 (105)\ttotal: 403ms\tremaining: 3.4s\n", 1026 | "106:\tlearn: 0.1327973\ttest: 0.1662342\tbest: 0.1662342 (106)\ttotal: 408ms\tremaining: 3.4s\n", 1027 | "107:\tlearn: 0.1322664\ttest: 0.1661034\tbest: 0.1661034 (107)\ttotal: 410ms\tremaining: 3.39s\n", 1028 | "108:\tlearn: 0.1317347\ttest: 0.1658443\tbest: 0.1658443 (108)\ttotal: 413ms\tremaining: 3.38s\n", 1029 | "109:\tlearn: 0.1310987\ttest: 0.1652628\tbest: 0.1652628 (109)\ttotal: 416ms\tremaining: 3.37s\n", 1030 | "110:\tlearn: 0.1305917\ttest: 0.1649934\tbest: 0.1649934 (110)\ttotal: 419ms\tremaining: 3.36s\n", 1031 | "111:\tlearn: 0.1305812\ttest: 0.1650006\tbest: 0.1649934 (110)\ttotal: 422ms\tremaining: 3.34s\n", 1032 | "112:\tlearn: 0.1301225\ttest: 0.1647297\tbest: 0.1647297 (112)\ttotal: 425ms\tremaining: 3.34s\n", 1033 | "113:\tlearn: 0.1296899\ttest: 0.1643031\tbest: 0.1643031 (113)\ttotal: 428ms\tremaining: 3.32s\n", 1034 | "114:\tlearn: 0.1294766\ttest: 0.1642072\tbest: 0.1642072 (114)\ttotal: 431ms\tremaining: 3.32s\n", 1035 | "115:\tlearn: 0.1291889\ttest: 0.1640454\tbest: 0.1640454 (115)\ttotal: 434ms\tremaining: 3.3s\n", 1036 | "116:\tlearn: 0.1291278\ttest: 0.1640400\tbest: 0.1640400 (116)\ttotal: 436ms\tremaining: 3.29s\n", 1037 | "117:\tlearn: 0.1288480\ttest: 0.1639734\tbest: 0.1639734 (117)\ttotal: 439ms\tremaining: 3.28s\n", 1038 | "118:\tlearn: 0.1286497\ttest: 0.1639575\tbest: 0.1639575 (118)\ttotal: 442ms\tremaining: 3.27s\n", 1039 | "119:\tlearn: 0.1282663\ttest: 0.1637441\tbest: 0.1637441 (119)\ttotal: 445ms\tremaining: 3.27s\n", 1040 | "120:\tlearn: 0.1281113\ttest: 0.1637939\tbest: 0.1637441 (119)\ttotal: 449ms\tremaining: 3.26s\n", 1041 | "121:\tlearn: 0.1276841\ttest: 0.1633714\tbest: 0.1633714 (121)\ttotal: 452ms\tremaining: 3.25s\n", 1042 | "122:\tlearn: 0.1274058\ttest: 0.1633725\tbest: 0.1633714 (121)\ttotal: 455ms\tremaining: 3.24s\n", 1043 | "123:\tlearn: 0.1269155\ttest: 0.1634801\tbest: 0.1633714 (121)\ttotal: 458ms\tremaining: 3.23s\n", 1044 | "124:\tlearn: 0.1268380\ttest: 0.1635373\tbest: 0.1633714 (121)\ttotal: 462ms\tremaining: 3.23s\n", 1045 | "125:\tlearn: 0.1267048\ttest: 0.1635890\tbest: 0.1633714 (121)\ttotal: 465ms\tremaining: 3.23s\n", 1046 | "126:\tlearn: 0.1262448\ttest: 0.1629872\tbest: 0.1629872 (126)\ttotal: 470ms\tremaining: 3.23s\n", 1047 | "127:\tlearn: 0.1258275\ttest: 0.1625769\tbest: 0.1625769 (127)\ttotal: 473ms\tremaining: 3.22s\n", 1048 | "128:\tlearn: 0.1253859\ttest: 0.1621351\tbest: 0.1621351 (128)\ttotal: 477ms\tremaining: 3.22s\n", 1049 | "129:\tlearn: 0.1249703\ttest: 0.1619379\tbest: 0.1619379 (129)\ttotal: 481ms\tremaining: 3.22s\n", 1050 | "130:\tlearn: 0.1247342\ttest: 0.1616268\tbest: 0.1616268 (130)\ttotal: 485ms\tremaining: 3.22s\n", 1051 | "131:\tlearn: 0.1244258\ttest: 0.1614230\tbest: 0.1614230 (131)\ttotal: 488ms\tremaining: 3.21s\n", 1052 | "132:\tlearn: 0.1236921\ttest: 0.1611347\tbest: 0.1611347 (132)\ttotal: 492ms\tremaining: 3.2s\n", 1053 | "133:\tlearn: 0.1234099\ttest: 0.1610724\tbest: 0.1610724 (133)\ttotal: 495ms\tremaining: 3.2s\n", 1054 | "134:\tlearn: 0.1233218\ttest: 0.1609980\tbest: 0.1609980 (134)\ttotal: 497ms\tremaining: 3.19s\n", 1055 | "135:\tlearn: 0.1231105\ttest: 0.1609405\tbest: 0.1609405 (135)\ttotal: 501ms\tremaining: 3.18s\n", 1056 | "136:\tlearn: 0.1230489\ttest: 0.1608900\tbest: 0.1608900 (136)\ttotal: 503ms\tremaining: 3.17s\n", 1057 | "137:\tlearn: 0.1224238\ttest: 0.1606435\tbest: 0.1606435 (137)\ttotal: 507ms\tremaining: 3.17s\n", 1058 | "138:\tlearn: 0.1221154\ttest: 0.1606038\tbest: 0.1606038 (138)\ttotal: 510ms\tremaining: 3.16s\n", 1059 | "139:\tlearn: 0.1218823\ttest: 0.1604324\tbest: 0.1604324 (139)\ttotal: 514ms\tremaining: 3.16s\n", 1060 | "140:\tlearn: 0.1217761\ttest: 0.1602990\tbest: 0.1602990 (140)\ttotal: 517ms\tremaining: 3.15s\n", 1061 | "141:\tlearn: 0.1215440\ttest: 0.1601603\tbest: 0.1601603 (141)\ttotal: 520ms\tremaining: 3.14s\n", 1062 | "142:\tlearn: 0.1212282\ttest: 0.1599872\tbest: 0.1599872 (142)\ttotal: 523ms\tremaining: 3.14s\n", 1063 | "143:\tlearn: 0.1210965\ttest: 0.1599391\tbest: 0.1599391 (143)\ttotal: 526ms\tremaining: 3.13s\n", 1064 | "144:\tlearn: 0.1206828\ttest: 0.1596531\tbest: 0.1596531 (144)\ttotal: 530ms\tremaining: 3.12s\n", 1065 | "145:\tlearn: 0.1203065\ttest: 0.1597038\tbest: 0.1596531 (144)\ttotal: 533ms\tremaining: 3.12s\n", 1066 | "146:\tlearn: 0.1201046\ttest: 0.1594982\tbest: 0.1594982 (146)\ttotal: 536ms\tremaining: 3.11s\n", 1067 | "147:\tlearn: 0.1199300\ttest: 0.1595389\tbest: 0.1594982 (146)\ttotal: 540ms\tremaining: 3.1s\n", 1068 | "148:\tlearn: 0.1195486\ttest: 0.1594501\tbest: 0.1594501 (148)\ttotal: 542ms\tremaining: 3.1s\n", 1069 | "149:\tlearn: 0.1190920\ttest: 0.1597642\tbest: 0.1594501 (148)\ttotal: 546ms\tremaining: 3.09s\n", 1070 | "150:\tlearn: 0.1185979\ttest: 0.1595261\tbest: 0.1594501 (148)\ttotal: 548ms\tremaining: 3.08s\n", 1071 | "151:\tlearn: 0.1183069\ttest: 0.1592414\tbest: 0.1592414 (151)\ttotal: 551ms\tremaining: 3.07s\n", 1072 | "152:\tlearn: 0.1181374\ttest: 0.1592164\tbest: 0.1592164 (152)\ttotal: 554ms\tremaining: 3.06s\n", 1073 | "153:\tlearn: 0.1177571\ttest: 0.1589785\tbest: 0.1589785 (153)\ttotal: 557ms\tremaining: 3.06s\n", 1074 | "154:\tlearn: 0.1173531\ttest: 0.1588217\tbest: 0.1588217 (154)\ttotal: 561ms\tremaining: 3.06s\n", 1075 | "155:\tlearn: 0.1170582\ttest: 0.1587787\tbest: 0.1587787 (155)\ttotal: 563ms\tremaining: 3.05s\n", 1076 | "156:\tlearn: 0.1167705\ttest: 0.1586385\tbest: 0.1586385 (156)\ttotal: 569ms\tremaining: 3.05s\n", 1077 | "157:\tlearn: 0.1164068\ttest: 0.1585486\tbest: 0.1585486 (157)\ttotal: 572ms\tremaining: 3.04s\n", 1078 | "158:\tlearn: 0.1163820\ttest: 0.1585151\tbest: 0.1585151 (158)\ttotal: 574ms\tremaining: 3.03s\n", 1079 | "159:\tlearn: 0.1162444\ttest: 0.1585091\tbest: 0.1585091 (159)\ttotal: 577ms\tremaining: 3.03s\n", 1080 | "160:\tlearn: 0.1161463\ttest: 0.1584859\tbest: 0.1584859 (160)\ttotal: 579ms\tremaining: 3.02s\n", 1081 | "161:\tlearn: 0.1157391\ttest: 0.1580840\tbest: 0.1580840 (161)\ttotal: 582ms\tremaining: 3.01s\n", 1082 | "162:\tlearn: 0.1156908\ttest: 0.1580406\tbest: 0.1580406 (162)\ttotal: 586ms\tremaining: 3.01s\n", 1083 | "163:\tlearn: 0.1153443\ttest: 0.1578388\tbest: 0.1578388 (163)\ttotal: 588ms\tremaining: 3s\n", 1084 | "164:\tlearn: 0.1150407\ttest: 0.1576972\tbest: 0.1576972 (164)\ttotal: 591ms\tremaining: 2.99s\n", 1085 | "165:\tlearn: 0.1148386\ttest: 0.1577687\tbest: 0.1576972 (164)\ttotal: 595ms\tremaining: 2.99s\n", 1086 | "166:\tlearn: 0.1143723\ttest: 0.1572801\tbest: 0.1572801 (166)\ttotal: 599ms\tremaining: 2.99s\n", 1087 | "167:\tlearn: 0.1141279\ttest: 0.1574169\tbest: 0.1572801 (166)\ttotal: 602ms\tremaining: 2.98s\n", 1088 | "168:\tlearn: 0.1139445\ttest: 0.1573275\tbest: 0.1572801 (166)\ttotal: 605ms\tremaining: 2.97s\n", 1089 | "169:\tlearn: 0.1135710\ttest: 0.1570117\tbest: 0.1570117 (169)\ttotal: 608ms\tremaining: 2.97s\n", 1090 | "170:\tlearn: 0.1134066\ttest: 0.1569465\tbest: 0.1569465 (170)\ttotal: 612ms\tremaining: 2.96s\n", 1091 | "171:\tlearn: 0.1132348\ttest: 0.1569808\tbest: 0.1569465 (170)\ttotal: 615ms\tremaining: 2.96s\n", 1092 | "172:\tlearn: 0.1129800\ttest: 0.1570577\tbest: 0.1569465 (170)\ttotal: 618ms\tremaining: 2.95s\n", 1093 | "173:\tlearn: 0.1126281\ttest: 0.1568818\tbest: 0.1568818 (173)\ttotal: 623ms\tremaining: 2.96s\n", 1094 | "174:\tlearn: 0.1124995\ttest: 0.1568308\tbest: 0.1568308 (174)\ttotal: 625ms\tremaining: 2.95s\n", 1095 | "175:\tlearn: 0.1123167\ttest: 0.1567765\tbest: 0.1567765 (175)\ttotal: 628ms\tremaining: 2.94s\n", 1096 | "176:\tlearn: 0.1122578\ttest: 0.1568048\tbest: 0.1567765 (175)\ttotal: 633ms\tremaining: 2.94s\n", 1097 | "177:\tlearn: 0.1120069\ttest: 0.1567596\tbest: 0.1567596 (177)\ttotal: 636ms\tremaining: 2.94s\n", 1098 | "178:\tlearn: 0.1118183\ttest: 0.1567967\tbest: 0.1567596 (177)\ttotal: 640ms\tremaining: 2.93s\n", 1099 | "179:\tlearn: 0.1116753\ttest: 0.1566245\tbest: 0.1566245 (179)\ttotal: 645ms\tremaining: 2.94s\n", 1100 | "180:\tlearn: 0.1114712\ttest: 0.1565494\tbest: 0.1565494 (180)\ttotal: 648ms\tremaining: 2.93s\n", 1101 | "181:\tlearn: 0.1110647\ttest: 0.1563492\tbest: 0.1563492 (181)\ttotal: 653ms\tremaining: 2.94s\n", 1102 | "182:\tlearn: 0.1109584\ttest: 0.1563476\tbest: 0.1563476 (182)\ttotal: 656ms\tremaining: 2.93s\n", 1103 | "183:\tlearn: 0.1108534\ttest: 0.1563045\tbest: 0.1563045 (183)\ttotal: 661ms\tremaining: 2.93s\n", 1104 | "184:\tlearn: 0.1104813\ttest: 0.1561775\tbest: 0.1561775 (184)\ttotal: 664ms\tremaining: 2.92s\n", 1105 | "185:\tlearn: 0.1100818\ttest: 0.1557271\tbest: 0.1557271 (185)\ttotal: 666ms\tremaining: 2.92s\n", 1106 | "186:\tlearn: 0.1099489\ttest: 0.1557738\tbest: 0.1557271 (185)\ttotal: 669ms\tremaining: 2.91s\n", 1107 | "187:\tlearn: 0.1097746\ttest: 0.1559275\tbest: 0.1557271 (185)\ttotal: 672ms\tremaining: 2.9s\n", 1108 | "188:\tlearn: 0.1094178\ttest: 0.1557098\tbest: 0.1557098 (188)\ttotal: 676ms\tremaining: 2.9s\n", 1109 | "189:\tlearn: 0.1091134\ttest: 0.1553536\tbest: 0.1553536 (189)\ttotal: 678ms\tremaining: 2.89s\n", 1110 | "190:\tlearn: 0.1089026\ttest: 0.1553270\tbest: 0.1553270 (190)\ttotal: 681ms\tremaining: 2.88s\n", 1111 | "191:\tlearn: 0.1088370\ttest: 0.1552442\tbest: 0.1552442 (191)\ttotal: 684ms\tremaining: 2.88s\n", 1112 | "192:\tlearn: 0.1085390\ttest: 0.1552017\tbest: 0.1552017 (192)\ttotal: 687ms\tremaining: 2.87s\n", 1113 | "193:\tlearn: 0.1083417\ttest: 0.1552433\tbest: 0.1552017 (192)\ttotal: 689ms\tremaining: 2.86s\n", 1114 | "194:\tlearn: 0.1081520\ttest: 0.1553490\tbest: 0.1552017 (192)\ttotal: 693ms\tremaining: 2.86s\n", 1115 | "195:\tlearn: 0.1078712\ttest: 0.1555736\tbest: 0.1552017 (192)\ttotal: 696ms\tremaining: 2.85s\n", 1116 | "196:\tlearn: 0.1074892\ttest: 0.1556243\tbest: 0.1552017 (192)\ttotal: 699ms\tremaining: 2.85s\n", 1117 | "197:\tlearn: 0.1074879\ttest: 0.1556146\tbest: 0.1552017 (192)\ttotal: 701ms\tremaining: 2.84s\n", 1118 | "198:\tlearn: 0.1073743\ttest: 0.1555928\tbest: 0.1552017 (192)\ttotal: 703ms\tremaining: 2.83s\n", 1119 | "199:\tlearn: 0.1071665\ttest: 0.1554248\tbest: 0.1552017 (192)\ttotal: 706ms\tremaining: 2.82s\n", 1120 | "200:\tlearn: 0.1070061\ttest: 0.1554337\tbest: 0.1552017 (192)\ttotal: 709ms\tremaining: 2.82s\n", 1121 | "201:\tlearn: 0.1067106\ttest: 0.1552604\tbest: 0.1552017 (192)\ttotal: 712ms\tremaining: 2.81s\n", 1122 | "202:\tlearn: 0.1061867\ttest: 0.1547731\tbest: 0.1547731 (202)\ttotal: 715ms\tremaining: 2.81s\n", 1123 | "203:\tlearn: 0.1061642\ttest: 0.1547586\tbest: 0.1547586 (203)\ttotal: 716ms\tremaining: 2.79s\n", 1124 | "204:\tlearn: 0.1058801\ttest: 0.1544104\tbest: 0.1544104 (204)\ttotal: 719ms\tremaining: 2.79s\n", 1125 | "205:\tlearn: 0.1056995\ttest: 0.1543959\tbest: 0.1543959 (205)\ttotal: 722ms\tremaining: 2.78s\n", 1126 | "206:\tlearn: 0.1053829\ttest: 0.1540443\tbest: 0.1540443 (206)\ttotal: 725ms\tremaining: 2.78s\n", 1127 | "207:\tlearn: 0.1050172\ttest: 0.1537433\tbest: 0.1537433 (207)\ttotal: 730ms\tremaining: 2.78s\n", 1128 | "208:\tlearn: 0.1046315\ttest: 0.1538177\tbest: 0.1537433 (207)\ttotal: 735ms\tremaining: 2.78s\n", 1129 | "209:\tlearn: 0.1042301\ttest: 0.1540596\tbest: 0.1537433 (207)\ttotal: 738ms\tremaining: 2.78s\n", 1130 | "210:\tlearn: 0.1038595\ttest: 0.1542492\tbest: 0.1537433 (207)\ttotal: 741ms\tremaining: 2.77s\n", 1131 | "211:\tlearn: 0.1038453\ttest: 0.1542450\tbest: 0.1537433 (207)\ttotal: 744ms\tremaining: 2.76s\n", 1132 | "212:\tlearn: 0.1037689\ttest: 0.1542358\tbest: 0.1537433 (207)\ttotal: 748ms\tremaining: 2.76s\n", 1133 | "213:\tlearn: 0.1035594\ttest: 0.1541500\tbest: 0.1537433 (207)\ttotal: 752ms\tremaining: 2.76s\n", 1134 | "214:\tlearn: 0.1032968\ttest: 0.1541669\tbest: 0.1537433 (207)\ttotal: 756ms\tremaining: 2.76s\n", 1135 | "215:\tlearn: 0.1031981\ttest: 0.1541580\tbest: 0.1537433 (207)\ttotal: 759ms\tremaining: 2.75s\n", 1136 | "216:\tlearn: 0.1029143\ttest: 0.1540810\tbest: 0.1537433 (207)\ttotal: 763ms\tremaining: 2.75s\n", 1137 | "217:\tlearn: 0.1027793\ttest: 0.1541172\tbest: 0.1537433 (207)\ttotal: 766ms\tremaining: 2.75s\n", 1138 | "218:\tlearn: 0.1026013\ttest: 0.1541685\tbest: 0.1537433 (207)\ttotal: 772ms\tremaining: 2.75s\n", 1139 | "219:\tlearn: 0.1024922\ttest: 0.1540507\tbest: 0.1537433 (207)\ttotal: 774ms\tremaining: 2.75s\n", 1140 | "220:\tlearn: 0.1024123\ttest: 0.1540590\tbest: 0.1537433 (207)\ttotal: 777ms\tremaining: 2.74s\n", 1141 | "221:\tlearn: 0.1019649\ttest: 0.1539730\tbest: 0.1537433 (207)\ttotal: 780ms\tremaining: 2.73s\n", 1142 | "222:\tlearn: 0.1015602\ttest: 0.1540056\tbest: 0.1537433 (207)\ttotal: 783ms\tremaining: 2.73s\n", 1143 | "223:\tlearn: 0.1014548\ttest: 0.1539750\tbest: 0.1537433 (207)\ttotal: 787ms\tremaining: 2.73s\n", 1144 | "224:\tlearn: 0.1013972\ttest: 0.1538995\tbest: 0.1537433 (207)\ttotal: 790ms\tremaining: 2.72s\n", 1145 | "225:\tlearn: 0.1013235\ttest: 0.1538275\tbest: 0.1537433 (207)\ttotal: 793ms\tremaining: 2.72s\n", 1146 | "226:\tlearn: 0.1011391\ttest: 0.1538708\tbest: 0.1537433 (207)\ttotal: 796ms\tremaining: 2.71s\n", 1147 | "227:\tlearn: 0.1011154\ttest: 0.1538654\tbest: 0.1537433 (207)\ttotal: 800ms\tremaining: 2.71s\n", 1148 | "Stopped by overfitting detector (20 iterations wait)\n", 1149 | "\n", 1150 | "bestTest = 0.1537432969\n", 1151 | "bestIteration = 207\n", 1152 | "\n", 1153 | "Shrink model to first 208 iterations.\n" 1154 | ] 1155 | } 1156 | ], 1157 | "source": [ 1158 | "if params['problem_type'] == 'regression':\n", 1159 | " from catboost import CatBoostRegressor, Pool\n", 1160 | "\n", 1161 | " model_catboost = CatBoostRegressor(n_estimators=1000, \n", 1162 | " early_stopping_rounds=20)\n", 1163 | " train_data = Pool(\n", 1164 | " data=X_train_raw,\n", 1165 | " label=y_train,\n", 1166 | " cat_features=categorical_feature_indices,\n", 1167 | " )\n", 1168 | "\n", 1169 | " eval_data = Pool(\n", 1170 | " data=X_valid_raw,\n", 1171 | " label=y_valid,\n", 1172 | " cat_features=categorical_feature_indices,\n", 1173 | " )\n", 1174 | "\n", 1175 | " model_catboost.fit(X=train_data, \n", 1176 | " eval_set=eval_data)\n", 1177 | "\n", 1178 | "\n", 1179 | "\n", 1180 | "\n", 1181 | " preds_catboost = model_catboost.predict(X_test_raw)\n", 1182 | "else:\n", 1183 | " from catboost import CatBoostClassifier, Pool\n", 1184 | "\n", 1185 | " model_catboost = CatBoostClassifier(n_estimators=1000, \n", 1186 | " early_stopping_rounds=20)\n", 1187 | " train_data = Pool(\n", 1188 | " data=X_train_raw,\n", 1189 | " label=y_train,\n", 1190 | " cat_features=categorical_feature_indices,\n", 1191 | " #weight=calculate_sample_weights(y_train)\n", 1192 | " )\n", 1193 | "\n", 1194 | " eval_data = Pool(\n", 1195 | " data=X_valid_raw,\n", 1196 | " label=y_valid,\n", 1197 | " cat_features=categorical_feature_indices,\n", 1198 | " #weight=calculate_sample_weights(y_valid),\n", 1199 | " )\n", 1200 | "\n", 1201 | " model_catboost.fit(X=train_data, \n", 1202 | " eval_set=eval_data)\n", 1203 | "\n", 1204 | "\n", 1205 | "\n", 1206 | " preds_catboost = model_catboost.predict_proba(X_test_raw)\n" 1207 | ] 1208 | }, 1209 | { 1210 | "cell_type": "code", 1211 | "execution_count": 12, 1212 | "id": "21b9be01-ba30-4300-b3e0-f4fd43b55c87", 1213 | "metadata": { 1214 | "execution": { 1215 | "iopub.execute_input": "2024-03-15T11:57:52.429019Z", 1216 | "iopub.status.busy": "2024-03-15T11:57:52.428555Z", 1217 | "iopub.status.idle": "2024-03-15T11:57:52.454151Z", 1218 | "shell.execute_reply": "2024-03-15T11:57:52.453620Z", 1219 | "shell.execute_reply.started": "2024-03-15T11:57:52.428990Z" 1220 | } 1221 | }, 1222 | "outputs": [ 1223 | { 1224 | "name": "stdout", 1225 | "output_type": "stream", 1226 | "text": [ 1227 | "Accuracy GRANDE: 0.954\n", 1228 | "F1 Score GRANDE: 0.8976494984825425\n", 1229 | "ROC AUC GRANDE: 0.922450889462646\n", 1230 | "\n", 1231 | "\n", 1232 | "Accuracy XGB: 0.952\n", 1233 | "F1 Score XGB: 0.8902857142857143\n", 1234 | "ROC AUC XGB: 0.9152775340703048\n", 1235 | "\n", 1236 | "\n", 1237 | "Accuracy CatBoost: 0.957\n", 1238 | "F1 Score CatBoost: 0.8999941857084715\n", 1239 | "ROC AUC CatBoost: 0.9137693329656831\n", 1240 | "\n", 1241 | "\n" 1242 | ] 1243 | } 1244 | ], 1245 | "source": [ 1246 | "if params['problem_type'] == 'binary':\n", 1247 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_grande[:,1]))\n", 1248 | " f1_score = sklearn.metrics.f1_score(y_test, np.round(preds_grande[:,1]), average='macro')\n", 1249 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_grande[:,1], average='macro', multi_class='ovo')\n", 1250 | "\n", 1251 | " print('Accuracy GRANDE:', accuracy)\n", 1252 | " print('F1 Score GRANDE:', f1_score)\n", 1253 | " print('ROC AUC GRANDE:', roc_auc)\n", 1254 | " print('\\n')\n", 1255 | "\n", 1256 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_xgb[:,1]))\n", 1257 | " f1_score = sklearn.metrics.f1_score(y_test, np.round(preds_xgb[:,1]), average='macro')\n", 1258 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_xgb[:,1], average='macro', multi_class='ovo')\n", 1259 | "\n", 1260 | " print('Accuracy XGB:', accuracy)\n", 1261 | " print('F1 Score XGB:', f1_score)\n", 1262 | " print('ROC AUC XGB:', roc_auc)\n", 1263 | " print('\\n')\n", 1264 | "\n", 1265 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_catboost[:,1]))\n", 1266 | " f1_score = sklearn.metrics.f1_score(y_test, np.round(preds_catboost[:,1]), average='macro')\n", 1267 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_catboost[:,1], average='macro', multi_class='ovo')\n", 1268 | "\n", 1269 | " print('Accuracy CatBoost:', accuracy)\n", 1270 | " print('F1 Score CatBoost:', f1_score)\n", 1271 | " print('ROC AUC CatBoost:', roc_auc)\n", 1272 | " print('\\n')\n", 1273 | "elif params['problem_type'] == 'multiclass':\n", 1274 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.argmax(preds_grande, axis=1))\n", 1275 | " f1_score = sklearn.metrics.f1_score(y_test, np.argmax(preds_grande, axis=1), average='macro')\n", 1276 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_grande, average='macro', multi_class='ovo', labels=[i for i in range(preds_grande.shape[1])])\n", 1277 | "\n", 1278 | " print('Accuracy GRANDE:', accuracy)\n", 1279 | " print('F1 Score GRANDE:', f1_score)\n", 1280 | " print('ROC AUC GRANDE:', roc_auc)\n", 1281 | " print('\\n')\n", 1282 | "\n", 1283 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.argmax(preds_xgb, axis=1))\n", 1284 | " f1_score = sklearn.metrics.f1_score(y_test, np.argmax(preds_xgb, axis=1), average='macro')\n", 1285 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_xgb, average='macro', multi_class='ovo', labels=[i for i in range(preds_grande.shape[1])])\n", 1286 | "\n", 1287 | " print('Accuracy XGB:', accuracy)\n", 1288 | " print('F1 Score XGB:', f1_score)\n", 1289 | " print('ROC AUC XGB:', roc_auc)\n", 1290 | " print('\\n')\n", 1291 | "\n", 1292 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.argmax(preds_catboost, axis=1))\n", 1293 | " f1_score = sklearn.metrics.f1_score(y_test, np.argmax(preds_catboost, axis=1), average='macro')\n", 1294 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_catboost, average='macro', multi_class='ovo', labels=[i for i in range(preds_grande.shape[1])])\n", 1295 | "\n", 1296 | " print('Accuracy CatBoost:', accuracy)\n", 1297 | " print('F1 Score CatBoost:', f1_score)\n", 1298 | " print('ROC AUC CatBoost:', roc_auc)\n", 1299 | " print('\\n')\n", 1300 | "else:\n", 1301 | " mean_absolute_error = sklearn.metrics.mean_absolute_error(y_test, preds_grande)\n", 1302 | " root_mean_squared_error = np.sqrt(((y_test - preds_grande) ** 2).mean())\n", 1303 | " r2_score = sklearn.metrics.r2_score(y_test, preds_grande)\n", 1304 | "\n", 1305 | " print('MAE GRANDE:', mean_absolute_error)\n", 1306 | " print('RMSE GRANDE:', root_mean_squared_error)\n", 1307 | " print('R2 Score GRANDE:', r2_score)\n", 1308 | " print('\\n')\n", 1309 | "\n", 1310 | " mean_absolute_error = sklearn.metrics.mean_absolute_error(y_test, preds_xgb)\n", 1311 | " root_mean_squared_error = np.sqrt(((y_test - preds_xgb) ** 2).mean())\n", 1312 | " r2_score = sklearn.metrics.r2_score(y_test, preds_xgb)\n", 1313 | "\n", 1314 | " print('MAE XGB:', mean_absolute_error)\n", 1315 | " print('RMSE XGB:', root_mean_squared_error)\n", 1316 | " print('R2 Score XGB:', r2_score)\n", 1317 | " print('\\n')\n", 1318 | "\n", 1319 | " mean_absolute_error = sklearn.metrics.mean_absolute_error(y_test, preds_catboost)\n", 1320 | " root_mean_squared_error = np.sqrt(((y_test - preds_catboost) ** 2).mean())\n", 1321 | " r2_score = sklearn.metrics.r2_score(y_test, preds_catboost)\n", 1322 | "\n", 1323 | " print('MAE CatBoost:', mean_absolute_error)\n", 1324 | " print('RMSE CatBoost:', root_mean_squared_error)\n", 1325 | " print('R2 Score CatBoost:', r2_score)\n", 1326 | " print('\\n')" 1327 | ] 1328 | }, 1329 | { 1330 | "cell_type": "code", 1331 | "execution_count": null, 1332 | "id": "c7ef884c-1269-4e0b-8ca7-7ea74c6e380c", 1333 | "metadata": {}, 1334 | "outputs": [], 1335 | "source": [] 1336 | }, 1337 | { 1338 | "cell_type": "code", 1339 | "execution_count": null, 1340 | "id": "2fc02cd3", 1341 | "metadata": {}, 1342 | "outputs": [], 1343 | "source": [] 1344 | } 1345 | ], 1346 | "metadata": { 1347 | "kernelspec": { 1348 | "display_name": "ReMeDe", 1349 | "language": "python", 1350 | "name": "python3" 1351 | }, 1352 | "language_info": { 1353 | "codemirror_mode": { 1354 | "name": "ipython", 1355 | "version": 3 1356 | }, 1357 | "file_extension": ".py", 1358 | "mimetype": "text/x-python", 1359 | "name": "python", 1360 | "nbconvert_exporter": "python", 1361 | "pygments_lexer": "ipython3", 1362 | "version": "3.12.9" 1363 | } 1364 | }, 1365 | "nbformat": 4, 1366 | "nbformat_minor": 5 1367 | } 1368 | --------------------------------------------------------------------------------