├── .gitignore ├── COPYING ├── README.md ├── example-regression.py └── grid_search ├── MPINestedGridSearchCV.py └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.py[co] 3 | *.tmp 4 | *.swp 5 | __pycache__ 6 | .idea 7 | -------------------------------------------------------------------------------- /COPYING: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2015 Sebastian Pölsterl. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Nested Cross-Validation for scikit-learn using MPI 2 | 3 | This package provides nested cross-validation similar to 4 | scikit-learn's [GridSearchCV](http://scikit-learn.org/stable/modules/generated/sklearn.grid_search.GridSearchCV.html) 5 | but uses the Message Passing Interface (MPI) for parallel computing. 6 | 7 | ## Requirements 8 | 9 | * [scikit-learn](http://scikit-learn.org) 0.16.0 or later 10 | * [mpi4py](http://mpi4py.scipy.org) 11 | * [pandas](http://pandas.pydata.org) 12 | 13 | ## Example 14 | 15 | ```python 16 | from mpi4py import MPI 17 | import numpy 18 | from sklearn.datasets import load_boston 19 | from sklearn.svm import SVR 20 | from grid_search import NestedGridSearchCV 21 | 22 | data = load_boston() 23 | X = data['data'] 24 | y = data['target'] 25 | 26 | estimator = SVR(max_iter=1000, tol=1e-5) 27 | 28 | param_grid = {'C': 2. ** numpy.arange(-5, 15, 2), 29 | 'gamma': 2. ** numpy.arange(3, -15, -2), 30 | 'kernel': ['poly', 'rbf']} 31 | 32 | nested_cv = NestedGridSearchCV(estimator, param_grid, 'mean_absolute_error', 33 | cv=5, inner_cv=3) 34 | nested_cv.fit(X, y) 35 | 36 | if MPI.COMM_WORLD.Get_rank() == 0: 37 | for i, scores in enumerate(nested_cv.grid_scores_): 38 | scores.to_csv('grid-scores-%d.csv' % (i + 1), index=False) 39 | 40 | print(nested_cv.best_params_) 41 | ``` 42 | 43 | The result should look like this: 44 | 45 | | | score (Validation) | C | gamma | kernel | score (Test) | 46 | | --- | ----------------- | ---: | -----: | ------: | ------------ | 47 | | fold | | | | 48 | | 1 | -7.252490 | 0.5 | 0.000122 | rbf | -4.178257 | 49 | | 2 | -5.662221 | 128.0 | 0.000122 | rbf | -5.445915 | 50 | | 3 | -5.582780 | 32.0 | 0.000122 | rbf | -7.066123 | 51 | | 4 | -6.306561 | 0.5 | 0.000122 | rbf | -6.059503 | 52 | | 5 | -6.174779 | 128.0 | 0.000122 | rbf | -6.606218 | 53 | -------------------------------------------------------------------------------- /example-regression.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy 3 | 4 | from sklearn.cross_validation import StratifiedKFold 5 | from sklearn.preprocessing import StandardScaler 6 | from sklearn.datasets import load_boston 7 | 8 | from grid_search import NestedGridSearchCV 9 | from sklearn.svm import SVR 10 | 11 | data = load_boston() 12 | X = data['data'] 13 | y = data['target'] 14 | X = StandardScaler().fit_transform(X, y) 15 | 16 | estimator = SVR(max_iter=1000, tol=1e-5) 17 | 18 | param_grid = {'C': 2. ** numpy.arange(-5, 15, 2), 19 | 'gamma': 2. ** numpy.arange(3, -15, -2), 20 | 'kernel': ['poly', 'rbf']} 21 | 22 | kfold_cv = StratifiedKFold(y, n_folds=5) 23 | 24 | 25 | logging.basicConfig(level=logging.INFO) 26 | 27 | nested_cv = NestedGridSearchCV(estimator, param_grid, 'mean_absolute_error', cv=kfold_cv, 28 | inner_cv=lambda _x, _y: StratifiedKFold(_y, n_folds=3)) 29 | nested_cv.fit(X, y) 30 | 31 | from mpi4py import MPI 32 | if MPI.COMM_WORLD.Get_rank() == 0: 33 | for i, scores in enumerate(nested_cv.grid_scores_): 34 | scores.to_csv('grid-scores-%d.csv' % (i + 1), index=False) 35 | print("______________") 36 | print(nested_cv.best_params_) 37 | -------------------------------------------------------------------------------- /grid_search/MPINestedGridSearchCV.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy 4 | from mpi4py import MPI 5 | import pandas 6 | from sklearn.base import BaseEstimator, clone 7 | from sklearn.cross_validation import _check_cv, check_scoring, is_classifier, _fit_and_score 8 | from sklearn.grid_search import ParameterGrid, _check_param_grid 9 | from sklearn.utils import check_X_y 10 | 11 | __all__ = ['NestedGridSearchCV'] 12 | 13 | LOG = logging.getLogger(__package__) 14 | 15 | MPI_TAG_RESULT = 3 16 | 17 | MPI_MSG_TERMINATE = 0 18 | MPI_MSG_CV = 1 19 | MPI_MSG_TEST = 2 20 | MPI_TAG_TRAIN_TEST_DATA = 5 21 | 22 | comm = MPI.COMM_WORLD 23 | comm_size = comm.Get_size() 24 | comm_rank = comm.Get_rank() 25 | 26 | 27 | def _get_best_parameters(fold_results, param_names): 28 | """Get best setting of parameters from grid search 29 | 30 | Parameters 31 | ---------- 32 | fold_results : pandas.DataFrame 33 | Contains performance measures as well as hyper-parameters 34 | as columns. Must contain a column 'fold'. 35 | 36 | param_names : list 37 | Names of the hyper-parameters. Each name should be a column 38 | in ``fold_results``. 39 | 40 | Returns 41 | ------- 42 | max_performance : pandas.Series 43 | Maximum performance and its hyper-parameters 44 | """ 45 | if pandas.isnull(fold_results.loc[:, 'score']).all(): 46 | raise ValueError("Results are all NaN") 47 | 48 | # average across inner folds 49 | grouped = fold_results.drop('fold', axis=1).groupby(param_names) 50 | mean_performance = grouped.mean() 51 | # highest average across performance measures 52 | max_idx = mean_performance.loc[:, 'score'].idxmax() 53 | 54 | # best parameters 55 | max_performance = pandas.Series({'score': mean_performance.loc[max_idx, 'score']}) 56 | if len(param_names) == 1: 57 | key = param_names[0] 58 | max_performance[key] = max_idx 59 | else: 60 | # index has multiple levels 61 | for i, name in enumerate(mean_performance.index.names): 62 | max_performance[name] = max_idx[i] 63 | 64 | return max_performance 65 | 66 | 67 | class MPIBatchWorker(object): 68 | """Base class to fit and score an estimator""" 69 | 70 | def __init__(self, estimator, scorer, fit_params, verbose=False): 71 | self.estimator = estimator 72 | self.scorer = scorer 73 | self.verbose = verbose 74 | self.fit_params = fit_params 75 | 76 | # first item denotes ID of fold, and second item encodes 77 | # a message that tells slaves what to do 78 | self._task_desc = numpy.empty(2, dtype=int) 79 | # stores data that the root node broadcasts 80 | self._data_X = None 81 | self._data_y = None 82 | 83 | def process_batch(self, work_batch): 84 | fit_params = self.fit_params if self.fit_params is not None else {} 85 | 86 | LOG.debug("Node %d received %d work items", comm_rank, len(work_batch)) 87 | 88 | results = [] 89 | for fold_id, train_index, test_index, parameters in work_batch: 90 | ret = _fit_and_score(clone(self.estimator), 91 | self._data_X, self._data_y, 92 | self.scorer, train_index, test_index, 93 | self.verbose, parameters, fit_params) 94 | 95 | result = parameters.copy() 96 | result['score'] = ret[0] 97 | result['n_samples_test'] = ret[1] 98 | result['scoring_time'] = ret[2] 99 | result['fold'] = fold_id 100 | results.append(result) 101 | 102 | LOG.debug("Node %d is done with fold %d", comm_rank, fold_id) 103 | return results 104 | 105 | 106 | class MPISlave(MPIBatchWorker): 107 | """Receives task from root node and sends results back""" 108 | 109 | def __init__(self, estimator, scorer, fit_params): 110 | super(MPISlave, self).__init__(estimator, scorer, fit_params) 111 | 112 | def _run_grid_search(self): 113 | # get data 114 | self._data_X, self._data_y = comm.bcast(None, root=0) 115 | # get batch 116 | work_batch = comm.scatter(None, root=0) 117 | 118 | results = self.process_batch(work_batch) 119 | # send result 120 | comm.gather(results, root=0) 121 | 122 | def _run_train_test(self): 123 | # get data 124 | self._data_X, self._data_y = comm.bcast(None, root=0) 125 | 126 | work_item = comm.recv(None, source=0, tag=MPI_TAG_TRAIN_TEST_DATA) 127 | fold_id = work_item[0] 128 | if fold_id == MPI_MSG_TERMINATE: 129 | return 130 | 131 | LOG.debug("Node %d is running testing for fold %d", comm_rank, fold_id) 132 | 133 | test_results = self.process_batch([work_item]) 134 | 135 | comm.send((fold_id, test_results[0]['score']), dest=0, tag=MPI_TAG_RESULT) 136 | 137 | def run(self): 138 | """Wait for new data until node receives a message with MPI_MSG_TERMINATE or MPI_MSG_TEST 139 | 140 | In the beginning, the node is waiting for new batches distributed by 141 | :class:`MPIGridSearchCVMaster._scatter_work`. After the grid search has been completed, 142 | the node either receives data from :func:`_fit_and_score_with_parameters` to 143 | evaluate the estimator given the parameters determined during grid-search, or is asked 144 | to terminate. 145 | """ 146 | task_desc = self._task_desc 147 | 148 | while True: 149 | comm.Bcast([task_desc, MPI.INT], root=0) 150 | if task_desc[1] == MPI_MSG_TERMINATE: 151 | LOG.debug("Node %d received terminate message", comm_rank) 152 | return 153 | if task_desc[1] == MPI_MSG_CV: 154 | self._run_grid_search() 155 | elif task_desc[1] == MPI_MSG_TEST: 156 | self._run_train_test() 157 | break 158 | else: 159 | raise ValueError('unknown task with id %d' % task_desc[1]) 160 | 161 | LOG.debug("Node %d is terminating", comm_rank) 162 | 163 | 164 | class MPIGridSearchCVMaster(MPIBatchWorker): 165 | """Running on the root node and distributes work across slaves""" 166 | 167 | def __init__(self, param_grid, cv_iter, estimator, scorer, fit_params): 168 | super(MPIGridSearchCVMaster, self).__init__(estimator, 169 | scorer, fit_params) 170 | self.param_grid = param_grid 171 | self.cv_iter = cv_iter 172 | 173 | def _create_batches(self): 174 | param_iter = ParameterGrid(self.param_grid) 175 | 176 | # divide work into batches equal to the communicator's size 177 | work_batches = [[] for _ in range(comm_size)] 178 | i = 0 179 | for fold_id, (train_index, test_index) in enumerate(self.cv_iter): 180 | for parameters in param_iter: 181 | work_batches[i % comm_size].append((fold_id + 1, train_index, test_index, parameters)) 182 | i += 1 183 | 184 | return work_batches 185 | 186 | def _scatter_work(self): 187 | work_batches = self._create_batches() 188 | 189 | LOG.debug("Distributed items into %d batches of size %d", comm_size, len(work_batches[0])) 190 | 191 | # Distribute batches across all nodes 192 | root_work_batch = comm.scatter(work_batches, root=0) 193 | # The root node also does receive one batch it has to process 194 | root_result_batch = self.process_batch(root_work_batch) 195 | return root_result_batch 196 | 197 | def _gather_work(self, root_result_batch): 198 | # collect results: list of list of dict of parameters and performance measures 199 | result_batches = comm.gather(root_result_batch, root=0) 200 | 201 | out = [] 202 | for result_batch in result_batches: 203 | if result_batch is None: 204 | continue 205 | for result_item in result_batch: 206 | out.append(result_item) 207 | LOG.debug("Received %d valid results", len(out)) 208 | 209 | return pandas.DataFrame(out) 210 | 211 | def run(self, train_X, train_y): 212 | # tell slave that it should do hyper-parameter search 213 | self._task_desc[0] = 0 214 | self._task_desc[1] = MPI_MSG_CV 215 | 216 | comm.Bcast([self._task_desc, MPI.INT], root=0) 217 | comm.bcast((train_X, train_y), root=0) 218 | 219 | self._data_X = train_X 220 | self._data_y = train_y 221 | 222 | root_result_batch = self._scatter_work() 223 | return self._gather_work(root_result_batch) 224 | 225 | 226 | def _fit_and_score_with_parameters(X, y, cv, best_parameters): 227 | """Distributes work of non-nested cross-validation across slave nodes""" 228 | 229 | # tell slaves testing phase is next 230 | _task_desc = numpy.empty(2, dtype=int) 231 | _task_desc[1] = MPI_MSG_TEST 232 | 233 | comm.Bcast([_task_desc, MPI.INT], root=0) 234 | comm.bcast((X, y), root=0) 235 | 236 | assert comm_size >= len(cv) 237 | 238 | for i, (train_index, test_index) in enumerate(cv): 239 | fold_id = i + 1 240 | LOG.info("Testing fold %d", fold_id) 241 | 242 | parameters = best_parameters.loc[fold_id, :].to_dict() 243 | work_item = (fold_id, train_index, test_index, parameters) 244 | 245 | comm.send(work_item, dest=fold_id, tag=MPI_TAG_TRAIN_TEST_DATA) 246 | 247 | scores = {} 248 | for i in range(len(cv)): 249 | fold_id, test_result = comm.recv(source=MPI.ANY_SOURCE, tag=MPI_TAG_RESULT) 250 | scores[fold_id] = test_result 251 | 252 | # Tell all nodes to terminate 253 | for i in range(len(cv), comm_size): 254 | comm.send((0, None), dest=i, tag=MPI_TAG_TRAIN_TEST_DATA) 255 | 256 | return pandas.Series(scores) 257 | 258 | 259 | class NestedGridSearchCV(BaseEstimator): 260 | """Cross-validation with nested hyper-parameter search for each training fold. 261 | 262 | The data is first split into ``cv`` train and test sets. For each training set. 263 | a grid search over the specified set of parameters is performed (inner cross-validation). 264 | The set of parameters that achieved the highest average score across all inner folds 265 | is used to re-fit a model on the entire training set of the outer cross-validation loop. 266 | Finally, results on the test set of the outer loop are reported. 267 | 268 | Parameters 269 | ---------- 270 | estimator : object type that implements the "fit" and "predict" methods 271 | A object of that type is instantiated for each grid point. 272 | 273 | param_grid : dict or list of dictionaries 274 | Dictionary with parameters names (string) as keys and lists of 275 | parameter settings to try as values, or a list of such 276 | dictionaries, in which case the grids spanned by each dictionary 277 | in the list are explored. This enables searching over any sequence 278 | of parameter settings. 279 | 280 | scoring : string, callable or None, optional, default: None 281 | A string (see model evaluation documentation) or 282 | a scorer callable object / function with signature 283 | ``scorer(estimator, X, y)``. 284 | See sklearn.metrics.get_scorer for details. 285 | 286 | fit_params : dict, optional 287 | Parameters to pass to the fit method. 288 | 289 | cv : integer or cross-validation generator, default=3 290 | If an integer is passed, it is the number of folds. 291 | Specific cross-validation objects can be passed, see 292 | sklearn.cross_validation module for the list of possible objects 293 | 294 | inner_cv : integer or callable, default=3 295 | If an integer is passed, it is the number of folds. 296 | If callable, the function must have the signature ``inner_cv_func(X, y)`` 297 | and return a cross-validation object, see sklearn.cross_validation 298 | module for the list of possible objects. 299 | 300 | multi_output : boolean, default=False 301 | Allow multi-output y, as for multivariate regression. 302 | 303 | Attributes 304 | ---------- 305 | best_params_ : pandas.DataFrame 306 | Contains selected parameter settings for each fold. 307 | The validation score refers to average score across all folds of the 308 | inner cross-validation, the test score to the score on the test set 309 | of the outer cross-validation loop. 310 | 311 | grid_scores_ : list of pandas.DataFrame 312 | Contains full results of grid search for each training set of the 313 | outer cross-validation loop. 314 | 315 | scorer_ : function 316 | Scorer function used on the held out data to choose the best 317 | parameters for the model. 318 | """ 319 | 320 | def __init__(self, estimator, param_grid, scoring=None, fit_params=None, cv=None, 321 | inner_cv=None, multi_output=False): 322 | self.scoring = scoring 323 | self.estimator = estimator 324 | self.param_grid = param_grid 325 | self.scoring = scoring 326 | self.fit_params = fit_params 327 | self.cv = cv 328 | self.inner_cv = inner_cv 329 | self.multi_output = multi_output 330 | 331 | def _grid_search(self, train_X, train_y): 332 | if callable(self.inner_cv): 333 | inner_cv = self.inner_cv(train_X, train_y) 334 | else: 335 | inner_cv = _check_cv(self.inner_cv, train_X, train_y, classifier=is_classifier(self.estimator)) 336 | 337 | master = MPIGridSearchCVMaster(self.param_grid, inner_cv, self.estimator, self.scorer_, self.fit_params) 338 | return master.run(train_X, train_y) 339 | 340 | def _fit_master(self, X, y, cv): 341 | param_names = list(self.param_grid.keys()) 342 | 343 | best_parameters = [] 344 | grid_search_results = [] 345 | for i, (train_index, test_index) in enumerate(cv): 346 | LOG.info("Training fold %d", i + 1) 347 | 348 | train_X = X[train_index, :] 349 | train_y = y[train_index] 350 | 351 | grid_results = self._grid_search(train_X, train_y) 352 | grid_search_results.append(grid_results) 353 | 354 | max_performance = _get_best_parameters(grid_results, param_names) 355 | LOG.info("Best performance for fold %d:\n%s", i + 1, max_performance) 356 | max_performance['fold'] = i + 1 357 | best_parameters.append(max_performance) 358 | 359 | best_parameters = pandas.DataFrame(best_parameters) 360 | best_parameters.set_index('fold', inplace=True) 361 | best_parameters['score (Test)'] = 0.0 362 | best_parameters.rename(columns={'score': 'score (Validation)'}, inplace=True) 363 | 364 | scores = _fit_and_score_with_parameters(X, y, cv, best_parameters.loc[:, param_names]) 365 | best_parameters['score (Test)'] = scores 366 | 367 | self.best_params_ = best_parameters 368 | self.grid_scores_ = grid_search_results 369 | 370 | def _fit_slave(self): 371 | slave = MPISlave(self.estimator, self.scorer_, self.fit_params) 372 | slave.run() 373 | 374 | def fit(self, X, y): 375 | X, y = check_X_y(X, y, force_all_finite=False, multi_output=self.multi_output) 376 | _check_param_grid(self.param_grid) 377 | 378 | cv = _check_cv(self.cv, X, y, classifier=is_classifier(self.estimator)) 379 | 380 | self.scorer_ = check_scoring(self.estimator, scoring=self.scoring) 381 | 382 | if comm_rank == 0: 383 | self._fit_master(X, y, cv) 384 | else: 385 | self._fit_slave() 386 | 387 | return self 388 | 389 | -------------------------------------------------------------------------------- /grid_search/__init__.py: -------------------------------------------------------------------------------- 1 | from .MPINestedGridSearchCV import NestedGridSearchCV --------------------------------------------------------------------------------