├── GRANDE
├── __init__.py
└── tf_legacy
│ └── GRANDE.py
├── figures
├── grande.jpg
└── tabarena_leaderboard.jpg
├── requirements.txt
├── setup.py
├── LICENSE
├── README.md
├── GRANDE_minimal_example_with_comparison_REG.ipynb
└── GRANDE_minimal_example_with_comparison_BINARY.ipynb
/GRANDE/__init__.py:
--------------------------------------------------------------------------------
1 | from .GRANDE import GRANDE
--------------------------------------------------------------------------------
/figures/grande.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s-marton/GRANDE/HEAD/figures/grande.jpg
--------------------------------------------------------------------------------
/figures/tabarena_leaderboard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s-marton/GRANDE/HEAD/figures/tabarena_leaderboard.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Core runtime requirements (GRANDE/GRANDE.py)
2 | numpy>=1.25.0,<2.4.0
3 | pandas>=2.0.0,<2.4.0
4 | torch>=2.6,<2.10
5 | scikit-learn>=1.4.0,<1.8.0
6 | category-encoders>=2.6.4,<=2.9.0
7 | tqdm>=4.38,<5
8 | autogluon>=1.3.0,<=1.4.0
9 |
10 | # Notebook / example-only requirements (GRANDE_minimal_example_with_comparison_BINARY.ipynb)
11 | openml>=0.15.0,<=0.15.1
12 | xgboost>=2.0,<3.1
13 | catboost>=1.2,<1.3
14 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="GRANDE",
5 | license="MIT",
6 | version="0.2.0",
7 | packages=find_packages(),
8 | install_requires=[
9 | 'numpy>=1.25.0,<2.4.0',
10 | 'pandas>=2.0.0,<2.4.0',
11 | 'torch>=2.6,<2.10',
12 | 'scikit-learn>=1.4.0,<1.8.0',
13 | 'category-encoders>=2.6.4,<=2.9.0',
14 | 'tqdm>=4.38,<5',
15 | 'autogluon>=1.3.0,<=1.4.0',
16 | ],
17 | author="Sascha Marton",
18 | author_email="sascha.marton@gmail.com",
19 | description="A novel ensemble method for hard, axis-aligned decision trees learned end-to-end with gradient descent.",
20 | long_description=open('README.md').read(),
21 | long_description_content_type="text/markdown",
22 | url="https://github.com/s-marton/GRANDE",
23 | )
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Sascha Marton
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🌳 GRANDE: Gradient-Based Decision Tree Ensembles 🌳
2 |
3 | [](https://pypi.org/project/GRANDE/) [](https://openreview.net/forum?id=XEFWBxi075) [](https://arxiv.org/abs/2309.17130)
4 |
5 |
6 |
7 |
8 |

9 |
10 |
TabArena. The updated PyTorch GRANDE has been evaluated on TabArena and achieved strong results.
11 |
12 |
13 |
14 | 🔍 What's new?
15 | - PyTorch-native implementation for seamless integration; TensorFlow is maintained as a legacy version.
16 | - Strong results on TabArena (specifically for binary classification and regression; multi-class results are less strong dragging down the overall performance which can hopefully be fixed in a future release)
17 | - Method updates for improved performance, including optional categorical and numerical embeddings.
18 | - Training improvements (optimizers, schedulers, early stopping, optional SWA).
19 | - Enhanced preprocessing pipeline with optional frequency encoding and robust normalization.
20 |
21 |
22 |
23 |

24 |
25 |
Figure 1: Overview GRANDE. GRANDE learns hard, axis-aligned trees end-to-end via gradient descent, and uses dynamic instance-wise leaf weighting to combine estimators into a strong ensemble.
26 |
27 |
28 |
29 | 🌳 GRANDE is a gradient-based decision tree ensemble for tabular data.
30 | GRANDE trains ensembles of hard, axis-aligned decision trees end-to-end with gradient descent. Each estimator contributes via instance-wise leaf weights that are learned jointly with split locations and leaf values. This combines the strong inductive bias of trees with the flexibility of neural optimization. The PyTorch version optionally augments inputs with learnable categorical and numerical embeddings, improving representation capacity while preserving interpretability of splits.
31 |
32 | 📝 More details in the paper: https://openreview.net/forum?id=XEFWBxi075
33 |
34 |
35 | ## Cite us
36 | ```text
37 | @inproceedings{
38 | marton2024grande,
39 | title={{GRANDE}: Gradient-Based Decision Tree Ensembles},
40 | author={Sascha Marton and Stefan L{\"u}dtke and Christian Bartelt and Heiner Stuckenschmidt},
41 | booktitle={The Twelfth International Conference on Learning Representations},
42 | year={2024},
43 | url={https://openreview.net/forum?id=XEFWBxi075}
44 | }
45 | ```
46 |
47 | ## Installation
48 | To install the latest release:
49 | ```bash
50 | pip install git+https://github.com/s-marton/GRANDE.git
51 | ```
52 |
53 | ## Dependencies
54 | Install core runtime requirements (and optional notebook/example dependencies) via:
55 |
56 | ```bash
57 | pip install -r requirements.txt
58 | ```
59 |
60 | Notes:
61 | - The file contains a **core** section (library runtime deps) and a **notebook/example-only** section (OpenML/XGBoost/CatBoost).
62 |
63 | ## Usage (PyTorch)
64 | Example aligned with the attached notebook (binary classification, OpenML dataset 46915). GPU is recommended.
65 |
66 | ```python
67 | # Enable GPU (optional)
68 | import os
69 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
70 |
71 | # Load data
72 | from sklearn.model_selection import train_test_split
73 | import openml
74 | import numpy as np
75 | import sklearn
76 |
77 | dataset = openml.datasets.get_dataset(46915, download_data=True, download_qualities=True, download_features_meta_data=True)
78 | X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
79 |
80 | X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
81 | X_train, X_valid, y_train, y_valid = train_test_split(X_temp, y_temp, test_size=0.2, random_state=42)
82 |
83 | # GRANDE (PyTorch)
84 | from GRANDE import GRANDE
85 |
86 | params = {
87 | 'depth': 5,
88 | 'n_estimators': 1024,
89 |
90 | 'learning_rate_weights': 0.001,
91 | 'learning_rate_index': 0.01,
92 | 'learning_rate_values': 0.05,
93 | 'learning_rate_leaf': 0.05,
94 | 'learning_rate_embedding': 0.02, # used if embeddings are enabled
95 |
96 | # Embeddings (set True to enable)
97 | 'use_category_embeddings': False, # True to enable
98 | 'embedding_dim_cat': 8,
99 | 'use_numeric_embeddings': False, # True to enable
100 | 'embedding_dim_num': 8,
101 | 'embedding_threshold': 1, # low-cardinality split for categorical embeddings
102 | 'loo_cardinality': 10, # high-cardinality split for encoders
103 |
104 | 'dropout': 0.2,
105 | 'selected_variables': 0.8,
106 | 'data_subset_fraction': 1.0,
107 | 'bootstrap': False,
108 | 'missing_values': False,
109 |
110 | 'optimizer': 'adam', # options: nadam, radam, adamw, adam
111 | 'cosine_decay_restarts': False,
112 | 'reduce_on_plateau_scheduler': True,
113 | 'label_smoothing': 0.0,
114 | 'use_class_weights': False,
115 | 'focal_loss': False,
116 | 'swa': False,
117 | 'es_metric': True, # AUC for binary, MSE for regression, val_loss for multiclass
118 |
119 | 'epochs': 250,
120 | 'batch_size': 256,
121 | 'early_stopping_epochs': 50,
122 |
123 | 'use_freq_enc': False,
124 | 'use_robust_scale_smoothing': False,
125 |
126 | # Important: use problem_type, not objective
127 | 'problem_type': 'binary', # {'binary', 'multiclass', 'regression'}
128 |
129 | 'random_seed': 42,
130 | 'verbose': 2,
131 | }
132 |
133 | model_grande = GRANDE(params=params)
134 | model_grande.fit(X=X_train, y=y_train, X_val=X_valid, y_val=y_valid)
135 |
136 | # Predict
137 | preds_grande = model_grande.predict_proba(X_test)
138 |
139 | # Evaluate (binary)
140 | accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_grande[:, 1]))
141 | f1 = sklearn.metrics.f1_score(y_test, np.round(preds_grande[:, 1]), average='macro')
142 | roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_grande[:, 1], average='macro')
143 |
144 | print('Accuracy GRANDE:', accuracy)
145 | print('F1 Score GRANDE:', f1)
146 | print('ROC AUC GRANDE:', roc_auc)
147 | ```
148 |
149 | Notes:
150 | - Set use_category_embeddings/use_numeric_embeddings to True to enable embeddings.
151 | - For multiclass, use problem_type='multiclass'. For regression, use 'regression'.
152 | - TensorFlow is supported as a legacy version; the PyTorch path is the recommended/default.
153 |
154 | ## More
155 | This is an experimental implementation. If you encounter issues, please open an issue or report unexpected behavior.
156 |
--------------------------------------------------------------------------------
/GRANDE/tf_legacy/GRANDE.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | #import tensorflow_addons as tfa
4 |
5 | import sklearn
6 | from copy import deepcopy
7 | import category_encoders as ce
8 | import pandas as pd
9 | import math
10 | from focal_loss import SparseCategoricalFocalLoss
11 |
12 | import pickle
13 | import zipfile
14 | import os
15 |
16 | class GRANDE(tf.keras.Model):
17 | def __init__(self,
18 | params,
19 | args):
20 |
21 | params.update(args)
22 | self.config = None
23 |
24 | super(GRANDE, self).__init__()
25 | self.set_params(**params)
26 |
27 | self.is_fitted = False
28 |
29 | self.internal_node_num_ = 2 ** self.depth - 1
30 | self.leaf_node_num_ = 2 ** self.depth
31 |
32 | tf.keras.utils.set_random_seed(self.random_seed)
33 |
34 | def call(self, inputs, training):
35 | # einsum syntax:
36 | # - b is the batch size
37 | # - e is the number of estimators
38 | # - l the number of leaf nodes (i.e. the number of paths)
39 | # - i is the number of internal nodes
40 | # - d is the depth (i.e. the length of each path)
41 | # - n is the number of variables (one value is stored for each variable)
42 |
43 |
44 | # Adjust data: For each estimator select a subset of the features (output shape: (b, e, n))
45 | if self.data_subset_fraction < 1.0 and training:
46 | #select subset of samples for each estimator during training if hyperparameter set accordingly
47 | X_estimator = tf.nn.embedding_lookup(tf.transpose(inputs), self.features_by_estimator)
48 | X_estimator = tf.transpose(tf.gather_nd(tf.transpose(X_estimator, [0,2,1]), tf.expand_dims(self.data_select,2), batch_dims=1), [1,0,2])
49 | else:
50 | #use complete data
51 | X_estimator = tf.gather(inputs, self.features_by_estimator, axis=1)
52 |
53 | #entmax transformaton
54 | split_index_array = entmax15TF(self.split_index_array)
55 |
56 | #use ST-Operator to get one-hot encoded vector for feature index
57 | split_index_array = split_index_array - tf.stop_gradient(split_index_array - tf.one_hot(tf.argmax(split_index_array, axis=-1), depth=split_index_array.shape[-1]))
58 |
59 | # as split_index_array_selected is one-hot-encoded, taking the sum over the last axis after multiplication results in selecting the desired value at the index
60 | s1_sum = tf.einsum("ein,ein->ei", self.split_values, split_index_array)
61 | s2_sum = tf.einsum("ben,ein->bei", X_estimator, split_index_array)
62 |
63 | # calculate the split (output shape: (b, e, i))
64 | node_result = (tf.nn.softsign(s1_sum-s2_sum) + 1) / 2
65 |
66 | #use round operation with ST operator to get hard decision for each node
67 | node_result_corrected = node_result - tf.stop_gradient(node_result - tf.round(node_result))
68 |
69 | #generate tensor for further calculation:
70 | # - internal_node_index_list holds the indices for the internal nodes traversed for each path (there are l paths) in the tree
71 | # - for each estimator and for each path in each estimator, the tensors hold the information for all internal nodex traversed
72 | # - the resulting shape of the tensors is (b, e, l, d):
73 | node_result_extended = tf.gather(node_result_corrected, self.internal_node_index_list, axis=2)
74 |
75 |
76 | #reduce the path via multiplication to get result for each path (in each estimator) based on the results of the corresponding internal nodes (output shape: (b, e, l))
77 | p = tf.reduce_prod(((1-self.path_identifier_list)*node_result_extended + self.path_identifier_list*(1-node_result_extended)), axis=3)
78 |
79 | #calculate instance-wise leaf weights for each estimator by selecting the weight of the selected path for each estimator
80 | estimator_weights_leaf = tf.einsum("el,bel->be", self.estimator_weights, p)
81 |
82 | #use softmax over weights for each instance
83 | estimator_weights_leaf_softmax = tf.keras.activations.softmax(estimator_weights_leaf)
84 |
85 | #optional dropout (deactivating random estimators)
86 | estimator_weights_leaf_softmax = self.apply_dropout_leaf(estimator_weights_leaf_softmax, training=training)
87 |
88 | #get raw prediction for each estimator
89 | #optionally transform to probability distribution before weighting
90 | if self.objective == 'regression':
91 | layer_output = tf.einsum('el,bel->be', self.leaf_classes_array, p)
92 | layer_output = tf.einsum('be,be->be', estimator_weights_leaf_softmax, layer_output)
93 | elif self.objective == 'binary':
94 | if self.from_logits:
95 | layer_output = tf.einsum('el,bel->be', self.leaf_classes_array, p)
96 | else:
97 | layer_output = tf.math.sigmoid(tf.einsum('el,bel->be', self.leaf_classes_array, p))
98 | layer_output = tf.einsum('be,be->be', estimator_weights_leaf_softmax, layer_output)
99 | elif self.objective == 'classification':
100 | if self.from_logits:
101 | layer_output = tf.einsum('elc,bel->bec', self.leaf_classes_array, p)
102 | layer_output = tf.einsum('be,bec->bec', estimator_weights_leaf_softmax, layer_output)
103 | else:
104 | layer_output = tf.keras.activations.softmax(tf.einsum('elc,bel->bec', self.leaf_classes_array, p))
105 | layer_output = tf.einsum('be,bec->bec', estimator_weights_leaf_softmax, layer_output)
106 |
107 | if self.data_subset_fraction < 1.0 and training:
108 | result = tf.scatter_nd(indices=tf.expand_dims(self.data_select, 2), updates=tf.transpose(layer_output), shape=[tf.shape(inputs)[0]])
109 | result = (result / self.counts) * self.n_estimators
110 | else:
111 | if self.objective == 'regression' or self.objective == 'binary':
112 | result = tf.einsum('be->b', layer_output)
113 | else:
114 | result = tf.einsum('bec->bc', layer_output)
115 |
116 | if self.objective == 'regression' or self.objective == 'binary':
117 | result = tf.expand_dims(result, 1)
118 |
119 | return result
120 |
121 | def apply_preprocessing(self, X):
122 |
123 | if not isinstance(X, pd.DataFrame):
124 | X = pd.DataFrame(X)
125 |
126 | if len(self.num_columns) > 0:
127 | X[self.num_columns] = X[self.num_columns].fillna(self.mean_train_num)
128 | if len(self.cat_columns) > 0:
129 | X[self.cat_columns] = X[self.cat_columns].fillna(self.mode_train_cat)
130 |
131 | X = self.encoder_ordinal.transform(X)
132 | X = self.encoder_loo.transform(X)
133 | X = self.encoder_ohe.transform(X)
134 |
135 | X = self.normalizer.transform(X.values.astype(np.float64))
136 |
137 | return X
138 |
139 | def perform_preprocessing(self,
140 | X_train,
141 | y_train,
142 | X_val,
143 | y_val):
144 |
145 | if isinstance(y_train, pd.Series):
146 | try:
147 | y_train = y_train.values.codes.astype(np.float64)
148 | except:
149 | pass
150 | if isinstance(y_val, pd.Series):
151 | try:
152 | y_val = y_val.values.codes.astype(np.float64)
153 | except:
154 | pass
155 |
156 | if not isinstance(X_train, pd.DataFrame):
157 | X_train = pd.DataFrame(X_train)
158 | if not isinstance(X_val, pd.DataFrame):
159 | X_val = pd.DataFrame(X_val)
160 |
161 | self.mean = np.mean(y_train)
162 | self.std = np.std(y_train)
163 |
164 | binary_indices = []
165 | low_cardinality_indices = []
166 | high_cardinality_indices = []
167 | num_columns = []
168 | for column_index, column in enumerate(X_train.columns):
169 | if column_index in self.cat_idx:
170 | if len(X_train.iloc[:,column_index].unique()) <= 2:
171 | binary_indices.append(column)
172 | elif len(X_train.iloc[:,column_index].unique()) < 5:
173 | low_cardinality_indices.append(column)
174 | else:
175 | high_cardinality_indices.append(column)
176 | else:
177 | num_columns.append(column)
178 |
179 | cat_columns = [col for col in X_train.columns if col not in num_columns]
180 |
181 | if len(num_columns) > 0:
182 | self.mean_train_num = X_train[num_columns].mean(axis=0).iloc[0]
183 | X_train[num_columns] = X_train[num_columns].fillna(self.mean_train_num)
184 | X_val[num_columns] = X_val[num_columns].fillna(self.mean_train_num)
185 | if len(cat_columns) > 0:
186 | self.mode_train_cat = X_train[cat_columns].mode(axis=0).iloc[0]
187 | X_train[cat_columns] = X_train[cat_columns].fillna(self.mode_train_cat)
188 | X_val[cat_columns] = X_val[cat_columns].fillna(self.mode_train_cat)
189 |
190 | self.cat_columns = cat_columns
191 | self.num_columns = num_columns
192 |
193 | self.encoder_ordinal = ce.OrdinalEncoder(cols=binary_indices)
194 | self.encoder_ordinal.fit(X_train)
195 | X_train = self.encoder_ordinal.transform(X_train)
196 | X_val = self.encoder_ordinal.transform(X_val)
197 |
198 | self.encoder_loo = ce.LeaveOneOutEncoder(cols=high_cardinality_indices)
199 | if self.objective == 'regression':
200 | self.encoder_loo.fit(X_train, (y_train-self.mean)/self.std)
201 | else:
202 | self.encoder_loo.fit(X_train, y_train)
203 | X_train = self.encoder_loo.transform(X_train)
204 | X_val = self.encoder_loo.transform(X_val)
205 |
206 | self.encoder_ohe = ce.OneHotEncoder(cols=low_cardinality_indices)
207 | self.encoder_ohe.fit(X_train)
208 | X_train = self.encoder_ohe.transform(X_train)
209 | X_val = self.encoder_ohe.transform(X_val)
210 |
211 | X_train = X_train.astype(np.float32)
212 | X_val = X_val.astype(np.float32)
213 |
214 | quantile_noise = 1e-4
215 | quantile_train = np.copy(X_train.values).astype(np.float64)
216 | np.random.seed(42)
217 | stds = np.std(quantile_train, axis=0, keepdims=True)
218 | noise_std = quantile_noise / np.maximum(stds, quantile_noise)
219 | quantile_train += noise_std * np.random.randn(*quantile_train.shape)
220 |
221 | quantile_train = pd.DataFrame(quantile_train, columns=X_train.columns, index=X_train.index)
222 |
223 | self.normalizer = sklearn.preprocessing.QuantileTransformer(
224 | n_quantiles=min(quantile_train.shape[0], 1000),
225 | output_distribution='normal',
226 | )
227 |
228 | self.normalizer.fit(quantile_train.values.astype(np.float64))
229 | X_train = self.normalizer.transform(X_train.values.astype(np.float64))
230 | X_val = self.normalizer.transform(X_val.values.astype(np.float64))
231 |
232 | return X_train, y_train, X_val, y_val
233 |
234 | def convert_to_numpy(self,data):
235 | """
236 | Converts input data (Pandas DataFrame, TensorFlow tensor, PyTorch tensor, list, or similar iterable) to a NumPy array.
237 |
238 | Args:
239 | data: Input data to be converted. Can be a Pandas DataFrame, TensorFlow tensor, list, or similar iterable.
240 |
241 | Returns:
242 | numpy_array: A NumPy array representation of the input data.
243 | """
244 | # Check if the data is a Pandas DataFrame
245 | if isinstance(data, (pd.DataFrame, pd.Series)):
246 | return data.values
247 |
248 | # Check if the data is a TensorFlow tensor
249 | elif isinstance(data, tf.Tensor):
250 | return data.numpy()
251 |
252 | # Check if the data is a list or similar iterable (not including strings)
253 | elif isinstance(data, (list, tuple, np.ndarray)):
254 | return np.array(data)
255 |
256 | else:
257 | raise TypeError("The input data type is not supported for conversion.")
258 |
259 | def fit(self,
260 | X_train,
261 | y_train,
262 | X_val=None,
263 | y_val=None,
264 | **kwargs):
265 |
266 | if self.preprocess_data:
267 | X_train, y_train, X_val, y_val = self.perform_preprocessing(X_train, y_train, X_val, y_val)
268 | else:
269 | X_train = self.convert_to_numpy(X_train)
270 | y_train = self.convert_to_numpy(y_train)
271 | X_val = self.convert_to_numpy(X_val)
272 | y_val = self.convert_to_numpy(y_val)
273 |
274 | jit_compile = True #X_train.shape[0] < 10_000
275 |
276 | self.number_of_variables = X_train.shape[1]
277 | if self.use_class_weights:
278 | if self.objective == 'classification' or self.objective == 'binary':
279 | self.number_of_classes = len(np.unique(y_train))
280 | self.class_weights = sklearn.utils.class_weight.compute_class_weight(class_weight = 'balanced', classes = np.unique(y_train), y = y_train)
281 |
282 | self.class_weight_dict = {}
283 | for i in range(self.number_of_classes):
284 | self.class_weight_dict[i] = self.class_weights[i]
285 |
286 | else:
287 | self.number_of_classes = 1
288 | self.class_weights = np.ones_like(np.unique(y_train))
289 | self.class_weight_dict = None
290 | else:
291 | if self.objective == 'classification' or self.objective == 'binary':
292 | self.number_of_classes = len(np.unique(y_train))
293 | else:
294 | self.number_of_classes = 1
295 | self.class_weights = np.ones_like(np.unique(y_train))
296 | self.class_weight_dict = None
297 |
298 | self.build_model()
299 |
300 | self.compile(loss=self.loss_name, metrics=[], jit_compile=jit_compile, mean=self.mean, std=self.std, class_weight=self.class_weights)
301 |
302 |
303 | train_data = tf.data.Dataset.from_tensor_slices((tf.dtypes.cast(tf.convert_to_tensor(X_train), tf.float32),
304 | tf.dtypes.cast(tf.convert_to_tensor(y_train), tf.float32)))
305 |
306 | if self.data_subset_fraction < 1.0:
307 | train_data = (train_data
308 | #.shuffle(32_768)
309 | .cache()
310 | .batch(batch_size=self.batch_size, drop_remainder=True)
311 | .prefetch(tf.data.AUTOTUNE)
312 | )
313 | else:
314 | train_data = (train_data
315 | .shuffle(32_768)
316 | .cache()
317 | .batch(batch_size=self.batch_size, drop_remainder=False)
318 | .prefetch(tf.data.AUTOTUNE)
319 | )
320 |
321 | if X_val is not None and y_val is not None:
322 | validation_data = (X_val, y_val)
323 | validation_data = tf.data.Dataset.from_tensor_slices((tf.dtypes.cast(tf.convert_to_tensor(validation_data[0]), tf.float32),
324 | tf.dtypes.cast(tf.convert_to_tensor(validation_data[1]), tf.float32)))
325 |
326 |
327 | validation_data = (validation_data
328 | .cache()
329 | .batch(batch_size=self.batch_size, drop_remainder=False)
330 | .prefetch(tf.data.AUTOTUNE)
331 | )
332 |
333 | monitor = 'val_loss'
334 | else:
335 | monitor = 'loss'
336 |
337 |
338 | if 'callbacks' not in kwargs.keys():
339 | callbacks = []
340 |
341 | early_stopping = tf.keras.callbacks.EarlyStopping(monitor=monitor,
342 | patience=self.early_stopping_epochs,
343 | min_delta=1e-3,
344 | restore_best_weights=True)
345 | callbacks.append(early_stopping)
346 |
347 | if 'reduce_lr' in kwargs.keys() and kwargs['reduce_lr']:
348 | reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(factor=0.2, patience=self.early_stopping_epochs//3)
349 | callbacks.append(reduce_lr)
350 |
351 | super(GRANDE, self).fit(train_data,
352 | validation_data = validation_data,
353 | epochs = self.epochs,
354 | callbacks = callbacks,
355 | class_weight = self.class_weight_dict,
356 | verbose=self.verbose,
357 | **kwargs)
358 |
359 | def build_model(self):
360 |
361 | tf.keras.utils.set_random_seed(self.random_seed)
362 |
363 | if self.selected_variables > 1:
364 | self.selected_variables = min(self.selected_variables, self.number_of_variables)
365 | else:
366 | self.selected_variables = int(self.number_of_variables * self.selected_variables)
367 | self.selected_variables = min(self.selected_variables, 50)
368 | self.selected_variables = max(self.selected_variables, 10)
369 | self.selected_variables = min(self.selected_variables, self.number_of_variables)
370 | if self.objective != 'binary':
371 | self.data_subset_fraction = 1.0
372 | if self.data_subset_fraction < 1.0:
373 | if self.batch_size * self.data_subset_fraction > 4:
374 | self.subset_size = tf.cast(self.batch_size * self.data_subset_fraction, tf.int32)
375 | else:
376 | self.subset_size = tf.cast(4, tf.int32)
377 |
378 | if self.bootstrap:
379 | self.data_select = tf.random.uniform(shape=(self.n_estimators, self.subset_size), minval=0, maxval=self.batch_size, dtype=tf.int32)
380 | else:
381 |
382 | indices = [np.random.choice(self.batch_size, size=self.subset_size, replace=False) for _ in range(self.n_estimators)]
383 | self.data_select = tf.stack(indices)
384 |
385 | items, self.counts = np.unique(self.data_select, return_counts=True)
386 | self.counts = tf.constant(self.counts, dtype=tf.float32)
387 |
388 | self.features_by_estimator = tf.stack([np.random.choice(self.number_of_variables, size=(self.selected_variables), replace=False, p=None) for _ in range(self.n_estimators)])
389 |
390 | self.path_identifier_list = []
391 | self.internal_node_index_list = []
392 | for leaf_index in tf.unstack(tf.constant([i for i in range(self.leaf_node_num_)])):
393 | for current_depth in tf.unstack(tf.constant([i for i in range(1, self.depth+1)])):
394 | path_identifier = tf.cast(tf.math.floormod(tf.math.floor(leaf_index/(tf.math.pow(2, (self.depth-current_depth)))), 2), tf.float32)
395 | internal_node_index = tf.cast(tf.cast(tf.math.pow(2, (current_depth-1)), tf.float32) + tf.cast(tf.math.floor(leaf_index/(tf.math.pow(2, (self.depth-(current_depth-1))))), tf.float32) - 1.0, tf.int64)
396 | self.path_identifier_list.append(path_identifier)
397 | self.internal_node_index_list.append(internal_node_index)
398 | self.path_identifier_list = tf.reshape(tf.stack(self.path_identifier_list), (-1,self.depth))
399 | self.internal_node_index_list = tf.reshape(tf.cast(tf.stack(self.internal_node_index_list), tf.int64), (-1,self.depth))
400 |
401 | leaf_classes_array_shape = [self.n_estimators,self.leaf_node_num_,] if self.objective == 'binary' or self.objective == 'regression' else [self.n_estimators, self.leaf_node_num_, self.number_of_classes]
402 |
403 | weight_shape = [self.n_estimators,self.leaf_node_num_]
404 |
405 | self.estimator_weights = self.add_weight(shape=weight_shape,
406 | initializer={'class_name': self.initializer, 'config': {'seed': self.random_seed + 1}},
407 | trainable=True,
408 | name="estimator_weights",)
409 |
410 | self.split_values = self.add_weight(shape=(self.n_estimators, self.internal_node_num_, self.selected_variables),
411 | initializer={'class_name': self.initializer, 'config': {'seed': self.random_seed + 2}},
412 | trainable=True,
413 | name="split_values",)
414 |
415 |
416 | self.split_index_array = self.add_weight(shape=(self.n_estimators, self.internal_node_num_, self.selected_variables),
417 | initializer={'class_name': self.initializer, 'config': {'seed': self.random_seed + 3}},
418 | trainable=True,
419 | name="split_index_array",)
420 |
421 | self.leaf_classes_array = self.add_weight(shape=leaf_classes_array_shape,
422 | initializer={'class_name': self.initializer, 'config': {'seed': self.random_seed + 4}},
423 | trainable=True,
424 | name="leaf_classes_array",)
425 |
426 |
427 | def compile(self,
428 | loss,
429 | metrics,
430 | jit_compile,
431 | **kwargs):
432 |
433 | if self.objective == 'classification':
434 |
435 | if loss == 'crossentropy':
436 | if not self.focal_loss:
437 | loss_function = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=self.from_logits) #tf.keras.losses.get('categorical_crossentropy')
438 | else:
439 | loss_function = SparseCategoricalFocalLoss(gamma=2, class_weight=self.class_weights, from_logits=from_logits)
440 | else:
441 | loss_function = tf.keras.losses.get(loss)
442 | try:
443 | loss_function.from_logits = self.from_logits
444 | except:
445 | pass
446 | elif self.objective == 'binary':
447 | if loss == 'crossentropy':
448 | if not self.focal_loss:
449 | loss_function = tf.keras.losses.BinaryCrossentropy(from_logits=self.from_logits) #tf.keras.losses.get('binary_crossentropy')
450 | else:
451 | loss_function = tf.keras.losses.BinaryFocalCrossentropy(alpha=0.5, gamma=2, from_logits=self.from_logits)
452 | else:
453 | loss_function = tf.keras.losses.get(loss)
454 | try:
455 | loss_function.from_logits = self.from_logits
456 | except:
457 | pass
458 |
459 | elif self.objective == 'regression':
460 | loss_function = loss_function_regression(loss_name=loss, mean=kwargs['mean'], std=kwargs['std'])
461 |
462 | loss_function = loss_function_weighting(loss_function, temp=self.temperature)
463 | self.weights_optimizer = get_optimizer_by_name(optimizer_name=self.optimizer_name, learning_rate=self.learning_rate_weights, warmup_steps=0, cosine_decay_steps=self.cosine_decay_steps)
464 | self.index_optimizer = get_optimizer_by_name(optimizer_name=self.optimizer_name, learning_rate=self.learning_rate_index, warmup_steps=0, cosine_decay_steps=self.cosine_decay_steps)
465 | self.values_optimizer = get_optimizer_by_name(optimizer_name=self.optimizer_name, learning_rate=self.learning_rate_values, warmup_steps=0, cosine_decay_steps=self.cosine_decay_steps)
466 | self.leaf_optimizer = get_optimizer_by_name(optimizer_name=self.optimizer_name, learning_rate=self.learning_rate_leaf, warmup_steps=0, cosine_decay_steps=self.cosine_decay_steps)
467 |
468 | super(GRANDE, self).compile(loss=loss_function, metrics=metrics, jit_compile=jit_compile)
469 |
470 | def train_step(self, data):
471 |
472 | if len(data) == 3:
473 | x, y, sample_weight = data
474 | else:
475 | x, y = data
476 | sample_weight = None
477 |
478 | if not self.built:
479 | _ = self(x, training=True)
480 |
481 | with tf.GradientTape(persistent=True) as tape:
482 | y_pred = self(x, training=True)
483 | loss = self.compute_loss(x=None, y=y, y_pred=y_pred, sample_weight=sample_weight)
484 |
485 | weights_gradients = tape.gradient(loss, [self.estimator_weights])
486 | self.weights_optimizer.apply_gradients(zip(weights_gradients, [self.estimator_weights]))
487 |
488 | index_gradients = tape.gradient(loss, [self.split_index_array])
489 | self.index_optimizer.apply_gradients(zip(index_gradients, [self.split_index_array]))
490 |
491 | values_gradients = tape.gradient(loss, [self.split_values])
492 | self.values_optimizer.apply_gradients(zip(values_gradients, [self.split_values]))
493 |
494 | leaf_gradients = tape.gradient(loss, [self.leaf_classes_array])
495 | self.leaf_optimizer.apply_gradients(zip(leaf_gradients, [self.leaf_classes_array]))
496 |
497 |
498 | for metric in self.metrics:
499 | if metric.name == "loss":
500 | metric.update_state(loss)
501 | else:
502 | metric.update_state(y, y_pred)
503 |
504 | return {m.name: m.result() for m in self.metrics}
505 |
506 |
507 | def predict(self, X):
508 |
509 | if self.preprocess_data:
510 | X = self.apply_preprocessing(X)
511 | else:
512 | X = self.convert_to_numpy(X)
513 |
514 | preds = super(GRANDE, self).predict(X, self.batch_size, verbose=0)
515 | preds = tf.convert_to_tensor(preds)
516 | if self.objective == 'regression':
517 | preds = preds * self.std + self.mean
518 | else:
519 | if self.from_logits:
520 | if self.objective == 'binary':
521 | preds = tf.math.sigmoid(preds)
522 | elif self.objective == 'classification':
523 | preds = tf.keras.activations.softmax(preds)
524 |
525 | if self.objective == 'binary':
526 | preds = tf.stack([1-tf.squeeze(preds), tf.squeeze(preds)], axis=-1)
527 |
528 | return preds.numpy()
529 |
530 | def set_params(self, **kwargs):
531 |
532 | if self.config is None:
533 | self.config = self.default_parameters()
534 |
535 | self.config.update(kwargs)
536 |
537 | if self.config['n_estimators'] == 1:
538 | self.config['selected_variables'] = 1.0
539 | self.config['data_subset_fraction'] = 1.0
540 | self.config['bootstrap'] = False
541 | self.config['dropout'] = 0.0
542 |
543 | if 'loss' not in self.config.keys():
544 | if self.config['objective'] == 'classification' or self.config['objective'] == 'binary':
545 | self.config['loss'] = 'crossentropy'
546 | self.config['focal_loss'] = False
547 | elif self.config['objective'] == 'regression':
548 | self.config['loss'] = 'mse'
549 | self.config['focal_loss'] = False
550 |
551 | self.config['optimizer_name'] = self.config.pop('optimizer')
552 | self.config['loss_name'] = self.config.pop('loss')
553 |
554 | for arg_key, arg_value in self.config.items():
555 | setattr(self, arg_key, arg_value)
556 |
557 | tf.keras.utils.set_random_seed(self.random_seed)
558 |
559 |
560 | def get_params(self):
561 | return self.config
562 |
563 | def define_trial_parameters(self, trial, args):
564 | params = {
565 | 'depth': trial.suggest_int("depth", 3, 7),
566 | 'n_estimators': trial.suggest_int("n_estimators", 512, 2048),
567 |
568 | 'learning_rate_weights': trial.suggest_float("learning_rate_weights", 0.0001, 0.25),
569 | 'learning_rate_index': trial.suggest_float("learning_rate_index", 0.0001, 0.25),
570 | 'learning_rate_values': trial.suggest_float("learning_rate_values", 0.0001, 0.25),
571 | 'learning_rate_leaf': trial.suggest_float("learning_rate_leaf", 0.0001, 0.25),
572 |
573 | 'cosine_decay_steps': trial.suggest_categorical("cosine_decay_steps", [0, 100, 1000]),
574 |
575 | 'dropout': trial.suggest_categorical("dropout", [0, 0.25, 0.5]),
576 |
577 | 'selected_variables': trial.suggest_categorical("selected_variables", [1.0, 0.75, 0.5]),
578 | 'data_subset_fraction': trial.suggest_categorical("data_subset_fraction", [1.0, 0.8]),
579 | }
580 |
581 | try:
582 | if args['objective'] != 'regression':
583 | params['focal_loss'] = trial.suggest_categorical("focal_loss", [True, False])
584 | params['temperature'] = trial.suggest_categorical("temperature", [0, 0.25])
585 | except:
586 | if self.objective != 'regression':
587 | params['focal_loss'] = trial.suggest_categorical("focal_loss", [True, False])
588 | params['temperature'] = trial.suggest_categorical("temperature", [0, 0.25])
589 | return params
590 |
591 | def get_random_parameters(self, seed):
592 | rs = np.random.RandomState(seed)
593 | params = {
594 | 'depth': rs.randint(3, 7),
595 | 'n_estimators': rs.randint(512, 2048),
596 |
597 | 'learning_rate_weights': rs.uniform(0.0001, 0.25),
598 | 'learning_rate_index': rs.uniform(0.0001, 0.25),
599 | 'learning_rate_values': rs.uniform(0.0001, 0.25),
600 | 'learning_rate_leaf': rs.uniform(0.0001, 0.25),
601 |
602 | 'cosine_decay_steps': rs.choice([0, 100, 1000], p=[0.5, 0.25, 0.25]),
603 | 'dropout': rs.choice([0, 0.25]),
604 |
605 | 'selected_variables': rs.choice([1.0, 0.75, 0.5]),
606 | 'data_subset_fraction': rs.choice([1.0, 0.8]),
607 |
608 | }
609 |
610 | if self.objective != 'regression':
611 | params['focal_loss'] = rs.choice([True, False])
612 | params['temperature'] = rs.choice([1, 1/3, 1/5, 1/7, 1/9, 0], p=[0.1, 0.1, 0.1, 0.1, 0.1,0.5])
613 |
614 | return params
615 |
616 | def default_parameters(self):
617 | params = {
618 | 'depth': 5,
619 | 'n_estimators': 2048,
620 |
621 | 'learning_rate_weights': 0.005,
622 | 'learning_rate_index': 0.01,
623 | 'learning_rate_values': 0.01,
624 | 'learning_rate_leaf': 0.01,
625 |
626 | 'optimizer': 'adam',
627 | 'cosine_decay_steps': 0,
628 | 'temperature': 0.0,
629 |
630 | 'initializer': 'RandomNormal',
631 |
632 | 'loss': 'crossentropy',
633 | 'focal_loss': False,
634 |
635 | 'from_logits': True,
636 | 'use_class_weights': True,
637 | 'preprocess_data': True,
638 |
639 | 'dropout': 0.0,
640 |
641 | 'selected_variables': 0.8,
642 | 'data_subset_fraction': 1.0,
643 | 'bootstrap': False,
644 | }
645 |
646 | return params
647 |
648 | def save_model(self, save_path='model_gande'):
649 | config = {'params': {
650 | 'depth': self.depth,
651 | 'n_estimators': self.n_estimators,
652 |
653 | 'std': self.std,
654 | 'mean': self.mean,
655 | 'mode_train_cat': self.mode_train_cat,
656 | 'mean_train_num': self.mean_train_num,
657 |
658 | 'encoder_ordinal': self.encoder_ordinal,
659 | 'encoder_loo': self.encoder_loo,
660 | 'encoder_ohe': self.encoder_ohe,
661 | 'normalizer': self.normalizer,
662 |
663 | 'number_of_classes': self.number_of_classes,
664 | 'number_of_variables': self.number_of_variables,
665 |
666 | 'selected_variables': self.selected_variables,
667 | 'data_subset_fraction': self.data_subset_fraction,
668 | },
669 | 'args': {
670 | 'objective': self.objective,
671 |
672 | 'random_seed': self.random_seed,
673 | 'verbose': self.verbose,
674 |
675 | }
676 | }
677 |
678 | temp_dir = 'temp_model'
679 | os.makedirs(temp_dir, exist_ok=True)
680 |
681 | np.savez(os.path.join(temp_dir, 'variables.npz'),
682 | estimator_weights=self.estimator_weights.numpy(),
683 | split_values=self.split_values.numpy(),
684 | split_index_array=self.split_index_array.numpy(),
685 | leaf_classes_array=self.leaf_classes_array.numpy())
686 |
687 | config_path = os.path.join(temp_dir, 'config.pkl')
688 | with open(config_path, 'wb') as config_file:
689 | pickle.dump(config, config_file)
690 |
691 | with zipfile.ZipFile(save_path, 'w') as model_zip:
692 | for dirname, _, filenames in os.walk(temp_dir):
693 | for filename in filenames:
694 | file_path = os.path.join(dirname, filename)
695 | model_zip.write(file_path, arcname=os.path.relpath(file_path, temp_dir))
696 |
697 | for dirname, _, filenames in os.walk(temp_dir):
698 | for filename in filenames:
699 | os.remove(os.path.join(dirname, filename))
700 | os.rmdir(temp_dir)
701 |
702 | print(f"Model and config saved to {save_path}")
703 |
704 |
705 | def load_model(load_path='model_gande'):
706 | temp_dir = 'temp_model'
707 | with zipfile.ZipFile(load_path, 'r') as model_zip:
708 | model_zip.extractall(temp_dir)
709 |
710 | config_path = os.path.join(temp_dir, 'config.pkl')
711 | with open(config_path, 'rb') as config_file:
712 | config = pickle.load(config_file)
713 |
714 |
715 | model = GRANDE(params=config['params'], args=config['args'])
716 |
717 | model.build_model()
718 | model.predict(np.random.uniform(0,1, (1,config['params']['number_of_variables'])))
719 |
720 | variables = np.load(os.path.join(temp_dir, 'variables.npz'))
721 |
722 | model.estimator_weights.assign(variables['estimator_weights'])
723 | model.split_values.assign(variables['split_values'])
724 | model.split_index_array.assign(variables['split_index_array'])
725 | model.leaf_classes_array.assign(variables['leaf_classes_array'])
726 |
727 | for dirname, _, filenames in os.walk(temp_dir):
728 | for filename in filenames:
729 | os.remove(os.path.join(dirname, filename))
730 | os.rmdir(temp_dir)
731 |
732 | print("Model and config loaded successfully.")
733 |
734 | return model
735 |
736 | def apply_dropout_leaf(self,
737 | index_array: tf.Tensor,
738 | training: bool):
739 |
740 | if training and self.dropout > 0.0:
741 | index_array = tf.nn.dropout(index_array, rate=self.dropout)*(1-self.dropout)
742 | index_array = index_array/tf.expand_dims(tf.reduce_sum(index_array, axis=1), 1)
743 | else:
744 | index_array = index_array
745 |
746 | return index_array
747 |
748 |
749 | def entmax15TF(inputs, axis=-1):
750 |
751 | # Implementation taken from: https://github.com/deep-spin/entmax/tree/master/entmax
752 |
753 | """
754 | Entmax 1.5 implementation, heavily inspired by
755 | * paper: https://arxiv.org/pdf/1905.05702.pdf
756 | * pytorch code: https://github.com/deep-spin/entmax
757 | :param inputs: similar to softmax logits, but for entmax1.5
758 | :param axis: entmax1.5 outputs will sum to 1 over this axis
759 | :return: entmax activations of same shape as inputs
760 | """
761 | @tf.custom_gradient
762 | def _entmax_inner(inputs):
763 | with tf.name_scope('entmax'):
764 | inputs = inputs / 2 # divide by 2 so as to solve actual entmax
765 | inputs -= tf.reduce_max(inputs, axis, keepdims=True) # subtract max for stability
766 |
767 | threshold, _ = entmax_threshold_and_supportTF(inputs, axis)
768 | outputs_sqrt = tf.nn.relu(inputs - threshold)
769 | outputs = tf.square(outputs_sqrt)
770 |
771 | def grad_fn(d_outputs):
772 | with tf.name_scope('entmax_grad'):
773 | d_inputs = d_outputs * outputs_sqrt
774 | q = tf.reduce_sum(d_inputs, axis=axis, keepdims=True)
775 | q = q / tf.reduce_sum(outputs_sqrt, axis=axis, keepdims=True)
776 | d_inputs -= q * outputs_sqrt
777 | return d_inputs
778 |
779 | return outputs, grad_fn
780 |
781 | return _entmax_inner(inputs)
782 |
783 |
784 | def top_k_over_axisTF(inputs, k, axis=-1, **kwargs):
785 | """ performs tf.nn.top_k over any chosen axis """
786 | with tf.name_scope('top_k_along_axis'):
787 | if axis == -1:
788 | return tf.nn.top_k(inputs, k, **kwargs)
789 |
790 | perm_order = list(range(inputs.shape.ndims))
791 | perm_order.append(perm_order.pop(axis))
792 | inv_order = [perm_order.index(i) for i in range(len(perm_order))]
793 |
794 | input_perm = tf.transpose(inputs, perm_order)
795 | input_perm_sorted, sort_indices_perm = tf.nn.top_k(
796 | input_perm, k=k, **kwargs)
797 |
798 | input_sorted = tf.transpose(input_perm_sorted, inv_order)
799 | sort_indices = tf.transpose(sort_indices_perm, inv_order)
800 | return input_sorted, sort_indices
801 |
802 |
803 | def _make_ix_likeTF(inputs, axis=-1):
804 | """ creates indices 0, ... , input[axis] unsqueezed to input dimensios """
805 | assert inputs.shape.ndims is not None
806 | rho = tf.cast(tf.range(1, tf.shape(inputs)[axis] + 1), dtype=inputs.dtype)
807 | view = [1] * inputs.shape.ndims
808 | view[axis] = -1
809 | return tf.reshape(rho, view)
810 |
811 |
812 | def gather_over_axisTF(values, indices, gather_axis):
813 | """
814 | replicates the behavior of torch.gather for tf<=1.8;
815 | for newer versions use tf.gather with batch_dims
816 | :param values: tensor [d0, ..., dn]
817 | :param indices: int64 tensor of same shape as values except for gather_axis
818 | :param gather_axis: performs gather along this axis
819 | :returns: gathered values, same shape as values except for gather_axis
820 | If gather_axis == 2
821 | gathered_values[i, j, k, ...] = values[i, j, indices[i, j, k, ...], ...]
822 | see torch.gather for more detils
823 | """
824 | assert indices.shape.ndims is not None
825 | assert indices.shape.ndims == values.shape.ndims
826 |
827 | ndims = indices.shape.ndims
828 | gather_axis = gather_axis % ndims
829 | shape = tf.shape(indices)
830 |
831 | selectors = []
832 | for axis_i in range(ndims):
833 | if axis_i == gather_axis:
834 | selectors.append(indices)
835 | else:
836 | index_i = tf.range(tf.cast(shape[axis_i], dtype=indices.dtype), dtype=indices.dtype)
837 | index_i = tf.reshape(index_i, [-1 if i == axis_i else 1 for i in range(ndims)])
838 | index_i = tf.tile(index_i, [shape[i] if i != axis_i else 1 for i in range(ndims)])
839 | selectors.append(index_i)
840 |
841 | return tf.gather_nd(values, tf.stack(selectors, axis=-1))
842 |
843 |
844 | def entmax_threshold_and_supportTF(inputs, axis=-1):
845 | """
846 | Computes clipping threshold for entmax1.5 over specified axis
847 | NOTE this implementation uses the same heuristic as
848 | the original code: https://tinyurl.com/pytorch-entmax-line-203
849 | :param inputs: (entmax1.5 inputs - max) / 2
850 | :param axis: entmax1.5 outputs will sum to 1 over this axis
851 | """
852 |
853 | with tf.name_scope('entmax_threshold_and_supportTF'):
854 | num_outcomes = tf.shape(inputs)[axis]
855 | inputs_sorted, _ = top_k_over_axisTF(inputs, k=num_outcomes, axis=axis, sorted=True)
856 |
857 | rho = _make_ix_likeTF(inputs, axis=axis)
858 |
859 | mean = tf.cumsum(inputs_sorted, axis=axis) / rho
860 |
861 | mean_sq = tf.cumsum(tf.square(inputs_sorted), axis=axis) / rho
862 | delta = (1 - rho * (mean_sq - tf.square(mean))) / rho
863 |
864 | delta_nz = tf.nn.relu(delta)
865 | tau = mean - tf.sqrt(delta_nz)
866 |
867 | support_size = tf.reduce_sum(tf.cast(tf.less_equal(tau, inputs_sorted), tf.int64), axis=axis, keepdims=True)
868 |
869 | tau_star = gather_over_axisTF(tau, support_size - 1, axis)
870 | return tau_star, support_size
871 |
872 |
873 | def loss_function_weighting(loss_function, temp=0.25):
874 |
875 | # Implementation of "Stochastic Re-weighted Gradient Descent via Distributionally Robust Optimization" from https://arxiv.org/abs/2306.09222
876 |
877 | loss_function.reduction = tf.keras.losses.Reduction.NONE
878 | def _loss_function_weighting(y_true, y_pred):
879 | loss = loss_function(y_true, y_pred)
880 |
881 | if temp > 0:
882 | clamped_loss = tf.clip_by_value(loss, clip_value_min=float('-inf'), clip_value_max=temp)
883 |
884 | out = loss * tf.stop_gradient(tf.exp(clamped_loss / (temp + 1)))
885 | else:
886 | out = loss
887 |
888 | return tf.reduce_mean(out)
889 | return _loss_function_weighting
890 |
891 |
892 | def loss_function_regression(loss_name, mean, std): #mean, log,
893 | loss_function = tf.keras.losses.get(loss_name)
894 | def _loss_function_regression(y_true, y_pred):
895 | #if tf.keras.backend.learning_phase():
896 | y_true = (y_true - mean) / std
897 |
898 | loss = loss_function(y_true, y_pred)
899 |
900 | return loss
901 | return _loss_function_regression
902 |
903 | def _threshold_and_supportTF(input, dim=-1):
904 | Xsrt = tf.sort(input, axis=dim, direction='DESCENDING')
905 |
906 | rho = tf.range(1, tf.shape(input)[dim] + 1, dtype=input.dtype)
907 | mean = tf.math.cumsum(Xsrt, axis=dim) / rho
908 | mean_sq = tf.math.cumsum(tf.square(Xsrt), axis=dim) / rho
909 | ss = rho * (mean_sq - tf.square(mean))
910 | delta = (1 - ss) / rho
911 |
912 | delta_nz = tf.maximum(delta, 0)
913 | tau = mean - tf.sqrt(delta_nz)
914 |
915 | support_size = tf.reduce_sum(tf.cast(tau <= Xsrt, tf.int32), axis=dim)
916 | tau_star = tf.gather(tau, support_size - 1, batch_dims=-1)
917 | return tau_star, support_size
918 |
919 | def get_optimizer_by_name(optimizer_name, learning_rate, warmup_steps, cosine_decay_steps):
920 |
921 |
922 | if cosine_decay_steps > 0:
923 | learning_rate = tf.keras.optimizers.schedules.CosineDecayRestarts(
924 | initial_learning_rate=learning_rate,
925 | first_decay_steps=cosine_decay_steps,
926 | #first_decay_steps=steps_per_epoch,
927 | )
928 |
929 | if optimizer_name== 'SWA' or optimizer_name== 'EMA':
930 | #optimizer = tfa.optimizers.SWA(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=learning_rate), average_period=5)
931 | frequency = 10
932 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,
933 | use_ema = True,
934 | ema_momentum = 1/frequency,
935 | ema_overwrite_frequency = 1
936 | )
937 | else:
938 | optimizer = tf.keras.optimizers.get(optimizer_name)
939 | optimizer.learning_rate = learning_rate
940 |
941 | return optimizer
942 |
943 |
944 |
--------------------------------------------------------------------------------
/GRANDE_minimal_example_with_comparison_REG.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "bf584ce8-8849-4a53-8ce9-2de8f9dd752d",
7 | "metadata": {
8 | "execution": {
9 | "iopub.execute_input": "2024-03-15T11:56:56.253578Z",
10 | "iopub.status.busy": "2024-03-15T11:56:56.253464Z",
11 | "iopub.status.idle": "2024-03-15T11:56:56.265905Z",
12 | "shell.execute_reply": "2024-03-15T11:56:56.265512Z",
13 | "shell.execute_reply.started": "2024-03-15T11:56:56.253566Z"
14 | }
15 | },
16 | "outputs": [],
17 | "source": [
18 | "#specify GPU to use\n",
19 | "import os\n",
20 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "id": "52b07238-f972-45df-86ce-a11cc11d72e3",
27 | "metadata": {
28 | "execution": {
29 | "iopub.execute_input": "2024-03-15T11:56:56.266511Z",
30 | "iopub.status.busy": "2024-03-15T11:56:56.266392Z",
31 | "iopub.status.idle": "2024-03-15T11:56:57.956110Z",
32 | "shell.execute_reply": "2024-03-15T11:56:57.955562Z",
33 | "shell.execute_reply.started": "2024-03-15T11:56:56.266499Z"
34 | }
35 | },
36 | "outputs": [
37 | {
38 | "name": "stdout",
39 | "output_type": "stream",
40 | "text": [
41 | "Training set size: 856\n",
42 | "Validation set size: 214\n",
43 | "Test set size: 268\n"
44 | ]
45 | },
46 | {
47 | "data": {
48 | "text/html": [
49 | "\n",
50 | "\n",
63 | "
\n",
64 | " \n",
65 | " \n",
66 | " | \n",
67 | " age | \n",
68 | " sex | \n",
69 | " bmi | \n",
70 | " children | \n",
71 | " smoker | \n",
72 | " region | \n",
73 | "
\n",
74 | " \n",
75 | " \n",
76 | " \n",
77 | " | 94 | \n",
78 | " 33 | \n",
79 | " male | \n",
80 | " 35.750 | \n",
81 | " 1 | \n",
82 | " yes | \n",
83 | " southeast | \n",
84 | "
\n",
85 | " \n",
86 | " | 814 | \n",
87 | " 46 | \n",
88 | " male | \n",
89 | " 19.855 | \n",
90 | " 0 | \n",
91 | " no | \n",
92 | " northwest | \n",
93 | "
\n",
94 | " \n",
95 | " | 246 | \n",
96 | " 28 | \n",
97 | " female | \n",
98 | " 27.500 | \n",
99 | " 2 | \n",
100 | " no | \n",
101 | " southwest | \n",
102 | "
\n",
103 | " \n",
104 | " | 794 | \n",
105 | " 36 | \n",
106 | " female | \n",
107 | " 25.900 | \n",
108 | " 1 | \n",
109 | " no | \n",
110 | " southwest | \n",
111 | "
\n",
112 | " \n",
113 | " | 1239 | \n",
114 | " 42 | \n",
115 | " female | \n",
116 | " 36.195 | \n",
117 | " 1 | \n",
118 | " no | \n",
119 | " northwest | \n",
120 | "
\n",
121 | " \n",
122 | "
\n",
123 | "
"
124 | ],
125 | "text/plain": [
126 | " age sex bmi children smoker region\n",
127 | "94 33 male 35.750 1 yes southeast\n",
128 | "814 46 male 19.855 0 no northwest\n",
129 | "246 28 female 27.500 2 no southwest\n",
130 | "794 36 female 25.900 1 no southwest\n",
131 | "1239 42 female 36.195 1 no northwest"
132 | ]
133 | },
134 | "execution_count": 2,
135 | "metadata": {},
136 | "output_type": "execute_result"
137 | }
138 | ],
139 | "source": [
140 | "from sklearn.model_selection import train_test_split\n",
141 | "import openml\n",
142 | "import category_encoders as ce\n",
143 | "import numpy as np\n",
144 | "import sklearn\n",
145 | "\n",
146 | "# Load healthcare_insurance_expenses dataset\n",
147 | "dataset = openml.datasets.get_dataset(46931, download_data=True, download_qualities=True, download_features_meta_data=True)\n",
148 | "X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
149 | "categorical_feature_indices = [idx for idx, idx_bool in enumerate(categorical_indicator) if idx_bool]\n",
150 | "\n",
151 | "X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
152 | "\n",
153 | "X_train, X_valid, y_train, y_valid = train_test_split(X_temp, y_temp, test_size=0.2, random_state=42)\n",
154 | "\n",
155 | "print(\"Training set size:\", len(X_train))\n",
156 | "print(\"Validation set size:\", len(X_valid))\n",
157 | "print(\"Test set size:\", len(X_test))\n",
158 | "\n",
159 | "X_train.head()"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": 3,
165 | "id": "add65bc2",
166 | "metadata": {},
167 | "outputs": [
168 | {
169 | "data": {
170 | "text/plain": [
171 | "[1, 4, 5]"
172 | ]
173 | },
174 | "execution_count": 3,
175 | "metadata": {},
176 | "output_type": "execute_result"
177 | }
178 | ],
179 | "source": [
180 | "categorical_feature_indices"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "id": "bf0a93c2-e9a4-4fa5-a7f6-44689dd07d37",
187 | "metadata": {
188 | "execution": {
189 | "iopub.execute_input": "2024-03-15T11:56:57.957002Z",
190 | "iopub.status.busy": "2024-03-15T11:56:57.956779Z",
191 | "iopub.status.idle": "2024-03-15T11:57:30.600318Z",
192 | "shell.execute_reply": "2024-03-15T11:57:30.599632Z",
193 | "shell.execute_reply.started": "2024-03-15T11:56:57.956987Z"
194 | },
195 | "scrolled": true
196 | },
197 | "outputs": [
198 | {
199 | "name": "stdout",
200 | "output_type": "stream",
201 | "text": [
202 | "self.params {'depth': 5, 'n_estimators': 1024, 'learning_rate_weights': 0.001, 'learning_rate_index': 0.01, 'learning_rate_values': 0.05, 'learning_rate_leaf': 0.15, 'learning_rate_embedding': 0.01, 'use_category_embeddings': True, 'embedding_dim_cat': 8, 'use_numeric_embeddings': False, 'embedding_dim_num': 8, 'embedding_threshold': 1, 'loo_cardinality': 10, 'dropout': 0.2, 'selected_variables': 0.8, 'data_subset_fraction': 1.0, 'bootstrap': False, 'missing_values': False, 'optimizer': 'adam', 'cosine_decay_restarts': False, 'reduce_on_plateau_scheduler': True, 'label_smoothing': 0.0, 'use_class_weights': False, 'focal_loss': False, 'swa': False, 'es_metric': True, 'epochs': 250, 'batch_size': 256, 'early_stopping_epochs': 50, 'use_freq_enc': False, 'use_robust_scale_smoothing': False, 'problem_type': 'regression', 'random_seed': 42, 'verbose': 2, 'device': 'cuda:0', 'objective': 'regression'}\n",
203 | "X_train shape before embeddings (856, 6)\n",
204 | "categories [5]\n",
205 | "number_of_variables 1 6\n",
206 | "number_of_variables 2 13\n",
207 | "{'self.column_names_dataframe': ['age',\n",
208 | " 'sex',\n",
209 | " 'bmi',\n",
210 | " 'children',\n",
211 | " 'smoker',\n",
212 | " 'region'],\n",
213 | " 'encoded_columns': ['sex', 'smoker'],\n",
214 | " 'encoded_columns_indices': [1, 4],\n",
215 | " 'num_columns': ['age', 'bmi', 'children'],\n",
216 | " 'num_columns_indices': [0, 2, 3],\n",
217 | " 'categorical_features_raw_indices': [5],\n",
218 | " 'not_encoded_columns': ['age', 'bmi', 'children', 'region']}\n",
219 | "Use Category Embeddings: Embedding(5, 8)\n"
220 | ]
221 | },
222 | {
223 | "name": "stderr",
224 | "output_type": "stream",
225 | "text": [
226 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n",
227 | "Online softmax is disabled on the fly since Inductor decides to\n",
228 | "split the reduction. Cut an issue to PyTorch if this is an\n",
229 | "important use case and you want to speed it up with online\n",
230 | "softmax.\n",
231 | "\n",
232 | " warnings.warn(\n",
233 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n",
234 | "Online softmax is disabled on the fly since Inductor decides to\n",
235 | "split the reduction. Cut an issue to PyTorch if this is an\n",
236 | "important use case and you want to speed it up with online\n",
237 | "softmax.\n",
238 | "\n",
239 | " warnings.warn(\n"
240 | ]
241 | },
242 | {
243 | "name": "stdout",
244 | "output_type": "stream",
245 | "text": [
246 | "Epoch 001 | TrainLoss: 0.9442 | ValLoss: 0.6983 | MSE: 110805304.0000 | MAE: 7907.7832 | R2: 0.2430 | Time: 6.61s\n",
247 | "Epoch 002 | TrainLoss: 0.6745 | ValLoss: 0.4770 | MSE: 75694136.0000 | MAE: 6047.0703 | R2: 0.4829 | Time: 0.02s\n",
248 | "Epoch 003 | TrainLoss: 0.4664 | ValLoss: 0.3108 | MSE: 49315516.0000 | MAE: 4471.3096 | R2: 0.6631 | Time: 0.02s\n",
249 | "Epoch 004 | TrainLoss: 0.3181 | ValLoss: 0.2167 | MSE: 34391280.0000 | MAE: 4079.7856 | R2: 0.7651 | Time: 0.02s\n",
250 | "Epoch 005 | TrainLoss: 0.2304 | ValLoss: 0.1863 | MSE: 29557700.0000 | MAE: 3613.9580 | R2: 0.7981 | Time: 0.02s\n",
251 | "Epoch 006 | TrainLoss: 0.2101 | ValLoss: 0.1661 | MSE: 26352892.0000 | MAE: 3145.3518 | R2: 0.8200 | Time: 0.02s\n",
252 | "Epoch 007 | TrainLoss: 0.1697 | ValLoss: 0.1398 | MSE: 22176866.0000 | MAE: 2775.7600 | R2: 0.8485 | Time: 0.02s\n",
253 | "Epoch 008 | TrainLoss: 0.1484 | ValLoss: 0.1316 | MSE: 20874780.0000 | MAE: 2580.6177 | R2: 0.8574 | Time: 0.02s\n",
254 | "Epoch 009 | TrainLoss: 0.1400 | ValLoss: 0.1328 | MSE: 21068986.0000 | MAE: 2652.9009 | R2: 0.8561 | Time: 0.02s\n",
255 | "[EarlyStop TRAINING] no improve (1/50). best=20874780.000000, curr=21068986.000000\n",
256 | "Epoch 010 | TrainLoss: 0.1354 | ValLoss: 0.1315 | MSE: 20857652.0000 | MAE: 2709.4941 | R2: 0.8575 | Time: 0.02s\n",
257 | "Epoch 011 | TrainLoss: 0.1323 | ValLoss: 0.1329 | MSE: 21088880.0000 | MAE: 2658.9639 | R2: 0.8559 | Time: 0.02s\n",
258 | "[EarlyStop TRAINING] no improve (1/50). best=20857652.000000, curr=21088880.000000\n",
259 | "Epoch 012 | TrainLoss: 0.1407 | ValLoss: 0.1376 | MSE: 21836154.0000 | MAE: 2449.4062 | R2: 0.8508 | Time: 0.02s\n",
260 | "[EarlyStop TRAINING] no improve (2/50). best=20857652.000000, curr=21836154.000000\n",
261 | "Epoch 013 | TrainLoss: 0.1340 | ValLoss: 0.1359 | MSE: 21563870.0000 | MAE: 2985.1030 | R2: 0.8527 | Time: 0.02s\n",
262 | "[EarlyStop TRAINING] no improve (3/50). best=20857652.000000, curr=21563870.000000\n",
263 | "Epoch 014 | TrainLoss: 0.1363 | ValLoss: 0.1375 | MSE: 21812660.0000 | MAE: 2578.1575 | R2: 0.8510 | Time: 0.02s\n",
264 | "[EarlyStop TRAINING] no improve (4/50). best=20857652.000000, curr=21812660.000000\n",
265 | "Epoch 015 | TrainLoss: 0.1294 | ValLoss: 0.1369 | MSE: 21724846.0000 | MAE: 2606.5918 | R2: 0.8516 | Time: 0.02s\n",
266 | "[EarlyStop TRAINING] no improve (5/50). best=20857652.000000, curr=21724846.000000\n",
267 | "Epoch 016 | TrainLoss: 0.1150 | ValLoss: 0.1372 | MSE: 21774892.0000 | MAE: 2736.3572 | R2: 0.8512 | Time: 0.02s\n",
268 | "[EarlyStop TRAINING] no improve (6/50). best=20857652.000000, curr=21774892.000000\n",
269 | "Epoch 017 | TrainLoss: 0.1263 | ValLoss: 0.1347 | MSE: 21375200.0000 | MAE: 2670.4187 | R2: 0.8540 | Time: 0.02s\n",
270 | "[EarlyStop TRAINING] no improve (7/50). best=20857652.000000, curr=21375200.000000\n",
271 | "Epoch 018 | TrainLoss: 0.1145 | ValLoss: 0.1361 | MSE: 21592756.0000 | MAE: 2709.2681 | R2: 0.8525 | Time: 0.02s\n",
272 | "[EarlyStop TRAINING] no improve (8/50). best=20857652.000000, curr=21592756.000000\n",
273 | "Epoch 019 | TrainLoss: 0.1116 | ValLoss: 0.1443 | MSE: 22900512.0000 | MAE: 2851.0134 | R2: 0.8436 | Time: 0.02s\n",
274 | "[EarlyStop TRAINING] no improve (9/50). best=20857652.000000, curr=22900512.000000\n",
275 | "Epoch 020 | TrainLoss: 0.1201 | ValLoss: 0.1437 | MSE: 22796944.0000 | MAE: 2674.2185 | R2: 0.8443 | Time: 0.02s\n",
276 | "[EarlyStop TRAINING] no improve (10/50). best=20857652.000000, curr=22796944.000000\n",
277 | "Epoch 021 | TrainLoss: 0.1125 | ValLoss: 0.1405 | MSE: 22288538.0000 | MAE: 2823.5095 | R2: 0.8477 | Time: 0.02s\n",
278 | "[EarlyStop TRAINING] no improve (11/50). best=20857652.000000, curr=22288538.000000\n",
279 | "Epoch 022 | TrainLoss: 0.1150 | ValLoss: 0.1425 | MSE: 22610858.0000 | MAE: 2622.5259 | R2: 0.8455 | Time: 0.02s\n",
280 | "[EarlyStop TRAINING] no improve (12/50). best=20857652.000000, curr=22610858.000000\n",
281 | "Epoch 023 | TrainLoss: 0.1120 | ValLoss: 0.1439 | MSE: 22834784.0000 | MAE: 3120.4143 | R2: 0.8440 | Time: 0.02s\n",
282 | "[EarlyStop TRAINING] no improve (13/50). best=20857652.000000, curr=22834784.000000\n",
283 | "Epoch 024 | TrainLoss: 0.1106 | ValLoss: 0.1451 | MSE: 23015768.0000 | MAE: 2632.4998 | R2: 0.8428 | Time: 0.02s\n",
284 | "[EarlyStop TRAINING] no improve (14/50). best=20857652.000000, curr=23015768.000000\n",
285 | "Epoch 025 | TrainLoss: 0.1116 | ValLoss: 0.1432 | MSE: 22729254.0000 | MAE: 3018.4402 | R2: 0.8447 | Time: 0.02s\n",
286 | "[EarlyStop TRAINING] no improve (15/50). best=20857652.000000, curr=22729254.000000\n",
287 | "Epoch 026 | TrainLoss: 0.1055 | ValLoss: 0.1451 | MSE: 23025022.0000 | MAE: 2757.2036 | R2: 0.8427 | Time: 0.02s\n",
288 | "[EarlyStop TRAINING] no improve (16/50). best=20857652.000000, curr=23025022.000000\n",
289 | "Epoch 027 | TrainLoss: 0.1120 | ValLoss: 0.1428 | MSE: 22663312.0000 | MAE: 2849.4290 | R2: 0.8452 | Time: 0.02s\n",
290 | "[EarlyStop TRAINING] no improve (17/50). best=20857652.000000, curr=22663312.000000\n",
291 | "Epoch 028 | TrainLoss: 0.1024 | ValLoss: 0.1442 | MSE: 22888228.0000 | MAE: 2710.1841 | R2: 0.8436 | Time: 0.02s\n",
292 | "[EarlyStop TRAINING] no improve (18/50). best=20857652.000000, curr=22888228.000000\n",
293 | "Epoch 029 | TrainLoss: 0.1029 | ValLoss: 0.1486 | MSE: 23580132.0000 | MAE: 2878.5706 | R2: 0.8389 | Time: 0.02s\n",
294 | "[EarlyStop TRAINING] no improve (19/50). best=20857652.000000, curr=23580132.000000\n",
295 | "Epoch 030 | TrainLoss: 0.0962 | ValLoss: 0.1455 | MSE: 23092736.0000 | MAE: 2784.8010 | R2: 0.8422 | Time: 0.02s\n",
296 | "[EarlyStop TRAINING] no improve (20/50). best=20857652.000000, curr=23092736.000000\n",
297 | "Epoch 031 | TrainLoss: 0.1075 | ValLoss: 0.1473 | MSE: 23373388.0000 | MAE: 2717.4727 | R2: 0.8403 | Time: 0.02s\n",
298 | "[EarlyStop TRAINING] no improve (21/50). best=20857652.000000, curr=23373388.000000\n",
299 | "Epoch 032 | TrainLoss: 0.0999 | ValLoss: 0.1550 | MSE: 24594842.0000 | MAE: 2826.3821 | R2: 0.8320 | Time: 0.02s\n",
300 | "[EarlyStop TRAINING] no improve (22/50). best=20857652.000000, curr=24594842.000000\n",
301 | "Epoch 033 | TrainLoss: 0.1044 | ValLoss: 0.1545 | MSE: 24521254.0000 | MAE: 2981.0029 | R2: 0.8325 | Time: 0.02s\n",
302 | "[EarlyStop TRAINING] no improve (23/50). best=20857652.000000, curr=24521254.000000\n",
303 | "Epoch 034 | TrainLoss: 0.1000 | ValLoss: 0.1488 | MSE: 23607160.0000 | MAE: 2794.5957 | R2: 0.8387 | Time: 0.02s\n",
304 | "[EarlyStop TRAINING] no improve (24/50). best=20857652.000000, curr=23607160.000000\n",
305 | "Epoch 035 | TrainLoss: 0.1003 | ValLoss: 0.1524 | MSE: 24181636.0000 | MAE: 2961.8938 | R2: 0.8348 | Time: 0.02s\n",
306 | "[EarlyStop TRAINING] no improve (25/50). best=20857652.000000, curr=24181636.000000\n",
307 | "Epoch 036 | TrainLoss: 0.0970 | ValLoss: 0.1529 | MSE: 24256992.0000 | MAE: 2983.7617 | R2: 0.8343 | Time: 0.02s\n",
308 | "[EarlyStop TRAINING] no improve (26/50). best=20857652.000000, curr=24256992.000000\n",
309 | "Epoch 037 | TrainLoss: 0.0917 | ValLoss: 0.1549 | MSE: 24575086.0000 | MAE: 2947.2693 | R2: 0.8321 | Time: 0.02s\n",
310 | "[EarlyStop TRAINING] no improve (27/50). best=20857652.000000, curr=24575086.000000\n",
311 | "Epoch 038 | TrainLoss: 0.0932 | ValLoss: 0.1538 | MSE: 24408730.0000 | MAE: 2851.0256 | R2: 0.8333 | Time: 0.02s\n",
312 | "[EarlyStop TRAINING] no improve (28/50). best=20857652.000000, curr=24408730.000000\n",
313 | "Epoch 039 | TrainLoss: 0.0938 | ValLoss: 0.1576 | MSE: 25012860.0000 | MAE: 2975.3767 | R2: 0.8291 | Time: 0.02s\n",
314 | "[EarlyStop TRAINING] no improve (29/50). best=20857652.000000, curr=25012860.000000\n",
315 | "Epoch 040 | TrainLoss: 0.0978 | ValLoss: 0.1578 | MSE: 25035894.0000 | MAE: 3047.7266 | R2: 0.8290 | Time: 0.02s\n",
316 | "[EarlyStop TRAINING] no improve (30/50). best=20857652.000000, curr=25035894.000000\n",
317 | "Epoch 041 | TrainLoss: 0.0941 | ValLoss: 0.1551 | MSE: 24604590.0000 | MAE: 2932.2500 | R2: 0.8319 | Time: 0.02s\n",
318 | "[EarlyStop TRAINING] no improve (31/50). best=20857652.000000, curr=24604590.000000\n",
319 | "Epoch 042 | TrainLoss: 0.0862 | ValLoss: 0.1587 | MSE: 25179846.0000 | MAE: 3095.2544 | R2: 0.8280 | Time: 0.02s\n",
320 | "[EarlyStop TRAINING] no improve (32/50). best=20857652.000000, curr=25179846.000000\n",
321 | "Epoch 043 | TrainLoss: 0.0967 | ValLoss: 0.1597 | MSE: 25341562.0000 | MAE: 3101.8044 | R2: 0.8269 | Time: 0.02s\n",
322 | "[EarlyStop TRAINING] no improve (33/50). best=20857652.000000, curr=25341562.000000\n",
323 | "Epoch 044 | TrainLoss: 0.0841 | ValLoss: 0.1600 | MSE: 25389414.0000 | MAE: 2996.4243 | R2: 0.8266 | Time: 0.02s\n",
324 | "[EarlyStop TRAINING] no improve (34/50). best=20857652.000000, curr=25389414.000000\n",
325 | "Epoch 045 | TrainLoss: 0.0937 | ValLoss: 0.1590 | MSE: 25234082.0000 | MAE: 3053.0994 | R2: 0.8276 | Time: 0.02s\n",
326 | "[EarlyStop TRAINING] no improve (35/50). best=20857652.000000, curr=25234082.000000\n",
327 | "Epoch 046 | TrainLoss: 0.0850 | ValLoss: 0.1643 | MSE: 26070064.0000 | MAE: 3206.8984 | R2: 0.8219 | Time: 0.02s\n",
328 | "[EarlyStop TRAINING] no improve (36/50). best=20857652.000000, curr=26070064.000000\n",
329 | "Epoch 047 | TrainLoss: 0.0836 | ValLoss: 0.1669 | MSE: 26475984.0000 | MAE: 3009.0151 | R2: 0.8191 | Time: 0.02s\n",
330 | "[EarlyStop TRAINING] no improve (37/50). best=20857652.000000, curr=26475984.000000\n",
331 | "Epoch 048 | TrainLoss: 0.0814 | ValLoss: 0.1633 | MSE: 25911006.0000 | MAE: 3283.3823 | R2: 0.8230 | Time: 0.02s\n",
332 | "[EarlyStop TRAINING] no improve (38/50). best=20857652.000000, curr=25911006.000000\n",
333 | "Epoch 049 | TrainLoss: 0.0842 | ValLoss: 0.1626 | MSE: 25801604.0000 | MAE: 2987.0801 | R2: 0.8237 | Time: 0.02s\n",
334 | "[EarlyStop TRAINING] no improve (39/50). best=20857652.000000, curr=25801604.000000\n",
335 | "Epoch 050 | TrainLoss: 0.0828 | ValLoss: 0.1612 | MSE: 25575234.0000 | MAE: 3190.6045 | R2: 0.8253 | Time: 0.02s\n",
336 | "[EarlyStop TRAINING] no improve (40/50). best=20857652.000000, curr=25575234.000000\n",
337 | "Epoch 051 | TrainLoss: 0.0822 | ValLoss: 0.1709 | MSE: 27122086.0000 | MAE: 3196.4832 | R2: 0.8147 | Time: 0.02s\n",
338 | "[EarlyStop TRAINING] no improve (41/50). best=20857652.000000, curr=27122086.000000\n",
339 | "Epoch 052 | TrainLoss: 0.0838 | ValLoss: 0.1670 | MSE: 26501512.0000 | MAE: 3267.1182 | R2: 0.8190 | Time: 0.02s\n",
340 | "[EarlyStop TRAINING] no improve (42/50). best=20857652.000000, curr=26501512.000000\n",
341 | "Epoch 053 | TrainLoss: 0.0850 | ValLoss: 0.1626 | MSE: 25795368.0000 | MAE: 3034.9268 | R2: 0.8238 | Time: 0.02s\n",
342 | "[EarlyStop TRAINING] no improve (43/50). best=20857652.000000, curr=25795368.000000\n",
343 | "Epoch 054 | TrainLoss: 0.0825 | ValLoss: 0.1644 | MSE: 26090838.0000 | MAE: 3215.0332 | R2: 0.8218 | Time: 0.02s\n",
344 | "[EarlyStop TRAINING] no improve (44/50). best=20857652.000000, curr=26090838.000000\n",
345 | "Epoch 055 | TrainLoss: 0.0807 | ValLoss: 0.1652 | MSE: 26218004.0000 | MAE: 3167.4287 | R2: 0.8209 | Time: 0.02s\n",
346 | "[EarlyStop TRAINING] no improve (45/50). best=20857652.000000, curr=26218004.000000\n",
347 | "Epoch 056 | TrainLoss: 0.0816 | ValLoss: 0.1737 | MSE: 27564982.0000 | MAE: 3368.5759 | R2: 0.8117 | Time: 0.02s\n",
348 | "[EarlyStop TRAINING] no improve (46/50). best=20857652.000000, curr=27564982.000000\n",
349 | "Epoch 057 | TrainLoss: 0.0711 | ValLoss: 0.1764 | MSE: 27982072.0000 | MAE: 3213.2036 | R2: 0.8088 | Time: 0.02s\n",
350 | "[EarlyStop TRAINING] no improve (47/50). best=20857652.000000, curr=27982072.000000\n",
351 | "Epoch 058 | TrainLoss: 0.0840 | ValLoss: 0.1687 | MSE: 26770818.0000 | MAE: 3231.4912 | R2: 0.8171 | Time: 0.02s\n",
352 | "[EarlyStop TRAINING] no improve (48/50). best=20857652.000000, curr=26770818.000000\n",
353 | "Epoch 059 | TrainLoss: 0.0809 | ValLoss: 0.1716 | MSE: 27229828.0000 | MAE: 3382.3977 | R2: 0.8140 | Time: 0.02s\n",
354 | "[EarlyStop TRAINING] no improve (49/50). best=20857652.000000, curr=27229828.000000\n",
355 | "Epoch 060 | TrainLoss: 0.0767 | ValLoss: 0.1727 | MSE: 27396762.0000 | MAE: 3288.3879 | R2: 0.8128 | Time: 0.02s\n",
356 | "[EarlyStop TRAINING] no improve (50/50). best=20857652.000000, curr=27396762.000000\n",
357 | "[EarlyStop TRAINING] restoring best weights and stopping.\n",
358 | "Restoring best model from epoch with val score: 20857652.0\n"
359 | ]
360 | },
361 | {
362 | "name": "stderr",
363 | "output_type": "stream",
364 | "text": [
365 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n",
366 | "Online softmax is disabled on the fly since Inductor decides to\n",
367 | "split the reduction. Cut an issue to PyTorch if this is an\n",
368 | "important use case and you want to speed it up with online\n",
369 | "softmax.\n",
370 | "\n",
371 | " warnings.warn(\n",
372 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n",
373 | "Online softmax is disabled on the fly since Inductor decides to\n",
374 | "split the reduction. Cut an issue to PyTorch if this is an\n",
375 | "important use case and you want to speed it up with online\n",
376 | "softmax.\n",
377 | "\n",
378 | " warnings.warn(\n"
379 | ]
380 | }
381 | ],
382 | "source": [
383 | "from GRANDE import GRANDE\n",
384 | "\n",
385 | "params = {\n",
386 | " 'depth': 5,\n",
387 | " 'n_estimators': 1024,\n",
388 | " \n",
389 | " 'learning_rate_weights': 0.001,\n",
390 | " 'learning_rate_index': 0.01,\n",
391 | " 'learning_rate_values': 0.05,\n",
392 | " 'learning_rate_leaf': 0.15,\n",
393 | " 'learning_rate_embedding': 0.01,\n",
394 | "\n",
395 | " 'use_category_embeddings': True,\n",
396 | " 'embedding_dim_cat': 8,\n",
397 | " 'use_numeric_embeddings': False,\n",
398 | " 'embedding_dim_num': 8,\n",
399 | " 'embedding_threshold': 1,\n",
400 | " 'loo_cardinality': 10,\n",
401 | "\n",
402 | "\n",
403 | " 'dropout': 0.2,\n",
404 | " 'selected_variables': 0.8,\n",
405 | " 'data_subset_fraction': 1.0,\n",
406 | " 'bootstrap': False,\n",
407 | " 'missing_values': False,\n",
408 | "\n",
409 | " 'optimizer': 'adam', #nadam, radam, adamw, adam \n",
410 | " 'cosine_decay_restarts': False,\n",
411 | " 'reduce_on_plateau_scheduler': True,\n",
412 | " 'label_smoothing': 0.0,\n",
413 | " 'use_class_weights': False,\n",
414 | " 'focal_loss': False,\n",
415 | " 'swa': False,\n",
416 | " 'es_metric': True, # if True use AUC for binary, MSE for regression, val_loss for multiclass\n",
417 | "\n",
418 | "\n",
419 | " 'epochs': 250,\n",
420 | " 'batch_size': 256,\n",
421 | " 'early_stopping_epochs': 50,\n",
422 | "\n",
423 | " 'use_freq_enc': False,\n",
424 | " 'use_robust_scale_smoothing': False,\n",
425 | " 'problem_type': 'regression',\n",
426 | " \n",
427 | " 'random_seed': 42,\n",
428 | " 'verbose': 2,\n",
429 | "}\n",
430 | "\n",
431 | "model_grande = GRANDE(params=params)\n",
432 | "\n",
433 | "model_grande.fit(X=X_train,\n",
434 | " y=y_train,\n",
435 | " X_val=X_valid,\n",
436 | " y_val=y_valid)\n",
437 | "\n",
438 | "preds_grande = model_grande.predict_proba(X_test)"
439 | ]
440 | },
441 | {
442 | "cell_type": "code",
443 | "execution_count": 5,
444 | "id": "b251df9d-67f0-4dd1-a0cc-379621e33fad",
445 | "metadata": {
446 | "execution": {
447 | "iopub.execute_input": "2024-03-15T11:57:30.601353Z",
448 | "iopub.status.busy": "2024-03-15T11:57:30.601004Z",
449 | "iopub.status.idle": "2024-03-15T11:57:30.604147Z",
450 | "shell.execute_reply": "2024-03-15T11:57:30.603752Z",
451 | "shell.execute_reply.started": "2024-03-15T11:57:30.601338Z"
452 | }
453 | },
454 | "outputs": [],
455 | "source": [
456 | "def calculate_sample_weights(y_data):\n",
457 | " class_weights = sklearn.utils.class_weight.compute_class_weight(class_weight = 'balanced', classes = np.unique(y_data), y = y_data)\n",
458 | " sample_weights = sklearn.utils.class_weight.compute_sample_weight(class_weight = 'balanced', y =y_data)\n",
459 | " return sample_weights"
460 | ]
461 | },
462 | {
463 | "cell_type": "code",
464 | "execution_count": 6,
465 | "id": "c17e805c-0110-4e39-97d3-34573804296b",
466 | "metadata": {
467 | "execution": {
468 | "iopub.execute_input": "2024-03-15T11:57:30.605592Z",
469 | "iopub.status.busy": "2024-03-15T11:57:30.605338Z",
470 | "iopub.status.idle": "2024-03-15T11:57:30.611465Z",
471 | "shell.execute_reply": "2024-03-15T11:57:30.611074Z",
472 | "shell.execute_reply.started": "2024-03-15T11:57:30.605578Z"
473 | }
474 | },
475 | "outputs": [],
476 | "source": [
477 | "try:\n",
478 | " y_train = y_train.values.codes.astype(np.float64)\n",
479 | " y_valid = y_valid.values.codes.astype(np.float64)\n",
480 | " y_test = y_test.values.codes.astype(np.float64)\n",
481 | "except:\n",
482 | " y_train = y_train.values.astype(np.float64)\n",
483 | " y_valid = y_valid.values.astype(np.float64)\n",
484 | " y_test = y_test.values.astype(np.float64)"
485 | ]
486 | },
487 | {
488 | "cell_type": "code",
489 | "execution_count": 7,
490 | "id": "cfc121e3-2d7b-4fb1-8616-0318fcc20ac1",
491 | "metadata": {
492 | "execution": {
493 | "iopub.execute_input": "2024-03-15T11:57:30.612164Z",
494 | "iopub.status.busy": "2024-03-15T11:57:30.611960Z",
495 | "iopub.status.idle": "2024-03-15T11:57:30.627515Z",
496 | "shell.execute_reply": "2024-03-15T11:57:30.627108Z",
497 | "shell.execute_reply.started": "2024-03-15T11:57:30.612151Z"
498 | }
499 | },
500 | "outputs": [],
501 | "source": [
502 | "binary_indices = []\n",
503 | "low_cardinality_indices = []\n",
504 | "high_cardinality_indices = []\n",
505 | "num_columns = []\n",
506 | "for column_index, column in enumerate(X_train.columns):\n",
507 | " if column_index in categorical_feature_indices:\n",
508 | " if len(X_train.iloc[:,column_index].unique()) <= 2:\n",
509 | " binary_indices.append(column)\n",
510 | " if len(X_train.iloc[:,column_index].unique()) < 5:\n",
511 | " low_cardinality_indices.append(column)\n",
512 | " else:\n",
513 | " high_cardinality_indices.append(column)\n",
514 | " else:\n",
515 | " num_columns.append(column) \n",
516 | "cat_columns = [col for col in X_train.columns if col not in num_columns]"
517 | ]
518 | },
519 | {
520 | "cell_type": "code",
521 | "execution_count": 8,
522 | "id": "f7e55785-bf56-4aad-9825-7c5737be5ab7",
523 | "metadata": {
524 | "execution": {
525 | "iopub.execute_input": "2024-03-15T11:57:30.628204Z",
526 | "iopub.status.busy": "2024-03-15T11:57:30.627992Z",
527 | "iopub.status.idle": "2024-03-15T11:57:30.691371Z",
528 | "shell.execute_reply": "2024-03-15T11:57:30.690745Z",
529 | "shell.execute_reply.started": "2024-03-15T11:57:30.628191Z"
530 | }
531 | },
532 | "outputs": [],
533 | "source": [
534 | "if len(num_columns) > 0:\n",
535 | " mean_train_num = X_train[num_columns].mean(axis=0).iloc[0]\n",
536 | " X_train[num_columns] = X_train[num_columns].fillna(mean_train_num)\n",
537 | " X_valid[num_columns] = X_valid[num_columns].fillna(mean_train_num)\n",
538 | " X_test[num_columns] = X_test[num_columns].fillna(mean_train_num)\n",
539 | "if len(cat_columns) > 0:\n",
540 | " mode_train_cat = X_train[cat_columns].mode(axis=0).iloc[0]\n",
541 | " X_train[cat_columns] = X_train[cat_columns].fillna(mode_train_cat)\n",
542 | " X_valid[cat_columns] = X_valid[cat_columns].fillna(mode_train_cat)\n",
543 | " X_test[cat_columns] = X_test[cat_columns].fillna(mode_train_cat)\n",
544 | "\n",
545 | "X_train_raw = X_train.copy()\n",
546 | "X_valid_raw = X_valid.copy()\n",
547 | "X_test_raw = X_test.copy()"
548 | ]
549 | },
550 | {
551 | "cell_type": "code",
552 | "execution_count": 9,
553 | "id": "b5d5f00f",
554 | "metadata": {
555 | "execution": {
556 | "iopub.execute_input": "2024-03-15T11:57:30.692101Z",
557 | "iopub.status.busy": "2024-03-15T11:57:30.691967Z",
558 | "iopub.status.idle": "2024-03-15T11:57:31.276758Z",
559 | "shell.execute_reply": "2024-03-15T11:57:31.276086Z",
560 | "shell.execute_reply.started": "2024-03-15T11:57:30.692089Z"
561 | }
562 | },
563 | "outputs": [],
564 | "source": [
565 | "encoder_ordinal = ce.OrdinalEncoder(cols=binary_indices)\n",
566 | "encoder_ordinal.fit(X_train)\n",
567 | "X_train = encoder_ordinal.transform(X_train)\n",
568 | "X_valid = encoder_ordinal.transform(X_valid) \n",
569 | "X_test = encoder_ordinal.transform(X_test) \n",
570 | "\n",
571 | "encoder = ce.LeaveOneOutEncoder(cols=high_cardinality_indices)\n",
572 | "encoder.fit(X_train, y_train)\n",
573 | "X_train = encoder.transform(X_train)\n",
574 | "X_valid = encoder.transform(X_valid)\n",
575 | "X_test = encoder.transform(X_test)\n",
576 | "\n",
577 | "encoder = ce.OneHotEncoder(cols=low_cardinality_indices)\n",
578 | "encoder.fit(X_train)\n",
579 | "X_train = encoder.transform(X_train)\n",
580 | "X_valid = encoder.transform(X_valid)\n",
581 | "X_test = encoder.transform(X_test)\n",
582 | "\n",
583 | "X_train = X_train.astype(np.float32)\n",
584 | "X_valid = X_valid.astype(np.float32)\n",
585 | "X_test = X_test.astype(np.float32)"
586 | ]
587 | },
588 | {
589 | "cell_type": "code",
590 | "execution_count": 10,
591 | "id": "361f7d9f-7e58-4bc1-8855-153b9e972742",
592 | "metadata": {
593 | "execution": {
594 | "iopub.execute_input": "2024-03-15T11:57:31.277594Z",
595 | "iopub.status.busy": "2024-03-15T11:57:31.277439Z",
596 | "iopub.status.idle": "2024-03-15T11:57:31.575860Z",
597 | "shell.execute_reply": "2024-03-15T11:57:31.575224Z",
598 | "shell.execute_reply.started": "2024-03-15T11:57:31.277581Z"
599 | }
600 | },
601 | "outputs": [
602 | {
603 | "name": "stdout",
604 | "output_type": "stream",
605 | "text": [
606 | "[0]\tvalidation_0-rmse:9117.37519\n",
607 | "[1]\tvalidation_0-rmse:7182.91726\n",
608 | "[2]\tvalidation_0-rmse:6008.67928\n",
609 | "[3]\tvalidation_0-rmse:5384.74130\n",
610 | "[4]\tvalidation_0-rmse:5041.28295\n",
611 | "[5]\tvalidation_0-rmse:4821.53449\n",
612 | "[6]\tvalidation_0-rmse:4714.24431\n",
613 | "[7]\tvalidation_0-rmse:4660.87726\n",
614 | "[8]\tvalidation_0-rmse:4676.10315\n",
615 | "[9]\tvalidation_0-rmse:4700.87577\n",
616 | "[10]\tvalidation_0-rmse:4677.24837\n",
617 | "[11]\tvalidation_0-rmse:4667.92836\n",
618 | "[12]\tvalidation_0-rmse:4681.76173\n",
619 | "[13]\tvalidation_0-rmse:4687.39543\n",
620 | "[14]\tvalidation_0-rmse:4754.37928\n",
621 | "[15]\tvalidation_0-rmse:4738.09888\n",
622 | "[16]\tvalidation_0-rmse:4735.64695\n",
623 | "[17]\tvalidation_0-rmse:4751.13172\n",
624 | "[18]\tvalidation_0-rmse:4773.50335\n",
625 | "[19]\tvalidation_0-rmse:4790.19284\n",
626 | "[20]\tvalidation_0-rmse:4773.47564\n",
627 | "[21]\tvalidation_0-rmse:4774.69551\n",
628 | "[22]\tvalidation_0-rmse:4778.65167\n",
629 | "[23]\tvalidation_0-rmse:4775.53956\n",
630 | "[24]\tvalidation_0-rmse:4781.68505\n",
631 | "[25]\tvalidation_0-rmse:4795.20388\n",
632 | "[26]\tvalidation_0-rmse:4829.21788\n",
633 | "[27]\tvalidation_0-rmse:4837.16893\n"
634 | ]
635 | }
636 | ],
637 | "source": [
638 | "if params['problem_type'] == 'regression':\n",
639 | " from xgboost import XGBRegressor\n",
640 | " model_xgb = XGBRegressor(n_estimators=1000, early_stopping_rounds=20)\n",
641 | " model_xgb.fit(X_train, \n",
642 | " y_train, \n",
643 | " eval_set=[(X_valid, y_valid)], \n",
644 | " )\n",
645 | " preds_xgb = model_xgb.predict(X_test)\n",
646 | "else:\n",
647 | " from xgboost import XGBClassifier\n",
648 | " model_xgb = XGBClassifier(n_estimators=1000, early_stopping_rounds=20)\n",
649 | " model_xgb.fit(X_train, \n",
650 | " y_train, \n",
651 | " #sample_weight=calculate_sample_weights(y_train), \n",
652 | " eval_set=[(X_valid, y_valid)], \n",
653 | " #sample_weight_eval_set=[calculate_sample_weights(y_valid)]\n",
654 | " )\n",
655 | "\n",
656 | "\n",
657 | " preds_xgb = model_xgb.predict_proba(X_test)"
658 | ]
659 | },
660 | {
661 | "cell_type": "code",
662 | "execution_count": 11,
663 | "id": "6b23c67e-cdde-40c2-842a-61383e1a62d0",
664 | "metadata": {
665 | "execution": {
666 | "iopub.execute_input": "2024-03-15T11:57:31.576940Z",
667 | "iopub.status.busy": "2024-03-15T11:57:31.576558Z",
668 | "iopub.status.idle": "2024-03-15T11:57:52.427379Z",
669 | "shell.execute_reply": "2024-03-15T11:57:52.426427Z",
670 | "shell.execute_reply.started": "2024-03-15T11:57:31.576923Z"
671 | }
672 | },
673 | "outputs": [
674 | {
675 | "name": "stdout",
676 | "output_type": "stream",
677 | "text": [
678 | "Learning rate set to 0.049693\n",
679 | "0:\tlearn: 12135.8956932\ttest: 11637.9908345\tbest: 11637.9908345 (0)\ttotal: 48.8ms\tremaining: 48.8s\n",
680 | "1:\tlearn: 11663.8838516\ttest: 11175.5245901\tbest: 11175.5245901 (1)\ttotal: 50.8ms\tremaining: 25.4s\n",
681 | "2:\tlearn: 11273.2960222\ttest: 10835.8030762\tbest: 10835.8030762 (2)\ttotal: 52.5ms\tremaining: 17.5s\n",
682 | "3:\tlearn: 10889.3829166\ttest: 10474.9396901\tbest: 10474.9396901 (3)\ttotal: 53.4ms\tremaining: 13.3s\n",
683 | "4:\tlearn: 10532.8002095\ttest: 10132.1552031\tbest: 10132.1552031 (4)\ttotal: 55ms\tremaining: 10.9s\n",
684 | "5:\tlearn: 10196.7503986\ttest: 9816.3979358\tbest: 9816.3979358 (5)\ttotal: 56.6ms\tremaining: 9.38s\n",
685 | "6:\tlearn: 9836.9135106\ttest: 9465.9962973\tbest: 9465.9962973 (6)\ttotal: 58.7ms\tremaining: 8.32s\n",
686 | "7:\tlearn: 9523.0102491\ttest: 9176.2224294\tbest: 9176.2224294 (7)\ttotal: 59.7ms\tremaining: 7.4s\n",
687 | "8:\tlearn: 9234.9423371\ttest: 8905.7040385\tbest: 8905.7040385 (8)\ttotal: 60.2ms\tremaining: 6.63s\n",
688 | "9:\tlearn: 8946.3797800\ttest: 8636.1173694\tbest: 8636.1173694 (9)\ttotal: 61ms\tremaining: 6.04s\n",
689 | "10:\tlearn: 8693.7650289\ttest: 8391.1433126\tbest: 8391.1433126 (10)\ttotal: 61.8ms\tremaining: 5.56s\n",
690 | "11:\tlearn: 8430.3329885\ttest: 8151.9293896\tbest: 8151.9293896 (11)\ttotal: 62.5ms\tremaining: 5.14s\n",
691 | "12:\tlearn: 8197.7535505\ttest: 7937.7069961\tbest: 7937.7069961 (12)\ttotal: 63.8ms\tremaining: 4.84s\n",
692 | "13:\tlearn: 7985.5741708\ttest: 7728.4296985\tbest: 7728.4296985 (13)\ttotal: 64.8ms\tremaining: 4.57s\n",
693 | "14:\tlearn: 7806.9133166\ttest: 7562.0935256\tbest: 7562.0935256 (14)\ttotal: 65.3ms\tremaining: 4.29s\n",
694 | "15:\tlearn: 7651.9441909\ttest: 7419.6808464\tbest: 7419.6808464 (15)\ttotal: 65.5ms\tremaining: 4.03s\n",
695 | "16:\tlearn: 7446.1069448\ttest: 7236.7268464\tbest: 7236.7268464 (16)\ttotal: 66.2ms\tremaining: 3.83s\n",
696 | "17:\tlearn: 7274.6704815\ttest: 7064.1042181\tbest: 7064.1042181 (17)\ttotal: 66.9ms\tremaining: 3.65s\n",
697 | "18:\tlearn: 7090.3414827\ttest: 6880.2568747\tbest: 6880.2568747 (18)\ttotal: 67.5ms\tremaining: 3.48s\n",
698 | "19:\tlearn: 6935.8494326\ttest: 6736.2596571\tbest: 6736.2596571 (19)\ttotal: 68.7ms\tremaining: 3.37s\n",
699 | "20:\tlearn: 6799.9531942\ttest: 6603.7128076\tbest: 6603.7128076 (20)\ttotal: 69ms\tremaining: 3.22s\n",
700 | "21:\tlearn: 6682.0661648\ttest: 6498.8882793\tbest: 6498.8882793 (21)\ttotal: 69.5ms\tremaining: 3.09s\n",
701 | "22:\tlearn: 6538.5362251\ttest: 6367.2327244\tbest: 6367.2327244 (22)\ttotal: 70.2ms\tremaining: 2.98s\n",
702 | "23:\tlearn: 6413.3233042\ttest: 6249.6403121\tbest: 6249.6403121 (23)\ttotal: 70.8ms\tremaining: 2.88s\n",
703 | "24:\tlearn: 6289.2337797\ttest: 6136.6353585\tbest: 6136.6353585 (24)\ttotal: 71.7ms\tremaining: 2.79s\n",
704 | "25:\tlearn: 6166.4909063\ttest: 6026.3136729\tbest: 6026.3136729 (25)\ttotal: 72.4ms\tremaining: 2.71s\n",
705 | "26:\tlearn: 6056.6980556\ttest: 5923.1889379\tbest: 5923.1889379 (26)\ttotal: 72.9ms\tremaining: 2.63s\n",
706 | "27:\tlearn: 5955.0594430\ttest: 5826.9385445\tbest: 5826.9385445 (27)\ttotal: 73.5ms\tremaining: 2.55s\n",
707 | "28:\tlearn: 5860.7535599\ttest: 5738.5303225\tbest: 5738.5303225 (28)\ttotal: 74.2ms\tremaining: 2.48s\n",
708 | "29:\tlearn: 5774.8792282\ttest: 5658.1048594\tbest: 5658.1048594 (29)\ttotal: 75.5ms\tremaining: 2.44s\n",
709 | "30:\tlearn: 5700.9796461\ttest: 5591.5205624\tbest: 5591.5205624 (30)\ttotal: 76.1ms\tremaining: 2.38s\n",
710 | "31:\tlearn: 5622.6454958\ttest: 5518.4609745\tbest: 5518.4609745 (31)\ttotal: 76.8ms\tremaining: 2.32s\n",
711 | "32:\tlearn: 5546.7784980\ttest: 5456.5329749\tbest: 5456.5329749 (32)\ttotal: 78.2ms\tremaining: 2.29s\n",
712 | "33:\tlearn: 5480.4837742\ttest: 5393.4714671\tbest: 5393.4714671 (33)\ttotal: 78.8ms\tremaining: 2.24s\n",
713 | "34:\tlearn: 5426.0583102\ttest: 5346.8814621\tbest: 5346.8814621 (34)\ttotal: 80ms\tremaining: 2.21s\n",
714 | "35:\tlearn: 5364.8177436\ttest: 5285.3885735\tbest: 5285.3885735 (35)\ttotal: 81.6ms\tremaining: 2.18s\n",
715 | "36:\tlearn: 5317.6277425\ttest: 5245.2405633\tbest: 5245.2405633 (36)\ttotal: 81.9ms\tremaining: 2.13s\n",
716 | "37:\tlearn: 5273.9836551\ttest: 5199.5748939\tbest: 5199.5748939 (37)\ttotal: 82.2ms\tremaining: 2.08s\n",
717 | "38:\tlearn: 5226.4476525\ttest: 5152.4112228\tbest: 5152.4112228 (38)\ttotal: 83.5ms\tremaining: 2.06s\n",
718 | "39:\tlearn: 5187.3186203\ttest: 5121.6387926\tbest: 5121.6387926 (39)\ttotal: 84.3ms\tremaining: 2.02s\n",
719 | "40:\tlearn: 5148.1007653\ttest: 5091.1530033\tbest: 5091.1530033 (40)\ttotal: 84.9ms\tremaining: 1.99s\n",
720 | "41:\tlearn: 5118.4574272\ttest: 5065.1781067\tbest: 5065.1781067 (41)\ttotal: 85.1ms\tremaining: 1.94s\n",
721 | "42:\tlearn: 5082.0201202\ttest: 5034.1461523\tbest: 5034.1461523 (42)\ttotal: 86ms\tremaining: 1.91s\n",
722 | "43:\tlearn: 5046.9010507\ttest: 5001.7382661\tbest: 5001.7382661 (43)\ttotal: 86.6ms\tremaining: 1.88s\n",
723 | "44:\tlearn: 5014.4682607\ttest: 4967.7026153\tbest: 4967.7026153 (44)\ttotal: 87.4ms\tremaining: 1.85s\n",
724 | "45:\tlearn: 4984.8106676\ttest: 4937.7427381\tbest: 4937.7427381 (45)\ttotal: 88.1ms\tremaining: 1.83s\n",
725 | "46:\tlearn: 4956.7529140\ttest: 4911.5675003\tbest: 4911.5675003 (46)\ttotal: 88.6ms\tremaining: 1.8s\n",
726 | "47:\tlearn: 4925.9004687\ttest: 4885.6560194\tbest: 4885.6560194 (47)\ttotal: 89.4ms\tremaining: 1.77s\n",
727 | "48:\tlearn: 4897.7500534\ttest: 4862.7024286\tbest: 4862.7024286 (48)\ttotal: 90.2ms\tremaining: 1.75s\n",
728 | "49:\tlearn: 4876.4231527\ttest: 4844.1445645\tbest: 4844.1445645 (49)\ttotal: 90.9ms\tremaining: 1.73s\n",
729 | "50:\tlearn: 4850.3335001\ttest: 4820.7833775\tbest: 4820.7833775 (50)\ttotal: 91.8ms\tremaining: 1.71s\n",
730 | "51:\tlearn: 4830.3468948\ttest: 4801.6432589\tbest: 4801.6432589 (51)\ttotal: 92.7ms\tremaining: 1.69s\n",
731 | "52:\tlearn: 4804.6721278\ttest: 4781.7853353\tbest: 4781.7853353 (52)\ttotal: 93.4ms\tremaining: 1.67s\n",
732 | "53:\tlearn: 4785.7884488\ttest: 4765.1396970\tbest: 4765.1396970 (53)\ttotal: 94.1ms\tremaining: 1.65s\n",
733 | "54:\tlearn: 4764.6931172\ttest: 4745.1402240\tbest: 4745.1402240 (54)\ttotal: 95ms\tremaining: 1.63s\n",
734 | "55:\tlearn: 4745.0014477\ttest: 4731.9202365\tbest: 4731.9202365 (55)\ttotal: 95.8ms\tremaining: 1.61s\n",
735 | "56:\tlearn: 4733.6924249\ttest: 4721.4346631\tbest: 4721.4346631 (56)\ttotal: 96.7ms\tremaining: 1.6s\n",
736 | "57:\tlearn: 4718.4285051\ttest: 4706.8977435\tbest: 4706.8977435 (57)\ttotal: 97.5ms\tremaining: 1.58s\n",
737 | "58:\tlearn: 4708.3924650\ttest: 4700.5360220\tbest: 4700.5360220 (58)\ttotal: 98ms\tremaining: 1.56s\n",
738 | "59:\tlearn: 4697.4051340\ttest: 4688.5274928\tbest: 4688.5274928 (59)\ttotal: 98.6ms\tremaining: 1.54s\n",
739 | "60:\tlearn: 4687.6765171\ttest: 4675.1671119\tbest: 4675.1671119 (60)\ttotal: 99.3ms\tremaining: 1.53s\n",
740 | "61:\tlearn: 4673.5371214\ttest: 4660.3489143\tbest: 4660.3489143 (61)\ttotal: 100ms\tremaining: 1.51s\n",
741 | "62:\tlearn: 4664.5705858\ttest: 4650.6639993\tbest: 4650.6639993 (62)\ttotal: 101ms\tremaining: 1.5s\n",
742 | "63:\tlearn: 4655.5121100\ttest: 4642.6593428\tbest: 4642.6593428 (63)\ttotal: 101ms\tremaining: 1.48s\n",
743 | "64:\tlearn: 4640.6477105\ttest: 4632.1934679\tbest: 4632.1934679 (64)\ttotal: 103ms\tremaining: 1.47s\n",
744 | "65:\tlearn: 4631.2775709\ttest: 4628.6864938\tbest: 4628.6864938 (65)\ttotal: 103ms\tremaining: 1.46s\n",
745 | "66:\tlearn: 4617.5614730\ttest: 4620.7160823\tbest: 4620.7160823 (66)\ttotal: 104ms\tremaining: 1.45s\n",
746 | "67:\tlearn: 4609.6661970\ttest: 4615.9897070\tbest: 4615.9897070 (67)\ttotal: 105ms\tremaining: 1.44s\n",
747 | "68:\tlearn: 4594.9269326\ttest: 4607.7160626\tbest: 4607.7160626 (68)\ttotal: 107ms\tremaining: 1.44s\n",
748 | "69:\tlearn: 4586.4000778\ttest: 4599.2989263\tbest: 4599.2989263 (69)\ttotal: 107ms\tremaining: 1.42s\n",
749 | "70:\tlearn: 4584.2924168\ttest: 4595.2372445\tbest: 4595.2372445 (70)\ttotal: 108ms\tremaining: 1.41s\n",
750 | "71:\tlearn: 4580.6310982\ttest: 4590.8254827\tbest: 4590.8254827 (71)\ttotal: 108ms\tremaining: 1.4s\n",
751 | "72:\tlearn: 4573.0512085\ttest: 4584.5655906\tbest: 4584.5655906 (72)\ttotal: 109ms\tremaining: 1.39s\n",
752 | "73:\tlearn: 4564.9022767\ttest: 4582.2014024\tbest: 4582.2014024 (73)\ttotal: 110ms\tremaining: 1.38s\n",
753 | "74:\tlearn: 4554.9704480\ttest: 4575.4565012\tbest: 4575.4565012 (74)\ttotal: 111ms\tremaining: 1.37s\n",
754 | "75:\tlearn: 4550.6381327\ttest: 4575.0104562\tbest: 4575.0104562 (75)\ttotal: 111ms\tremaining: 1.35s\n",
755 | "76:\tlearn: 4543.7447385\ttest: 4574.9770248\tbest: 4574.9770248 (76)\ttotal: 112ms\tremaining: 1.34s\n",
756 | "77:\tlearn: 4541.5607817\ttest: 4571.2443289\tbest: 4571.2443289 (77)\ttotal: 113ms\tremaining: 1.33s\n",
757 | "78:\tlearn: 4535.4912581\ttest: 4564.2435780\tbest: 4564.2435780 (78)\ttotal: 113ms\tremaining: 1.32s\n",
758 | "79:\tlearn: 4531.0415122\ttest: 4559.8685708\tbest: 4559.8685708 (79)\ttotal: 114ms\tremaining: 1.31s\n",
759 | "80:\tlearn: 4527.2248414\ttest: 4555.9682281\tbest: 4555.9682281 (80)\ttotal: 115ms\tremaining: 1.3s\n",
760 | "81:\tlearn: 4519.2937423\ttest: 4549.6245735\tbest: 4549.6245735 (81)\ttotal: 116ms\tremaining: 1.3s\n",
761 | "82:\tlearn: 4515.3573293\ttest: 4545.7439628\tbest: 4545.7439628 (82)\ttotal: 117ms\tremaining: 1.29s\n",
762 | "83:\tlearn: 4510.5792299\ttest: 4540.8327696\tbest: 4540.8327696 (83)\ttotal: 117ms\tremaining: 1.28s\n",
763 | "84:\tlearn: 4506.0538270\ttest: 4541.5568665\tbest: 4540.8327696 (83)\ttotal: 118ms\tremaining: 1.27s\n",
764 | "85:\tlearn: 4503.1630183\ttest: 4537.6270336\tbest: 4537.6270336 (85)\ttotal: 119ms\tremaining: 1.26s\n",
765 | "86:\tlearn: 4496.0911875\ttest: 4531.5680559\tbest: 4531.5680559 (86)\ttotal: 120ms\tremaining: 1.25s\n",
766 | "87:\tlearn: 4492.2828701\ttest: 4531.0761587\tbest: 4531.0761587 (87)\ttotal: 120ms\tremaining: 1.24s\n",
767 | "88:\tlearn: 4488.9621285\ttest: 4528.3046536\tbest: 4528.3046536 (88)\ttotal: 121ms\tremaining: 1.24s\n",
768 | "89:\tlearn: 4488.8701209\ttest: 4528.2794489\tbest: 4528.2794489 (89)\ttotal: 121ms\tremaining: 1.22s\n",
769 | "90:\tlearn: 4485.2112888\ttest: 4524.9720945\tbest: 4524.9720945 (90)\ttotal: 122ms\tremaining: 1.22s\n",
770 | "91:\tlearn: 4481.2354323\ttest: 4521.7897667\tbest: 4521.7897667 (91)\ttotal: 123ms\tremaining: 1.21s\n",
771 | "92:\tlearn: 4474.3205123\ttest: 4517.9420374\tbest: 4517.9420374 (92)\ttotal: 124ms\tremaining: 1.21s\n",
772 | "93:\tlearn: 4466.7718387\ttest: 4514.6783108\tbest: 4514.6783108 (93)\ttotal: 124ms\tremaining: 1.2s\n",
773 | "94:\tlearn: 4466.1435996\ttest: 4514.5008349\tbest: 4514.5008349 (94)\ttotal: 125ms\tremaining: 1.19s\n",
774 | "95:\tlearn: 4464.3661156\ttest: 4514.0741901\tbest: 4514.0741901 (95)\ttotal: 125ms\tremaining: 1.18s\n",
775 | "96:\tlearn: 4461.7752866\ttest: 4510.9960879\tbest: 4510.9960879 (96)\ttotal: 126ms\tremaining: 1.17s\n",
776 | "97:\tlearn: 4461.5146567\ttest: 4510.9796991\tbest: 4510.9796991 (97)\ttotal: 126ms\tremaining: 1.16s\n",
777 | "98:\tlearn: 4456.6558845\ttest: 4508.3112522\tbest: 4508.3112522 (98)\ttotal: 126ms\tremaining: 1.15s\n",
778 | "99:\tlearn: 4454.3871627\ttest: 4505.7860654\tbest: 4505.7860654 (99)\ttotal: 127ms\tremaining: 1.14s\n",
779 | "100:\tlearn: 4448.9852680\ttest: 4503.1081417\tbest: 4503.1081417 (100)\ttotal: 128ms\tremaining: 1.14s\n",
780 | "101:\tlearn: 4440.5592360\ttest: 4502.4170018\tbest: 4502.4170018 (101)\ttotal: 129ms\tremaining: 1.13s\n",
781 | "102:\tlearn: 4431.6324937\ttest: 4500.9596119\tbest: 4500.9596119 (102)\ttotal: 129ms\tremaining: 1.13s\n",
782 | "103:\tlearn: 4429.3687673\ttest: 4501.4204720\tbest: 4500.9596119 (102)\ttotal: 130ms\tremaining: 1.12s\n",
783 | "104:\tlearn: 4429.3629093\ttest: 4501.3898254\tbest: 4500.9596119 (102)\ttotal: 131ms\tremaining: 1.11s\n",
784 | "105:\tlearn: 4428.8235974\ttest: 4501.0482453\tbest: 4500.9596119 (102)\ttotal: 131ms\tremaining: 1.11s\n",
785 | "106:\tlearn: 4425.9788709\ttest: 4500.6305192\tbest: 4500.6305192 (106)\ttotal: 132ms\tremaining: 1.1s\n",
786 | "107:\tlearn: 4424.3855718\ttest: 4499.2275541\tbest: 4499.2275541 (107)\ttotal: 133ms\tremaining: 1.09s\n",
787 | "108:\tlearn: 4422.0120302\ttest: 4499.8376373\tbest: 4499.2275541 (107)\ttotal: 133ms\tremaining: 1.09s\n",
788 | "109:\tlearn: 4418.3186434\ttest: 4498.5182120\tbest: 4498.5182120 (109)\ttotal: 134ms\tremaining: 1.08s\n",
789 | "110:\tlearn: 4417.2920763\ttest: 4497.8476563\tbest: 4497.8476563 (110)\ttotal: 134ms\tremaining: 1.07s\n",
790 | "111:\tlearn: 4416.7227745\ttest: 4498.9462544\tbest: 4497.8476563 (110)\ttotal: 135ms\tremaining: 1.07s\n",
791 | "112:\tlearn: 4412.7025996\ttest: 4498.7626385\tbest: 4497.8476563 (110)\ttotal: 136ms\tremaining: 1.06s\n",
792 | "113:\tlearn: 4407.8196920\ttest: 4498.6152396\tbest: 4497.8476563 (110)\ttotal: 136ms\tremaining: 1.06s\n",
793 | "114:\tlearn: 4404.0040730\ttest: 4496.1760525\tbest: 4496.1760525 (114)\ttotal: 137ms\tremaining: 1.05s\n",
794 | "115:\tlearn: 4402.6820993\ttest: 4493.5948619\tbest: 4493.5948619 (115)\ttotal: 138ms\tremaining: 1.05s\n",
795 | "116:\tlearn: 4400.7635771\ttest: 4491.6903454\tbest: 4491.6903454 (116)\ttotal: 139ms\tremaining: 1.05s\n",
796 | "117:\tlearn: 4399.7669773\ttest: 4489.9523133\tbest: 4489.9523133 (117)\ttotal: 140ms\tremaining: 1.04s\n",
797 | "118:\tlearn: 4393.0292150\ttest: 4486.7754864\tbest: 4486.7754864 (118)\ttotal: 141ms\tremaining: 1.04s\n",
798 | "119:\tlearn: 4385.1056174\ttest: 4482.0289943\tbest: 4482.0289943 (119)\ttotal: 142ms\tremaining: 1.04s\n",
799 | "120:\tlearn: 4382.9640733\ttest: 4482.4460160\tbest: 4482.0289943 (119)\ttotal: 142ms\tremaining: 1.03s\n",
800 | "121:\tlearn: 4381.6342302\ttest: 4479.7860514\tbest: 4479.7860514 (121)\ttotal: 143ms\tremaining: 1.03s\n",
801 | "122:\tlearn: 4381.1578021\ttest: 4479.6616481\tbest: 4479.6616481 (122)\ttotal: 144ms\tremaining: 1.02s\n",
802 | "123:\tlearn: 4375.5387552\ttest: 4478.0570824\tbest: 4478.0570824 (123)\ttotal: 144ms\tremaining: 1.02s\n",
803 | "124:\tlearn: 4375.1315479\ttest: 4477.1129288\tbest: 4477.1129288 (124)\ttotal: 145ms\tremaining: 1.01s\n",
804 | "125:\tlearn: 4373.8160483\ttest: 4475.2068758\tbest: 4475.2068758 (125)\ttotal: 145ms\tremaining: 1.01s\n",
805 | "126:\tlearn: 4370.8924010\ttest: 4477.4286380\tbest: 4475.2068758 (125)\ttotal: 146ms\tremaining: 1s\n",
806 | "127:\tlearn: 4369.3506544\ttest: 4476.3365496\tbest: 4475.2068758 (125)\ttotal: 147ms\tremaining: 1s\n",
807 | "128:\tlearn: 4367.1210753\ttest: 4474.4884111\tbest: 4474.4884111 (128)\ttotal: 149ms\tremaining: 1s\n",
808 | "129:\tlearn: 4366.7213657\ttest: 4474.6765012\tbest: 4474.4884111 (128)\ttotal: 149ms\tremaining: 997ms\n",
809 | "130:\tlearn: 4361.4266661\ttest: 4472.3525687\tbest: 4472.3525687 (130)\ttotal: 150ms\tremaining: 993ms\n",
810 | "131:\tlearn: 4359.1617519\ttest: 4468.8500676\tbest: 4468.8500676 (131)\ttotal: 150ms\tremaining: 989ms\n",
811 | "132:\tlearn: 4357.3157065\ttest: 4467.6042087\tbest: 4467.6042087 (132)\ttotal: 151ms\tremaining: 985ms\n",
812 | "133:\tlearn: 4355.4754571\ttest: 4466.7978797\tbest: 4466.7978797 (133)\ttotal: 152ms\tremaining: 982ms\n",
813 | "134:\tlearn: 4353.0920681\ttest: 4467.2216987\tbest: 4466.7978797 (133)\ttotal: 153ms\tremaining: 980ms\n",
814 | "135:\tlearn: 4349.3034898\ttest: 4462.7715216\tbest: 4462.7715216 (135)\ttotal: 154ms\tremaining: 976ms\n",
815 | "136:\tlearn: 4348.0399740\ttest: 4461.4525517\tbest: 4461.4525517 (136)\ttotal: 154ms\tremaining: 973ms\n",
816 | "137:\tlearn: 4342.2369116\ttest: 4463.1921563\tbest: 4461.4525517 (136)\ttotal: 155ms\tremaining: 969ms\n",
817 | "138:\tlearn: 4337.4500163\ttest: 4462.3466978\tbest: 4461.4525517 (136)\ttotal: 156ms\tremaining: 968ms\n",
818 | "139:\tlearn: 4336.7592406\ttest: 4462.1731148\tbest: 4461.4525517 (136)\ttotal: 157ms\tremaining: 964ms\n",
819 | "140:\tlearn: 4336.5533039\ttest: 4462.1081990\tbest: 4461.4525517 (136)\ttotal: 157ms\tremaining: 959ms\n",
820 | "141:\tlearn: 4335.4981537\ttest: 4461.7103284\tbest: 4461.4525517 (136)\ttotal: 158ms\tremaining: 957ms\n",
821 | "142:\tlearn: 4335.1834055\ttest: 4461.6170483\tbest: 4461.4525517 (136)\ttotal: 159ms\tremaining: 951ms\n",
822 | "143:\tlearn: 4333.4594557\ttest: 4461.5227009\tbest: 4461.4525517 (136)\ttotal: 160ms\tremaining: 950ms\n",
823 | "144:\tlearn: 4331.8637410\ttest: 4462.5664485\tbest: 4461.4525517 (136)\ttotal: 161ms\tremaining: 947ms\n",
824 | "145:\tlearn: 4330.8941104\ttest: 4462.2997026\tbest: 4461.4525517 (136)\ttotal: 161ms\tremaining: 943ms\n",
825 | "146:\tlearn: 4327.6644517\ttest: 4462.2039381\tbest: 4461.4525517 (136)\ttotal: 162ms\tremaining: 940ms\n",
826 | "147:\tlearn: 4326.5420613\ttest: 4462.8728516\tbest: 4461.4525517 (136)\ttotal: 163ms\tremaining: 936ms\n",
827 | "148:\tlearn: 4322.3927087\ttest: 4462.6585701\tbest: 4461.4525517 (136)\ttotal: 164ms\tremaining: 934ms\n",
828 | "149:\tlearn: 4321.0124809\ttest: 4463.2566088\tbest: 4461.4525517 (136)\ttotal: 164ms\tremaining: 931ms\n",
829 | "150:\tlearn: 4319.9787726\ttest: 4464.1263409\tbest: 4461.4525517 (136)\ttotal: 165ms\tremaining: 927ms\n",
830 | "151:\tlearn: 4318.2804397\ttest: 4463.9719555\tbest: 4461.4525517 (136)\ttotal: 165ms\tremaining: 923ms\n",
831 | "152:\tlearn: 4317.1621473\ttest: 4464.3190037\tbest: 4461.4525517 (136)\ttotal: 167ms\tremaining: 922ms\n",
832 | "153:\tlearn: 4315.3082889\ttest: 4462.3371173\tbest: 4461.4525517 (136)\ttotal: 168ms\tremaining: 922ms\n",
833 | "154:\tlearn: 4313.6240588\ttest: 4460.9090316\tbest: 4460.9090316 (154)\ttotal: 169ms\tremaining: 919ms\n",
834 | "155:\tlearn: 4308.7408461\ttest: 4461.9994794\tbest: 4460.9090316 (154)\ttotal: 170ms\tremaining: 917ms\n",
835 | "156:\tlearn: 4304.3605530\ttest: 4460.8561739\tbest: 4460.8561739 (156)\ttotal: 170ms\tremaining: 914ms\n",
836 | "157:\tlearn: 4296.5265000\ttest: 4458.2121464\tbest: 4458.2121464 (157)\ttotal: 171ms\tremaining: 912ms\n",
837 | "158:\tlearn: 4292.9219052\ttest: 4457.3077812\tbest: 4457.3077812 (158)\ttotal: 172ms\tremaining: 911ms\n",
838 | "159:\tlearn: 4291.2528338\ttest: 4457.1570880\tbest: 4457.1570880 (159)\ttotal: 173ms\tremaining: 908ms\n",
839 | "160:\tlearn: 4289.4613153\ttest: 4456.5441936\tbest: 4456.5441936 (160)\ttotal: 174ms\tremaining: 906ms\n",
840 | "161:\tlearn: 4285.0691200\ttest: 4456.8069797\tbest: 4456.5441936 (160)\ttotal: 175ms\tremaining: 903ms\n",
841 | "162:\tlearn: 4283.4649878\ttest: 4455.2279840\tbest: 4455.2279840 (162)\ttotal: 175ms\tremaining: 900ms\n",
842 | "163:\tlearn: 4281.2809661\ttest: 4455.3887740\tbest: 4455.2279840 (162)\ttotal: 176ms\tremaining: 897ms\n",
843 | "164:\tlearn: 4281.1947633\ttest: 4455.3803900\tbest: 4455.2279840 (162)\ttotal: 176ms\tremaining: 892ms\n",
844 | "165:\tlearn: 4279.2680603\ttest: 4456.9433372\tbest: 4455.2279840 (162)\ttotal: 177ms\tremaining: 890ms\n",
845 | "166:\tlearn: 4274.1633733\ttest: 4459.1915051\tbest: 4455.2279840 (162)\ttotal: 178ms\tremaining: 888ms\n",
846 | "167:\tlearn: 4271.5861198\ttest: 4457.6296101\tbest: 4455.2279840 (162)\ttotal: 179ms\tremaining: 885ms\n",
847 | "168:\tlearn: 4271.3799469\ttest: 4457.6197179\tbest: 4455.2279840 (162)\ttotal: 179ms\tremaining: 881ms\n",
848 | "169:\tlearn: 4269.6933288\ttest: 4456.6673339\tbest: 4455.2279840 (162)\ttotal: 180ms\tremaining: 878ms\n",
849 | "170:\tlearn: 4266.6707750\ttest: 4456.0106319\tbest: 4455.2279840 (162)\ttotal: 181ms\tremaining: 876ms\n",
850 | "171:\tlearn: 4265.4996249\ttest: 4456.0314402\tbest: 4455.2279840 (162)\ttotal: 181ms\tremaining: 873ms\n",
851 | "172:\tlearn: 4264.2493406\ttest: 4456.6100371\tbest: 4455.2279840 (162)\ttotal: 182ms\tremaining: 870ms\n",
852 | "173:\tlearn: 4259.6207748\ttest: 4457.9821824\tbest: 4455.2279840 (162)\ttotal: 183ms\tremaining: 867ms\n",
853 | "174:\tlearn: 4253.2039726\ttest: 4458.0419803\tbest: 4455.2279840 (162)\ttotal: 183ms\tremaining: 864ms\n",
854 | "175:\tlearn: 4252.1777721\ttest: 4458.1837295\tbest: 4455.2279840 (162)\ttotal: 184ms\tremaining: 861ms\n",
855 | "176:\tlearn: 4251.4820019\ttest: 4458.6473260\tbest: 4455.2279840 (162)\ttotal: 185ms\tremaining: 859ms\n",
856 | "177:\tlearn: 4248.0365555\ttest: 4457.1442368\tbest: 4455.2279840 (162)\ttotal: 185ms\tremaining: 856ms\n",
857 | "178:\tlearn: 4246.5116860\ttest: 4459.3102022\tbest: 4455.2279840 (162)\ttotal: 186ms\tremaining: 855ms\n",
858 | "179:\tlearn: 4244.3109488\ttest: 4459.7738963\tbest: 4455.2279840 (162)\ttotal: 187ms\tremaining: 852ms\n",
859 | "180:\tlearn: 4243.5314996\ttest: 4460.0109834\tbest: 4455.2279840 (162)\ttotal: 188ms\tremaining: 851ms\n",
860 | "181:\tlearn: 4243.3602516\ttest: 4460.2987130\tbest: 4455.2279840 (162)\ttotal: 189ms\tremaining: 848ms\n",
861 | "182:\tlearn: 4240.7892696\ttest: 4459.7722494\tbest: 4455.2279840 (162)\ttotal: 190ms\tremaining: 847ms\n",
862 | "Stopped by overfitting detector (20 iterations wait)\n",
863 | "\n",
864 | "bestTest = 4455.227984\n",
865 | "bestIteration = 162\n",
866 | "\n",
867 | "Shrink model to first 163 iterations.\n"
868 | ]
869 | }
870 | ],
871 | "source": [
872 | "if params['problem_type'] == 'regression':\n",
873 | " from catboost import CatBoostRegressor, Pool\n",
874 | "\n",
875 | " model_catboost = CatBoostRegressor(n_estimators=1000, \n",
876 | " early_stopping_rounds=20)\n",
877 | " train_data = Pool(\n",
878 | " data=X_train_raw,\n",
879 | " label=y_train,\n",
880 | " cat_features=categorical_feature_indices,\n",
881 | " )\n",
882 | "\n",
883 | " eval_data = Pool(\n",
884 | " data=X_valid_raw,\n",
885 | " label=y_valid,\n",
886 | " cat_features=categorical_feature_indices,\n",
887 | " )\n",
888 | "\n",
889 | " model_catboost.fit(X=train_data, \n",
890 | " eval_set=eval_data)\n",
891 | "\n",
892 | "\n",
893 | "\n",
894 | "\n",
895 | " preds_catboost = model_catboost.predict(X_test_raw)\n",
896 | "else:\n",
897 | " from catboost import CatBoostClassifier, Pool\n",
898 | "\n",
899 | " model_catboost = CatBoostClassifier(n_estimators=1000, \n",
900 | " early_stopping_rounds=20)\n",
901 | " train_data = Pool(\n",
902 | " data=X_train_raw,\n",
903 | " label=y_train,\n",
904 | " cat_features=categorical_feature_indices,\n",
905 | " #weight=calculate_sample_weights(y_train)\n",
906 | " )\n",
907 | "\n",
908 | " eval_data = Pool(\n",
909 | " data=X_valid_raw,\n",
910 | " label=y_valid,\n",
911 | " cat_features=categorical_feature_indices,\n",
912 | " #weight=calculate_sample_weights(y_valid),\n",
913 | " )\n",
914 | "\n",
915 | " model_catboost.fit(X=train_data, \n",
916 | " eval_set=eval_data)\n",
917 | "\n",
918 | "\n",
919 | "\n",
920 | " preds_catboost = model_catboost.predict_proba(X_test_raw)\n"
921 | ]
922 | },
923 | {
924 | "cell_type": "code",
925 | "execution_count": 12,
926 | "id": "21b9be01-ba30-4300-b3e0-f4fd43b55c87",
927 | "metadata": {
928 | "execution": {
929 | "iopub.execute_input": "2024-03-15T11:57:52.429019Z",
930 | "iopub.status.busy": "2024-03-15T11:57:52.428555Z",
931 | "iopub.status.idle": "2024-03-15T11:57:52.454151Z",
932 | "shell.execute_reply": "2024-03-15T11:57:52.453620Z",
933 | "shell.execute_reply.started": "2024-03-15T11:57:52.428990Z"
934 | }
935 | },
936 | "outputs": [
937 | {
938 | "name": "stdout",
939 | "output_type": "stream",
940 | "text": [
941 | "MAE GRANDE: 2481.6191651649956\n",
942 | "RMSE GRANDE: 3761.773606283763\n",
943 | "R2 Score GRANDE: 0.8653776732830805\n",
944 | "\n",
945 | "\n",
946 | "MAE XGB: 2437.6996424431554\n",
947 | "RMSE XGB: 3930.3302577846944\n",
948 | "R2 Score XGB: 0.8530431372729947\n",
949 | "\n",
950 | "\n",
951 | "MAE CatBoost: 2194.5464048630265\n",
952 | "RMSE CatBoost: 3637.838113587122\n",
953 | "R2 Score CatBoost: 0.8741020902072107\n",
954 | "\n",
955 | "\n"
956 | ]
957 | }
958 | ],
959 | "source": [
960 | "if params['problem_type'] == 'binary':\n",
961 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_grande[:,1]))\n",
962 | " f1_score = sklearn.metrics.f1_score(y_test, np.round(preds_grande[:,1]), average='macro')\n",
963 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_grande[:,1], average='macro', multi_class='ovo')\n",
964 | "\n",
965 | " print('Accuracy GRANDE:', accuracy)\n",
966 | " print('F1 Score GRANDE:', f1_score)\n",
967 | " print('ROC AUC GRANDE:', roc_auc)\n",
968 | " print('\\n')\n",
969 | "\n",
970 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_xgb[:,1]))\n",
971 | " f1_score = sklearn.metrics.f1_score(y_test, np.round(preds_xgb[:,1]), average='macro')\n",
972 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_xgb[:,1], average='macro', multi_class='ovo')\n",
973 | "\n",
974 | " print('Accuracy XGB:', accuracy)\n",
975 | " print('F1 Score XGB:', f1_score)\n",
976 | " print('ROC AUC XGB:', roc_auc)\n",
977 | " print('\\n')\n",
978 | "\n",
979 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_catboost[:,1]))\n",
980 | " f1_score = sklearn.metrics.f1_score(y_test, np.round(preds_catboost[:,1]), average='macro')\n",
981 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_catboost[:,1], average='macro', multi_class='ovo')\n",
982 | "\n",
983 | " print('Accuracy CatBoost:', accuracy)\n",
984 | " print('F1 Score CatBoost:', f1_score)\n",
985 | " print('ROC AUC CatBoost:', roc_auc)\n",
986 | " print('\\n')\n",
987 | "elif params['problem_type'] == 'multiclass':\n",
988 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.argmax(preds_grande, axis=1))\n",
989 | " f1_score = sklearn.metrics.f1_score(y_test, np.argmax(preds_grande, axis=1), average='macro')\n",
990 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_grande, average='macro', multi_class='ovo', labels=[i for i in range(preds_grande.shape[1])])\n",
991 | "\n",
992 | " print('Accuracy GRANDE:', accuracy)\n",
993 | " print('F1 Score GRANDE:', f1_score)\n",
994 | " print('ROC AUC GRANDE:', roc_auc)\n",
995 | " print('\\n')\n",
996 | "\n",
997 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.argmax(preds_xgb, axis=1))\n",
998 | " f1_score = sklearn.metrics.f1_score(y_test, np.argmax(preds_xgb, axis=1), average='macro')\n",
999 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_xgb, average='macro', multi_class='ovo', labels=[i for i in range(preds_grande.shape[1])])\n",
1000 | "\n",
1001 | " print('Accuracy XGB:', accuracy)\n",
1002 | " print('F1 Score XGB:', f1_score)\n",
1003 | " print('ROC AUC XGB:', roc_auc)\n",
1004 | " print('\\n')\n",
1005 | "\n",
1006 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.argmax(preds_catboost, axis=1))\n",
1007 | " f1_score = sklearn.metrics.f1_score(y_test, np.argmax(preds_catboost, axis=1), average='macro')\n",
1008 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_catboost, average='macro', multi_class='ovo', labels=[i for i in range(preds_grande.shape[1])])\n",
1009 | "\n",
1010 | " print('Accuracy CatBoost:', accuracy)\n",
1011 | " print('F1 Score CatBoost:', f1_score)\n",
1012 | " print('ROC AUC CatBoost:', roc_auc)\n",
1013 | " print('\\n')\n",
1014 | "else:\n",
1015 | " mean_absolute_error = sklearn.metrics.mean_absolute_error(y_test, preds_grande)\n",
1016 | " root_mean_squared_error = np.sqrt(((y_test - preds_grande) ** 2).mean())\n",
1017 | " r2_score = sklearn.metrics.r2_score(y_test, preds_grande)\n",
1018 | "\n",
1019 | " print('MAE GRANDE:', mean_absolute_error)\n",
1020 | " print('RMSE GRANDE:', root_mean_squared_error)\n",
1021 | " print('R2 Score GRANDE:', r2_score)\n",
1022 | " print('\\n')\n",
1023 | "\n",
1024 | " mean_absolute_error = sklearn.metrics.mean_absolute_error(y_test, preds_xgb)\n",
1025 | " root_mean_squared_error = np.sqrt(((y_test - preds_xgb) ** 2).mean())\n",
1026 | " r2_score = sklearn.metrics.r2_score(y_test, preds_xgb)\n",
1027 | "\n",
1028 | " print('MAE XGB:', mean_absolute_error)\n",
1029 | " print('RMSE XGB:', root_mean_squared_error)\n",
1030 | " print('R2 Score XGB:', r2_score)\n",
1031 | " print('\\n')\n",
1032 | "\n",
1033 | " mean_absolute_error = sklearn.metrics.mean_absolute_error(y_test, preds_catboost)\n",
1034 | " root_mean_squared_error = np.sqrt(((y_test - preds_catboost) ** 2).mean())\n",
1035 | " r2_score = sklearn.metrics.r2_score(y_test, preds_catboost)\n",
1036 | "\n",
1037 | " print('MAE CatBoost:', mean_absolute_error)\n",
1038 | " print('RMSE CatBoost:', root_mean_squared_error)\n",
1039 | " print('R2 Score CatBoost:', r2_score)\n",
1040 | " print('\\n')"
1041 | ]
1042 | },
1043 | {
1044 | "cell_type": "code",
1045 | "execution_count": null,
1046 | "id": "c7ef884c-1269-4e0b-8ca7-7ea74c6e380c",
1047 | "metadata": {},
1048 | "outputs": [],
1049 | "source": []
1050 | },
1051 | {
1052 | "cell_type": "code",
1053 | "execution_count": null,
1054 | "id": "2fc02cd3",
1055 | "metadata": {},
1056 | "outputs": [],
1057 | "source": []
1058 | }
1059 | ],
1060 | "metadata": {
1061 | "kernelspec": {
1062 | "display_name": "ReMeDe",
1063 | "language": "python",
1064 | "name": "python3"
1065 | },
1066 | "language_info": {
1067 | "codemirror_mode": {
1068 | "name": "ipython",
1069 | "version": 3
1070 | },
1071 | "file_extension": ".py",
1072 | "mimetype": "text/x-python",
1073 | "name": "python",
1074 | "nbconvert_exporter": "python",
1075 | "pygments_lexer": "ipython3",
1076 | "version": "3.12.9"
1077 | }
1078 | },
1079 | "nbformat": 4,
1080 | "nbformat_minor": 5
1081 | }
1082 |
--------------------------------------------------------------------------------
/GRANDE_minimal_example_with_comparison_BINARY.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "bf584ce8-8849-4a53-8ce9-2de8f9dd752d",
7 | "metadata": {
8 | "execution": {
9 | "iopub.execute_input": "2024-03-15T11:56:56.253578Z",
10 | "iopub.status.busy": "2024-03-15T11:56:56.253464Z",
11 | "iopub.status.idle": "2024-03-15T11:56:56.265905Z",
12 | "shell.execute_reply": "2024-03-15T11:56:56.265512Z",
13 | "shell.execute_reply.started": "2024-03-15T11:56:56.253566Z"
14 | }
15 | },
16 | "outputs": [],
17 | "source": [
18 | "#specify GPU to use\n",
19 | "import os\n",
20 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "id": "52b07238-f972-45df-86ce-a11cc11d72e3",
27 | "metadata": {
28 | "execution": {
29 | "iopub.execute_input": "2024-03-15T11:56:56.266511Z",
30 | "iopub.status.busy": "2024-03-15T11:56:56.266392Z",
31 | "iopub.status.idle": "2024-03-15T11:56:57.956110Z",
32 | "shell.execute_reply": "2024-03-15T11:56:57.955562Z",
33 | "shell.execute_reply.started": "2024-03-15T11:56:56.266499Z"
34 | }
35 | },
36 | "outputs": [
37 | {
38 | "name": "stdout",
39 | "output_type": "stream",
40 | "text": [
41 | "Training set size: 3200\n",
42 | "Validation set size: 800\n",
43 | "Test set size: 1000\n"
44 | ]
45 | },
46 | {
47 | "data": {
48 | "text/html": [
49 | "\n",
50 | "\n",
63 | "
\n",
64 | " \n",
65 | " \n",
66 | " | \n",
67 | " state | \n",
68 | " account_length | \n",
69 | " area_code | \n",
70 | " international_plan | \n",
71 | " voice_mail_plan | \n",
72 | " number_vmail_messages | \n",
73 | " total_day_minutes | \n",
74 | " total_day_calls | \n",
75 | " total_day_charge | \n",
76 | " total_eve_minutes | \n",
77 | " total_eve_calls | \n",
78 | " total_eve_charge | \n",
79 | " total_night_minutes | \n",
80 | " total_night_calls | \n",
81 | " total_night_charge | \n",
82 | " total_intl_minutes | \n",
83 | " total_intl_calls | \n",
84 | " total_intl_charge | \n",
85 | " number_customer_service_calls | \n",
86 | "
\n",
87 | " \n",
88 | " \n",
89 | " \n",
90 | " | 3444 | \n",
91 | " 1 | \n",
92 | " 120 | \n",
93 | " 415.0 | \n",
94 | " 0 | \n",
95 | " 1 | \n",
96 | " 27 | \n",
97 | " 179.6 | \n",
98 | " 142 | \n",
99 | " 30.53 | \n",
100 | " 262.8 | \n",
101 | " 103 | \n",
102 | " 22.34 | \n",
103 | " 239.9 | \n",
104 | " 81 | \n",
105 | " 10.80 | \n",
106 | " 11.1 | \n",
107 | " 2 | \n",
108 | " 3.00 | \n",
109 | " 1 | \n",
110 | "
\n",
111 | " \n",
112 | " | 2063 | \n",
113 | " 15 | \n",
114 | " 130 | \n",
115 | " 408.0 | \n",
116 | " 0 | \n",
117 | " 0 | \n",
118 | " 0 | \n",
119 | " 115.6 | \n",
120 | " 129 | \n",
121 | " 19.65 | \n",
122 | " 167.8 | \n",
123 | " 104 | \n",
124 | " 14.26 | \n",
125 | " 141.8 | \n",
126 | " 124 | \n",
127 | " 6.38 | \n",
128 | " 12.6 | \n",
129 | " 9 | \n",
130 | " 3.40 | \n",
131 | " 1 | \n",
132 | "
\n",
133 | " \n",
134 | " | 3714 | \n",
135 | " 26 | \n",
136 | " 66 | \n",
137 | " 510.0 | \n",
138 | " 0 | \n",
139 | " 0 | \n",
140 | " 0 | \n",
141 | " 201.3 | \n",
142 | " 95 | \n",
143 | " 34.22 | \n",
144 | " 152.8 | \n",
145 | " 66 | \n",
146 | " 12.99 | \n",
147 | " 233.2 | \n",
148 | " 101 | \n",
149 | " 10.49 | \n",
150 | " 7.5 | \n",
151 | " 4 | \n",
152 | " 2.03 | \n",
153 | " 1 | \n",
154 | "
\n",
155 | " \n",
156 | " | 2671 | \n",
157 | " 13 | \n",
158 | " 60 | \n",
159 | " 510.0 | \n",
160 | " 0 | \n",
161 | " 0 | \n",
162 | " 0 | \n",
163 | " 221.1 | \n",
164 | " 106 | \n",
165 | " 37.59 | \n",
166 | " 178.6 | \n",
167 | " 48 | \n",
168 | " 15.18 | \n",
169 | " 202.7 | \n",
170 | " 90 | \n",
171 | " 9.12 | \n",
172 | " 7.4 | \n",
173 | " 3 | \n",
174 | " 2.00 | \n",
175 | " 1 | \n",
176 | "
\n",
177 | " \n",
178 | " | 2154 | \n",
179 | " 35 | \n",
180 | " 134 | \n",
181 | " 408.0 | \n",
182 | " 0 | \n",
183 | " 0 | \n",
184 | " 0 | \n",
185 | " 202.7 | \n",
186 | " 105 | \n",
187 | " 34.46 | \n",
188 | " 224.9 | \n",
189 | " 90 | \n",
190 | " 19.12 | \n",
191 | " 253.9 | \n",
192 | " 108 | \n",
193 | " 11.43 | \n",
194 | " 12.1 | \n",
195 | " 7 | \n",
196 | " 3.27 | \n",
197 | " 0 | \n",
198 | "
\n",
199 | " \n",
200 | "
\n",
201 | "
"
202 | ],
203 | "text/plain": [
204 | " state account_length area_code international_plan voice_mail_plan \\\n",
205 | "3444 1 120 415.0 0 1 \n",
206 | "2063 15 130 408.0 0 0 \n",
207 | "3714 26 66 510.0 0 0 \n",
208 | "2671 13 60 510.0 0 0 \n",
209 | "2154 35 134 408.0 0 0 \n",
210 | "\n",
211 | " number_vmail_messages total_day_minutes total_day_calls \\\n",
212 | "3444 27 179.6 142 \n",
213 | "2063 0 115.6 129 \n",
214 | "3714 0 201.3 95 \n",
215 | "2671 0 221.1 106 \n",
216 | "2154 0 202.7 105 \n",
217 | "\n",
218 | " total_day_charge total_eve_minutes total_eve_calls total_eve_charge \\\n",
219 | "3444 30.53 262.8 103 22.34 \n",
220 | "2063 19.65 167.8 104 14.26 \n",
221 | "3714 34.22 152.8 66 12.99 \n",
222 | "2671 37.59 178.6 48 15.18 \n",
223 | "2154 34.46 224.9 90 19.12 \n",
224 | "\n",
225 | " total_night_minutes total_night_calls total_night_charge \\\n",
226 | "3444 239.9 81 10.80 \n",
227 | "2063 141.8 124 6.38 \n",
228 | "3714 233.2 101 10.49 \n",
229 | "2671 202.7 90 9.12 \n",
230 | "2154 253.9 108 11.43 \n",
231 | "\n",
232 | " total_intl_minutes total_intl_calls total_intl_charge \\\n",
233 | "3444 11.1 2 3.00 \n",
234 | "2063 12.6 9 3.40 \n",
235 | "3714 7.5 4 2.03 \n",
236 | "2671 7.4 3 2.00 \n",
237 | "2154 12.1 7 3.27 \n",
238 | "\n",
239 | " number_customer_service_calls \n",
240 | "3444 1 \n",
241 | "2063 1 \n",
242 | "3714 1 \n",
243 | "2671 1 \n",
244 | "2154 0 "
245 | ]
246 | },
247 | "execution_count": 2,
248 | "metadata": {},
249 | "output_type": "execute_result"
250 | }
251 | ],
252 | "source": [
253 | "from sklearn.model_selection import train_test_split\n",
254 | "import openml\n",
255 | "import category_encoders as ce\n",
256 | "import numpy as np\n",
257 | "import sklearn\n",
258 | "\n",
259 | "# Load churn dataset\n",
260 | "dataset = openml.datasets.get_dataset(46915, download_data=True, download_qualities=True, download_features_meta_data=True)\n",
261 | "X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n",
262 | "categorical_feature_indices = [idx for idx, idx_bool in enumerate(categorical_indicator) if idx_bool]\n",
263 | "\n",
264 | "X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
265 | "\n",
266 | "X_train, X_valid, y_train, y_valid = train_test_split(X_temp, y_temp, test_size=0.2, random_state=42)\n",
267 | "\n",
268 | "print(\"Training set size:\", len(X_train))\n",
269 | "print(\"Validation set size:\", len(X_valid))\n",
270 | "print(\"Test set size:\", len(X_test))\n",
271 | "\n",
272 | "X_train.head()"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": 3,
278 | "id": "d8364d07",
279 | "metadata": {},
280 | "outputs": [
281 | {
282 | "data": {
283 | "text/plain": [
284 | "state category\n",
285 | "account_length uint8\n",
286 | "area_code category\n",
287 | "international_plan category\n",
288 | "voice_mail_plan category\n",
289 | "number_vmail_messages uint8\n",
290 | "total_day_minutes float64\n",
291 | "total_day_calls uint8\n",
292 | "total_day_charge float64\n",
293 | "total_eve_minutes float64\n",
294 | "total_eve_calls uint8\n",
295 | "total_eve_charge float64\n",
296 | "total_night_minutes float64\n",
297 | "total_night_calls uint8\n",
298 | "total_night_charge float64\n",
299 | "total_intl_minutes float64\n",
300 | "total_intl_calls uint8\n",
301 | "total_intl_charge float64\n",
302 | "number_customer_service_calls uint8\n",
303 | "dtype: object"
304 | ]
305 | },
306 | "execution_count": 3,
307 | "metadata": {},
308 | "output_type": "execute_result"
309 | }
310 | ],
311 | "source": [
312 | "X_train.dtypes"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": null,
318 | "id": "bf0a93c2-e9a4-4fa5-a7f6-44689dd07d37",
319 | "metadata": {
320 | "execution": {
321 | "iopub.execute_input": "2024-03-15T11:56:57.957002Z",
322 | "iopub.status.busy": "2024-03-15T11:56:57.956779Z",
323 | "iopub.status.idle": "2024-03-15T11:57:30.600318Z",
324 | "shell.execute_reply": "2024-03-15T11:57:30.599632Z",
325 | "shell.execute_reply.started": "2024-03-15T11:56:57.956987Z"
326 | },
327 | "scrolled": true
328 | },
329 | "outputs": [
330 | {
331 | "name": "stdout",
332 | "output_type": "stream",
333 | "text": [
334 | "self.params {'depth': 5, 'n_estimators': 1024, 'learning_rate_weights': 0.001, 'learning_rate_index': 0.01, 'learning_rate_values': 0.05, 'learning_rate_leaf': 0.05, 'learning_rate_embedding': 0.02, 'use_category_embeddings': False, 'embedding_dim_cat': 8, 'use_numeric_embeddings': False, 'embedding_dim_num': 8, 'embedding_threshold': 1, 'loo_cardinality': 10, 'dropout': 0.2, 'selected_variables': 0.8, 'data_subset_fraction': 1.0, 'bootstrap': False, 'missing_values': False, 'optimizer': 'adam', 'cosine_decay_restarts': False, 'reduce_on_plateau_scheduler': True, 'label_smoothing': 0.0, 'use_class_weights': False, 'focal_loss': False, 'swa': False, 'es_metric': True, 'epochs': 250, 'batch_size': 256, 'early_stopping_epochs': 50, 'use_freq_enc': False, 'use_robust_scale_smoothing': False, 'problem_type': 'binary', 'random_seed': 42, 'verbose': 2, 'device': 'cuda:0', 'objective': 'binary'}\n",
335 | "{'self.column_names_dataframe': ['state',\n",
336 | " 'account_length',\n",
337 | " 'area_code_1',\n",
338 | " 'area_code_2',\n",
339 | " 'area_code_3',\n",
340 | " 'international_plan',\n",
341 | " 'voice_mail_plan',\n",
342 | " 'number_vmail_messages',\n",
343 | " 'total_day_minutes',\n",
344 | " 'total_day_calls',\n",
345 | " 'total_day_charge',\n",
346 | " 'total_eve_minutes',\n",
347 | " 'total_eve_calls',\n",
348 | " 'total_eve_charge',\n",
349 | " 'total_night_minutes',\n",
350 | " 'total_night_calls',\n",
351 | " 'total_night_charge',\n",
352 | " 'total_intl_minutes',\n",
353 | " 'total_intl_calls',\n",
354 | " 'total_intl_charge',\n",
355 | " 'number_customer_service_calls'],\n",
356 | " 'encoded_columns': ['state',\n",
357 | " 'area_code_1',\n",
358 | " 'area_code_2',\n",
359 | " 'area_code_3',\n",
360 | " 'international_plan',\n",
361 | " 'voice_mail_plan'],\n",
362 | " 'encoded_columns_indices': [0, 2, 3, 4, 5, 6],\n",
363 | " 'num_columns': ['account_length',\n",
364 | " 'number_vmail_messages',\n",
365 | " 'total_day_minutes',\n",
366 | " 'total_day_calls',\n",
367 | " 'total_day_charge',\n",
368 | " 'total_eve_minutes',\n",
369 | " 'total_eve_calls',\n",
370 | " 'total_eve_charge',\n",
371 | " 'total_night_minutes',\n",
372 | " 'total_night_calls',\n",
373 | " 'total_night_charge',\n",
374 | " 'total_intl_minutes',\n",
375 | " 'total_intl_calls',\n",
376 | " 'total_intl_charge',\n",
377 | " 'number_customer_service_calls'],\n",
378 | " 'num_columns_indices': [1,\n",
379 | " 7,\n",
380 | " 8,\n",
381 | " 9,\n",
382 | " 10,\n",
383 | " 11,\n",
384 | " 12,\n",
385 | " 13,\n",
386 | " 14,\n",
387 | " 15,\n",
388 | " 16,\n",
389 | " 17,\n",
390 | " 18,\n",
391 | " 19,\n",
392 | " 20],\n",
393 | " 'categorical_features_raw_indices': [],\n",
394 | " 'not_encoded_columns': []}\n"
395 | ]
396 | },
397 | {
398 | "name": "stderr",
399 | "output_type": "stream",
400 | "text": [
401 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n",
402 | "Online softmax is disabled on the fly since Inductor decides to\n",
403 | "split the reduction. Cut an issue to PyTorch if this is an\n",
404 | "important use case and you want to speed it up with online\n",
405 | "softmax.\n",
406 | "\n",
407 | " warnings.warn(\n",
408 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n",
409 | "Online softmax is disabled on the fly since Inductor decides to\n",
410 | "split the reduction. Cut an issue to PyTorch if this is an\n",
411 | "important use case and you want to speed it up with online\n",
412 | "softmax.\n",
413 | "\n",
414 | " warnings.warn(\n",
415 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n",
416 | "Online softmax is disabled on the fly since Inductor decides to\n",
417 | "split the reduction. Cut an issue to PyTorch if this is an\n",
418 | "important use case and you want to speed it up with online\n",
419 | "softmax.\n",
420 | "\n",
421 | " warnings.warn(\n"
422 | ]
423 | },
424 | {
425 | "name": "stdout",
426 | "output_type": "stream",
427 | "text": [
428 | "Epoch 001 | TrainLoss: 0.5474 | ValLoss: 0.4453 | ValAcc: 0.8462 | ValAUC: 0.7365 | ValF1: 0.4584 | Time: 7.30s\n",
429 | "Epoch 002 | TrainLoss: 0.3948 | ValLoss: 0.3909 | ValAcc: 0.8462 | ValAUC: 0.7867 | ValF1: 0.4584 | Time: 0.07s\n",
430 | "Epoch 003 | TrainLoss: 0.3378 | ValLoss: 0.3544 | ValAcc: 0.8512 | ValAUC: 0.8047 | ValF1: 0.4911 | Time: 0.07s\n",
431 | "Epoch 004 | TrainLoss: 0.2971 | ValLoss: 0.3175 | ValAcc: 0.8688 | ValAUC: 0.8550 | ValF1: 0.6018 | Time: 0.07s\n",
432 | "Epoch 005 | TrainLoss: 0.2589 | ValLoss: 0.2886 | ValAcc: 0.8712 | ValAUC: 0.8808 | ValF1: 0.6189 | Time: 0.07s\n",
433 | "Epoch 006 | TrainLoss: 0.2375 | ValLoss: 0.2643 | ValAcc: 0.8775 | ValAUC: 0.8945 | ValF1: 0.6479 | Time: 0.07s\n",
434 | "Epoch 007 | TrainLoss: 0.2188 | ValLoss: 0.2462 | ValAcc: 0.8912 | ValAUC: 0.9017 | ValF1: 0.7092 | Time: 0.07s\n",
435 | "Epoch 008 | TrainLoss: 0.2062 | ValLoss: 0.2279 | ValAcc: 0.9125 | ValAUC: 0.9067 | ValF1: 0.7929 | Time: 0.07s\n",
436 | "Epoch 009 | TrainLoss: 0.1965 | ValLoss: 0.2168 | ValAcc: 0.9200 | ValAUC: 0.9133 | ValF1: 0.8123 | Time: 0.07s\n",
437 | "Epoch 010 | TrainLoss: 0.1835 | ValLoss: 0.2026 | ValAcc: 0.9413 | ValAUC: 0.9144 | ValF1: 0.8717 | Time: 0.07s\n",
438 | "Epoch 011 | TrainLoss: 0.1750 | ValLoss: 0.1998 | ValAcc: 0.9463 | ValAUC: 0.9169 | ValF1: 0.8807 | Time: 0.07s\n",
439 | "Epoch 012 | TrainLoss: 0.1680 | ValLoss: 0.1945 | ValAcc: 0.9375 | ValAUC: 0.9203 | ValF1: 0.8584 | Time: 0.07s\n",
440 | "Epoch 013 | TrainLoss: 0.1612 | ValLoss: 0.1865 | ValAcc: 0.9550 | ValAUC: 0.9199 | ValF1: 0.9044 | Time: 0.07s\n",
441 | "[EarlyStop TRAINING] no improve (1/50). best=-0.920284, curr=-0.919936\n",
442 | "Epoch 014 | TrainLoss: 0.1590 | ValLoss: 0.1786 | ValAcc: 0.9575 | ValAUC: 0.9222 | ValF1: 0.9083 | Time: 0.07s\n",
443 | "Epoch 015 | TrainLoss: 0.1521 | ValLoss: 0.1729 | ValAcc: 0.9650 | ValAUC: 0.9227 | ValF1: 0.9268 | Time: 0.07s\n",
444 | "[EarlyStop TRAINING] no improve (1/50). best=-0.922158, curr=-0.922722\n",
445 | "Epoch 016 | TrainLoss: 0.1500 | ValLoss: 0.1707 | ValAcc: 0.9600 | ValAUC: 0.9223 | ValF1: 0.9163 | Time: 0.07s\n",
446 | "[EarlyStop TRAINING] no improve (2/50). best=-0.922158, curr=-0.922326\n",
447 | "Epoch 017 | TrainLoss: 0.1441 | ValLoss: 0.1696 | ValAcc: 0.9613 | ValAUC: 0.9226 | ValF1: 0.9180 | Time: 0.07s\n",
448 | "[EarlyStop TRAINING] no improve (3/50). best=-0.922158, curr=-0.922578\n",
449 | "Epoch 018 | TrainLoss: 0.1402 | ValLoss: 0.1669 | ValAcc: 0.9600 | ValAUC: 0.9283 | ValF1: 0.9150 | Time: 0.07s\n",
450 | "Epoch 019 | TrainLoss: 0.1379 | ValLoss: 0.1611 | ValAcc: 0.9613 | ValAUC: 0.9261 | ValF1: 0.9174 | Time: 0.07s\n",
451 | "[EarlyStop TRAINING] no improve (1/50). best=-0.928258, curr=-0.926085\n",
452 | "Epoch 020 | TrainLoss: 0.1307 | ValLoss: 0.1609 | ValAcc: 0.9663 | ValAUC: 0.9257 | ValF1: 0.9286 | Time: 0.07s\n",
453 | "[EarlyStop TRAINING] no improve (2/50). best=-0.928258, curr=-0.925724\n",
454 | "Epoch 021 | TrainLoss: 0.1305 | ValLoss: 0.1588 | ValAcc: 0.9650 | ValAUC: 0.9254 | ValF1: 0.9268 | Time: 0.07s\n",
455 | "[EarlyStop TRAINING] no improve (3/50). best=-0.928258, curr=-0.925436\n",
456 | "Epoch 022 | TrainLoss: 0.1302 | ValLoss: 0.1556 | ValAcc: 0.9637 | ValAUC: 0.9291 | ValF1: 0.9233 | Time: 0.07s\n",
457 | "[EarlyStop TRAINING] no improve (4/50). best=-0.928258, curr=-0.929087\n",
458 | "Epoch 023 | TrainLoss: 0.1240 | ValLoss: 0.1555 | ValAcc: 0.9650 | ValAUC: 0.9262 | ValF1: 0.9268 | Time: 0.07s\n",
459 | "[EarlyStop TRAINING] no improve (5/50). best=-0.928258, curr=-0.926193\n",
460 | "Epoch 024 | TrainLoss: 0.1186 | ValLoss: 0.1511 | ValAcc: 0.9637 | ValAUC: 0.9299 | ValF1: 0.9239 | Time: 0.07s\n",
461 | "Epoch 025 | TrainLoss: 0.1205 | ValLoss: 0.1556 | ValAcc: 0.9600 | ValAUC: 0.9273 | ValF1: 0.9150 | Time: 0.07s\n",
462 | "[EarlyStop TRAINING] no improve (1/50). best=-0.929868, curr=-0.927334\n",
463 | "Epoch 026 | TrainLoss: 0.1175 | ValLoss: 0.1524 | ValAcc: 0.9675 | ValAUC: 0.9271 | ValF1: 0.9330 | Time: 0.07s\n",
464 | "[EarlyStop TRAINING] no improve (2/50). best=-0.929868, curr=-0.927093\n",
465 | "Epoch 027 | TrainLoss: 0.1147 | ValLoss: 0.1469 | ValAcc: 0.9675 | ValAUC: 0.9283 | ValF1: 0.9330 | Time: 0.07s\n",
466 | "[EarlyStop TRAINING] no improve (3/50). best=-0.929868, curr=-0.928270\n",
467 | "Epoch 028 | TrainLoss: 0.1148 | ValLoss: 0.1520 | ValAcc: 0.9637 | ValAUC: 0.9266 | ValF1: 0.9250 | Time: 0.07s\n",
468 | "[EarlyStop TRAINING] no improve (4/50). best=-0.929868, curr=-0.926577\n",
469 | "Epoch 029 | TrainLoss: 0.1127 | ValLoss: 0.1524 | ValAcc: 0.9663 | ValAUC: 0.9304 | ValF1: 0.9297 | Time: 0.07s\n",
470 | "[EarlyStop TRAINING] no improve (5/50). best=-0.929868, curr=-0.930408\n",
471 | "Epoch 030 | TrainLoss: 0.1154 | ValLoss: 0.1515 | ValAcc: 0.9650 | ValAUC: 0.9277 | ValF1: 0.9268 | Time: 0.07s\n",
472 | "[EarlyStop TRAINING] no improve (6/50). best=-0.929868, curr=-0.927658\n",
473 | "Epoch 031 | TrainLoss: 0.1106 | ValLoss: 0.1499 | ValAcc: 0.9650 | ValAUC: 0.9312 | ValF1: 0.9268 | Time: 0.07s\n",
474 | "Epoch 032 | TrainLoss: 0.1100 | ValLoss: 0.1485 | ValAcc: 0.9663 | ValAUC: 0.9305 | ValF1: 0.9302 | Time: 0.07s\n",
475 | "[EarlyStop TRAINING] no improve (1/50). best=-0.931189, curr=-0.930480\n",
476 | "Epoch 033 | TrainLoss: 0.1067 | ValLoss: 0.1470 | ValAcc: 0.9688 | ValAUC: 0.9312 | ValF1: 0.9354 | Time: 0.07s\n",
477 | "[EarlyStop TRAINING] no improve (2/50). best=-0.931189, curr=-0.931237\n",
478 | "Epoch 034 | TrainLoss: 0.1061 | ValLoss: 0.1527 | ValAcc: 0.9650 | ValAUC: 0.9299 | ValF1: 0.9273 | Time: 0.07s\n",
479 | "[EarlyStop TRAINING] no improve (3/50). best=-0.931189, curr=-0.929856\n",
480 | "Epoch 035 | TrainLoss: 0.1047 | ValLoss: 0.1492 | ValAcc: 0.9625 | ValAUC: 0.9280 | ValF1: 0.9215 | Time: 0.07s\n",
481 | "[EarlyStop TRAINING] no improve (4/50). best=-0.931189, curr=-0.927970\n",
482 | "Epoch 036 | TrainLoss: 0.1035 | ValLoss: 0.1483 | ValAcc: 0.9663 | ValAUC: 0.9278 | ValF1: 0.9307 | Time: 0.07s\n",
483 | "[EarlyStop TRAINING] no improve (5/50). best=-0.931189, curr=-0.927838\n",
484 | "Epoch 037 | TrainLoss: 0.1021 | ValLoss: 0.1567 | ValAcc: 0.9650 | ValAUC: 0.9291 | ValF1: 0.9262 | Time: 0.07s\n",
485 | "[EarlyStop TRAINING] no improve (6/50). best=-0.931189, curr=-0.929075\n",
486 | "Epoch 038 | TrainLoss: 0.0999 | ValLoss: 0.1498 | ValAcc: 0.9675 | ValAUC: 0.9279 | ValF1: 0.9325 | Time: 0.07s\n",
487 | "[EarlyStop TRAINING] no improve (7/50). best=-0.931189, curr=-0.927898\n",
488 | "Epoch 039 | TrainLoss: 0.1008 | ValLoss: 0.1490 | ValAcc: 0.9663 | ValAUC: 0.9317 | ValF1: 0.9297 | Time: 0.07s\n",
489 | "[EarlyStop TRAINING] no improve (8/50). best=-0.931189, curr=-0.931681\n",
490 | "Epoch 040 | TrainLoss: 0.0934 | ValLoss: 0.1425 | ValAcc: 0.9650 | ValAUC: 0.9309 | ValF1: 0.9284 | Time: 0.07s\n",
491 | "[EarlyStop TRAINING] no improve (9/50). best=-0.931189, curr=-0.930924\n",
492 | "Epoch 041 | TrainLoss: 0.0974 | ValLoss: 0.1554 | ValAcc: 0.9600 | ValAUC: 0.9309 | ValF1: 0.9163 | Time: 0.07s\n",
493 | "[EarlyStop TRAINING] no improve (10/50). best=-0.931189, curr=-0.930924\n",
494 | "Epoch 042 | TrainLoss: 0.0959 | ValLoss: 0.1494 | ValAcc: 0.9637 | ValAUC: 0.9328 | ValF1: 0.9244 | Time: 0.07s\n",
495 | "Epoch 043 | TrainLoss: 0.0944 | ValLoss: 0.1482 | ValAcc: 0.9625 | ValAUC: 0.9292 | ValF1: 0.9244 | Time: 0.07s\n",
496 | "[EarlyStop TRAINING] no improve (1/50). best=-0.932774, curr=-0.929207\n",
497 | "Epoch 044 | TrainLoss: 0.0944 | ValLoss: 0.1477 | ValAcc: 0.9675 | ValAUC: 0.9328 | ValF1: 0.9330 | Time: 0.07s\n",
498 | "[EarlyStop TRAINING] no improve (2/50). best=-0.932774, curr=-0.932798\n",
499 | "Epoch 045 | TrainLoss: 0.0895 | ValLoss: 0.1522 | ValAcc: 0.9663 | ValAUC: 0.9321 | ValF1: 0.9297 | Time: 0.07s\n",
500 | "[EarlyStop TRAINING] no improve (3/50). best=-0.932774, curr=-0.932149\n",
501 | "Epoch 046 | TrainLoss: 0.0899 | ValLoss: 0.1515 | ValAcc: 0.9637 | ValAUC: 0.9296 | ValF1: 0.9250 | Time: 0.07s\n",
502 | "[EarlyStop TRAINING] no improve (4/50). best=-0.932774, curr=-0.929603\n",
503 | "Epoch 047 | TrainLoss: 0.0879 | ValLoss: 0.1504 | ValAcc: 0.9650 | ValAUC: 0.9282 | ValF1: 0.9279 | Time: 0.07s\n",
504 | "[EarlyStop TRAINING] no improve (5/50). best=-0.932774, curr=-0.928198\n",
505 | "Epoch 048 | TrainLoss: 0.0885 | ValLoss: 0.1496 | ValAcc: 0.9625 | ValAUC: 0.9324 | ValF1: 0.9209 | Time: 0.07s\n",
506 | "[EarlyStop TRAINING] no improve (6/50). best=-0.932774, curr=-0.932365\n",
507 | "Epoch 049 | TrainLoss: 0.0881 | ValLoss: 0.1506 | ValAcc: 0.9625 | ValAUC: 0.9311 | ValF1: 0.9227 | Time: 0.07s\n",
508 | "[EarlyStop TRAINING] no improve (7/50). best=-0.932774, curr=-0.931080\n",
509 | "Epoch 050 | TrainLoss: 0.0889 | ValLoss: 0.1470 | ValAcc: 0.9663 | ValAUC: 0.9321 | ValF1: 0.9307 | Time: 0.07s\n",
510 | "[EarlyStop TRAINING] no improve (8/50). best=-0.932774, curr=-0.932089\n",
511 | "Epoch 051 | TrainLoss: 0.0854 | ValLoss: 0.1552 | ValAcc: 0.9613 | ValAUC: 0.9301 | ValF1: 0.9192 | Time: 0.07s\n",
512 | "[EarlyStop TRAINING] no improve (9/50). best=-0.932774, curr=-0.930120\n",
513 | "Epoch 052 | TrainLoss: 0.0853 | ValLoss: 0.1479 | ValAcc: 0.9650 | ValAUC: 0.9310 | ValF1: 0.9279 | Time: 0.07s\n",
514 | "[EarlyStop TRAINING] no improve (10/50). best=-0.932774, curr=-0.930960\n",
515 | "Epoch 053 | TrainLoss: 0.0836 | ValLoss: 0.1521 | ValAcc: 0.9600 | ValAUC: 0.9294 | ValF1: 0.9188 | Time: 0.07s\n",
516 | "[EarlyStop TRAINING] no improve (11/50). best=-0.932774, curr=-0.929375\n",
517 | "Epoch 054 | TrainLoss: 0.0848 | ValLoss: 0.1556 | ValAcc: 0.9637 | ValAUC: 0.9287 | ValF1: 0.9250 | Time: 0.07s\n",
518 | "[EarlyStop TRAINING] no improve (12/50). best=-0.932774, curr=-0.928679\n",
519 | "Epoch 055 | TrainLoss: 0.0813 | ValLoss: 0.1457 | ValAcc: 0.9688 | ValAUC: 0.9304 | ValF1: 0.9358 | Time: 0.07s\n",
520 | "[EarlyStop TRAINING] no improve (13/50). best=-0.932774, curr=-0.930432\n",
521 | "Epoch 056 | TrainLoss: 0.0790 | ValLoss: 0.1501 | ValAcc: 0.9663 | ValAUC: 0.9313 | ValF1: 0.9302 | Time: 0.07s\n",
522 | "[EarlyStop TRAINING] no improve (14/50). best=-0.932774, curr=-0.931273\n",
523 | "Epoch 057 | TrainLoss: 0.0786 | ValLoss: 0.1516 | ValAcc: 0.9650 | ValAUC: 0.9286 | ValF1: 0.9273 | Time: 0.07s\n",
524 | "[EarlyStop TRAINING] no improve (15/50). best=-0.932774, curr=-0.928607\n",
525 | "Epoch 058 | TrainLoss: 0.0790 | ValLoss: 0.1556 | ValAcc: 0.9625 | ValAUC: 0.9300 | ValF1: 0.9227 | Time: 0.07s\n",
526 | "[EarlyStop TRAINING] no improve (16/50). best=-0.932774, curr=-0.930000\n",
527 | "Epoch 059 | TrainLoss: 0.0801 | ValLoss: 0.1562 | ValAcc: 0.9637 | ValAUC: 0.9296 | ValF1: 0.9244 | Time: 0.07s\n",
528 | "[EarlyStop TRAINING] no improve (17/50). best=-0.932774, curr=-0.929579\n",
529 | "Epoch 060 | TrainLoss: 0.0751 | ValLoss: 0.1561 | ValAcc: 0.9625 | ValAUC: 0.9308 | ValF1: 0.9221 | Time: 0.07s\n",
530 | "[EarlyStop TRAINING] no improve (18/50). best=-0.932774, curr=-0.930792\n",
531 | "Epoch 061 | TrainLoss: 0.0762 | ValLoss: 0.1549 | ValAcc: 0.9675 | ValAUC: 0.9313 | ValF1: 0.9325 | Time: 0.07s\n",
532 | "[EarlyStop TRAINING] no improve (19/50). best=-0.932774, curr=-0.931321\n",
533 | "Epoch 062 | TrainLoss: 0.0745 | ValLoss: 0.1514 | ValAcc: 0.9625 | ValAUC: 0.9285 | ValF1: 0.9244 | Time: 0.07s\n",
534 | "[EarlyStop TRAINING] no improve (20/50). best=-0.932774, curr=-0.928487\n",
535 | "Epoch 063 | TrainLoss: 0.0741 | ValLoss: 0.1551 | ValAcc: 0.9625 | ValAUC: 0.9263 | ValF1: 0.9238 | Time: 0.07s\n",
536 | "[EarlyStop TRAINING] no improve (21/50). best=-0.932774, curr=-0.926337\n",
537 | "Epoch 064 | TrainLoss: 0.0734 | ValLoss: 0.1568 | ValAcc: 0.9613 | ValAUC: 0.9262 | ValF1: 0.9216 | Time: 0.07s\n",
538 | "[EarlyStop TRAINING] no improve (22/50). best=-0.932774, curr=-0.926181\n",
539 | "Epoch 065 | TrainLoss: 0.0744 | ValLoss: 0.1549 | ValAcc: 0.9650 | ValAUC: 0.9260 | ValF1: 0.9289 | Time: 0.07s\n",
540 | "[EarlyStop TRAINING] no improve (23/50). best=-0.932774, curr=-0.925965\n",
541 | "Epoch 066 | TrainLoss: 0.0713 | ValLoss: 0.1528 | ValAcc: 0.9637 | ValAUC: 0.9269 | ValF1: 0.9261 | Time: 0.07s\n",
542 | "[EarlyStop TRAINING] no improve (24/50). best=-0.932774, curr=-0.926913\n",
543 | "Epoch 067 | TrainLoss: 0.0724 | ValLoss: 0.1587 | ValAcc: 0.9587 | ValAUC: 0.9246 | ValF1: 0.9159 | Time: 0.07s\n",
544 | "[EarlyStop TRAINING] no improve (25/50). best=-0.932774, curr=-0.924644\n",
545 | "Epoch 068 | TrainLoss: 0.0670 | ValLoss: 0.1577 | ValAcc: 0.9637 | ValAUC: 0.9261 | ValF1: 0.9256 | Time: 0.07s\n",
546 | "[EarlyStop TRAINING] no improve (26/50). best=-0.932774, curr=-0.926109\n",
547 | "Epoch 069 | TrainLoss: 0.0703 | ValLoss: 0.1552 | ValAcc: 0.9613 | ValAUC: 0.9290 | ValF1: 0.9210 | Time: 0.07s\n",
548 | "[EarlyStop TRAINING] no improve (27/50). best=-0.932774, curr=-0.928967\n",
549 | "Epoch 070 | TrainLoss: 0.0673 | ValLoss: 0.1601 | ValAcc: 0.9650 | ValAUC: 0.9306 | ValF1: 0.9273 | Time: 0.07s\n",
550 | "[EarlyStop TRAINING] no improve (28/50). best=-0.932774, curr=-0.930600\n",
551 | "Epoch 071 | TrainLoss: 0.0698 | ValLoss: 0.1594 | ValAcc: 0.9600 | ValAUC: 0.9305 | ValF1: 0.9182 | Time: 0.07s\n",
552 | "[EarlyStop TRAINING] no improve (29/50). best=-0.932774, curr=-0.930468\n",
553 | "Epoch 072 | TrainLoss: 0.0683 | ValLoss: 0.1585 | ValAcc: 0.9625 | ValAUC: 0.9278 | ValF1: 0.9227 | Time: 0.07s\n",
554 | "[EarlyStop TRAINING] no improve (30/50). best=-0.932774, curr=-0.927838\n",
555 | "Epoch 073 | TrainLoss: 0.0656 | ValLoss: 0.1548 | ValAcc: 0.9625 | ValAUC: 0.9283 | ValF1: 0.9244 | Time: 0.07s\n",
556 | "[EarlyStop TRAINING] no improve (31/50). best=-0.932774, curr=-0.928258\n",
557 | "Epoch 074 | TrainLoss: 0.0647 | ValLoss: 0.1610 | ValAcc: 0.9625 | ValAUC: 0.9265 | ValF1: 0.9227 | Time: 0.07s\n",
558 | "[EarlyStop TRAINING] no improve (32/50). best=-0.932774, curr=-0.926457\n",
559 | "Epoch 075 | TrainLoss: 0.0618 | ValLoss: 0.1574 | ValAcc: 0.9625 | ValAUC: 0.9267 | ValF1: 0.9238 | Time: 0.07s\n",
560 | "[EarlyStop TRAINING] no improve (33/50). best=-0.932774, curr=-0.926685\n",
561 | "Epoch 076 | TrainLoss: 0.0649 | ValLoss: 0.1592 | ValAcc: 0.9625 | ValAUC: 0.9297 | ValF1: 0.9227 | Time: 0.07s\n",
562 | "[EarlyStop TRAINING] no improve (34/50). best=-0.932774, curr=-0.929663\n",
563 | "Epoch 077 | TrainLoss: 0.0645 | ValLoss: 0.1579 | ValAcc: 0.9625 | ValAUC: 0.9285 | ValF1: 0.9244 | Time: 0.07s\n",
564 | "[EarlyStop TRAINING] no improve (35/50). best=-0.932774, curr=-0.928462\n",
565 | "Epoch 078 | TrainLoss: 0.0631 | ValLoss: 0.1680 | ValAcc: 0.9587 | ValAUC: 0.9268 | ValF1: 0.9153 | Time: 0.07s\n",
566 | "[EarlyStop TRAINING] no improve (36/50). best=-0.932774, curr=-0.926829\n",
567 | "Epoch 079 | TrainLoss: 0.0616 | ValLoss: 0.1701 | ValAcc: 0.9575 | ValAUC: 0.9266 | ValF1: 0.9111 | Time: 0.08s\n",
568 | "[EarlyStop TRAINING] no improve (37/50). best=-0.932774, curr=-0.926625\n",
569 | "Epoch 080 | TrainLoss: 0.0631 | ValLoss: 0.1588 | ValAcc: 0.9587 | ValAUC: 0.9266 | ValF1: 0.9165 | Time: 0.07s\n",
570 | "[EarlyStop TRAINING] no improve (38/50). best=-0.932774, curr=-0.926649\n",
571 | "Epoch 081 | TrainLoss: 0.0612 | ValLoss: 0.1628 | ValAcc: 0.9587 | ValAUC: 0.9277 | ValF1: 0.9171 | Time: 0.07s\n",
572 | "[EarlyStop TRAINING] no improve (39/50). best=-0.932774, curr=-0.927694\n",
573 | "Epoch 082 | TrainLoss: 0.0611 | ValLoss: 0.1618 | ValAcc: 0.9625 | ValAUC: 0.9280 | ValF1: 0.9238 | Time: 0.07s\n",
574 | "[EarlyStop TRAINING] no improve (40/50). best=-0.932774, curr=-0.928042\n",
575 | "Epoch 083 | TrainLoss: 0.0601 | ValLoss: 0.1600 | ValAcc: 0.9575 | ValAUC: 0.9290 | ValF1: 0.9149 | Time: 0.07s\n",
576 | "[EarlyStop TRAINING] no improve (41/50). best=-0.932774, curr=-0.929027\n",
577 | "Epoch 084 | TrainLoss: 0.0641 | ValLoss: 0.1614 | ValAcc: 0.9550 | ValAUC: 0.9310 | ValF1: 0.9112 | Time: 0.07s\n",
578 | "[EarlyStop TRAINING] no improve (42/50). best=-0.932774, curr=-0.930960\n",
579 | "Epoch 085 | TrainLoss: 0.0615 | ValLoss: 0.1635 | ValAcc: 0.9537 | ValAUC: 0.9291 | ValF1: 0.9077 | Time: 0.07s\n",
580 | "[EarlyStop TRAINING] no improve (43/50). best=-0.932774, curr=-0.929087\n",
581 | "Epoch 086 | TrainLoss: 0.0595 | ValLoss: 0.1597 | ValAcc: 0.9625 | ValAUC: 0.9305 | ValF1: 0.9238 | Time: 0.07s\n",
582 | "[EarlyStop TRAINING] no improve (44/50). best=-0.932774, curr=-0.930492\n",
583 | "Epoch 087 | TrainLoss: 0.0599 | ValLoss: 0.1655 | ValAcc: 0.9613 | ValAUC: 0.9287 | ValF1: 0.9204 | Time: 0.07s\n",
584 | "[EarlyStop TRAINING] no improve (45/50). best=-0.932774, curr=-0.928727\n",
585 | "Epoch 088 | TrainLoss: 0.0586 | ValLoss: 0.1715 | ValAcc: 0.9587 | ValAUC: 0.9277 | ValF1: 0.9159 | Time: 0.07s\n",
586 | "[EarlyStop TRAINING] no improve (46/50). best=-0.932774, curr=-0.927682\n",
587 | "Epoch 089 | TrainLoss: 0.0570 | ValLoss: 0.1698 | ValAcc: 0.9625 | ValAUC: 0.9266 | ValF1: 0.9227 | Time: 0.07s\n",
588 | "[EarlyStop TRAINING] no improve (47/50). best=-0.932774, curr=-0.926589\n",
589 | "Epoch 090 | TrainLoss: 0.0592 | ValLoss: 0.1741 | ValAcc: 0.9563 | ValAUC: 0.9266 | ValF1: 0.9088 | Time: 0.07s\n",
590 | "[EarlyStop TRAINING] no improve (48/50). best=-0.932774, curr=-0.926577\n",
591 | "Epoch 091 | TrainLoss: 0.0547 | ValLoss: 0.1732 | ValAcc: 0.9625 | ValAUC: 0.9256 | ValF1: 0.9227 | Time: 0.07s\n",
592 | "[EarlyStop TRAINING] no improve (49/50). best=-0.932774, curr=-0.925592\n",
593 | "Epoch 092 | TrainLoss: 0.0566 | ValLoss: 0.1760 | ValAcc: 0.9587 | ValAUC: 0.9273 | ValF1: 0.9134 | Time: 0.07s\n",
594 | "[EarlyStop TRAINING] no improve (50/50). best=-0.932774, curr=-0.927274\n",
595 | "[EarlyStop TRAINING] restoring best weights and stopping.\n",
596 | "Restoring best model from epoch with val score: -0.9327977327040627\n"
597 | ]
598 | },
599 | {
600 | "name": "stderr",
601 | "output_type": "stream",
602 | "text": [
603 | "/home/smarton/anaconda3/envs/ReMeDe/lib/python3.12/site-packages/torch/_inductor/lowering.py:7242: UserWarning: \n",
604 | "Online softmax is disabled on the fly since Inductor decides to\n",
605 | "split the reduction. Cut an issue to PyTorch if this is an\n",
606 | "important use case and you want to speed it up with online\n",
607 | "softmax.\n",
608 | "\n",
609 | " warnings.warn(\n"
610 | ]
611 | }
612 | ],
613 | "source": [
614 | "from GRANDE import GRANDE\n",
615 | "\n",
616 | "params = {\n",
617 | " 'depth': 5,\n",
618 | " 'n_estimators': 1024,\n",
619 | "\n",
620 | " 'learning_rate_weights': 0.001,\n",
621 | " 'learning_rate_index': 0.01,\n",
622 | " 'learning_rate_values': 0.05,\n",
623 | " 'learning_rate_leaf': 0.05,\n",
624 | " 'learning_rate_embedding': 0.02,\n",
625 | "\n",
626 | " 'use_category_embeddings': False,\n",
627 | " 'embedding_dim_cat': 8,\n",
628 | " 'use_numeric_embeddings': False,\n",
629 | " 'embedding_dim_num': 8,\n",
630 | " 'embedding_threshold': 1,\n",
631 | " 'loo_cardinality': 10,\n",
632 | "\n",
633 | "\n",
634 | " 'dropout': 0.2,\n",
635 | " 'selected_variables': 0.8,\n",
636 | " 'data_subset_fraction': 1.0,\n",
637 | " 'bootstrap': False,\n",
638 | " 'missing_values': False,\n",
639 | "\n",
640 | " 'optimizer': 'adam', #nadam, radam, adamw, adam \n",
641 | " 'cosine_decay_restarts': False,\n",
642 | " 'reduce_on_plateau_scheduler': True,\n",
643 | " 'label_smoothing': 0.0,\n",
644 | " 'use_class_weights': False,\n",
645 | " 'focal_loss': False,\n",
646 | " 'swa': False,\n",
647 | " 'es_metric': True, # if True use AUC for binary, MSE for regression, val_loss for multiclass\n",
648 | "\n",
649 | "\n",
650 | " 'epochs': 250,\n",
651 | " 'batch_size': 256,\n",
652 | " 'early_stopping_epochs': 50,\n",
653 | "\n",
654 | " 'use_freq_enc': False,\n",
655 | " 'use_robust_scale_smoothing': False,\n",
656 | " 'problem_type': 'binary',\n",
657 | " \n",
658 | " 'random_seed': 42,\n",
659 | " 'verbose': 2,\n",
660 | "}\n",
661 | "\n",
662 | "model_grande = GRANDE(params=params)\n",
663 | "\n",
664 | "model_grande.fit(X=X_train,\n",
665 | " y=y_train,\n",
666 | " X_val=X_valid,\n",
667 | " y_val=y_valid)\n",
668 | "\n",
669 | "preds_grande = model_grande.predict_proba(X_test)"
670 | ]
671 | },
672 | {
673 | "cell_type": "code",
674 | "execution_count": 5,
675 | "id": "b251df9d-67f0-4dd1-a0cc-379621e33fad",
676 | "metadata": {
677 | "execution": {
678 | "iopub.execute_input": "2024-03-15T11:57:30.601353Z",
679 | "iopub.status.busy": "2024-03-15T11:57:30.601004Z",
680 | "iopub.status.idle": "2024-03-15T11:57:30.604147Z",
681 | "shell.execute_reply": "2024-03-15T11:57:30.603752Z",
682 | "shell.execute_reply.started": "2024-03-15T11:57:30.601338Z"
683 | }
684 | },
685 | "outputs": [],
686 | "source": [
687 | "def calculate_sample_weights(y_data):\n",
688 | " class_weights = sklearn.utils.class_weight.compute_class_weight(class_weight = 'balanced', classes = np.unique(y_data), y = y_data)\n",
689 | " sample_weights = sklearn.utils.class_weight.compute_sample_weight(class_weight = 'balanced', y =y_data)\n",
690 | " return sample_weights"
691 | ]
692 | },
693 | {
694 | "cell_type": "code",
695 | "execution_count": 6,
696 | "id": "c17e805c-0110-4e39-97d3-34573804296b",
697 | "metadata": {
698 | "execution": {
699 | "iopub.execute_input": "2024-03-15T11:57:30.605592Z",
700 | "iopub.status.busy": "2024-03-15T11:57:30.605338Z",
701 | "iopub.status.idle": "2024-03-15T11:57:30.611465Z",
702 | "shell.execute_reply": "2024-03-15T11:57:30.611074Z",
703 | "shell.execute_reply.started": "2024-03-15T11:57:30.605578Z"
704 | }
705 | },
706 | "outputs": [],
707 | "source": [
708 | "try:\n",
709 | " y_train = y_train.values.codes.astype(np.float64)\n",
710 | " y_valid = y_valid.values.codes.astype(np.float64)\n",
711 | " y_test = y_test.values.codes.astype(np.float64)\n",
712 | "except:\n",
713 | " y_train = y_train.values.astype(np.float64)\n",
714 | " y_valid = y_valid.values.astype(np.float64)\n",
715 | " y_test = y_test.values.astype(np.float64)"
716 | ]
717 | },
718 | {
719 | "cell_type": "code",
720 | "execution_count": 7,
721 | "id": "cfc121e3-2d7b-4fb1-8616-0318fcc20ac1",
722 | "metadata": {
723 | "execution": {
724 | "iopub.execute_input": "2024-03-15T11:57:30.612164Z",
725 | "iopub.status.busy": "2024-03-15T11:57:30.611960Z",
726 | "iopub.status.idle": "2024-03-15T11:57:30.627515Z",
727 | "shell.execute_reply": "2024-03-15T11:57:30.627108Z",
728 | "shell.execute_reply.started": "2024-03-15T11:57:30.612151Z"
729 | }
730 | },
731 | "outputs": [],
732 | "source": [
733 | "binary_indices = []\n",
734 | "low_cardinality_indices = []\n",
735 | "high_cardinality_indices = []\n",
736 | "num_columns = []\n",
737 | "for column_index, column in enumerate(X_train.columns):\n",
738 | " if column_index in categorical_feature_indices or X_train.iloc[:,column_index].dtype.name == 'category' or X_train.iloc[:,column_index].dtype.name == 'object':\n",
739 | " if len(X_train.iloc[:,column_index].unique()) <= 2:\n",
740 | " binary_indices.append(column)\n",
741 | " if len(X_train.iloc[:,column_index].unique()) < 5:\n",
742 | " low_cardinality_indices.append(column)\n",
743 | " else:\n",
744 | " high_cardinality_indices.append(column)\n",
745 | " else:\n",
746 | " num_columns.append(column) \n",
747 | "cat_columns = [col for col in X_train.columns if col not in num_columns]\n",
748 | "categorical_feature_indices = [X_train.columns.get_loc(col) for col in cat_columns]\n"
749 | ]
750 | },
751 | {
752 | "cell_type": "code",
753 | "execution_count": 8,
754 | "id": "f7e55785-bf56-4aad-9825-7c5737be5ab7",
755 | "metadata": {
756 | "execution": {
757 | "iopub.execute_input": "2024-03-15T11:57:30.628204Z",
758 | "iopub.status.busy": "2024-03-15T11:57:30.627992Z",
759 | "iopub.status.idle": "2024-03-15T11:57:30.691371Z",
760 | "shell.execute_reply": "2024-03-15T11:57:30.690745Z",
761 | "shell.execute_reply.started": "2024-03-15T11:57:30.628191Z"
762 | }
763 | },
764 | "outputs": [],
765 | "source": [
766 | "if len(num_columns) > 0:\n",
767 | " mean_train_num = X_train[num_columns].mean(axis=0).iloc[0]\n",
768 | " X_train[num_columns] = X_train[num_columns].fillna(mean_train_num)\n",
769 | " X_valid[num_columns] = X_valid[num_columns].fillna(mean_train_num)\n",
770 | " X_test[num_columns] = X_test[num_columns].fillna(mean_train_num)\n",
771 | "if len(cat_columns) > 0:\n",
772 | " mode_train_cat = X_train[cat_columns].mode(axis=0).iloc[0]\n",
773 | " X_train[cat_columns] = X_train[cat_columns].fillna(mode_train_cat)\n",
774 | " X_valid[cat_columns] = X_valid[cat_columns].fillna(mode_train_cat)\n",
775 | " X_test[cat_columns] = X_test[cat_columns].fillna(mode_train_cat)\n",
776 | "\n",
777 | "X_train_raw = X_train.copy()\n",
778 | "X_valid_raw = X_valid.copy()\n",
779 | "X_test_raw = X_test.copy()"
780 | ]
781 | },
782 | {
783 | "cell_type": "code",
784 | "execution_count": 9,
785 | "id": "b5d5f00f",
786 | "metadata": {
787 | "execution": {
788 | "iopub.execute_input": "2024-03-15T11:57:30.692101Z",
789 | "iopub.status.busy": "2024-03-15T11:57:30.691967Z",
790 | "iopub.status.idle": "2024-03-15T11:57:31.276758Z",
791 | "shell.execute_reply": "2024-03-15T11:57:31.276086Z",
792 | "shell.execute_reply.started": "2024-03-15T11:57:30.692089Z"
793 | }
794 | },
795 | "outputs": [],
796 | "source": [
797 | "encoder_ordinal = ce.OrdinalEncoder(cols=binary_indices)\n",
798 | "encoder_ordinal.fit(X_train)\n",
799 | "X_train = encoder_ordinal.transform(X_train)\n",
800 | "X_valid = encoder_ordinal.transform(X_valid) \n",
801 | "X_test = encoder_ordinal.transform(X_test) \n",
802 | "\n",
803 | "encoder = ce.LeaveOneOutEncoder(cols=high_cardinality_indices)\n",
804 | "encoder.fit(X_train, y_train)\n",
805 | "X_train = encoder.transform(X_train)\n",
806 | "X_valid = encoder.transform(X_valid)\n",
807 | "X_test = encoder.transform(X_test)\n",
808 | "\n",
809 | "encoder = ce.OneHotEncoder(cols=low_cardinality_indices)\n",
810 | "encoder.fit(X_train)\n",
811 | "X_train = encoder.transform(X_train)\n",
812 | "X_valid = encoder.transform(X_valid)\n",
813 | "X_test = encoder.transform(X_test)\n",
814 | "\n",
815 | "X_train = X_train.astype(np.float32)\n",
816 | "X_valid = X_valid.astype(np.float32)\n",
817 | "X_test = X_test.astype(np.float32)"
818 | ]
819 | },
820 | {
821 | "cell_type": "code",
822 | "execution_count": 10,
823 | "id": "361f7d9f-7e58-4bc1-8855-153b9e972742",
824 | "metadata": {
825 | "execution": {
826 | "iopub.execute_input": "2024-03-15T11:57:31.277594Z",
827 | "iopub.status.busy": "2024-03-15T11:57:31.277439Z",
828 | "iopub.status.idle": "2024-03-15T11:57:31.575860Z",
829 | "shell.execute_reply": "2024-03-15T11:57:31.575224Z",
830 | "shell.execute_reply.started": "2024-03-15T11:57:31.277581Z"
831 | }
832 | },
833 | "outputs": [
834 | {
835 | "name": "stdout",
836 | "output_type": "stream",
837 | "text": [
838 | "[0]\tvalidation_0-logloss:0.28315\n",
839 | "[1]\tvalidation_0-logloss:0.23763\n",
840 | "[2]\tvalidation_0-logloss:0.20618\n",
841 | "[3]\tvalidation_0-logloss:0.18811\n",
842 | "[4]\tvalidation_0-logloss:0.17719\n",
843 | "[5]\tvalidation_0-logloss:0.16986\n",
844 | "[6]\tvalidation_0-logloss:0.16414\n",
845 | "[7]\tvalidation_0-logloss:0.16136\n",
846 | "[8]\tvalidation_0-logloss:0.15800\n",
847 | "[9]\tvalidation_0-logloss:0.15785\n",
848 | "[10]\tvalidation_0-logloss:0.15713\n",
849 | "[11]\tvalidation_0-logloss:0.15692\n",
850 | "[12]\tvalidation_0-logloss:0.15799\n",
851 | "[13]\tvalidation_0-logloss:0.15559\n",
852 | "[14]\tvalidation_0-logloss:0.15664\n",
853 | "[15]\tvalidation_0-logloss:0.15613\n",
854 | "[16]\tvalidation_0-logloss:0.15517\n",
855 | "[17]\tvalidation_0-logloss:0.15363\n",
856 | "[18]\tvalidation_0-logloss:0.15438\n",
857 | "[19]\tvalidation_0-logloss:0.15495\n",
858 | "[20]\tvalidation_0-logloss:0.15732\n",
859 | "[21]\tvalidation_0-logloss:0.15925\n",
860 | "[22]\tvalidation_0-logloss:0.15958\n",
861 | "[23]\tvalidation_0-logloss:0.15969\n",
862 | "[24]\tvalidation_0-logloss:0.16180\n",
863 | "[25]\tvalidation_0-logloss:0.16142\n",
864 | "[26]\tvalidation_0-logloss:0.16153\n",
865 | "[27]\tvalidation_0-logloss:0.16176\n",
866 | "[28]\tvalidation_0-logloss:0.16241\n",
867 | "[29]\tvalidation_0-logloss:0.16312\n",
868 | "[30]\tvalidation_0-logloss:0.16496\n",
869 | "[31]\tvalidation_0-logloss:0.16497\n",
870 | "[32]\tvalidation_0-logloss:0.16543\n",
871 | "[33]\tvalidation_0-logloss:0.16679\n",
872 | "[34]\tvalidation_0-logloss:0.16724\n",
873 | "[35]\tvalidation_0-logloss:0.16847\n",
874 | "[36]\tvalidation_0-logloss:0.17016\n"
875 | ]
876 | }
877 | ],
878 | "source": [
879 | "if params['problem_type'] == 'regression':\n",
880 | " from xgboost import XGBRegressor\n",
881 | " model_xgb = XGBRegressor(n_estimators=1000, early_stopping_rounds=20)\n",
882 | " model_xgb.fit(X_train, \n",
883 | " y_train, \n",
884 | " eval_set=[(X_valid, y_valid)], \n",
885 | " )\n",
886 | " preds_xgb = model_xgb.predict(X_test)\n",
887 | "else:\n",
888 | " from xgboost import XGBClassifier\n",
889 | " model_xgb = XGBClassifier(n_estimators=1000, early_stopping_rounds=20)\n",
890 | " model_xgb.fit(X_train, \n",
891 | " y_train, \n",
892 | " #sample_weight=calculate_sample_weights(y_train), \n",
893 | " eval_set=[(X_valid, y_valid)], \n",
894 | " #sample_weight_eval_set=[calculate_sample_weights(y_valid)]\n",
895 | " )\n",
896 | "\n",
897 | "\n",
898 | " preds_xgb = model_xgb.predict_proba(X_test)"
899 | ]
900 | },
901 | {
902 | "cell_type": "code",
903 | "execution_count": 11,
904 | "id": "6b23c67e-cdde-40c2-842a-61383e1a62d0",
905 | "metadata": {
906 | "execution": {
907 | "iopub.execute_input": "2024-03-15T11:57:31.576940Z",
908 | "iopub.status.busy": "2024-03-15T11:57:31.576558Z",
909 | "iopub.status.idle": "2024-03-15T11:57:52.427379Z",
910 | "shell.execute_reply": "2024-03-15T11:57:52.426427Z",
911 | "shell.execute_reply.started": "2024-03-15T11:57:31.576923Z"
912 | }
913 | },
914 | "outputs": [
915 | {
916 | "name": "stdout",
917 | "output_type": "stream",
918 | "text": [
919 | "Learning rate set to 0.042236\n",
920 | "0:\tlearn: 0.6403699\ttest: 0.6408011\tbest: 0.6408011 (0)\ttotal: 53.5ms\tremaining: 53.4s\n",
921 | "1:\tlearn: 0.6121572\ttest: 0.6148520\tbest: 0.6148520 (1)\ttotal: 56.1ms\tremaining: 28s\n",
922 | "2:\tlearn: 0.5688593\ttest: 0.5737767\tbest: 0.5737767 (2)\ttotal: 62ms\tremaining: 20.6s\n",
923 | "3:\tlearn: 0.5353572\ttest: 0.5418521\tbest: 0.5418521 (3)\ttotal: 65.5ms\tremaining: 16.3s\n",
924 | "4:\tlearn: 0.5080383\ttest: 0.5151234\tbest: 0.5151234 (4)\ttotal: 69.9ms\tremaining: 13.9s\n",
925 | "5:\tlearn: 0.4885596\ttest: 0.4971972\tbest: 0.4971972 (5)\ttotal: 74.7ms\tremaining: 12.4s\n",
926 | "6:\tlearn: 0.4654316\ttest: 0.4759393\tbest: 0.4759393 (6)\ttotal: 79.2ms\tremaining: 11.2s\n",
927 | "7:\tlearn: 0.4445447\ttest: 0.4561330\tbest: 0.4561330 (7)\ttotal: 82.5ms\tremaining: 10.2s\n",
928 | "8:\tlearn: 0.4288701\ttest: 0.4418229\tbest: 0.4418229 (8)\ttotal: 85.3ms\tremaining: 9.39s\n",
929 | "9:\tlearn: 0.4052309\ttest: 0.4181304\tbest: 0.4181304 (9)\ttotal: 88.3ms\tremaining: 8.74s\n",
930 | "10:\tlearn: 0.3831187\ttest: 0.3959788\tbest: 0.3959788 (10)\ttotal: 94.1ms\tremaining: 8.46s\n",
931 | "11:\tlearn: 0.3703703\ttest: 0.3838890\tbest: 0.3838890 (11)\ttotal: 97ms\tremaining: 7.99s\n",
932 | "12:\tlearn: 0.3566437\ttest: 0.3710660\tbest: 0.3710660 (12)\ttotal: 102ms\tremaining: 7.74s\n",
933 | "13:\tlearn: 0.3436633\ttest: 0.3582671\tbest: 0.3582671 (13)\ttotal: 105ms\tremaining: 7.42s\n",
934 | "14:\tlearn: 0.3299345\ttest: 0.3452403\tbest: 0.3452403 (14)\ttotal: 109ms\tremaining: 7.17s\n",
935 | "15:\tlearn: 0.3192912\ttest: 0.3344290\tbest: 0.3344290 (15)\ttotal: 112ms\tremaining: 6.91s\n",
936 | "16:\tlearn: 0.3104381\ttest: 0.3260456\tbest: 0.3260456 (16)\ttotal: 117ms\tremaining: 6.74s\n",
937 | "17:\tlearn: 0.3004882\ttest: 0.3165658\tbest: 0.3165658 (17)\ttotal: 122ms\tremaining: 6.67s\n",
938 | "18:\tlearn: 0.2912218\ttest: 0.3075181\tbest: 0.3075181 (18)\ttotal: 127ms\tremaining: 6.54s\n",
939 | "19:\tlearn: 0.2848872\ttest: 0.3029670\tbest: 0.3029670 (19)\ttotal: 131ms\tremaining: 6.42s\n",
940 | "20:\tlearn: 0.2819334\ttest: 0.3005996\tbest: 0.3005996 (20)\ttotal: 135ms\tremaining: 6.3s\n",
941 | "21:\tlearn: 0.2745152\ttest: 0.2935901\tbest: 0.2935901 (21)\ttotal: 140ms\tremaining: 6.21s\n",
942 | "22:\tlearn: 0.2719255\ttest: 0.2916018\tbest: 0.2916018 (22)\ttotal: 142ms\tremaining: 6.05s\n",
943 | "23:\tlearn: 0.2654353\ttest: 0.2855682\tbest: 0.2855682 (23)\ttotal: 146ms\tremaining: 5.93s\n",
944 | "24:\tlearn: 0.2590030\ttest: 0.2792839\tbest: 0.2792839 (24)\ttotal: 150ms\tremaining: 5.86s\n",
945 | "25:\tlearn: 0.2536191\ttest: 0.2739589\tbest: 0.2739589 (25)\ttotal: 154ms\tremaining: 5.75s\n",
946 | "26:\tlearn: 0.2493147\ttest: 0.2699481\tbest: 0.2699481 (26)\ttotal: 157ms\tremaining: 5.65s\n",
947 | "27:\tlearn: 0.2445817\ttest: 0.2658077\tbest: 0.2658077 (27)\ttotal: 160ms\tremaining: 5.56s\n",
948 | "28:\tlearn: 0.2389872\ttest: 0.2601796\tbest: 0.2601796 (28)\ttotal: 163ms\tremaining: 5.46s\n",
949 | "29:\tlearn: 0.2343432\ttest: 0.2557029\tbest: 0.2557029 (29)\ttotal: 166ms\tremaining: 5.38s\n",
950 | "30:\tlearn: 0.2306654\ttest: 0.2521669\tbest: 0.2521669 (30)\ttotal: 169ms\tremaining: 5.29s\n",
951 | "31:\tlearn: 0.2266281\ttest: 0.2486322\tbest: 0.2486322 (31)\ttotal: 172ms\tremaining: 5.21s\n",
952 | "32:\tlearn: 0.2238766\ttest: 0.2457850\tbest: 0.2457850 (32)\ttotal: 175ms\tremaining: 5.12s\n",
953 | "33:\tlearn: 0.2197826\ttest: 0.2419484\tbest: 0.2419484 (33)\ttotal: 178ms\tremaining: 5.05s\n",
954 | "34:\tlearn: 0.2164592\ttest: 0.2386139\tbest: 0.2386139 (34)\ttotal: 181ms\tremaining: 4.99s\n",
955 | "35:\tlearn: 0.2126253\ttest: 0.2345415\tbest: 0.2345415 (35)\ttotal: 184ms\tremaining: 4.93s\n",
956 | "36:\tlearn: 0.2091321\ttest: 0.2315528\tbest: 0.2315528 (36)\ttotal: 187ms\tremaining: 4.87s\n",
957 | "37:\tlearn: 0.2062252\ttest: 0.2289329\tbest: 0.2289329 (37)\ttotal: 190ms\tremaining: 4.8s\n",
958 | "38:\tlearn: 0.2037497\ttest: 0.2266702\tbest: 0.2266702 (38)\ttotal: 193ms\tremaining: 4.75s\n",
959 | "39:\tlearn: 0.2009984\ttest: 0.2239214\tbest: 0.2239214 (39)\ttotal: 196ms\tremaining: 4.69s\n",
960 | "40:\tlearn: 0.1989320\ttest: 0.2221738\tbest: 0.2221738 (40)\ttotal: 198ms\tremaining: 4.64s\n",
961 | "41:\tlearn: 0.1964169\ttest: 0.2201016\tbest: 0.2201016 (41)\ttotal: 201ms\tremaining: 4.59s\n",
962 | "42:\tlearn: 0.1946265\ttest: 0.2182654\tbest: 0.2182654 (42)\ttotal: 204ms\tremaining: 4.54s\n",
963 | "43:\tlearn: 0.1925528\ttest: 0.2164779\tbest: 0.2164779 (43)\ttotal: 207ms\tremaining: 4.5s\n",
964 | "44:\tlearn: 0.1906966\ttest: 0.2145020\tbest: 0.2145020 (44)\ttotal: 210ms\tremaining: 4.46s\n",
965 | "45:\tlearn: 0.1884299\ttest: 0.2120531\tbest: 0.2120531 (45)\ttotal: 213ms\tremaining: 4.42s\n",
966 | "46:\tlearn: 0.1871897\ttest: 0.2111221\tbest: 0.2111221 (46)\ttotal: 216ms\tremaining: 4.38s\n",
967 | "47:\tlearn: 0.1859624\ttest: 0.2102293\tbest: 0.2102293 (47)\ttotal: 219ms\tremaining: 4.35s\n",
968 | "48:\tlearn: 0.1844340\ttest: 0.2088910\tbest: 0.2088910 (48)\ttotal: 223ms\tremaining: 4.33s\n",
969 | "49:\tlearn: 0.1823357\ttest: 0.2068230\tbest: 0.2068230 (49)\ttotal: 226ms\tremaining: 4.29s\n",
970 | "50:\tlearn: 0.1806636\ttest: 0.2050706\tbest: 0.2050706 (50)\ttotal: 229ms\tremaining: 4.26s\n",
971 | "51:\tlearn: 0.1787616\ttest: 0.2029668\tbest: 0.2029668 (51)\ttotal: 231ms\tremaining: 4.22s\n",
972 | "52:\tlearn: 0.1773778\ttest: 0.2014285\tbest: 0.2014285 (52)\ttotal: 236ms\tremaining: 4.21s\n",
973 | "53:\tlearn: 0.1755366\ttest: 0.1996299\tbest: 0.1996299 (53)\ttotal: 239ms\tremaining: 4.18s\n",
974 | "54:\tlearn: 0.1741232\ttest: 0.1984643\tbest: 0.1984643 (54)\ttotal: 242ms\tremaining: 4.15s\n",
975 | "55:\tlearn: 0.1725883\ttest: 0.1966571\tbest: 0.1966571 (55)\ttotal: 245ms\tremaining: 4.12s\n",
976 | "56:\tlearn: 0.1710051\ttest: 0.1955276\tbest: 0.1955276 (56)\ttotal: 248ms\tremaining: 4.1s\n",
977 | "57:\tlearn: 0.1698981\ttest: 0.1942014\tbest: 0.1942014 (57)\ttotal: 251ms\tremaining: 4.08s\n",
978 | "58:\tlearn: 0.1684351\ttest: 0.1929899\tbest: 0.1929899 (58)\ttotal: 254ms\tremaining: 4.06s\n",
979 | "59:\tlearn: 0.1673545\ttest: 0.1921834\tbest: 0.1921834 (59)\ttotal: 258ms\tremaining: 4.04s\n",
980 | "60:\tlearn: 0.1661590\ttest: 0.1911501\tbest: 0.1911501 (60)\ttotal: 261ms\tremaining: 4.02s\n",
981 | "61:\tlearn: 0.1650260\ttest: 0.1899028\tbest: 0.1899028 (61)\ttotal: 264ms\tremaining: 4s\n",
982 | "62:\tlearn: 0.1638037\ttest: 0.1885839\tbest: 0.1885839 (62)\ttotal: 268ms\tremaining: 3.99s\n",
983 | "63:\tlearn: 0.1624676\ttest: 0.1877544\tbest: 0.1877544 (63)\ttotal: 272ms\tremaining: 3.98s\n",
984 | "64:\tlearn: 0.1610981\ttest: 0.1865590\tbest: 0.1865590 (64)\ttotal: 275ms\tremaining: 3.96s\n",
985 | "65:\tlearn: 0.1604604\ttest: 0.1865760\tbest: 0.1865590 (64)\ttotal: 278ms\tremaining: 3.94s\n",
986 | "66:\tlearn: 0.1589032\ttest: 0.1855309\tbest: 0.1855309 (66)\ttotal: 282ms\tremaining: 3.93s\n",
987 | "67:\tlearn: 0.1581581\ttest: 0.1847642\tbest: 0.1847642 (67)\ttotal: 285ms\tremaining: 3.91s\n",
988 | "68:\tlearn: 0.1572876\ttest: 0.1842998\tbest: 0.1842998 (68)\ttotal: 288ms\tremaining: 3.89s\n",
989 | "69:\tlearn: 0.1561982\ttest: 0.1833719\tbest: 0.1833719 (69)\ttotal: 291ms\tremaining: 3.87s\n",
990 | "70:\tlearn: 0.1550831\ttest: 0.1825030\tbest: 0.1825030 (70)\ttotal: 296ms\tremaining: 3.87s\n",
991 | "71:\tlearn: 0.1542308\ttest: 0.1817230\tbest: 0.1817230 (71)\ttotal: 298ms\tremaining: 3.85s\n",
992 | "72:\tlearn: 0.1535367\ttest: 0.1812467\tbest: 0.1812467 (72)\ttotal: 302ms\tremaining: 3.83s\n",
993 | "73:\tlearn: 0.1529628\ttest: 0.1810243\tbest: 0.1810243 (73)\ttotal: 305ms\tremaining: 3.81s\n",
994 | "74:\tlearn: 0.1520269\ttest: 0.1802059\tbest: 0.1802059 (74)\ttotal: 308ms\tremaining: 3.8s\n",
995 | "75:\tlearn: 0.1508695\ttest: 0.1787649\tbest: 0.1787649 (75)\ttotal: 311ms\tremaining: 3.78s\n",
996 | "76:\tlearn: 0.1498385\ttest: 0.1776925\tbest: 0.1776925 (76)\ttotal: 315ms\tremaining: 3.77s\n",
997 | "77:\tlearn: 0.1490603\ttest: 0.1768350\tbest: 0.1768350 (77)\ttotal: 318ms\tremaining: 3.76s\n",
998 | "78:\tlearn: 0.1482503\ttest: 0.1761571\tbest: 0.1761571 (78)\ttotal: 321ms\tremaining: 3.75s\n",
999 | "79:\tlearn: 0.1473683\ttest: 0.1753244\tbest: 0.1753244 (79)\ttotal: 324ms\tremaining: 3.73s\n",
1000 | "80:\tlearn: 0.1466100\ttest: 0.1745125\tbest: 0.1745125 (80)\ttotal: 327ms\tremaining: 3.71s\n",
1001 | "81:\tlearn: 0.1462307\ttest: 0.1742048\tbest: 0.1742048 (81)\ttotal: 330ms\tremaining: 3.69s\n",
1002 | "82:\tlearn: 0.1460834\ttest: 0.1743457\tbest: 0.1742048 (81)\ttotal: 333ms\tremaining: 3.67s\n",
1003 | "83:\tlearn: 0.1453986\ttest: 0.1739246\tbest: 0.1739246 (83)\ttotal: 335ms\tremaining: 3.66s\n",
1004 | "84:\tlearn: 0.1445232\ttest: 0.1731785\tbest: 0.1731785 (84)\ttotal: 338ms\tremaining: 3.64s\n",
1005 | "85:\tlearn: 0.1439600\ttest: 0.1732023\tbest: 0.1731785 (84)\ttotal: 342ms\tremaining: 3.63s\n",
1006 | "86:\tlearn: 0.1433421\ttest: 0.1728107\tbest: 0.1728107 (86)\ttotal: 345ms\tremaining: 3.62s\n",
1007 | "87:\tlearn: 0.1427270\ttest: 0.1723026\tbest: 0.1723026 (87)\ttotal: 348ms\tremaining: 3.6s\n",
1008 | "88:\tlearn: 0.1420955\ttest: 0.1717022\tbest: 0.1717022 (88)\ttotal: 351ms\tremaining: 3.59s\n",
1009 | "89:\tlearn: 0.1415378\ttest: 0.1715440\tbest: 0.1715440 (89)\ttotal: 354ms\tremaining: 3.58s\n",
1010 | "90:\tlearn: 0.1412110\ttest: 0.1715418\tbest: 0.1715418 (90)\ttotal: 357ms\tremaining: 3.56s\n",
1011 | "91:\tlearn: 0.1404274\ttest: 0.1708362\tbest: 0.1708362 (91)\ttotal: 360ms\tremaining: 3.55s\n",
1012 | "92:\tlearn: 0.1397125\ttest: 0.1707287\tbest: 0.1707287 (92)\ttotal: 364ms\tremaining: 3.55s\n",
1013 | "93:\tlearn: 0.1393044\ttest: 0.1705506\tbest: 0.1705506 (93)\ttotal: 367ms\tremaining: 3.53s\n",
1014 | "94:\tlearn: 0.1387496\ttest: 0.1699821\tbest: 0.1699821 (94)\ttotal: 369ms\tremaining: 3.51s\n",
1015 | "95:\tlearn: 0.1381773\ttest: 0.1697968\tbest: 0.1697968 (95)\ttotal: 372ms\tremaining: 3.5s\n",
1016 | "96:\tlearn: 0.1376471\ttest: 0.1693315\tbest: 0.1693315 (96)\ttotal: 375ms\tremaining: 3.49s\n",
1017 | "97:\tlearn: 0.1372260\ttest: 0.1694167\tbest: 0.1693315 (96)\ttotal: 378ms\tremaining: 3.48s\n",
1018 | "98:\tlearn: 0.1369108\ttest: 0.1691435\tbest: 0.1691435 (98)\ttotal: 380ms\tremaining: 3.46s\n",
1019 | "99:\tlearn: 0.1365988\ttest: 0.1690779\tbest: 0.1690779 (99)\ttotal: 383ms\tremaining: 3.45s\n",
1020 | "100:\tlearn: 0.1361507\ttest: 0.1685021\tbest: 0.1685021 (100)\ttotal: 386ms\tremaining: 3.44s\n",
1021 | "101:\tlearn: 0.1354671\ttest: 0.1681311\tbest: 0.1681311 (101)\ttotal: 389ms\tremaining: 3.42s\n",
1022 | "102:\tlearn: 0.1348790\ttest: 0.1677541\tbest: 0.1677541 (102)\ttotal: 393ms\tremaining: 3.42s\n",
1023 | "103:\tlearn: 0.1342635\ttest: 0.1669889\tbest: 0.1669889 (103)\ttotal: 395ms\tremaining: 3.41s\n",
1024 | "104:\tlearn: 0.1339653\ttest: 0.1668001\tbest: 0.1668001 (104)\ttotal: 400ms\tremaining: 3.41s\n",
1025 | "105:\tlearn: 0.1333219\ttest: 0.1665222\tbest: 0.1665222 (105)\ttotal: 403ms\tremaining: 3.4s\n",
1026 | "106:\tlearn: 0.1327973\ttest: 0.1662342\tbest: 0.1662342 (106)\ttotal: 408ms\tremaining: 3.4s\n",
1027 | "107:\tlearn: 0.1322664\ttest: 0.1661034\tbest: 0.1661034 (107)\ttotal: 410ms\tremaining: 3.39s\n",
1028 | "108:\tlearn: 0.1317347\ttest: 0.1658443\tbest: 0.1658443 (108)\ttotal: 413ms\tremaining: 3.38s\n",
1029 | "109:\tlearn: 0.1310987\ttest: 0.1652628\tbest: 0.1652628 (109)\ttotal: 416ms\tremaining: 3.37s\n",
1030 | "110:\tlearn: 0.1305917\ttest: 0.1649934\tbest: 0.1649934 (110)\ttotal: 419ms\tremaining: 3.36s\n",
1031 | "111:\tlearn: 0.1305812\ttest: 0.1650006\tbest: 0.1649934 (110)\ttotal: 422ms\tremaining: 3.34s\n",
1032 | "112:\tlearn: 0.1301225\ttest: 0.1647297\tbest: 0.1647297 (112)\ttotal: 425ms\tremaining: 3.34s\n",
1033 | "113:\tlearn: 0.1296899\ttest: 0.1643031\tbest: 0.1643031 (113)\ttotal: 428ms\tremaining: 3.32s\n",
1034 | "114:\tlearn: 0.1294766\ttest: 0.1642072\tbest: 0.1642072 (114)\ttotal: 431ms\tremaining: 3.32s\n",
1035 | "115:\tlearn: 0.1291889\ttest: 0.1640454\tbest: 0.1640454 (115)\ttotal: 434ms\tremaining: 3.3s\n",
1036 | "116:\tlearn: 0.1291278\ttest: 0.1640400\tbest: 0.1640400 (116)\ttotal: 436ms\tremaining: 3.29s\n",
1037 | "117:\tlearn: 0.1288480\ttest: 0.1639734\tbest: 0.1639734 (117)\ttotal: 439ms\tremaining: 3.28s\n",
1038 | "118:\tlearn: 0.1286497\ttest: 0.1639575\tbest: 0.1639575 (118)\ttotal: 442ms\tremaining: 3.27s\n",
1039 | "119:\tlearn: 0.1282663\ttest: 0.1637441\tbest: 0.1637441 (119)\ttotal: 445ms\tremaining: 3.27s\n",
1040 | "120:\tlearn: 0.1281113\ttest: 0.1637939\tbest: 0.1637441 (119)\ttotal: 449ms\tremaining: 3.26s\n",
1041 | "121:\tlearn: 0.1276841\ttest: 0.1633714\tbest: 0.1633714 (121)\ttotal: 452ms\tremaining: 3.25s\n",
1042 | "122:\tlearn: 0.1274058\ttest: 0.1633725\tbest: 0.1633714 (121)\ttotal: 455ms\tremaining: 3.24s\n",
1043 | "123:\tlearn: 0.1269155\ttest: 0.1634801\tbest: 0.1633714 (121)\ttotal: 458ms\tremaining: 3.23s\n",
1044 | "124:\tlearn: 0.1268380\ttest: 0.1635373\tbest: 0.1633714 (121)\ttotal: 462ms\tremaining: 3.23s\n",
1045 | "125:\tlearn: 0.1267048\ttest: 0.1635890\tbest: 0.1633714 (121)\ttotal: 465ms\tremaining: 3.23s\n",
1046 | "126:\tlearn: 0.1262448\ttest: 0.1629872\tbest: 0.1629872 (126)\ttotal: 470ms\tremaining: 3.23s\n",
1047 | "127:\tlearn: 0.1258275\ttest: 0.1625769\tbest: 0.1625769 (127)\ttotal: 473ms\tremaining: 3.22s\n",
1048 | "128:\tlearn: 0.1253859\ttest: 0.1621351\tbest: 0.1621351 (128)\ttotal: 477ms\tremaining: 3.22s\n",
1049 | "129:\tlearn: 0.1249703\ttest: 0.1619379\tbest: 0.1619379 (129)\ttotal: 481ms\tremaining: 3.22s\n",
1050 | "130:\tlearn: 0.1247342\ttest: 0.1616268\tbest: 0.1616268 (130)\ttotal: 485ms\tremaining: 3.22s\n",
1051 | "131:\tlearn: 0.1244258\ttest: 0.1614230\tbest: 0.1614230 (131)\ttotal: 488ms\tremaining: 3.21s\n",
1052 | "132:\tlearn: 0.1236921\ttest: 0.1611347\tbest: 0.1611347 (132)\ttotal: 492ms\tremaining: 3.2s\n",
1053 | "133:\tlearn: 0.1234099\ttest: 0.1610724\tbest: 0.1610724 (133)\ttotal: 495ms\tremaining: 3.2s\n",
1054 | "134:\tlearn: 0.1233218\ttest: 0.1609980\tbest: 0.1609980 (134)\ttotal: 497ms\tremaining: 3.19s\n",
1055 | "135:\tlearn: 0.1231105\ttest: 0.1609405\tbest: 0.1609405 (135)\ttotal: 501ms\tremaining: 3.18s\n",
1056 | "136:\tlearn: 0.1230489\ttest: 0.1608900\tbest: 0.1608900 (136)\ttotal: 503ms\tremaining: 3.17s\n",
1057 | "137:\tlearn: 0.1224238\ttest: 0.1606435\tbest: 0.1606435 (137)\ttotal: 507ms\tremaining: 3.17s\n",
1058 | "138:\tlearn: 0.1221154\ttest: 0.1606038\tbest: 0.1606038 (138)\ttotal: 510ms\tremaining: 3.16s\n",
1059 | "139:\tlearn: 0.1218823\ttest: 0.1604324\tbest: 0.1604324 (139)\ttotal: 514ms\tremaining: 3.16s\n",
1060 | "140:\tlearn: 0.1217761\ttest: 0.1602990\tbest: 0.1602990 (140)\ttotal: 517ms\tremaining: 3.15s\n",
1061 | "141:\tlearn: 0.1215440\ttest: 0.1601603\tbest: 0.1601603 (141)\ttotal: 520ms\tremaining: 3.14s\n",
1062 | "142:\tlearn: 0.1212282\ttest: 0.1599872\tbest: 0.1599872 (142)\ttotal: 523ms\tremaining: 3.14s\n",
1063 | "143:\tlearn: 0.1210965\ttest: 0.1599391\tbest: 0.1599391 (143)\ttotal: 526ms\tremaining: 3.13s\n",
1064 | "144:\tlearn: 0.1206828\ttest: 0.1596531\tbest: 0.1596531 (144)\ttotal: 530ms\tremaining: 3.12s\n",
1065 | "145:\tlearn: 0.1203065\ttest: 0.1597038\tbest: 0.1596531 (144)\ttotal: 533ms\tremaining: 3.12s\n",
1066 | "146:\tlearn: 0.1201046\ttest: 0.1594982\tbest: 0.1594982 (146)\ttotal: 536ms\tremaining: 3.11s\n",
1067 | "147:\tlearn: 0.1199300\ttest: 0.1595389\tbest: 0.1594982 (146)\ttotal: 540ms\tremaining: 3.1s\n",
1068 | "148:\tlearn: 0.1195486\ttest: 0.1594501\tbest: 0.1594501 (148)\ttotal: 542ms\tremaining: 3.1s\n",
1069 | "149:\tlearn: 0.1190920\ttest: 0.1597642\tbest: 0.1594501 (148)\ttotal: 546ms\tremaining: 3.09s\n",
1070 | "150:\tlearn: 0.1185979\ttest: 0.1595261\tbest: 0.1594501 (148)\ttotal: 548ms\tremaining: 3.08s\n",
1071 | "151:\tlearn: 0.1183069\ttest: 0.1592414\tbest: 0.1592414 (151)\ttotal: 551ms\tremaining: 3.07s\n",
1072 | "152:\tlearn: 0.1181374\ttest: 0.1592164\tbest: 0.1592164 (152)\ttotal: 554ms\tremaining: 3.06s\n",
1073 | "153:\tlearn: 0.1177571\ttest: 0.1589785\tbest: 0.1589785 (153)\ttotal: 557ms\tremaining: 3.06s\n",
1074 | "154:\tlearn: 0.1173531\ttest: 0.1588217\tbest: 0.1588217 (154)\ttotal: 561ms\tremaining: 3.06s\n",
1075 | "155:\tlearn: 0.1170582\ttest: 0.1587787\tbest: 0.1587787 (155)\ttotal: 563ms\tremaining: 3.05s\n",
1076 | "156:\tlearn: 0.1167705\ttest: 0.1586385\tbest: 0.1586385 (156)\ttotal: 569ms\tremaining: 3.05s\n",
1077 | "157:\tlearn: 0.1164068\ttest: 0.1585486\tbest: 0.1585486 (157)\ttotal: 572ms\tremaining: 3.04s\n",
1078 | "158:\tlearn: 0.1163820\ttest: 0.1585151\tbest: 0.1585151 (158)\ttotal: 574ms\tremaining: 3.03s\n",
1079 | "159:\tlearn: 0.1162444\ttest: 0.1585091\tbest: 0.1585091 (159)\ttotal: 577ms\tremaining: 3.03s\n",
1080 | "160:\tlearn: 0.1161463\ttest: 0.1584859\tbest: 0.1584859 (160)\ttotal: 579ms\tremaining: 3.02s\n",
1081 | "161:\tlearn: 0.1157391\ttest: 0.1580840\tbest: 0.1580840 (161)\ttotal: 582ms\tremaining: 3.01s\n",
1082 | "162:\tlearn: 0.1156908\ttest: 0.1580406\tbest: 0.1580406 (162)\ttotal: 586ms\tremaining: 3.01s\n",
1083 | "163:\tlearn: 0.1153443\ttest: 0.1578388\tbest: 0.1578388 (163)\ttotal: 588ms\tremaining: 3s\n",
1084 | "164:\tlearn: 0.1150407\ttest: 0.1576972\tbest: 0.1576972 (164)\ttotal: 591ms\tremaining: 2.99s\n",
1085 | "165:\tlearn: 0.1148386\ttest: 0.1577687\tbest: 0.1576972 (164)\ttotal: 595ms\tremaining: 2.99s\n",
1086 | "166:\tlearn: 0.1143723\ttest: 0.1572801\tbest: 0.1572801 (166)\ttotal: 599ms\tremaining: 2.99s\n",
1087 | "167:\tlearn: 0.1141279\ttest: 0.1574169\tbest: 0.1572801 (166)\ttotal: 602ms\tremaining: 2.98s\n",
1088 | "168:\tlearn: 0.1139445\ttest: 0.1573275\tbest: 0.1572801 (166)\ttotal: 605ms\tremaining: 2.97s\n",
1089 | "169:\tlearn: 0.1135710\ttest: 0.1570117\tbest: 0.1570117 (169)\ttotal: 608ms\tremaining: 2.97s\n",
1090 | "170:\tlearn: 0.1134066\ttest: 0.1569465\tbest: 0.1569465 (170)\ttotal: 612ms\tremaining: 2.96s\n",
1091 | "171:\tlearn: 0.1132348\ttest: 0.1569808\tbest: 0.1569465 (170)\ttotal: 615ms\tremaining: 2.96s\n",
1092 | "172:\tlearn: 0.1129800\ttest: 0.1570577\tbest: 0.1569465 (170)\ttotal: 618ms\tremaining: 2.95s\n",
1093 | "173:\tlearn: 0.1126281\ttest: 0.1568818\tbest: 0.1568818 (173)\ttotal: 623ms\tremaining: 2.96s\n",
1094 | "174:\tlearn: 0.1124995\ttest: 0.1568308\tbest: 0.1568308 (174)\ttotal: 625ms\tremaining: 2.95s\n",
1095 | "175:\tlearn: 0.1123167\ttest: 0.1567765\tbest: 0.1567765 (175)\ttotal: 628ms\tremaining: 2.94s\n",
1096 | "176:\tlearn: 0.1122578\ttest: 0.1568048\tbest: 0.1567765 (175)\ttotal: 633ms\tremaining: 2.94s\n",
1097 | "177:\tlearn: 0.1120069\ttest: 0.1567596\tbest: 0.1567596 (177)\ttotal: 636ms\tremaining: 2.94s\n",
1098 | "178:\tlearn: 0.1118183\ttest: 0.1567967\tbest: 0.1567596 (177)\ttotal: 640ms\tremaining: 2.93s\n",
1099 | "179:\tlearn: 0.1116753\ttest: 0.1566245\tbest: 0.1566245 (179)\ttotal: 645ms\tremaining: 2.94s\n",
1100 | "180:\tlearn: 0.1114712\ttest: 0.1565494\tbest: 0.1565494 (180)\ttotal: 648ms\tremaining: 2.93s\n",
1101 | "181:\tlearn: 0.1110647\ttest: 0.1563492\tbest: 0.1563492 (181)\ttotal: 653ms\tremaining: 2.94s\n",
1102 | "182:\tlearn: 0.1109584\ttest: 0.1563476\tbest: 0.1563476 (182)\ttotal: 656ms\tremaining: 2.93s\n",
1103 | "183:\tlearn: 0.1108534\ttest: 0.1563045\tbest: 0.1563045 (183)\ttotal: 661ms\tremaining: 2.93s\n",
1104 | "184:\tlearn: 0.1104813\ttest: 0.1561775\tbest: 0.1561775 (184)\ttotal: 664ms\tremaining: 2.92s\n",
1105 | "185:\tlearn: 0.1100818\ttest: 0.1557271\tbest: 0.1557271 (185)\ttotal: 666ms\tremaining: 2.92s\n",
1106 | "186:\tlearn: 0.1099489\ttest: 0.1557738\tbest: 0.1557271 (185)\ttotal: 669ms\tremaining: 2.91s\n",
1107 | "187:\tlearn: 0.1097746\ttest: 0.1559275\tbest: 0.1557271 (185)\ttotal: 672ms\tremaining: 2.9s\n",
1108 | "188:\tlearn: 0.1094178\ttest: 0.1557098\tbest: 0.1557098 (188)\ttotal: 676ms\tremaining: 2.9s\n",
1109 | "189:\tlearn: 0.1091134\ttest: 0.1553536\tbest: 0.1553536 (189)\ttotal: 678ms\tremaining: 2.89s\n",
1110 | "190:\tlearn: 0.1089026\ttest: 0.1553270\tbest: 0.1553270 (190)\ttotal: 681ms\tremaining: 2.88s\n",
1111 | "191:\tlearn: 0.1088370\ttest: 0.1552442\tbest: 0.1552442 (191)\ttotal: 684ms\tremaining: 2.88s\n",
1112 | "192:\tlearn: 0.1085390\ttest: 0.1552017\tbest: 0.1552017 (192)\ttotal: 687ms\tremaining: 2.87s\n",
1113 | "193:\tlearn: 0.1083417\ttest: 0.1552433\tbest: 0.1552017 (192)\ttotal: 689ms\tremaining: 2.86s\n",
1114 | "194:\tlearn: 0.1081520\ttest: 0.1553490\tbest: 0.1552017 (192)\ttotal: 693ms\tremaining: 2.86s\n",
1115 | "195:\tlearn: 0.1078712\ttest: 0.1555736\tbest: 0.1552017 (192)\ttotal: 696ms\tremaining: 2.85s\n",
1116 | "196:\tlearn: 0.1074892\ttest: 0.1556243\tbest: 0.1552017 (192)\ttotal: 699ms\tremaining: 2.85s\n",
1117 | "197:\tlearn: 0.1074879\ttest: 0.1556146\tbest: 0.1552017 (192)\ttotal: 701ms\tremaining: 2.84s\n",
1118 | "198:\tlearn: 0.1073743\ttest: 0.1555928\tbest: 0.1552017 (192)\ttotal: 703ms\tremaining: 2.83s\n",
1119 | "199:\tlearn: 0.1071665\ttest: 0.1554248\tbest: 0.1552017 (192)\ttotal: 706ms\tremaining: 2.82s\n",
1120 | "200:\tlearn: 0.1070061\ttest: 0.1554337\tbest: 0.1552017 (192)\ttotal: 709ms\tremaining: 2.82s\n",
1121 | "201:\tlearn: 0.1067106\ttest: 0.1552604\tbest: 0.1552017 (192)\ttotal: 712ms\tremaining: 2.81s\n",
1122 | "202:\tlearn: 0.1061867\ttest: 0.1547731\tbest: 0.1547731 (202)\ttotal: 715ms\tremaining: 2.81s\n",
1123 | "203:\tlearn: 0.1061642\ttest: 0.1547586\tbest: 0.1547586 (203)\ttotal: 716ms\tremaining: 2.79s\n",
1124 | "204:\tlearn: 0.1058801\ttest: 0.1544104\tbest: 0.1544104 (204)\ttotal: 719ms\tremaining: 2.79s\n",
1125 | "205:\tlearn: 0.1056995\ttest: 0.1543959\tbest: 0.1543959 (205)\ttotal: 722ms\tremaining: 2.78s\n",
1126 | "206:\tlearn: 0.1053829\ttest: 0.1540443\tbest: 0.1540443 (206)\ttotal: 725ms\tremaining: 2.78s\n",
1127 | "207:\tlearn: 0.1050172\ttest: 0.1537433\tbest: 0.1537433 (207)\ttotal: 730ms\tremaining: 2.78s\n",
1128 | "208:\tlearn: 0.1046315\ttest: 0.1538177\tbest: 0.1537433 (207)\ttotal: 735ms\tremaining: 2.78s\n",
1129 | "209:\tlearn: 0.1042301\ttest: 0.1540596\tbest: 0.1537433 (207)\ttotal: 738ms\tremaining: 2.78s\n",
1130 | "210:\tlearn: 0.1038595\ttest: 0.1542492\tbest: 0.1537433 (207)\ttotal: 741ms\tremaining: 2.77s\n",
1131 | "211:\tlearn: 0.1038453\ttest: 0.1542450\tbest: 0.1537433 (207)\ttotal: 744ms\tremaining: 2.76s\n",
1132 | "212:\tlearn: 0.1037689\ttest: 0.1542358\tbest: 0.1537433 (207)\ttotal: 748ms\tremaining: 2.76s\n",
1133 | "213:\tlearn: 0.1035594\ttest: 0.1541500\tbest: 0.1537433 (207)\ttotal: 752ms\tremaining: 2.76s\n",
1134 | "214:\tlearn: 0.1032968\ttest: 0.1541669\tbest: 0.1537433 (207)\ttotal: 756ms\tremaining: 2.76s\n",
1135 | "215:\tlearn: 0.1031981\ttest: 0.1541580\tbest: 0.1537433 (207)\ttotal: 759ms\tremaining: 2.75s\n",
1136 | "216:\tlearn: 0.1029143\ttest: 0.1540810\tbest: 0.1537433 (207)\ttotal: 763ms\tremaining: 2.75s\n",
1137 | "217:\tlearn: 0.1027793\ttest: 0.1541172\tbest: 0.1537433 (207)\ttotal: 766ms\tremaining: 2.75s\n",
1138 | "218:\tlearn: 0.1026013\ttest: 0.1541685\tbest: 0.1537433 (207)\ttotal: 772ms\tremaining: 2.75s\n",
1139 | "219:\tlearn: 0.1024922\ttest: 0.1540507\tbest: 0.1537433 (207)\ttotal: 774ms\tremaining: 2.75s\n",
1140 | "220:\tlearn: 0.1024123\ttest: 0.1540590\tbest: 0.1537433 (207)\ttotal: 777ms\tremaining: 2.74s\n",
1141 | "221:\tlearn: 0.1019649\ttest: 0.1539730\tbest: 0.1537433 (207)\ttotal: 780ms\tremaining: 2.73s\n",
1142 | "222:\tlearn: 0.1015602\ttest: 0.1540056\tbest: 0.1537433 (207)\ttotal: 783ms\tremaining: 2.73s\n",
1143 | "223:\tlearn: 0.1014548\ttest: 0.1539750\tbest: 0.1537433 (207)\ttotal: 787ms\tremaining: 2.73s\n",
1144 | "224:\tlearn: 0.1013972\ttest: 0.1538995\tbest: 0.1537433 (207)\ttotal: 790ms\tremaining: 2.72s\n",
1145 | "225:\tlearn: 0.1013235\ttest: 0.1538275\tbest: 0.1537433 (207)\ttotal: 793ms\tremaining: 2.72s\n",
1146 | "226:\tlearn: 0.1011391\ttest: 0.1538708\tbest: 0.1537433 (207)\ttotal: 796ms\tremaining: 2.71s\n",
1147 | "227:\tlearn: 0.1011154\ttest: 0.1538654\tbest: 0.1537433 (207)\ttotal: 800ms\tremaining: 2.71s\n",
1148 | "Stopped by overfitting detector (20 iterations wait)\n",
1149 | "\n",
1150 | "bestTest = 0.1537432969\n",
1151 | "bestIteration = 207\n",
1152 | "\n",
1153 | "Shrink model to first 208 iterations.\n"
1154 | ]
1155 | }
1156 | ],
1157 | "source": [
1158 | "if params['problem_type'] == 'regression':\n",
1159 | " from catboost import CatBoostRegressor, Pool\n",
1160 | "\n",
1161 | " model_catboost = CatBoostRegressor(n_estimators=1000, \n",
1162 | " early_stopping_rounds=20)\n",
1163 | " train_data = Pool(\n",
1164 | " data=X_train_raw,\n",
1165 | " label=y_train,\n",
1166 | " cat_features=categorical_feature_indices,\n",
1167 | " )\n",
1168 | "\n",
1169 | " eval_data = Pool(\n",
1170 | " data=X_valid_raw,\n",
1171 | " label=y_valid,\n",
1172 | " cat_features=categorical_feature_indices,\n",
1173 | " )\n",
1174 | "\n",
1175 | " model_catboost.fit(X=train_data, \n",
1176 | " eval_set=eval_data)\n",
1177 | "\n",
1178 | "\n",
1179 | "\n",
1180 | "\n",
1181 | " preds_catboost = model_catboost.predict(X_test_raw)\n",
1182 | "else:\n",
1183 | " from catboost import CatBoostClassifier, Pool\n",
1184 | "\n",
1185 | " model_catboost = CatBoostClassifier(n_estimators=1000, \n",
1186 | " early_stopping_rounds=20)\n",
1187 | " train_data = Pool(\n",
1188 | " data=X_train_raw,\n",
1189 | " label=y_train,\n",
1190 | " cat_features=categorical_feature_indices,\n",
1191 | " #weight=calculate_sample_weights(y_train)\n",
1192 | " )\n",
1193 | "\n",
1194 | " eval_data = Pool(\n",
1195 | " data=X_valid_raw,\n",
1196 | " label=y_valid,\n",
1197 | " cat_features=categorical_feature_indices,\n",
1198 | " #weight=calculate_sample_weights(y_valid),\n",
1199 | " )\n",
1200 | "\n",
1201 | " model_catboost.fit(X=train_data, \n",
1202 | " eval_set=eval_data)\n",
1203 | "\n",
1204 | "\n",
1205 | "\n",
1206 | " preds_catboost = model_catboost.predict_proba(X_test_raw)\n"
1207 | ]
1208 | },
1209 | {
1210 | "cell_type": "code",
1211 | "execution_count": 12,
1212 | "id": "21b9be01-ba30-4300-b3e0-f4fd43b55c87",
1213 | "metadata": {
1214 | "execution": {
1215 | "iopub.execute_input": "2024-03-15T11:57:52.429019Z",
1216 | "iopub.status.busy": "2024-03-15T11:57:52.428555Z",
1217 | "iopub.status.idle": "2024-03-15T11:57:52.454151Z",
1218 | "shell.execute_reply": "2024-03-15T11:57:52.453620Z",
1219 | "shell.execute_reply.started": "2024-03-15T11:57:52.428990Z"
1220 | }
1221 | },
1222 | "outputs": [
1223 | {
1224 | "name": "stdout",
1225 | "output_type": "stream",
1226 | "text": [
1227 | "Accuracy GRANDE: 0.954\n",
1228 | "F1 Score GRANDE: 0.8976494984825425\n",
1229 | "ROC AUC GRANDE: 0.922450889462646\n",
1230 | "\n",
1231 | "\n",
1232 | "Accuracy XGB: 0.952\n",
1233 | "F1 Score XGB: 0.8902857142857143\n",
1234 | "ROC AUC XGB: 0.9152775340703048\n",
1235 | "\n",
1236 | "\n",
1237 | "Accuracy CatBoost: 0.957\n",
1238 | "F1 Score CatBoost: 0.8999941857084715\n",
1239 | "ROC AUC CatBoost: 0.9137693329656831\n",
1240 | "\n",
1241 | "\n"
1242 | ]
1243 | }
1244 | ],
1245 | "source": [
1246 | "if params['problem_type'] == 'binary':\n",
1247 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_grande[:,1]))\n",
1248 | " f1_score = sklearn.metrics.f1_score(y_test, np.round(preds_grande[:,1]), average='macro')\n",
1249 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_grande[:,1], average='macro', multi_class='ovo')\n",
1250 | "\n",
1251 | " print('Accuracy GRANDE:', accuracy)\n",
1252 | " print('F1 Score GRANDE:', f1_score)\n",
1253 | " print('ROC AUC GRANDE:', roc_auc)\n",
1254 | " print('\\n')\n",
1255 | "\n",
1256 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_xgb[:,1]))\n",
1257 | " f1_score = sklearn.metrics.f1_score(y_test, np.round(preds_xgb[:,1]), average='macro')\n",
1258 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_xgb[:,1], average='macro', multi_class='ovo')\n",
1259 | "\n",
1260 | " print('Accuracy XGB:', accuracy)\n",
1261 | " print('F1 Score XGB:', f1_score)\n",
1262 | " print('ROC AUC XGB:', roc_auc)\n",
1263 | " print('\\n')\n",
1264 | "\n",
1265 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds_catboost[:,1]))\n",
1266 | " f1_score = sklearn.metrics.f1_score(y_test, np.round(preds_catboost[:,1]), average='macro')\n",
1267 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_catboost[:,1], average='macro', multi_class='ovo')\n",
1268 | "\n",
1269 | " print('Accuracy CatBoost:', accuracy)\n",
1270 | " print('F1 Score CatBoost:', f1_score)\n",
1271 | " print('ROC AUC CatBoost:', roc_auc)\n",
1272 | " print('\\n')\n",
1273 | "elif params['problem_type'] == 'multiclass':\n",
1274 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.argmax(preds_grande, axis=1))\n",
1275 | " f1_score = sklearn.metrics.f1_score(y_test, np.argmax(preds_grande, axis=1), average='macro')\n",
1276 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_grande, average='macro', multi_class='ovo', labels=[i for i in range(preds_grande.shape[1])])\n",
1277 | "\n",
1278 | " print('Accuracy GRANDE:', accuracy)\n",
1279 | " print('F1 Score GRANDE:', f1_score)\n",
1280 | " print('ROC AUC GRANDE:', roc_auc)\n",
1281 | " print('\\n')\n",
1282 | "\n",
1283 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.argmax(preds_xgb, axis=1))\n",
1284 | " f1_score = sklearn.metrics.f1_score(y_test, np.argmax(preds_xgb, axis=1), average='macro')\n",
1285 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_xgb, average='macro', multi_class='ovo', labels=[i for i in range(preds_grande.shape[1])])\n",
1286 | "\n",
1287 | " print('Accuracy XGB:', accuracy)\n",
1288 | " print('F1 Score XGB:', f1_score)\n",
1289 | " print('ROC AUC XGB:', roc_auc)\n",
1290 | " print('\\n')\n",
1291 | "\n",
1292 | " accuracy = sklearn.metrics.accuracy_score(y_test, np.argmax(preds_catboost, axis=1))\n",
1293 | " f1_score = sklearn.metrics.f1_score(y_test, np.argmax(preds_catboost, axis=1), average='macro')\n",
1294 | " roc_auc = sklearn.metrics.roc_auc_score(y_test, preds_catboost, average='macro', multi_class='ovo', labels=[i for i in range(preds_grande.shape[1])])\n",
1295 | "\n",
1296 | " print('Accuracy CatBoost:', accuracy)\n",
1297 | " print('F1 Score CatBoost:', f1_score)\n",
1298 | " print('ROC AUC CatBoost:', roc_auc)\n",
1299 | " print('\\n')\n",
1300 | "else:\n",
1301 | " mean_absolute_error = sklearn.metrics.mean_absolute_error(y_test, preds_grande)\n",
1302 | " root_mean_squared_error = np.sqrt(((y_test - preds_grande) ** 2).mean())\n",
1303 | " r2_score = sklearn.metrics.r2_score(y_test, preds_grande)\n",
1304 | "\n",
1305 | " print('MAE GRANDE:', mean_absolute_error)\n",
1306 | " print('RMSE GRANDE:', root_mean_squared_error)\n",
1307 | " print('R2 Score GRANDE:', r2_score)\n",
1308 | " print('\\n')\n",
1309 | "\n",
1310 | " mean_absolute_error = sklearn.metrics.mean_absolute_error(y_test, preds_xgb)\n",
1311 | " root_mean_squared_error = np.sqrt(((y_test - preds_xgb) ** 2).mean())\n",
1312 | " r2_score = sklearn.metrics.r2_score(y_test, preds_xgb)\n",
1313 | "\n",
1314 | " print('MAE XGB:', mean_absolute_error)\n",
1315 | " print('RMSE XGB:', root_mean_squared_error)\n",
1316 | " print('R2 Score XGB:', r2_score)\n",
1317 | " print('\\n')\n",
1318 | "\n",
1319 | " mean_absolute_error = sklearn.metrics.mean_absolute_error(y_test, preds_catboost)\n",
1320 | " root_mean_squared_error = np.sqrt(((y_test - preds_catboost) ** 2).mean())\n",
1321 | " r2_score = sklearn.metrics.r2_score(y_test, preds_catboost)\n",
1322 | "\n",
1323 | " print('MAE CatBoost:', mean_absolute_error)\n",
1324 | " print('RMSE CatBoost:', root_mean_squared_error)\n",
1325 | " print('R2 Score CatBoost:', r2_score)\n",
1326 | " print('\\n')"
1327 | ]
1328 | },
1329 | {
1330 | "cell_type": "code",
1331 | "execution_count": null,
1332 | "id": "c7ef884c-1269-4e0b-8ca7-7ea74c6e380c",
1333 | "metadata": {},
1334 | "outputs": [],
1335 | "source": []
1336 | },
1337 | {
1338 | "cell_type": "code",
1339 | "execution_count": null,
1340 | "id": "2fc02cd3",
1341 | "metadata": {},
1342 | "outputs": [],
1343 | "source": []
1344 | }
1345 | ],
1346 | "metadata": {
1347 | "kernelspec": {
1348 | "display_name": "ReMeDe",
1349 | "language": "python",
1350 | "name": "python3"
1351 | },
1352 | "language_info": {
1353 | "codemirror_mode": {
1354 | "name": "ipython",
1355 | "version": 3
1356 | },
1357 | "file_extension": ".py",
1358 | "mimetype": "text/x-python",
1359 | "name": "python",
1360 | "nbconvert_exporter": "python",
1361 | "pygments_lexer": "ipython3",
1362 | "version": "3.12.9"
1363 | }
1364 | },
1365 | "nbformat": 4,
1366 | "nbformat_minor": 5
1367 | }
1368 |
--------------------------------------------------------------------------------