├── .circleci └── config.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE.txt ├── MANIFEST.in ├── README.rst ├── ci └── environment-3.7.yaml ├── dask_xgboost ├── __init__.py ├── core.py ├── tests │ └── test_core.py └── tracker.py ├── requirements.txt ├── setup.cfg └── setup.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | docker: 5 | - image: continuumio/miniconda:latest 6 | environment: 7 | PYTHON: "3.7" 8 | steps: 9 | - checkout 10 | - run: 11 | name: configure conda 12 | command: | 13 | conda config --set always_yes true --set changeps1 false 14 | conda update -q conda 15 | conda install conda-build anaconda-client --yes 16 | conda config --add channels conda-forge 17 | conda env create -f ci/environment-${PYTHON}.yaml 18 | source activate dask-xgboost-test 19 | pip install --no-deps -e . 20 | conda list dask-xgboost-test 21 | - run: 22 | # TODO: Check on the conda-forge recipe for why this is nescessary 23 | command: | 24 | source activate dask-xgboost-test 25 | - run: 26 | command: | 27 | source activate dask-xgboost-test 28 | pytest -v -s dask_xgboost 29 | - run: 30 | command: | 31 | source activate dask-xgboost-test 32 | flake8 dask_xgboost 33 | - run: 34 | command: | 35 | source activate dask-xgboost-test 36 | black . 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.py~ 3 | *.egg-info 4 | docs/build 5 | dask-worker-space 6 | build 7 | dist 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/python/black 3 | rev: 19.10b0 4 | hooks: 5 | - id: black 6 | language_version: python3.7 7 | - repo: https://gitlab.com/pycqa/flake8 8 | rev: 3.7.9 9 | hooks: 10 | - id: flake8 11 | language_version: python3.7 12 | - repo: https://github.com/pre-commit/mirrors-isort 13 | rev: v4.3.21 14 | hooks: 15 | - id: isort 16 | 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Dask is a community maintained project. We welcome contributions in the form of bug reports, documentation, code, design proposals, and more. 2 | 3 | For general information on how to contribute see https://docs.dask.org/en/latest/develop.html. 4 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Continuum Analytics, Inc. and contributors 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | 10 | Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | Neither the name of Continuum Analytics nor the names of any contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 28 | THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include dask_xgboost *.py 2 | 3 | include requirements.txt 4 | include setup.py 5 | include README.rst 6 | include LICENSE.txt 7 | include MANIFEST.in 8 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Dask-XGBoost 2 | ============ 3 | 4 | .. warning:: 5 | 6 | Dask-XGBoost has been deprecated and is no longer maintained. The functionality 7 | of this project has been included directly in XGBoost. To use Dask and XGBoost 8 | together, please use ``xgboost.dask`` instead 9 | https://xgboost.readthedocs.io/en/latest/tutorials/dask.html. 10 | 11 | Distributed training with XGBoost and Dask.distributed 12 | 13 | This repository offers a legacy option to perform distributed training 14 | with XGBoost on Dask.array and Dask.dataframe collections. 15 | 16 | :: 17 | 18 | pip install dask-xgboost 19 | 20 | Please note that XGBoost now includes a Dask API as part of its official Python package. 21 | That API is independent of `dask-xgboost` and is now the recommended way to use Dask 22 | adn XGBoost together. See 23 | `the xgb.dask documentation here https://xgboost.readthedocs.io/en/latest/tutorials/dask.html` 24 | for more details on the new API. 25 | 26 | 27 | 28 | Example 29 | ------- 30 | 31 | .. code-block:: python 32 | 33 | from dask.distributed import Client 34 | client = Client('scheduler-address:8786') # connect to cluster 35 | 36 | import dask.dataframe as dd 37 | df = dd.read_csv('...') # use dask.dataframe to load and 38 | df_train = ... # preprocess data 39 | labels_train = ... 40 | 41 | import dask_xgboost as dxgb 42 | params = {'objective': 'binary:logistic', ...} # use normal xgboost params 43 | bst = dxgb.train(client, params, df_train, labels_train) 44 | 45 | >>> bst # Get back normal XGBoost result 46 | 47 | 48 | predictions = dxgb.predict(client, bst, data_test) 49 | 50 | 51 | How this works 52 | -------------- 53 | 54 | For more information on using Dask.dataframe for preprocessing see the 55 | `Dask.dataframe documentation `_. 56 | 57 | Once you have created suitable data and labels we are ready for distributed 58 | training with XGBoost. Every Dask worker sets up an XGBoost slave and gives 59 | them enough information to find each other. Then Dask workers hand their 60 | in-memory Pandas dataframes to XGBoost (one Dask dataframe is just many Pandas 61 | dataframes spread around the memory of many machines). XGBoost handles 62 | distributed training on its own without Dask interference. XGBoost then hands 63 | back a single ``xgboost.Booster`` result object. 64 | 65 | 66 | Larger Example 67 | -------------- 68 | 69 | For a more serious example see 70 | 71 | - `This blogpost `_ 72 | - `This notebook `_ 73 | - `This screencast `_ 74 | 75 | History 76 | ------- 77 | 78 | Conversation during development happened at `dmlc/xgboost #2032 79 | `_ 80 | -------------------------------------------------------------------------------- /ci/environment-3.7.yaml: -------------------------------------------------------------------------------- 1 | name: dask-xgboost-test 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - black 7 | - coverage 8 | - codecov 9 | - dask 10 | - dask-glm >=0.2.0 11 | - distributed 12 | - flake8 13 | - isort==4.3.21 14 | - multipledispatch >=0.4.9 15 | - mypy 16 | - numba 17 | - numpy >=1.16.3 18 | - numpydoc 19 | - packaging 20 | - pandas 21 | - psutil 22 | - pytest 23 | - pytest-cov 24 | - pytest-mock 25 | - pytest-xdist 26 | - python=3.7.* 27 | - scikit-learn>=0.23.0 28 | - scipy 29 | - sparse 30 | - toolz 31 | - xgboost=0.90 32 | -------------------------------------------------------------------------------- /dask_xgboost/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from .core import XGBClassifier, XGBRegressor, _train, predict, train # noqa 4 | 5 | __version__ = "0.2.0" 6 | 7 | warnings.warn( 8 | "Dask-XGBoost has been deprecated and is no longer maintained. The functionality " 9 | "of this project has been included directly in XGBoost. To use Dask and XGBoost " 10 | "together, please use ``xgboost.dask`` instead " 11 | "https://xgboost.readthedocs.io/en/latest/tutorials/dask.html." 12 | ) -------------------------------------------------------------------------------- /dask_xgboost/core.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import socket 3 | from collections import defaultdict 4 | from threading import Thread 5 | 6 | import dask.array as da 7 | import dask.dataframe as dd 8 | import numpy as np 9 | import pandas as pd 10 | import xgboost as xgb 11 | from dask import delayed, is_dask_collection 12 | from dask.distributed import default_client, wait 13 | from toolz import assoc, first 14 | from tornado import gen 15 | 16 | from .tracker import RabitTracker, get_host_ip 17 | 18 | try: 19 | import sparse 20 | except ImportError: 21 | sparse = False 22 | 23 | try: 24 | import scipy.sparse as ss 25 | except ImportError: 26 | ss = False 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def parse_host_port(address): 32 | if "://" in address: 33 | address = address.rsplit("://", 1)[1] 34 | host, port = address.split(":") 35 | port = int(port) 36 | return host, port 37 | 38 | 39 | def start_tracker(host, n_workers, default_host=None): 40 | """ Start Rabit tracker """ 41 | if host is None: 42 | try: 43 | host = get_host_ip("auto") 44 | except socket.gaierror: 45 | if default_host is not None: 46 | host = default_host 47 | else: 48 | raise 49 | 50 | env = {"DMLC_NUM_WORKER": n_workers} 51 | rabit = RabitTracker(hostIP=host, nslave=n_workers) 52 | env.update(rabit.slave_envs()) 53 | 54 | rabit.start(n_workers) 55 | logger.info("Starting Rabit Tracker") 56 | thread = Thread(target=rabit.join) 57 | thread.daemon = True 58 | thread.start() 59 | return env 60 | 61 | 62 | def concat(L): 63 | if isinstance(L[0], np.ndarray): 64 | return np.concatenate(L, axis=0) 65 | elif isinstance(L[0], (pd.DataFrame, pd.Series)): 66 | return pd.concat(L, axis=0) 67 | elif ss and isinstance(L[0], ss.csr_matrix): 68 | return ss.vstack(L, format="csr") 69 | elif sparse and isinstance(L[0], sparse.SparseArray): 70 | return sparse.concatenate(L, axis=0) 71 | else: 72 | raise TypeError( 73 | "Data must be either numpy arrays or pandas dataframes" 74 | ". Got %s" % type(L[0]) 75 | ) 76 | 77 | 78 | def train_part( 79 | env, 80 | param, 81 | list_of_parts, 82 | dmatrix_kwargs=None, 83 | eval_set=None, 84 | missing=None, 85 | n_jobs=None, 86 | sample_weight_eval_set=None, 87 | **kwargs 88 | ): 89 | """ 90 | Run part of XGBoost distributed workload 91 | 92 | This starts an xgboost.rabit slave, trains on provided data, and then shuts 93 | down the xgboost.rabit slave 94 | 95 | Returns 96 | ------- 97 | model if rank zero, None otherwise 98 | """ 99 | data, labels, sample_weight = zip(*list_of_parts) # Prepare data 100 | data = concat(data) # Concatenate many parts into one 101 | labels = concat(labels) 102 | sample_weight = concat(sample_weight) if np.all(sample_weight) else None 103 | 104 | if dmatrix_kwargs is None: 105 | dmatrix_kwargs = {} 106 | 107 | dmatrix_kwargs["feature_names"] = getattr(data, "columns", None) 108 | dtrain = xgb.DMatrix(data, labels, weight=sample_weight, **dmatrix_kwargs) 109 | 110 | evals = _package_evals( 111 | eval_set, 112 | sample_weight_eval_set=sample_weight_eval_set, 113 | missing=missing, 114 | n_jobs=n_jobs, 115 | ) 116 | 117 | args = [("%s=%s" % item).encode() for item in env.items()] 118 | xgb.rabit.init(args) 119 | try: 120 | local_history = {} 121 | logger.info("Starting Rabit, Rank %d", xgb.rabit.get_rank()) 122 | bst = xgb.train( 123 | param, dtrain, evals=evals, evals_result=local_history, **kwargs 124 | ) 125 | 126 | if xgb.rabit.get_rank() == 0: # Only return from one worker 127 | result = bst 128 | evals_result = local_history 129 | else: 130 | result = None 131 | evals_result = None 132 | finally: 133 | logger.info("Finalizing Rabit, Rank %d", xgb.rabit.get_rank()) 134 | xgb.rabit.finalize() 135 | return result, evals_result 136 | 137 | 138 | def _package_evals(eval_set, sample_weight_eval_set=None, missing=None, n_jobs=None): 139 | if eval_set is not None: 140 | if sample_weight_eval_set is None: 141 | sample_weight_eval_set = [None] * len(eval_set) 142 | evals = list( 143 | xgb.DMatrix( 144 | data, label=label, missing=missing, weight=weight, nthread=n_jobs, 145 | ) 146 | for ((data, label), weight) in zip(eval_set, sample_weight_eval_set) 147 | ) 148 | evals = list(zip(evals, ["validation_{}".format(i) for i in range(len(evals))])) 149 | else: 150 | evals = () 151 | return evals 152 | 153 | 154 | def _has_dask_collections(list_of_collections, message): 155 | list_of_collections = list_of_collections or [] 156 | if any( 157 | is_dask_collection(collection) 158 | for collections in list_of_collections 159 | for collection in collections 160 | ): 161 | raise TypeError(message) 162 | 163 | 164 | @gen.coroutine 165 | def _train( 166 | client, 167 | params, 168 | data, 169 | labels, 170 | dmatrix_kwargs={}, 171 | evals_result=None, 172 | sample_weight=None, 173 | **kwargs 174 | ): 175 | """ 176 | Asynchronous version of train 177 | 178 | See Also 179 | -------- 180 | train 181 | """ 182 | # Break apart Dask.array/dataframe into chunks/parts 183 | data_parts = data.to_delayed() 184 | label_parts = labels.to_delayed() 185 | if isinstance(data_parts, np.ndarray): 186 | assert data_parts.shape[1] == 1 187 | data_parts = data_parts.flatten().tolist() 188 | if isinstance(label_parts, np.ndarray): 189 | assert label_parts.ndim == 1 or label_parts.shape[1] == 1 190 | label_parts = label_parts.flatten().tolist() 191 | if sample_weight is not None: 192 | sample_weight_parts = sample_weight.to_delayed() 193 | if isinstance(sample_weight_parts, np.ndarray): 194 | assert sample_weight_parts.ndim == 1 or sample_weight_parts.shape[1] == 1 195 | sample_weight_parts = sample_weight_parts.flatten().tolist() 196 | else: 197 | # If sample_weight is None construct a list of Nones to keep 198 | # the structure of parts consistent. 199 | sample_weight_parts = [None] * len(data_parts) 200 | 201 | # Check that data, labels, and sample_weights are the same length 202 | lists = [data_parts, label_parts, sample_weight_parts] 203 | if len(set([len(l) for l in lists])) > 1: 204 | raise ValueError( 205 | "data, label, and sample_weight parts/chunks must have same length." 206 | ) 207 | 208 | # Arrange parts into triads. This enforces co-locality 209 | parts = list(map(delayed, zip(data_parts, label_parts, sample_weight_parts))) 210 | parts = client.compute(parts) # Start computation in the background 211 | yield wait(parts) 212 | 213 | for part in parts: 214 | if part.status == "error": 215 | yield part # trigger error locally 216 | 217 | _has_dask_collections( 218 | kwargs.get("eval_set", []), "Evaluation set must not contain dask collections." 219 | ) 220 | _has_dask_collections( 221 | kwargs.get("sample_weight_eval_set", []), 222 | "Sample weight evaluation set must not contain dask collections.", 223 | ) 224 | 225 | # Because XGBoost-python doesn't yet allow iterative training, we need to 226 | # find the locations of all chunks and map them to particular Dask workers 227 | key_to_part_dict = dict([(part.key, part) for part in parts]) 228 | who_has = yield client.scheduler.who_has(keys=[part.key for part in parts]) 229 | worker_map = defaultdict(list) 230 | for key, workers in who_has.items(): 231 | worker_map[first(workers)].append(key_to_part_dict[key]) 232 | 233 | ncores = yield client.scheduler.ncores() # Number of cores per worker 234 | 235 | default_host, _ = parse_host_port(client.scheduler.address) 236 | # Start the XGBoost tracker on the Dask scheduler 237 | env = yield client._run_on_scheduler( 238 | start_tracker, None, len(worker_map), default_host=default_host 239 | ) 240 | 241 | # Tell each worker to train on the chunks/parts that it has locally 242 | futures = [ 243 | client.submit( 244 | train_part, 245 | env, 246 | assoc(params, "nthread", ncores[worker]), 247 | list_of_parts, 248 | workers=worker, 249 | dmatrix_kwargs=dmatrix_kwargs, 250 | **kwargs 251 | ) 252 | for worker, list_of_parts in worker_map.items() 253 | ] 254 | 255 | # Get the results, only one will be non-None 256 | results = yield client._gather(futures) 257 | result, _evals_result = [v for v in results if v.count(None) != len(v)][0] 258 | 259 | if evals_result is not None: 260 | evals_result.update(_evals_result) 261 | 262 | num_class = params.get("num_class") 263 | if num_class: 264 | result.set_attr(num_class=str(num_class)) 265 | raise gen.Return(result) 266 | 267 | 268 | def train( 269 | client, 270 | params, 271 | data, 272 | labels, 273 | dmatrix_kwargs={}, 274 | evals_result=None, 275 | sample_weight=None, 276 | **kwargs 277 | ): 278 | """ Train an XGBoost model on a Dask Cluster 279 | 280 | This starts XGBoost on all Dask workers, moves input data to those workers, 281 | and then calls ``xgboost.train`` on the inputs. 282 | 283 | Parameters 284 | ---------- 285 | client: dask.distributed.Client 286 | params: dict 287 | Parameters to give to XGBoost (see xgb.Booster.train) 288 | data: dask array or dask.dataframe 289 | labels: dask.array or dask.dataframe 290 | dmatrix_kwargs: Keywords to give to Xgboost DMatrix 291 | evals_result: dict, optional 292 | Stores the evaluation result history of all the items in the eval_set 293 | by mutating evals_result in place. 294 | sample_weight : array_like, optional 295 | instance weights 296 | **kwargs: Keywords to give to XGBoost train 297 | 298 | Examples 299 | -------- 300 | >>> client = Client('scheduler-address:8786') # doctest: +SKIP 301 | >>> data = dd.read_csv('s3://...') # doctest: +SKIP 302 | >>> labels = data['outcome'] # doctest: +SKIP 303 | >>> del data['outcome'] # doctest: +SKIP 304 | >>> train(client, params, data, labels, **normal_kwargs) # doctest: +SKIP 305 | 306 | 307 | See Also 308 | -------- 309 | predict 310 | """ 311 | return client.sync( 312 | _train, 313 | client, 314 | params, 315 | data, 316 | labels, 317 | dmatrix_kwargs, 318 | evals_result, 319 | sample_weight, 320 | **kwargs 321 | ) 322 | 323 | 324 | def _predict_part(part, model=None): 325 | xgb.rabit.init() 326 | try: 327 | dm = xgb.DMatrix(part) 328 | result = model.predict(dm) 329 | finally: 330 | xgb.rabit.finalize() 331 | 332 | if isinstance(part, pd.DataFrame): 333 | if model.attr("num_class"): 334 | result = pd.DataFrame(result, index=part.index) 335 | else: 336 | result = pd.Series(result, index=part.index, name="predictions") 337 | return result 338 | 339 | 340 | def predict(client, model, data): 341 | """ Distributed prediction with XGBoost 342 | 343 | Parameters 344 | ---------- 345 | client: dask.distributed.Client 346 | model: xgboost.Booster 347 | data: dask array or dataframe 348 | 349 | Examples 350 | -------- 351 | >>> client = Client('scheduler-address:8786') # doctest: +SKIP 352 | >>> test_data = dd.read_csv('s3://...') # doctest: +SKIP 353 | >>> model 354 | 355 | 356 | >>> predictions = predict(client, model, test_data) # doctest: +SKIP 357 | 358 | Returns 359 | ------- 360 | Dask.dataframe or dask.array, depending on the input data type 361 | 362 | See Also 363 | -------- 364 | train 365 | """ 366 | if isinstance(data, dd._Frame): 367 | result = data.map_partitions(_predict_part, model=model) 368 | result = result.values 369 | elif isinstance(data, da.Array): 370 | num_class = model.attr("num_class") or 2 371 | num_class = int(num_class) 372 | 373 | if num_class > 2: 374 | kwargs = dict(drop_axis=None, chunks=(data.chunks[0], (num_class,))) 375 | else: 376 | kwargs = dict(drop_axis=1) 377 | result = data.map_blocks(_predict_part, model=model, dtype=np.float32, **kwargs) 378 | else: 379 | model = model.result() # Future to concrete 380 | if not isinstance(data, xgb.DMatrix): 381 | data = xgb.DMatrix(data) 382 | result = model.predict(data) 383 | 384 | return result 385 | 386 | 387 | class XGBRegressor(xgb.XGBRegressor): 388 | def fit( 389 | self, 390 | X, 391 | y=None, 392 | eval_set=None, 393 | sample_weight=None, 394 | sample_weight_eval_set=None, 395 | eval_metric=None, 396 | early_stopping_rounds=None, 397 | ): 398 | """Fit the gradient boosting model 399 | 400 | Parameters 401 | ---------- 402 | X : array-like [n_samples, n_features] 403 | y : array-like 404 | 405 | Returns 406 | ------- 407 | self : the fitted Regressor 408 | 409 | Notes 410 | ----- 411 | This differs from the XGBoost version not supporting the ``eval_set``, 412 | ``eval_metric``, ``early_stopping_rounds`` and ``verbose`` fit 413 | kwargs. 414 | eval_set : list, optional 415 | A list of (X, y) tuple pairs to use as validation sets, for which 416 | metrics will be computed. 417 | Validation metrics will help us track the performance of the model. 418 | sample_weight : array_like, optional 419 | instance weights 420 | sample_weight_eval_set : list, optional 421 | A list of the form [L_1, L_2, ..., L_n], where each L_i is a list 422 | of instance weights on the i-th validation set. 423 | eval_metric : str, list of str, or callable, optional 424 | If a str, should be a built-in evaluation metric to use. See 425 | `doc/parameter.rst `_. # noqa: E501 426 | If a list of str, should be the list of multiple built-in 427 | evaluation metrics to use. 428 | If callable, a custom evaluation metric. The call 429 | signature is ``func(y_predicted, y_true)`` where ``y_true`` will 430 | be a DMatrix object such that you may need to call the 431 | ``get_label`` method. It must return a str, value pair where 432 | the str is a name for the evaluation and value is the value of 433 | the evaluation function. The callable custom objective is always 434 | minimized. 435 | early_stopping_rounds : int 436 | Activates early stopping. Validation metric needs to improve at 437 | least once in every **early_stopping_rounds** round(s) to continue 438 | training. 439 | Requires at least one item in **eval_set**. 440 | The method returns the model from the last iteration (not the best 441 | one). 442 | If there's more than one item in **eval_set**, the last entry will 443 | be used for early stopping. 444 | If there's more than one metric in **eval_metric**, the last 445 | metric will be used for early stopping. 446 | If early stopping occurs, the model will have three additional 447 | fields: 448 | ``clf.best_score``, ``clf.best_iteration`` and 449 | ``clf.best_ntree_limit``. 450 | """ 451 | client = default_client() 452 | xgb_options = self.get_xgb_params() 453 | 454 | if eval_metric is not None: 455 | if callable(eval_metric): 456 | eval_metric = None 457 | else: 458 | xgb_options.update({"eval_metric": eval_metric}) 459 | 460 | self.evals_result_ = {} 461 | self._Booster = train( 462 | client, 463 | xgb_options, 464 | X, 465 | y, 466 | num_boost_round=self.n_estimators, 467 | eval_set=eval_set, 468 | sample_weight=sample_weight, 469 | sample_weight_eval_set=sample_weight_eval_set, 470 | missing=self.missing, 471 | n_jobs=self.n_jobs, 472 | early_stopping_rounds=early_stopping_rounds, 473 | evals_result=self.evals_result_, 474 | ) 475 | 476 | if early_stopping_rounds is not None: 477 | self.best_score = self._Booster.best_score 478 | self.best_iteration = self._Booster.best_iteration 479 | self.best_ntree_limit = self._Booster.best_ntree_limit 480 | return self 481 | 482 | def predict(self, X): 483 | client = default_client() 484 | return predict(client, self._Booster, X) 485 | 486 | 487 | class XGBClassifier(xgb.XGBClassifier): 488 | def fit( 489 | self, 490 | X, 491 | y=None, 492 | classes=None, 493 | eval_set=None, 494 | sample_weight=None, 495 | sample_weight_eval_set=None, 496 | eval_metric=None, 497 | early_stopping_rounds=None, 498 | ): 499 | """Fit a gradient boosting classifier 500 | 501 | Parameters 502 | ---------- 503 | X : array-like [n_samples, n_features] 504 | Feature Matrix. May be a dask.array or dask.dataframe 505 | y : array-like 506 | Labels 507 | eval_set : list, optional 508 | A list of (X, y) tuple pairs to use as validation sets, for which 509 | metrics will be computed. 510 | Validation metrics will help us track the performance of the model. 511 | sample_weight : array_like, optional 512 | instance weights 513 | sample_weight_eval_set : list, optional 514 | A list of the form [L_1, L_2, ..., L_n], where each L_i is a list 515 | of instance weights on the i-th validation set. 516 | eval_metric : str, list of str, or callable, optional 517 | If a str, should be a built-in evaluation metric to use. See 518 | `doc/parameter.rst `_. # noqa: E501 519 | If a list of str, should be the list of multiple built-in 520 | evaluation metrics to use. 521 | If callable, a custom evaluation metric. The call 522 | signature is ``func(y_predicted, y_true)`` where ``y_true`` will 523 | be a DMatrix object such that you may need to call the 524 | ``get_label`` method. It must return a str, value pair where 525 | the str is a name for the evaluation and value is the value of 526 | the evaluation function. The callable custom objective is always 527 | minimized. 528 | early_stopping_rounds : int 529 | Activates early stopping. Validation metric needs to improve at 530 | least once in every **early_stopping_rounds** round(s) to continue 531 | training. 532 | Requires at least one item in **eval_set**. 533 | The method returns the model from the last iteration (not the best 534 | one). 535 | If there's more than one item in **eval_set**, the last entry will 536 | be used for early stopping. 537 | If there's more than one metric in **eval_metric**, the last 538 | metric will be used for early stopping. 539 | If early stopping occurs, the model will have three additional 540 | fields: 541 | ``clf.best_score``, ``clf.best_iteration`` and 542 | ``clf.best_ntree_limit``. 543 | classes : sequence, optional 544 | The unique values in `y`. If no specified, this will be 545 | eagerly computed from `y` before training. 546 | 547 | Returns 548 | ------- 549 | self : XGBClassifier 550 | 551 | Notes 552 | ----- 553 | This differs from the XGBoost version in three ways 554 | 555 | 1. The ``verbose`` fit kwargs are not supported. 556 | 2. The labels are not automatically label-encoded 557 | 3. The ``classes_`` and ``n_classes_`` attributes are not learned 558 | """ 559 | client = default_client() 560 | 561 | if classes is None: 562 | if isinstance(y, da.Array): 563 | classes = da.unique(y) 564 | else: 565 | classes = y.unique() 566 | classes = classes.compute() 567 | else: 568 | classes = np.asarray(classes) 569 | self.classes_ = classes 570 | self.n_classes_ = len(self.classes_) 571 | 572 | xgb_options = self.get_xgb_params() 573 | 574 | if eval_metric is not None: 575 | if callable(eval_metric): 576 | eval_metric = None 577 | else: 578 | xgb_options.update({"eval_metric": eval_metric}) 579 | 580 | if self.n_classes_ > 2: 581 | # xgboost just ignores the user-provided objective 582 | # We only overwrite if it's the default... 583 | if xgb_options["objective"] == "binary:logistic": 584 | xgb_options["objective"] = "multi:softprob" 585 | 586 | xgb_options.setdefault("num_class", self.n_classes_) 587 | 588 | # xgboost sets this to self.objective, which I think is wrong 589 | # hyper-parameters should not be updated during fit. 590 | self.objective = xgb_options["objective"] 591 | 592 | # TODO: auto label-encode y 593 | # that will require a dependency on dask-ml 594 | 595 | self.evals_result_ = {} 596 | self._Booster = train( 597 | client, 598 | xgb_options, 599 | X, 600 | y, 601 | num_boost_round=self.n_estimators, 602 | eval_set=eval_set, 603 | sample_weight=sample_weight, 604 | sample_weight_eval_set=sample_weight_eval_set, 605 | missing=self.missing, 606 | n_jobs=self.n_jobs, 607 | early_stopping_rounds=early_stopping_rounds, 608 | evals_result=self.evals_result_, 609 | ) 610 | 611 | if early_stopping_rounds is not None: 612 | self.best_score = self._Booster.best_score 613 | self.best_iteration = self._Booster.best_iteration 614 | self.best_ntree_limit = self._Booster.best_ntree_limit 615 | return self 616 | 617 | def predict(self, X): 618 | client = default_client() 619 | class_probs = predict(client, self._Booster, X) 620 | if class_probs.ndim > 1: 621 | cidx = da.argmax(class_probs, axis=1) 622 | else: 623 | cidx = (class_probs > 0.5).astype(np.int64) 624 | return cidx 625 | 626 | def predict_proba(self, data, ntree_limit=None): 627 | client = default_client() 628 | if ntree_limit is not None: 629 | raise NotImplementedError("'ntree_limit' is not currently " "supported.") 630 | class_probs = predict(client, self._Booster, data) 631 | return class_probs 632 | -------------------------------------------------------------------------------- /dask_xgboost/tests/test_core.py: -------------------------------------------------------------------------------- 1 | # Workaround for conflict with distributed 1.23.0 2 | # https://github.com/dask/dask-xgboost/pull/27#issuecomment-417474734 3 | from concurrent.futures import ThreadPoolExecutor 4 | 5 | import dask 6 | import dask.array as da 7 | import dask.dataframe as dd 8 | import distributed.comm.utils 9 | import numpy as np 10 | import pandas as pd 11 | import pytest 12 | import scipy.sparse 13 | import xgboost as xgb 14 | from dask.array.utils import assert_eq 15 | from dask.distributed import Client 16 | from distributed.utils_test import cluster, gen_cluster, loop # noqa 17 | from sklearn.datasets import load_digits, load_iris 18 | from sklearn.model_selection import train_test_split 19 | 20 | import dask_xgboost as dxgb 21 | from dask_xgboost.core import _package_evals 22 | 23 | distributed.comm.utils._offload_executor = ThreadPoolExecutor(max_workers=2) 24 | 25 | 26 | df = pd.DataFrame( 27 | {"x": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "y": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0],} 28 | ) 29 | labels = pd.Series([1, 0, 1, 0, 1, 0, 1, 1, 1, 1]) 30 | 31 | param = { 32 | "max_depth": 2, 33 | "eta": 1, 34 | "silent": 1, 35 | "objective": "binary:logistic", 36 | } 37 | 38 | X = df.values 39 | y = labels.values 40 | 41 | 42 | def test_classifier(loop): # noqa 43 | digits = load_digits(2) 44 | X = digits["data"] 45 | y = digits["target"] 46 | 47 | with cluster() as (s, [a, b]): 48 | with Client(s["address"], loop=loop): 49 | a = dxgb.XGBClassifier() 50 | X2 = da.from_array(X) 51 | y2 = da.from_array(y) 52 | a.fit(X2, y2) 53 | p1 = a.predict(X2) 54 | 55 | b = xgb.XGBClassifier() 56 | b.fit(X, y) 57 | np.testing.assert_array_almost_equal(a.feature_importances_, b.feature_importances_) 58 | assert_eq(p1, b.predict(X)) 59 | 60 | 61 | def test_classifier_different_chunks(loop): # noqa 62 | with cluster() as (s, [a, b]): 63 | with Client(s["address"], loop=loop): 64 | a = dxgb.XGBClassifier() 65 | X2 = da.from_array(X, 5) 66 | y2 = da.from_array(y, 4) 67 | 68 | with pytest.raises(ValueError): 69 | a.fit(X2, y2) 70 | 71 | 72 | def test_multiclass_classifier(loop): # noqa 73 | # data 74 | iris = load_iris() 75 | X, y = iris.data, iris.target 76 | dX = da.from_array(X, 5) 77 | dy = da.from_array(y, 5) 78 | df = pd.DataFrame(X, columns=iris.feature_names) 79 | labels = pd.Series(y, name="target") 80 | 81 | ddf = dd.from_pandas(df, 2) 82 | dlabels = dd.from_pandas(labels, 2) 83 | # model 84 | a = xgb.XGBClassifier() # array 85 | b = dxgb.XGBClassifier() 86 | c = xgb.XGBClassifier() # frame 87 | d = dxgb.XGBClassifier() 88 | 89 | with cluster() as (s, [_, _]): 90 | with Client(s["address"], loop=loop): 91 | # fit 92 | a.fit(X, y) # array 93 | b.fit(dX, dy, classes=[0, 1, 2]) 94 | c.fit(df, labels) # frame 95 | d.fit(ddf, dlabels, classes=[0, 1, 2]) 96 | 97 | # check 98 | da.utils.assert_eq(a.predict(X), b.predict(dX)) 99 | da.utils.assert_eq(a.predict_proba(X), b.predict_proba(dX)) 100 | da.utils.assert_eq(c.predict(df), d.predict(ddf)) 101 | da.utils.assert_eq(c.predict_proba(df), d.predict_proba(ddf)) 102 | 103 | 104 | def test_classifier_early_stopping(loop): # noqa 105 | # data 106 | digits = load_digits(2) 107 | X = digits["data"] 108 | y = digits["target"] 109 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) 110 | 111 | dX_train = da.from_array(X_train) 112 | dy_train = da.from_array(y_train) 113 | 114 | clf1 = dxgb.XGBClassifier() 115 | clf2 = dxgb.XGBClassifier() 116 | clf3 = dxgb.XGBClassifier() 117 | with cluster() as (s, [_, _]): 118 | with Client(s["address"], loop=loop): 119 | clf1.fit( 120 | dX_train, 121 | dy_train, 122 | early_stopping_rounds=5, 123 | eval_metric="auc", 124 | eval_set=[(X_test, y_test)], 125 | ) 126 | clf2.fit( 127 | dX_train, 128 | dy_train, 129 | early_stopping_rounds=4, 130 | eval_metric="auc", 131 | eval_set=[(X_test, y_test)], 132 | ) 133 | 134 | # should be the same 135 | assert clf1.best_score == clf2.best_score 136 | assert clf1.best_score != 1 137 | 138 | # check overfit 139 | clf3.fit( 140 | dX_train, 141 | dy_train, 142 | early_stopping_rounds=10, 143 | eval_metric="auc", 144 | eval_set=[(X_test, y_test)], 145 | ) 146 | assert clf3.best_score == 1 147 | 148 | 149 | def test_package_evals(): 150 | # data 151 | digits = load_digits(2) 152 | X = digits["data"] 153 | y = digits["target"] 154 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) 155 | 156 | evals = _package_evals([(X_test, y_test), (X, y_test)]) 157 | 158 | assert len(evals) == 2 159 | 160 | evals = _package_evals( 161 | [(X_test, y_test), (X, y_test)], sample_weight_eval_set=[[1], [2]] 162 | ) 163 | 164 | assert len(evals) == 2 165 | 166 | evals = _package_evals( 167 | [(X_test, y_test), (X, y_test)], sample_weight_eval_set=[[1]] 168 | ) 169 | 170 | assert len(evals) == 1 171 | 172 | 173 | def test_validation_weights_xgbclassifier(loop): # noqa 174 | from sklearn.datasets import make_hastie_10_2 175 | 176 | # prepare training and test data 177 | X, y = make_hastie_10_2(n_samples=2000, random_state=42) 178 | labels, y = np.unique(y, return_inverse=True) 179 | 180 | param_dist = { 181 | "objective": "binary:logistic", 182 | "n_estimators": 2, 183 | "random_state": 123, 184 | } 185 | 186 | with cluster() as (s, [a, b]): 187 | with Client(s["address"], loop=loop): 188 | X_train, X_test = X[:1600], X[1600:] 189 | y_train, y_test = y[:1600], y[1600:] 190 | 191 | dX_train = da.from_array(X_train) 192 | dy_train = da.from_array(y_train) 193 | 194 | # instantiate model 195 | clf = dxgb.XGBClassifier(**param_dist) 196 | 197 | # train it using instance weights only in the training set 198 | weights_train = np.random.choice([1, 2], len(X_train)) 199 | weights_train = da.from_array(weights_train) 200 | clf.fit( 201 | dX_train, 202 | dy_train, 203 | sample_weight=weights_train, 204 | eval_set=[(X_test, y_test)], 205 | eval_metric="logloss", 206 | ) 207 | 208 | # evaluate logloss metric on test set *without* using weights 209 | evals_result_without_weights = clf.evals_result() 210 | logloss_without_weights = evals_result_without_weights["validation_0"][ 211 | "logloss" 212 | ] 213 | 214 | # now use weights for the test set 215 | np.random.seed(0) 216 | weights_test = np.random.choice([1, 2], len(X_test)) 217 | clf.fit( 218 | dX_train, 219 | dy_train, 220 | sample_weight=weights_train, 221 | eval_set=[(X_test, y_test)], 222 | sample_weight_eval_set=[weights_test], 223 | eval_metric="logloss", 224 | ) 225 | evals_result_with_weights = clf.evals_result() 226 | logloss_with_weights = evals_result_with_weights["validation_0"]["logloss"] 227 | 228 | # check that the logloss in the test set is actually different 229 | # when using weights than when not using them 230 | assert all((logloss_with_weights[i] != logloss_without_weights[i] for i in [0, 1])) 231 | 232 | 233 | @pytest.mark.parametrize("kind", ["array", "dataframe"]) 234 | def test_classifier_multi(kind, loop): # noqa: F811 235 | 236 | if kind == "array": 237 | X2 = da.from_array(X, 5) 238 | y2 = da.from_array(np.array([0, 1, 2, 0, 1, 2, 0, 0, 0, 1]), chunks=5) 239 | else: 240 | X2 = dd.from_pandas(df, npartitions=2) 241 | y2 = dd.from_pandas(labels, npartitions=2) 242 | 243 | with cluster() as (s, [a, b]): 244 | with Client(s["address"], loop=loop): 245 | a = dxgb.XGBClassifier( 246 | num_class=3, n_estimators=10, objective="multi:softprob" 247 | ) 248 | a.fit(X2, y2) 249 | p1 = a.predict(X2) 250 | 251 | assert dask.is_dask_collection(p1) 252 | 253 | if kind == "array": 254 | assert p1.shape == (10,) 255 | 256 | result = p1.compute() 257 | assert result.shape == (10,) 258 | 259 | # proba 260 | p2 = a.predict_proba(X2) 261 | assert dask.is_dask_collection(p2) 262 | 263 | if kind == "array": 264 | assert p2.shape == (10, 3) 265 | assert p2.compute().shape == (10, 3) 266 | 267 | 268 | def test_regressor(loop): # noqa 269 | with cluster() as (s, [a, b]): 270 | with Client(s["address"], loop=loop): 271 | a = dxgb.XGBRegressor() 272 | X2 = da.from_array(X, 5) 273 | y2 = da.from_array(y, 5) 274 | a.fit(X2, y2) 275 | p1 = a.predict(X2) 276 | 277 | b = xgb.XGBRegressor() 278 | b.fit(X, y) 279 | assert_eq(p1, b.predict(X)) 280 | 281 | 282 | def test_regressor_with_early_stopping(loop): # noqa 283 | with cluster() as (s, [a, b]): 284 | with Client(s["address"], loop=loop): 285 | a = dxgb.XGBRegressor() 286 | X2 = da.from_array(X, 5) 287 | y2 = da.from_array(y, 5) 288 | a.fit( 289 | X2, y2, early_stopping_rounds=4, eval_metric="rmse", eval_set=[(X, y)], 290 | ) 291 | p1 = a.predict(X2) 292 | 293 | b = xgb.XGBRegressor() 294 | b.fit(X, y, early_stopping_rounds=4, eval_metric="rmse", eval_set=[(X, y)]) 295 | assert_eq(p1, b.predict(X)) 296 | assert_eq(a.best_score, b.best_score) 297 | 298 | 299 | def test_validation_weights_xgbregressor(loop): # noqa 300 | from sklearn.datasets import make_regression 301 | from sklearn.metrics import mean_squared_error 302 | 303 | # prepare training and test data 304 | X, y = make_regression(n_samples=2000, random_state=42) 305 | 306 | with cluster() as (s, [a, b]): 307 | with Client(s["address"], loop=loop): 308 | X_train, X_test = X[:1600], X[1600:] 309 | y_train, y_test = y[:1600], y[1600:] 310 | 311 | dX_train = da.from_array(X_train) 312 | dy_train = da.from_array(y_train) 313 | dX_test = da.from_array(X_test) 314 | 315 | reg = dxgb.XGBRegressor() 316 | 317 | reg.fit( 318 | dX_train, dy_train, # sample_weight=weights_train, 319 | ) 320 | preds = reg.predict(dX_test) 321 | 322 | rng = np.random.RandomState(0) 323 | weights_train = 100.0 + rng.rand(len(X_train)) 324 | weights_train = da.from_array(weights_train) 325 | weights_test = 100.0 + rng.rand(len(X_test)) 326 | 327 | reg.fit( 328 | dX_train, 329 | dy_train, 330 | sample_weight=weights_train, 331 | sample_weight_eval_set=[weights_test], 332 | ) 333 | preds2 = reg.predict(dX_test) 334 | 335 | err = mean_squared_error(preds, y_test) 336 | err2 = mean_squared_error(preds2, y_test) 337 | assert err != err2 338 | 339 | 340 | @gen_cluster(client=True, timeout=None) 341 | def test_basic(c, s, a, b): 342 | dtrain = xgb.DMatrix(df, label=labels) 343 | bst = xgb.train(param, dtrain) 344 | 345 | ddf = dd.from_pandas(df, npartitions=4) 346 | dlabels = dd.from_pandas(labels, npartitions=4) 347 | dbst = yield dxgb.train(c, param, ddf, dlabels) 348 | dbst = yield dxgb.train(c, param, ddf, dlabels) # we can do this twice 349 | 350 | result = bst.predict(dtrain) 351 | dresult = dbst.predict(dtrain) 352 | 353 | correct = (result > 0.5) == labels 354 | dcorrect = (dresult > 0.5) == labels 355 | assert dcorrect.sum() >= correct.sum() 356 | 357 | predictions = dxgb.predict(c, dbst, ddf) 358 | assert isinstance(predictions, da.Array) 359 | predictions = yield c.compute(predictions)._result() 360 | assert isinstance(predictions, np.ndarray) 361 | 362 | assert ((predictions > 0.5) != labels).sum() < 2 363 | 364 | 365 | @gen_cluster(client=True, timeout=None) 366 | def test_dmatrix_kwargs(c, s, a, b): 367 | xgb.rabit.init() # workaround for "Doing rabit call after Finalize" 368 | dX = da.from_array(X, chunks=(2, 2)) 369 | dy = da.from_array(y, chunks=(2,)) 370 | dbst = yield dxgb.train(c, param, dX, dy, dmatrix_kwargs={"missing": 0.0}) 371 | 372 | # Distributed model matches local model with dmatrix kwargs 373 | dtrain = xgb.DMatrix(X, label=y, missing=0.0) 374 | bst = xgb.train(param, dtrain) 375 | result = bst.predict(dtrain) 376 | dresult = dbst.predict(dtrain) 377 | assert np.abs(result - dresult).sum() < 0.02 378 | 379 | # Distributed model gives bad predictions without dmatrix kwargs 380 | dtrain_incompat = xgb.DMatrix(X, label=y) 381 | dresult_incompat = dbst.predict(dtrain_incompat) 382 | assert np.abs(result - dresult_incompat).sum() > 0.02 383 | 384 | 385 | def _test_container(dbst, predictions, X_type): 386 | dtrain = xgb.DMatrix(X_type(X), label=y) 387 | bst = xgb.train(param, dtrain) 388 | 389 | result = bst.predict(dtrain) 390 | dresult = dbst.predict(dtrain) 391 | 392 | correct = (result > 0.5) == y 393 | dcorrect = (dresult > 0.5) == y 394 | 395 | assert dcorrect.sum() >= correct.sum() 396 | assert isinstance(predictions, np.ndarray) 397 | assert ((predictions > 0.5) != labels).sum() < 2 398 | 399 | 400 | @gen_cluster(client=True, timeout=None) 401 | def test_numpy(c, s, a, b): 402 | xgb.rabit.init() # workaround for "Doing rabit call after Finalize" 403 | dX = da.from_array(X, chunks=(2, 2)) 404 | dy = da.from_array(y, chunks=(2,)) 405 | dbst = yield dxgb.train(c, param, dX, dy) 406 | dbst = yield dxgb.train(c, param, dX, dy) # we can do this twice 407 | 408 | predictions = dxgb.predict(c, dbst, dX) 409 | assert isinstance(predictions, da.Array) 410 | predictions = yield c.compute(predictions) 411 | _test_container(dbst, predictions, np.array) 412 | 413 | 414 | @gen_cluster(client=True, timeout=None) 415 | def test_scipy_sparse(c, s, a, b): 416 | xgb.rabit.init() # workaround for "Doing rabit call after Finalize" 417 | dX = da.from_array(X, chunks=(2, 2)).map_blocks(scipy.sparse.csr_matrix) 418 | dy = da.from_array(y, chunks=(2,)) 419 | dbst = yield dxgb.train(c, param, dX, dy) 420 | dbst = yield dxgb.train(c, param, dX, dy) # we can do this twice 421 | 422 | predictions = dxgb.predict(c, dbst, dX) 423 | assert isinstance(predictions, da.Array) 424 | 425 | predictions_result = yield c.compute(predictions) 426 | _test_container(dbst, predictions_result, scipy.sparse.csr_matrix) 427 | 428 | 429 | @gen_cluster(client=True, timeout=None) 430 | def test_sparse(c, s, a, b): 431 | xgb.rabit.init() # workaround for "Doing rabit call after Finalize" 432 | dX = da.from_array(X, chunks=(2, 2)).map_blocks(scipy.sparse.csr_matrix) 433 | dy = da.from_array(y, chunks=(2,)) 434 | dbst = yield dxgb.train(c, param, dX, dy) 435 | dbst = yield dxgb.train(c, param, dX, dy) # we can do this twice 436 | 437 | predictions = dxgb.predict(c, dbst, dX) 438 | assert isinstance(predictions, da.Array) 439 | 440 | predictions_result = yield c.compute(predictions) 441 | _test_container(dbst, predictions_result, scipy.sparse.csr_matrix) 442 | 443 | 444 | def test_synchronous_api(loop): # noqa 445 | dtrain = xgb.DMatrix(df, label=labels) 446 | bst = xgb.train(param, dtrain) 447 | 448 | ddf = dd.from_pandas(df, npartitions=4) 449 | dlabels = dd.from_pandas(labels, npartitions=4) 450 | 451 | with cluster() as (s, [a, b]): 452 | with Client(s["address"], loop=loop) as c: 453 | 454 | dbst = dxgb.train(c, param, ddf, dlabels) 455 | 456 | result = bst.predict(dtrain) 457 | dresult = dbst.predict(dtrain) 458 | 459 | correct = (result > 0.5) == labels 460 | dcorrect = (dresult > 0.5) == labels 461 | assert dcorrect.sum() >= correct.sum() 462 | 463 | 464 | @gen_cluster(client=True, timeout=None) 465 | def test_errors(c, s, a, b): 466 | def f(part): 467 | raise Exception("foo") 468 | 469 | df = dd.demo.make_timeseries() 470 | df = df.map_partitions(f, meta=df._meta) 471 | 472 | with pytest.raises(Exception) as info: 473 | yield dxgb.train(c, param, df, df.x) 474 | 475 | assert "foo" in str(info.value) 476 | 477 | 478 | @gen_cluster(client=True, timeout=None) 479 | @pytest.mark.asyncio 480 | async def test_predict_proba(c, s, a, b): 481 | X = da.random.random((50, 2), chunks=25) 482 | y = da.random.randint(0, 2, size=50, chunks=25) 483 | X_ = await c.compute(X) 484 | 485 | # array 486 | clf = dxgb.XGBClassifier() 487 | clf.fit(X, y, classes=[0, 1]) 488 | booster = await clf._Booster 489 | 490 | result = clf.predict_proba(X_) 491 | expected = booster.predict(xgb.DMatrix(X_)) 492 | np.testing.assert_array_equal(result, expected) 493 | 494 | # dataframe 495 | XX = dd.from_dask_array(X, columns=["A", "B"]) 496 | yy = dd.from_dask_array(y) 497 | XX_ = await c.compute(XX) 498 | 499 | clf = dxgb.XGBClassifier() 500 | clf.fit(XX, yy, classes=[0, 1]) 501 | booster = await clf._Booster 502 | 503 | result = clf.predict_proba(XX_) 504 | expected = booster.predict(xgb.DMatrix(XX_)) 505 | np.testing.assert_array_equal(result, expected) 506 | 507 | 508 | def test_regressor_evals_result(loop): # noqa 509 | with cluster() as (s, [a, b]): 510 | with Client(s["address"], loop=loop): 511 | a = dxgb.XGBRegressor() 512 | X2 = da.from_array(X, 5) 513 | y2 = da.from_array(y, 5) 514 | a.fit(X2, y2, eval_metric="rmse", eval_set=[(X, y)]) 515 | evals_result = a.evals_result() 516 | 517 | b = xgb.XGBRegressor() 518 | b.fit(X, y, eval_metric="rmse", eval_set=[(X, y)]) 519 | assert_eq(evals_result, b.evals_result()) 520 | 521 | 522 | def test_classifier_evals_result(loop): # noqa 523 | with cluster() as (s, [a, b]): 524 | with Client(s["address"], loop=loop): 525 | a = dxgb.XGBClassifier() 526 | X2 = da.from_array(X, 5) 527 | y2 = da.from_array(y, 5) 528 | a.fit(X2, y2, eval_metric="rmse", eval_set=[(X, y)]) 529 | evals_result = a.evals_result() 530 | 531 | b = xgb.XGBClassifier() 532 | b.fit(X, y, eval_metric="rmse", eval_set=[(X, y)]) 533 | assert_eq(evals_result, b.evals_result()) 534 | 535 | 536 | @gen_cluster(client=True, timeout=None) 537 | def test_eval_set_dask_collection_exception(c, s, a, b): 538 | ddf = dd.from_pandas(df, npartitions=4) 539 | dlabels = dd.from_pandas(labels, npartitions=4) 540 | 541 | X2 = da.from_array(X, 5) 542 | y2 = da.from_array(y, 5) 543 | 544 | with pytest.raises(TypeError) as info: 545 | yield dxgb.train(c, param, ddf, dlabels, eval_set=[(X2, y2)]) 546 | 547 | assert "Evaluation set must not contain dask collections." in str(info.value) 548 | 549 | 550 | @gen_cluster(client=True, timeout=None) 551 | def test_sample_weight_eval_set_dask_collection_exception(c, s, a, b): 552 | ddf = dd.from_pandas(df, npartitions=4) 553 | dlabels = dd.from_pandas(labels, npartitions=4) 554 | 555 | X2 = da.from_array(X, 5) 556 | y2 = da.from_array(y, 5) 557 | 558 | with pytest.raises(TypeError) as info: 559 | yield dxgb.train(c, param, ddf, dlabels, sample_weight_eval_set=[(X2, y2)]) 560 | 561 | assert "Sample weight evaluation set must not contain dask collections." in str( 562 | info.value 563 | ) 564 | -------------------------------------------------------------------------------- /dask_xgboost/tracker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tracker script for DMLC 3 | Implements the tracker control protocol 4 | - start dmlc jobs 5 | - start ps scheduler and rabit tracker 6 | - help nodes to establish links with each other 7 | 8 | Tianqi Chen 9 | 10 | Notes from Matthew Rocklin 11 | -------------------------- 12 | 13 | This was taken from 14 | https://github.com/dmlc/dmlc-core/blob/master/tracker/dmlc_tracker/tracker.py 15 | See LICENSE here 16 | https://github.com/dmlc/dmlc-core/blob/master/LICENSE 17 | 18 | No code modified or added except for this explanatory comment. 19 | """ 20 | # pylint: disable=invalid-name, missing-docstring, too-many-arguments 21 | # pylint: disable=too-many-locals 22 | # pylint: disable=too-many-branches, too-many-statements 23 | from __future__ import absolute_import 24 | 25 | import argparse 26 | import logging 27 | import os 28 | import socket 29 | import struct 30 | import subprocess 31 | import sys 32 | import time 33 | from threading import Thread 34 | 35 | 36 | class ExSocket(object): 37 | """ 38 | Extension of socket to handle recv and send of special data 39 | """ 40 | 41 | def __init__(self, sock): 42 | self.sock = sock 43 | 44 | def recvall(self, nbytes): 45 | res = [] 46 | nread = 0 47 | while nread < nbytes: 48 | chunk = self.sock.recv(min(nbytes - nread, 1024)) 49 | nread += len(chunk) 50 | res.append(chunk) 51 | return b"".join(res) 52 | 53 | def recvint(self): 54 | return struct.unpack("@i", self.recvall(4))[0] 55 | 56 | def sendint(self, n): 57 | self.sock.sendall(struct.pack("@i", n)) 58 | 59 | def sendstr(self, s): 60 | self.sendint(len(s)) 61 | self.sock.sendall(s.encode()) 62 | 63 | def recvstr(self): 64 | slen = self.recvint() 65 | return self.recvall(slen).decode() 66 | 67 | 68 | # magic number used to verify existence of data 69 | kMagic = 0xFF99 70 | 71 | 72 | def get_some_ip(host): 73 | return socket.getaddrinfo(host, None)[0][4][0] 74 | 75 | 76 | def get_family(addr): 77 | return socket.getaddrinfo(addr, None)[0][0] 78 | 79 | 80 | class SlaveEntry(object): 81 | def __init__(self, sock, s_addr): 82 | slave = ExSocket(sock) 83 | self.sock = slave 84 | self.host = get_some_ip(s_addr[0]) 85 | magic = slave.recvint() 86 | assert magic == kMagic, "invalid magic number=%d from %s" % (magic, self.host,) 87 | slave.sendint(kMagic) 88 | self.rank = slave.recvint() 89 | self.world_size = slave.recvint() 90 | self.jobid = slave.recvstr() 91 | self.cmd = slave.recvstr() 92 | self.wait_accept = 0 93 | self.port = None 94 | 95 | def decide_rank(self, job_map): 96 | if self.rank >= 0: 97 | return self.rank 98 | if self.jobid != "NULL" and self.jobid in job_map: 99 | return job_map[self.jobid] 100 | return -1 101 | 102 | def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map): 103 | self.rank = rank 104 | nnset = set(tree_map[rank]) 105 | rprev, rnext = ring_map[rank] 106 | self.sock.sendint(rank) 107 | # send parent rank 108 | self.sock.sendint(parent_map[rank]) 109 | # send world size 110 | self.sock.sendint(len(tree_map)) 111 | self.sock.sendint(len(nnset)) 112 | # send the rprev and next link 113 | for r in nnset: 114 | self.sock.sendint(r) 115 | # send prev link 116 | if rprev != -1 and rprev != rank: 117 | nnset.add(rprev) 118 | self.sock.sendint(rprev) 119 | else: 120 | self.sock.sendint(-1) 121 | # send next link 122 | if rnext != -1 and rnext != rank: 123 | nnset.add(rnext) 124 | self.sock.sendint(rnext) 125 | else: 126 | self.sock.sendint(-1) 127 | while True: 128 | ngood = self.sock.recvint() 129 | goodset = set([]) 130 | for _ in range(ngood): 131 | goodset.add(self.sock.recvint()) 132 | assert goodset.issubset(nnset) 133 | badset = nnset - goodset 134 | conset = [] 135 | for r in badset: 136 | if r in wait_conn: 137 | conset.append(r) 138 | self.sock.sendint(len(conset)) 139 | self.sock.sendint(len(badset) - len(conset)) 140 | for r in conset: 141 | self.sock.sendstr(wait_conn[r].host) 142 | self.sock.sendint(wait_conn[r].port) 143 | self.sock.sendint(r) 144 | nerr = self.sock.recvint() 145 | if nerr != 0: 146 | continue 147 | self.port = self.sock.recvint() 148 | rmset = [] 149 | # all connection was successuly setup 150 | for r in conset: 151 | wait_conn[r].wait_accept -= 1 152 | if wait_conn[r].wait_accept == 0: 153 | rmset.append(r) 154 | for r in rmset: 155 | wait_conn.pop(r, None) 156 | self.wait_accept = len(badset) - len(conset) 157 | return rmset 158 | 159 | 160 | class RabitTracker(object): 161 | """ 162 | tracker for rabit 163 | """ 164 | 165 | def __init__(self, hostIP, nslave, port=9091, port_end=9999): 166 | sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM) 167 | for port in range(port, port_end): 168 | try: 169 | logging.info("Binding Rabit tracker %s:%d", hostIP, port) 170 | sock.bind((hostIP, port)) 171 | self.port = port 172 | break 173 | except socket.error as e: 174 | if e.errno in [98, 48]: 175 | continue 176 | else: 177 | logging.error(e, exc_info=True) 178 | raise 179 | sock.listen(256) 180 | self.sock = sock 181 | self.hostIP = hostIP 182 | self.thread = None 183 | self.start_time = None 184 | self.end_time = None 185 | self.nslave = nslave 186 | logging.info("start listen on %s:%d", hostIP, self.port) 187 | 188 | def __del__(self): 189 | self.sock.close() 190 | 191 | @staticmethod 192 | def get_neighbor(rank, nslave): 193 | rank = rank + 1 194 | ret = [] 195 | if rank > 1: 196 | ret.append(rank // 2 - 1) 197 | if rank * 2 - 1 < nslave: 198 | ret.append(rank * 2 - 1) 199 | if rank * 2 < nslave: 200 | ret.append(rank * 2) 201 | return ret 202 | 203 | def slave_envs(self): 204 | """ 205 | get enviroment variables for slaves 206 | can be passed in as args or envs 207 | """ 208 | return { 209 | "DMLC_TRACKER_URI": self.hostIP, 210 | "DMLC_TRACKER_PORT": self.port, 211 | } 212 | 213 | def get_tree(self, nslave): 214 | tree_map = {} 215 | parent_map = {} 216 | for r in range(nslave): 217 | tree_map[r] = self.get_neighbor(r, nslave) 218 | parent_map[r] = (r + 1) // 2 - 1 219 | return tree_map, parent_map 220 | 221 | def find_share_ring(self, tree_map, parent_map, r): 222 | """ 223 | get a ring structure that tends to share nodes with the tree 224 | return a list starting from r 225 | """ 226 | nset = set(tree_map[r]) 227 | cset = nset - set([parent_map[r]]) 228 | if len(cset) == 0: 229 | return [r] 230 | rlst = [r] 231 | cnt = 0 232 | for v in cset: 233 | vlst = self.find_share_ring(tree_map, parent_map, v) 234 | cnt += 1 235 | if cnt == len(cset): 236 | vlst.reverse() 237 | rlst += vlst 238 | return rlst 239 | 240 | def get_ring(self, tree_map, parent_map): 241 | """ 242 | get a ring connection used to recover local data 243 | """ 244 | assert parent_map[0] == -1 245 | rlst = self.find_share_ring(tree_map, parent_map, 0) 246 | assert len(rlst) == len(tree_map) 247 | ring_map = {} 248 | nslave = len(tree_map) 249 | for r in range(nslave): 250 | rprev = (r + nslave - 1) % nslave 251 | rnext = (r + 1) % nslave 252 | ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) 253 | return ring_map 254 | 255 | def get_link_map(self, nslave): 256 | """ 257 | get the link map, this is a bit hacky, call for better algorithm 258 | to place similar nodes together 259 | """ 260 | tree_map, parent_map = self.get_tree(nslave) 261 | ring_map = self.get_ring(tree_map, parent_map) 262 | rmap = {0: 0} 263 | k = 0 264 | for i in range(nslave - 1): 265 | k = ring_map[k][1] 266 | rmap[k] = i + 1 267 | 268 | ring_map_ = {} 269 | tree_map_ = {} 270 | parent_map_ = {} 271 | for k, v in ring_map.items(): 272 | ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]]) 273 | for k, v in tree_map.items(): 274 | tree_map_[rmap[k]] = [rmap[x] for x in v] 275 | for k, v in parent_map.items(): 276 | if k != 0: 277 | parent_map_[rmap[k]] = rmap[v] 278 | else: 279 | parent_map_[rmap[k]] = -1 280 | return tree_map_, parent_map_, ring_map_ 281 | 282 | def accept_slaves(self, nslave): 283 | # set of nodes that finishs the job 284 | shutdown = {} 285 | # set of nodes that is waiting for connections 286 | wait_conn = {} 287 | # maps job id to rank 288 | job_map = {} 289 | # list of workers that is pending to be assigned rank 290 | pending = [] 291 | # lazy initialize tree_map 292 | tree_map = None 293 | 294 | while len(shutdown) != nslave: 295 | fd, s_addr = self.sock.accept() 296 | s = SlaveEntry(fd, s_addr) 297 | if s.cmd == "print": 298 | msg = s.sock.recvstr() 299 | logging.info(msg.strip()) 300 | continue 301 | if s.cmd == "shutdown": 302 | assert s.rank >= 0 and s.rank not in shutdown 303 | assert s.rank not in wait_conn 304 | shutdown[s.rank] = s 305 | logging.debug("Recieve %s signal from %d", s.cmd, s.rank) 306 | continue 307 | assert s.cmd == "start" or s.cmd == "recover" 308 | # lazily initialize the slaves 309 | if tree_map is None: 310 | assert s.cmd == "start" 311 | if s.world_size > 0: 312 | nslave = s.world_size 313 | tree_map, parent_map, ring_map = self.get_link_map(nslave) 314 | # set of nodes that is pending for getting up 315 | todo_nodes = list(range(nslave)) 316 | else: 317 | assert s.world_size == -1 or s.world_size == nslave 318 | if s.cmd == "recover": 319 | assert s.rank >= 0 320 | 321 | rank = s.decide_rank(job_map) 322 | # batch assignment of ranks 323 | if rank == -1: 324 | assert len(todo_nodes) != 0 325 | pending.append(s) 326 | if len(pending) == len(todo_nodes): 327 | pending.sort(key=lambda x: x.host) 328 | for s in pending: 329 | rank = todo_nodes.pop(0) 330 | if s.jobid != "NULL": 331 | job_map[s.jobid] = rank 332 | s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) 333 | if s.wait_accept > 0: 334 | wait_conn[rank] = s 335 | logging.debug( 336 | "Recieve %s signal from %s; " "assign rank %d", 337 | s.cmd, 338 | s.host, 339 | s.rank, 340 | ) 341 | if len(todo_nodes) == 0: 342 | logging.info("@tracker All of %d nodes getting started", nslave) 343 | self.start_time = time.time() 344 | else: 345 | s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) 346 | logging.debug("Recieve %s signal from %d", s.cmd, s.rank) 347 | if s.wait_accept > 0: 348 | wait_conn[rank] = s 349 | logging.info("@tracker All nodes finishes job") 350 | self.end_time = time.time() 351 | logging.info( 352 | "@tracker %s secs between node start and job finish", 353 | str(self.end_time - self.start_time), 354 | ) 355 | 356 | def start(self, nslave): 357 | def run(): 358 | self.accept_slaves(nslave) 359 | 360 | self.thread = Thread(target=run, args=()) 361 | self.thread.setDaemon(True) 362 | self.thread.start() 363 | 364 | def join(self): 365 | while self.thread.isAlive(): 366 | self.thread.join(100) 367 | 368 | 369 | class PSTracker(object): 370 | """ 371 | Tracker module for PS 372 | """ 373 | 374 | def __init__(self, hostIP, cmd, port=9091, port_end=9999, envs=None): 375 | """ 376 | Starts the PS scheduler 377 | """ 378 | self.cmd = cmd 379 | if cmd is None: 380 | return 381 | envs = {} if envs is None else envs 382 | self.hostIP = hostIP 383 | sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM) 384 | for port in range(port, port_end): 385 | try: 386 | sock.bind(("", port)) 387 | self.port = port 388 | sock.close() 389 | break 390 | except socket.error: 391 | continue 392 | env = os.environ.copy() 393 | 394 | env["DMLC_ROLE"] = "scheduler" 395 | env["DMLC_PS_ROOT_URI"] = str(self.hostIP) 396 | env["DMLC_PS_ROOT_PORT"] = str(self.port) 397 | for k, v in envs.items(): 398 | env[k] = str(v) 399 | self.thread = Thread( 400 | target=(lambda: subprocess.check_call(self.cmd, env=env, shell=True)), 401 | args=(), 402 | ) 403 | self.thread.setDaemon(True) 404 | self.thread.start() 405 | 406 | def join(self): 407 | if self.cmd is not None: 408 | while self.thread.isAlive(): 409 | self.thread.join(100) 410 | 411 | def slave_envs(self): 412 | if self.cmd is None: 413 | return {} 414 | else: 415 | return { 416 | "DMLC_PS_ROOT_URI": self.hostIP, 417 | "DMLC_PS_ROOT_PORT": self.port, 418 | } 419 | 420 | 421 | def get_host_ip(hostIP=None): 422 | if hostIP is None or hostIP == "auto": 423 | hostIP = "ip" 424 | 425 | if hostIP == "dns": 426 | hostIP = socket.getfqdn() 427 | elif hostIP == "ip": 428 | from socket import gaierror 429 | 430 | try: 431 | hostIP = socket.gethostbyname(socket.getfqdn()) 432 | except gaierror: 433 | logging.warn( 434 | "gethostbyname(socket.getfqdn()) failed... trying on " "hostname()" 435 | ) 436 | hostIP = socket.gethostbyname(socket.gethostname()) 437 | if hostIP.startswith("127."): 438 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 439 | # doesn't have to be reachable 440 | s.connect(("10.255.255.255", 1)) 441 | hostIP = s.getsockname()[0] 442 | return hostIP 443 | 444 | 445 | def submit(nworker, nserver, fun_submit, hostIP="auto", pscmd=None): 446 | if nserver == 0: 447 | pscmd = None 448 | 449 | envs = {"DMLC_NUM_WORKER": nworker, "DMLC_NUM_SERVER": nserver} 450 | hostIP = get_host_ip(hostIP) 451 | 452 | if nserver == 0: 453 | rabit = RabitTracker(hostIP=hostIP, nslave=nworker) 454 | envs.update(rabit.slave_envs()) 455 | rabit.start(nworker) 456 | else: 457 | pserver = PSTracker(hostIP=hostIP, cmd=pscmd, envs=envs) 458 | envs.update(pserver.slave_envs()) 459 | fun_submit(nworker, nserver, envs) 460 | 461 | if nserver == 0: 462 | rabit.join() 463 | else: 464 | pserver.join() 465 | 466 | 467 | def start_rabit_tracker(args): 468 | """Standalone function to start rabit tracker. 469 | 470 | Parameters 471 | ---------- 472 | args: arguments to start the rabit tracker. 473 | """ 474 | envs = { 475 | "DMLC_NUM_WORKER": args.num_workers, 476 | "DMLC_NUM_SERVER": args.num_servers, 477 | } 478 | rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), nslave=args.num_workers) 479 | envs.update(rabit.slave_envs()) 480 | rabit.start(args.num_workers) 481 | sys.stdout.write("DMLC_TRACKER_ENV_START\n") 482 | # simply write configuration to stdout 483 | for k, v in envs.items(): 484 | sys.stdout.write("%s=%s\n" % (k, str(v))) 485 | sys.stdout.write("DMLC_TRACKER_ENV_END\n") 486 | sys.stdout.flush() 487 | rabit.join() 488 | 489 | 490 | def main(): 491 | """Main function if tracker is executed in standalone mode.""" 492 | parser = argparse.ArgumentParser(description="Rabit Tracker start.") 493 | parser.add_argument( 494 | "--num-workers", 495 | required=True, 496 | type=int, 497 | help="Number of worker proccess to be launched.", 498 | ) 499 | parser.add_argument( 500 | "--num-servers", 501 | default=0, 502 | type=int, 503 | help="Number of server process to be launched. Only " "used in PS jobs.", 504 | ) 505 | parser.add_argument( 506 | "--host-ip", 507 | default=None, 508 | type=str, 509 | help=( 510 | "Host IP addressed, this is only needed " 511 | + "if the host IP cannot be automatically guessed." 512 | ), 513 | ) 514 | parser.add_argument( 515 | "--log-level", 516 | default="INFO", 517 | type=str, 518 | choices=["INFO", "DEBUG"], 519 | help="Logging level of the logger.", 520 | ) 521 | args = parser.parse_args() 522 | 523 | fmt = "%(asctime)s %(levelname)s %(message)s" 524 | if args.log_level == "INFO": 525 | level = logging.INFO 526 | elif args.log_level == "DEBUG": 527 | level = logging.DEBUG 528 | else: 529 | raise RuntimeError("Unknown logging level %s" % args.log_level) 530 | 531 | logging.basicConfig(format=fmt, level=level) 532 | 533 | if args.num_servers == 0: 534 | start_rabit_tracker(args) 535 | else: 536 | raise RuntimeError("Do not yet support start ps tracker in standalone " "mode.") 537 | 538 | 539 | if __name__ == "__main__": 540 | main() 541 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | xgboost <=0.90 2 | dask 3 | distributed >= 1.15.2 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [flake8] 5 | exclude = tests/data,docs,benchmarks,scripts,.tox,env,.eggs,build 6 | max-line-length = 88 7 | ignore = 8 | # Assigning lambda expression 9 | E731 10 | # Ambiguous variable names 11 | E741 12 | # line break before binary operator 13 | W503 14 | # whitespace before : 15 | E203 16 | # whitespace after ',' 17 | E231 18 | 19 | [tool:pytest] 20 | addopts = -rsx -v 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | from setuptools import setup 5 | 6 | requires = open("requirements.txt").read().strip().split("\n") 7 | install_requires = [] 8 | extras_require = {"sparse": ["sparse", "scipy"]} 9 | for r in requires: 10 | if ";" in r: 11 | # requirements.txt conditional dependencies need to be reformatted for 12 | # wheels to the form: `'[extra_name]:condition' : ['requirements']` 13 | req, cond = r.split(";", 1) 14 | cond = ":" + cond 15 | cond_reqs = extras_require.setdefault(cond, []) 16 | cond_reqs.append(req) 17 | else: 18 | install_requires.append(r) 19 | 20 | setup( 21 | name="dask-xgboost", 22 | version="0.2.0", 23 | description="Interactions between Dask and XGBoost", 24 | maintainer="Matthew Rocklin", 25 | maintainer_email="mrocklin@continuum.io", 26 | url="https://github.com/dask/dask-xgboost", 27 | license="BSD", 28 | install_requires=install_requires, 29 | extras_require=extras_require, 30 | packages=["dask_xgboost"], 31 | long_description=( 32 | open("README.rst").read() if os.path.exists("README.rst") else "" 33 | ), 34 | zip_safe=False, 35 | ) 36 | --------------------------------------------------------------------------------