├── tests ├── __init__.py ├── conftest.py ├── test_eventloop.py ├── test_ipc.py ├── test_trial.py ├── test_messages.py └── test_managers.py ├── optuna_distributed ├── py.typed ├── __init__.py ├── ipc │ ├── __init__.py │ ├── base.py │ ├── pipe.py │ └── queue.py ├── managers │ ├── __init__.py │ ├── base.py │ ├── local.py │ └── distributed.py ├── messages │ ├── heartbeat.py │ ├── shouldprune.py │ ├── report.py │ ├── __init__.py │ ├── response.py │ ├── property.py │ ├── base.py │ ├── pruned.py │ ├── setattr.py │ ├── failed.py │ ├── suggest.py │ └── completed.py ├── config.py ├── terminal.py ├── eventloop.py ├── trial.py └── study.py ├── setup.cfg ├── .github ├── ISSUE_TEMPLATE │ ├── general-question.md │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── format.yaml │ ├── release.yaml │ └── test.yaml ├── examples ├── visualization.py ├── disable_logging.py ├── quadratic_simple.py ├── simple_storages.py └── simple_pruning.py ├── LICENSE ├── pyproject.toml ├── README.md └── .gitignore /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /optuna_distributed/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /optuna_distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from optuna_distributed import config 2 | from optuna_distributed.study import from_study 3 | 4 | 5 | __version__ = "0.7.0" 6 | __all__ = ["from_study", "config"] 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = D203, E701, E704 3 | exclude = 4 | .git, 5 | __pycache__, 6 | venv, 7 | build, 8 | dist 9 | max-complexity = 10 10 | max-line-length = 99 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/general-question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: General question 3 | about: Have a general question about Optuna-distributed? Feel free to ask 4 | title: '' 5 | labels: question 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /optuna_distributed/ipc/__init__.py: -------------------------------------------------------------------------------- 1 | from optuna_distributed.ipc.base import IPCPrimitive 2 | from optuna_distributed.ipc.pipe import Pipe 3 | from optuna_distributed.ipc.queue import Queue 4 | 5 | 6 | __all__ = ["IPCPrimitive", "Queue", "Pipe"] 7 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior. 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Additional context** 20 | Add any other context about the problem here. 21 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from dask.distributed import Client 2 | from dask.distributed import LocalCluster 3 | import optuna 4 | import pytest 5 | 6 | 7 | _test_cluster = LocalCluster(n_workers=1, threads_per_worker=1) 8 | _test_client = Client(_test_cluster.scheduler_address) 9 | 10 | 11 | @pytest.fixture 12 | def client() -> Client: 13 | return _test_client 14 | 15 | 16 | @pytest.fixture 17 | def study() -> optuna.Study: 18 | study = optuna.create_study() 19 | study.ask() 20 | return study 21 | -------------------------------------------------------------------------------- /optuna_distributed/managers/__init__.py: -------------------------------------------------------------------------------- 1 | from optuna_distributed.managers.base import DistributableFuncType 2 | from optuna_distributed.managers.base import ObjectiveFuncType 3 | from optuna_distributed.managers.base import OptimizationManager 4 | from optuna_distributed.managers.distributed import DistributedOptimizationManager 5 | from optuna_distributed.managers.local import LocalOptimizationManager 6 | 7 | 8 | __all__ = [ 9 | "OptimizationManager", 10 | "DistributedOptimizationManager", 11 | "LocalOptimizationManager", 12 | "DistributableFuncType", 13 | "ObjectiveFuncType", 14 | ] 15 | -------------------------------------------------------------------------------- /optuna_distributed/messages/heartbeat.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from optuna.study import Study 4 | 5 | from optuna_distributed.messages.base import Message 6 | 7 | 8 | if TYPE_CHECKING: 9 | from optuna_distributed.managers.base import OptimizationManager 10 | 11 | 12 | class HeartbeatMessage(Message): 13 | """A heartbeat message. 14 | 15 | Heartbeat messages do not carry any code or data. Their purpose 16 | is to generate some traffic on communication channels to confirm 17 | some things are still alive and well. 18 | """ 19 | 20 | closing = False 21 | 22 | def process(self, study: Study, manager: "OptimizationManager") -> None: ... 23 | -------------------------------------------------------------------------------- /examples/visualization.py: -------------------------------------------------------------------------------- 1 | import optuna 2 | 3 | import optuna_distributed 4 | 5 | 6 | def objective(trial): 7 | x = trial.suggest_float("x", -100, 100) 8 | y = trial.suggest_categorical("y", [-1, 0, 1]) 9 | return x**2 + y 10 | 11 | 12 | sampler = optuna.samplers.TPESampler(seed=10) 13 | distributed_study = optuna_distributed.from_study(optuna.create_study(sampler=sampler)) 14 | distributed_study.optimize(objective, n_trials=30) 15 | 16 | # Any plotting function from optuna.visualization module can be used with Optuna-distributed 17 | # thanks to .into_study() convenience method. 18 | optuna.visualization.plot_contour(distributed_study.into_study(), params=["x", "y"]).show() 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/format.yaml: -------------------------------------------------------------------------------- 1 | name: Code formatting 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | format: 13 | name: code-formatting 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: '3.10' 23 | 24 | - name: Install Python dependencies 25 | run: | 26 | pip install -U . .[dev] .[test] 27 | 28 | - name: Black 29 | run: black . --check --diff 30 | 31 | - name: Flake8 32 | run: flake8 . 33 | 34 | - name: Isort 35 | run: isort . --check --diff 36 | 37 | - name: Mypy 38 | run: mypy . 39 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: PyPI release 2 | 3 | on: 4 | release: 5 | types: 6 | - published 7 | 8 | jobs: 9 | release: 10 | name: PyPI-release 11 | if: github.repository == 'xadrianzetx/optuna-distributed' 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: '3.10' 21 | 22 | - name: Install Python dependencies 23 | run: pip install -U build 24 | 25 | - name: Build a Python distribution 26 | run: python -m build 27 | 28 | - name: Publish a Python distribution to PyPI 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | with: 31 | user: __token__ 32 | password: ${{ secrets.PYPI_API_TOKEN }} 33 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Unit tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | test: 13 | name: unit-tests 14 | runs-on: ${{ matrix.os }} 15 | 16 | strategy: 17 | matrix: 18 | os: 19 | - 'ubuntu-latest' 20 | - 'windows-latest' 21 | - 'macos-latest' 22 | python-version: 23 | - '3.8' 24 | - '3.9' 25 | - '3.10' 26 | - '3.11' 27 | - '3.12' 28 | 29 | steps: 30 | - uses: actions/checkout@v3 31 | 32 | - name: Set up Python ${{ matrix.python-version }} 33 | uses: actions/setup-python@v4 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | 37 | - name: Install Python dependencies 38 | run: | 39 | pip install -U . .[test] 40 | 41 | - name: Run tests 42 | run: pytest -v 43 | -------------------------------------------------------------------------------- /examples/disable_logging.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | import optuna 5 | 6 | import optuna_distributed 7 | 8 | 9 | def objective(trial): 10 | x = trial.suggest_float("x", -100, 100) 11 | y = trial.suggest_categorical("y", [-1, 0, 1]) 12 | time.sleep(random.uniform(0.0, 2.0)) 13 | return x**2 + y 14 | 15 | 16 | if __name__ == "__main__": 17 | optuna.logging.disable_default_handler() 18 | optuna_distributed.config.disable_logging() 19 | 20 | study = optuna_distributed.from_study(optuna.create_study()) 21 | print("Running 10 trials without logging...") 22 | study.optimize(objective, n_trials=10) 23 | print(f"Best value: {study.best_value} (params: {study.best_params})") 24 | 25 | optuna_distributed.config.enable_logging() 26 | print("Running 10 more trials with logging...") 27 | study.optimize(objective, n_trials=10) 28 | print(f"Best value: {study.best_value} (params: {study.best_params})") 29 | -------------------------------------------------------------------------------- /optuna_distributed/ipc/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from abc import ABC 3 | 4 | from optuna_distributed.messages.base import Message 5 | 6 | 7 | class IPCPrimitive(ABC): 8 | """An inter process communication primitive. 9 | 10 | This interface defines a common way to pass messages 11 | between processes hosted on the same machine or in cluster setups. 12 | """ 13 | 14 | @abc.abstractmethod 15 | def get(self) -> Message: 16 | """Retrieves a single message.""" 17 | raise NotImplementedError 18 | 19 | @abc.abstractmethod 20 | def put(self, message: Message) -> None: 21 | """Publishes a single message. 22 | 23 | Args: 24 | message: 25 | An instance of :class:'~optuna_distributed.messages.Message'. 26 | """ 27 | raise NotImplementedError 28 | 29 | @abc.abstractmethod 30 | def close(self) -> None: 31 | """Closes communication channel.""" 32 | raise NotImplementedError 33 | -------------------------------------------------------------------------------- /optuna_distributed/ipc/pipe.py: -------------------------------------------------------------------------------- 1 | from multiprocessing.connection import Connection 2 | 3 | from optuna_distributed.ipc import IPCPrimitive 4 | from optuna_distributed.messages import Message 5 | 6 | 7 | class Pipe(IPCPrimitive): 8 | """IPC primitive based on multiprocessing Pipe. 9 | 10 | Forms a thin layer of abstraction over one end of multiprocessing 11 | pipe connection, with get/put semantics. 12 | 13 | Args: 14 | connection: 15 | One of the ends of multiprocessing pipe. 16 | https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Pipe 17 | """ 18 | 19 | def __init__(self, connection: Connection) -> None: 20 | self._connection = connection 21 | 22 | def get(self) -> Message: 23 | return self._connection.recv() 24 | 25 | def put(self, message: Message) -> None: 26 | return self._connection.send(message) 27 | 28 | def close(self) -> None: 29 | return self._connection.close() 30 | -------------------------------------------------------------------------------- /optuna_distributed/messages/shouldprune.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from optuna.study import Study 4 | from optuna.trial import Trial 5 | 6 | from optuna_distributed.messages import Message 7 | from optuna_distributed.messages.response import ResponseMessage 8 | 9 | 10 | if TYPE_CHECKING: 11 | from optuna_distributed.managers import OptimizationManager 12 | 13 | 14 | class ShouldPruneMessage(Message): 15 | """A should prune trial message. 16 | 17 | This message is sent by :class:`~optuna_distributed.trial.DistributedTrial` to 18 | main process asking for whether trial should be pruned or not. 19 | 20 | Args: 21 | trial_id: 22 | Id of a trial to which the message is referring. 23 | """ 24 | 25 | closing = False 26 | 27 | def __init__(self, trial_id: int) -> None: 28 | self._trial_id = trial_id 29 | 30 | def process(self, study: Study, manager: "OptimizationManager") -> None: 31 | trial = Trial(study, self._trial_id) 32 | conn = manager.get_connection(self._trial_id) 33 | conn.put(ResponseMessage(self._trial_id, trial.should_prune())) 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Adrian Zuber 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /optuna_distributed/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | 5 | from rich.logging import RichHandler 6 | 7 | 8 | _default_handler: logging.Handler | None = None 9 | 10 | 11 | def _get_library_logger() -> logging.Logger: 12 | library_name = __name__.split(".")[0] 13 | return logging.getLogger(library_name) 14 | 15 | 16 | def _setup_logger() -> None: 17 | global _default_handler 18 | _default_handler = RichHandler(show_path=False) 19 | fmt = logging.Formatter(fmt="%(message)s", datefmt="[%X]") 20 | _default_handler.setFormatter(fmt) 21 | library_root_logger = _get_library_logger() 22 | library_root_logger.addHandler(_default_handler) 23 | library_root_logger.setLevel(logging.INFO) 24 | library_root_logger.propagate = False 25 | 26 | 27 | _setup_logger() 28 | 29 | 30 | def disable_logging() -> None: 31 | """Disables library level logger.""" 32 | assert _default_handler is not None 33 | _get_library_logger().removeHandler(_default_handler) 34 | 35 | 36 | def enable_logging() -> None: 37 | """Enables library level logger.""" 38 | assert _default_handler is not None 39 | _get_library_logger().addHandler(_default_handler) 40 | -------------------------------------------------------------------------------- /optuna_distributed/messages/report.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from optuna.study import Study 4 | from optuna.trial import Trial 5 | 6 | from optuna_distributed.messages import Message 7 | 8 | 9 | if TYPE_CHECKING: 10 | from optuna_distributed.managers import OptimizationManager 11 | 12 | 13 | class ReportMessage(Message): 14 | """Reports trial intermediate values. 15 | 16 | This message is sent by :class:`~optuna_distributed.trial.DistributedTrial` to 17 | main process reporting on intermediate value in trial. 18 | 19 | Args: 20 | trial_id: 21 | Id of a trial to which the message is referring. 22 | value: 23 | An itermediate value returned from the objective function. 24 | step: 25 | Step of the trial (e.g., Epoch of neural network training). 26 | """ 27 | 28 | closing = False 29 | 30 | def __init__(self, trial_id: int, value: float, step: int) -> None: 31 | self._trial_id = trial_id 32 | self._value = value 33 | self._step = step 34 | 35 | def process(self, study: Study, manager: "OptimizationManager") -> None: 36 | trial = Trial(study, self._trial_id) 37 | trial.report(self._value, self._step) 38 | -------------------------------------------------------------------------------- /optuna_distributed/messages/__init__.py: -------------------------------------------------------------------------------- 1 | from optuna_distributed.messages.base import Message 2 | from optuna_distributed.messages.completed import CompletedMessage 3 | from optuna_distributed.messages.failed import FailedMessage 4 | from optuna_distributed.messages.heartbeat import HeartbeatMessage 5 | from optuna_distributed.messages.property import TrialProperty 6 | from optuna_distributed.messages.property import TrialPropertyMessage 7 | from optuna_distributed.messages.pruned import PrunedMessage 8 | from optuna_distributed.messages.report import ReportMessage 9 | from optuna_distributed.messages.response import ResponseMessage 10 | from optuna_distributed.messages.setattr import AttributeType 11 | from optuna_distributed.messages.setattr import SetAttributeMessage 12 | from optuna_distributed.messages.shouldprune import ShouldPruneMessage 13 | from optuna_distributed.messages.suggest import SuggestMessage 14 | 15 | 16 | __all__ = [ 17 | "Message", 18 | "HeartbeatMessage", 19 | "ResponseMessage", 20 | "SuggestMessage", 21 | "CompletedMessage", 22 | "FailedMessage", 23 | "PrunedMessage", 24 | "ReportMessage", 25 | "ShouldPruneMessage", 26 | "SetAttributeMessage", 27 | "AttributeType", 28 | "TrialPropertyMessage", 29 | "TrialProperty", 30 | ] 31 | -------------------------------------------------------------------------------- /optuna_distributed/messages/response.py: -------------------------------------------------------------------------------- 1 | from typing import Generic 2 | from typing import TYPE_CHECKING 3 | from typing import TypeVar 4 | 5 | from optuna.study import Study 6 | 7 | from optuna_distributed.messages import Message 8 | 9 | 10 | if TYPE_CHECKING: 11 | from optuna_distributed.managers import OptimizationManager 12 | 13 | 14 | # FIXME: Would be nice to bound it to pickable interface but there is no such thing at the moment. 15 | # https://stackoverflow.com/questions/50328386/python-typing-pickle-and-serialisation 16 | T = TypeVar("T") 17 | 18 | 19 | class ResponseMessage(Message, Generic[T]): 20 | """A generic message. 21 | 22 | Response messages are used by client to pass data back to workers. 23 | These do not carry any code to execute, and should be used as a wrapper 24 | around data served as response to workers request. 25 | 26 | Args: 27 | trial_id: 28 | Id of a trial to which the message is referring. 29 | data: 30 | Serializable data that should be carried as a part of response. 31 | """ 32 | 33 | closing = False 34 | 35 | def __init__(self, trial_id: int, data: T) -> None: 36 | self.trial_id = trial_id 37 | self.data = data 38 | 39 | def process(self, study: Study, manager: "OptimizationManager") -> None: ... 40 | -------------------------------------------------------------------------------- /optuna_distributed/messages/property.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from typing import TYPE_CHECKING 3 | 4 | from optuna.study import Study 5 | from optuna.trial import Trial 6 | 7 | from optuna_distributed.messages import Message 8 | from optuna_distributed.messages.response import ResponseMessage 9 | 10 | 11 | if TYPE_CHECKING: 12 | from optuna_distributed.managers import OptimizationManager 13 | 14 | 15 | TrialProperty = Literal[ 16 | "params", "distributions", "user_attrs", "system_attrs", "datetime_start", "number" 17 | ] 18 | 19 | 20 | class TrialPropertyMessage(Message): 21 | """Requests one of trial properties. 22 | 23 | Args: 24 | trial_id: 25 | Id of a trial to which the message is referring. 26 | property: 27 | An option from :class:`~optuna_distributed.messages.TrialProperty` enum. 28 | """ 29 | 30 | closing = False 31 | 32 | def __init__(self, trial_id: int, property: TrialProperty) -> None: 33 | self._trial_id = trial_id 34 | self._property = property 35 | 36 | def process(self, study: Study, manager: "OptimizationManager") -> None: 37 | trial = Trial(study, self._trial_id) 38 | conn = manager.get_connection(self._trial_id) 39 | conn.put(ResponseMessage(self._trial_id, getattr(trial, self._property))) 40 | -------------------------------------------------------------------------------- /optuna_distributed/messages/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from abc import ABC 3 | from typing import TYPE_CHECKING 4 | 5 | from optuna.study import Study 6 | 7 | 8 | if TYPE_CHECKING: 9 | from optuna_distributed.managers.base import OptimizationManager 10 | 11 | 12 | class Message(ABC): 13 | """Base class for for IPC messages. 14 | 15 | These messages are used to pass data and code between client and workers. 16 | """ 17 | 18 | @property 19 | @abc.abstractmethod 20 | def closing(self) -> bool: 21 | """Indicates last message generated by particular trial.""" 22 | raise NotImplementedError 23 | 24 | @abc.abstractmethod 25 | def process(self, study: Study, manager: "OptimizationManager") -> None: 26 | """Process a message data with context available in main process. 27 | 28 | Concrete implementations of this method should contain operations that 29 | worker wants to execute using resources available only to the main process. 30 | This means stuff like hyperparameter suggestions, prune commands and general 31 | data passing. 32 | 33 | Args: 34 | study: 35 | An instance of Optuna study. 36 | manager: 37 | :class:`~optuna_distributed.managers.Manager` providing additional 38 | execution context. 39 | """ 40 | raise NotImplementedError 41 | -------------------------------------------------------------------------------- /optuna_distributed/messages/pruned.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import TYPE_CHECKING 3 | 4 | from optuna.exceptions import TrialPruned 5 | from optuna.study import Study 6 | from optuna.trial import Trial 7 | from optuna.trial import TrialState 8 | 9 | from optuna_distributed.messages import Message 10 | 11 | 12 | if TYPE_CHECKING: 13 | from optuna_distributed.managers import OptimizationManager 14 | 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | 19 | class PrunedMessage(Message): 20 | """A pruned trial message. 21 | 22 | This message is sent after :obj:`TrialPruned` exception has been rised 23 | in objective function and tells study to set associated trial to pruned state. 24 | 25 | Args: 26 | trial_id: 27 | Id of a trial to which the message is referring. 28 | exception: 29 | Instance of :obj:`TrialPruned` exception. 30 | """ 31 | 32 | closing = True 33 | 34 | def __init__(self, trial_id: int, exception: TrialPruned) -> None: 35 | self._trial_id = trial_id 36 | self._exception = exception 37 | 38 | def process(self, study: Study, manager: "OptimizationManager") -> None: 39 | trial = Trial(study, self._trial_id) 40 | frozen_trial = study.tell(trial, state=TrialState.PRUNED) 41 | manager.register_trial_exit(self._trial_id) 42 | _logger.info(f"Trial {frozen_trial.number} pruned. {repr(self._exception)}") 43 | -------------------------------------------------------------------------------- /optuna_distributed/messages/setattr.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from typing import Literal 3 | from typing import TYPE_CHECKING 4 | 5 | from optuna.study import Study 6 | 7 | from optuna_distributed.messages import Message 8 | 9 | 10 | if TYPE_CHECKING: 11 | from optuna_distributed.managers import OptimizationManager 12 | 13 | 14 | AttributeType = Literal["user", "system"] 15 | 16 | 17 | class SetAttributeMessage(Message): 18 | """Sets either user or system value on a trial. 19 | 20 | Args: 21 | trial_id: 22 | Id of a trial to which the message is referring. 23 | key: 24 | A key string of the attribute. 25 | value: 26 | A value of the attribute. The value should be able to serialize with pickle. 27 | kind: 28 | An option from :class:`~optuna_distributed.messages.AttributeType` enum. 29 | """ 30 | 31 | closing = False 32 | 33 | def __init__(self, trial_id: int, key: str, value: Any, *, kind: AttributeType) -> None: 34 | self._trial_id = trial_id 35 | self._kind = kind 36 | self._key = key 37 | self._value = value 38 | 39 | def process(self, study: Study, manager: "OptimizationManager") -> None: 40 | if self._kind == "user": 41 | study._storage.set_trial_user_attr(self._trial_id, self._key, self._value) 42 | elif self._kind == "system": 43 | study._storage.set_trial_system_attr(self._trial_id, self._key, self._value) 44 | else: 45 | assert False, "Should not reach." 46 | -------------------------------------------------------------------------------- /optuna_distributed/messages/failed.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | from typing import TYPE_CHECKING 4 | 5 | from optuna.study import Study 6 | from optuna.trial import Trial 7 | from optuna.trial import TrialState 8 | 9 | from optuna_distributed.messages import Message 10 | 11 | 12 | if TYPE_CHECKING: 13 | from optuna_distributed.managers import OptimizationManager 14 | 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | 19 | class FailedMessage(Message): 20 | """A failed trial message. 21 | 22 | This message is sent after objective function has failed while being evaluated 23 | and tells study to fail associated trial. Also, if exception that caused objective 24 | function to fail is not explicitly ignored by user, it will be re-raised in main 25 | process, failing it entirely. 26 | 27 | Args: 28 | trial_id: 29 | Id of a trial to which the message is referring. 30 | exception: 31 | Instance of exception that was raised in objective function. 32 | exc_info: 33 | Information about exception that was raised in objective function. 34 | """ 35 | 36 | closing = True 37 | 38 | def __init__(self, trial_id: int, exception: Exception, exc_info: Any) -> None: 39 | self._trial_id = trial_id 40 | self._exception = exception 41 | self._exc_info = exc_info 42 | 43 | def process(self, study: Study, manager: "OptimizationManager") -> None: 44 | trial = Trial(study, self._trial_id) 45 | frozen_trial = study.tell(trial, state=TrialState.FAIL) 46 | manager.register_trial_exit(self._trial_id) 47 | _logger.warning( 48 | f"Trial {frozen_trial.number} failed with parameters: {frozen_trial.params} " 49 | f"because of the following error: {repr(self._exception)}.", 50 | exc_info=self._exc_info, 51 | ) 52 | raise self._exception 53 | -------------------------------------------------------------------------------- /optuna_distributed/terminal.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from rich.progress import BarColumn 4 | from rich.progress import Progress 5 | from rich.progress import TaskProgressColumn 6 | from rich.progress import TextColumn 7 | from rich.progress import TimeElapsedColumn 8 | from rich.status import Status 9 | from rich.style import Style 10 | 11 | 12 | class Terminal: 13 | """Provides styled terminal output. 14 | 15 | Args: 16 | show_progress_bar: 17 | Enables progress bar. 18 | n_trials: 19 | The number of trials to run in total. 20 | timeout: 21 | Stops study after the given number of second(s). 22 | """ 23 | 24 | def __init__( 25 | self, show_progress_bar: bool, n_trials: int, timeout: float | None = None 26 | ) -> None: 27 | self._timeout = timeout 28 | self._progbar = Progress( 29 | TextColumn("[progress.description]{task.description}"), 30 | BarColumn(complete_style=Style(color="light_coral")), 31 | TaskProgressColumn(), 32 | TimeElapsedColumn(), 33 | transient=True, 34 | ) 35 | 36 | self._task = self._progbar.add_task("[blue]Running trials...[/blue]", total=n_trials) 37 | if show_progress_bar: 38 | self._progbar.start() 39 | 40 | def update_progress_bar(self) -> None: 41 | """Advance progress bar by one trial.""" 42 | self._progbar.advance(self._task) 43 | 44 | def close_progress_bar(self) -> None: 45 | """Closes progress bar.""" 46 | self._progbar.stop() 47 | 48 | def spin_while_trials_interrupted(self) -> Status: 49 | """Renders spinner animation while trials are being interrupted.""" 50 | self._progbar.stop() 51 | return self._progbar.console.status( 52 | "[blue]Interrupting running trials...[/blue]", spinner_style=Style(color="blue") # type: ignore # noqa: E501 53 | ) 54 | -------------------------------------------------------------------------------- /tests/test_eventloop.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | import optuna 5 | import pytest 6 | 7 | from optuna_distributed.eventloop import EventLoop 8 | from optuna_distributed.managers import LocalOptimizationManager 9 | from optuna_distributed.terminal import Terminal 10 | from optuna_distributed.trial import DistributedTrial 11 | 12 | 13 | pytestmark = pytest.mark.skipif( 14 | sys.platform == "win32", reason="Local optimization not supported on Windows." 15 | ) 16 | 17 | 18 | def _objective_raises(trial: DistributedTrial) -> float: 19 | raise ValueError() 20 | 21 | 22 | def test_raises_on_trial_exception() -> None: 23 | n_trials = 5 24 | study = optuna.create_study() 25 | manager = LocalOptimizationManager(n_trials, n_jobs=1) 26 | event_loop = EventLoop(study, manager, objective=_objective_raises, interrupt_patience=10.0) 27 | with pytest.raises(ValueError): 28 | event_loop.run(terminal=Terminal(show_progress_bar=False, n_trials=n_trials)) 29 | 30 | 31 | def test_catches_on_trial_exception() -> None: 32 | n_trials = 5 33 | study = optuna.create_study() 34 | manager = LocalOptimizationManager(n_trials, n_jobs=1) 35 | event_loop = EventLoop(study, manager, objective=_objective_raises, interrupt_patience=10.0) 36 | event_loop.run( 37 | terminal=Terminal(show_progress_bar=False, n_trials=n_trials), catch=(ValueError,) 38 | ) 39 | 40 | 41 | def _objective_sleeps(trial: DistributedTrial) -> float: 42 | uninterrupted_execution_time = 60.0 43 | time.sleep(uninterrupted_execution_time) 44 | return 1.0 45 | 46 | 47 | def test_stops_optimization_after_timeout() -> None: 48 | uninterrupted_execution_time = 60.0 49 | n_trials = 1 50 | study = optuna.create_study() 51 | manager = LocalOptimizationManager(n_trials, n_jobs=1) 52 | event_loop = EventLoop(study, manager, objective=_objective_sleeps, interrupt_patience=10.0) 53 | started_at = time.time() 54 | event_loop.run(terminal=Terminal(show_progress_bar=False, n_trials=n_trials), timeout=1.0) 55 | interrupted_execution_time = time.time() - started_at 56 | assert interrupted_execution_time < uninterrupted_execution_time 57 | -------------------------------------------------------------------------------- /examples/quadratic_simple.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example adds Optuna-distributed semantics on top of 3 | https://github.com/optuna/optuna-examples/blob/main/quadratic_simple.py 4 | 5 | Optuna example that optimizes a simple quadratic function. 6 | In this example, we optimize a simple quadratic function. We also demonstrate how to continue an 7 | optimization and to use timeouts. 8 | """ 9 | 10 | import random 11 | import socket 12 | import time 13 | 14 | import optuna 15 | 16 | import optuna_distributed 17 | 18 | 19 | # Define a simple 2-dimensional objective function whose minimum value is -1 when (x, y) = (0, -1). 20 | def objective(trial): 21 | x = trial.suggest_float("x", -100, 100) 22 | y = trial.suggest_categorical("y", [-1, 0, 1]) 23 | # Let's simulate long running job and identify worker doing the job. 24 | time.sleep(random.uniform(0.0, 2.0)) 25 | trial.set_user_attr("worker", socket.gethostname()) 26 | return x**2 + y 27 | 28 | 29 | if __name__ == "__main__": 30 | # By default, we are relying on process based parallelism to run 31 | # all trials on a single machine. However, with Dask client, we can easily scale up 32 | # to Dask cluster spanning multiple physical workers. To learn how to setup and use 33 | # Dask cluster, please refer to https://docs.dask.org/en/stable/deploying.html. 34 | # from dask.distributed import Client 35 | # client = Client() 36 | client = None 37 | 38 | # Optuna-distributed just wraps standard Optuna study. The resulting object behaves 39 | # just like regular study, but optimization process is asynchronous. 40 | study = optuna_distributed.from_study(optuna.create_study(), client=client) 41 | 42 | # And let's continue with original Optuna example from here. 43 | # Let us minimize the objective function above. 44 | print("Running 10 trials...") 45 | study.optimize(objective, n_trials=10) 46 | worker = study.best_trial.user_attrs["worker"] 47 | print(f"Best value: {study.best_value} (params: {study.best_params}) calculated by {worker}\n") 48 | 49 | # We can continue the optimization as follows. 50 | print("Running 20 additional trials...") 51 | study.optimize(objective, n_trials=20) 52 | worker = study.best_trial.user_attrs["worker"] 53 | print(f"Best value: {study.best_value} (params: {study.best_params}) calculated by {worker}\n") 54 | 55 | # We can specify the timeout. 56 | print("Running additional trials in 2 seconds...") 57 | study.optimize(objective, n_trials=100, timeout=2.0) 58 | print("Best value: {} (params: {})\n".format(study.best_value, study.best_params)) 59 | -------------------------------------------------------------------------------- /examples/simple_storages.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example adds Optuna-distributed semantics on top of 3 | https://github.com/optuna/optuna-examples/blob/main/quadratic_simple.py 4 | 5 | Optuna example that optimizes a simple quadratic function. 6 | In this example, we optimize a simple quadratic function. We also demonstrate how to continue an 7 | optimization and to use timeouts. 8 | """ 9 | 10 | import random 11 | import socket 12 | import time 13 | 14 | import optuna 15 | from optuna.samplers import NSGAIISampler 16 | from optuna.storages import RDBStorage 17 | 18 | import optuna_distributed 19 | 20 | 21 | # Define a simple 2-dimensional objective function whose minimum value is -1 when (x, y) = (0, -1). 22 | def objective(trial): 23 | x = trial.suggest_float("x", -100, 100) 24 | y = trial.suggest_categorical("y", [-1, 0, 1]) 25 | # Let's simulate long running job and identify worker doing the job. 26 | time.sleep(random.uniform(0.0, 2.0)) 27 | trial.set_user_attr("worker", socket.gethostname()) 28 | return x**2 + y 29 | 30 | 31 | if __name__ == "__main__": 32 | # Using Dask client, we can easily scale up to multiple machines. 33 | # from dask.distributed import Client 34 | # client = Client() 35 | client = None 36 | 37 | # All standard Optuna storage, sampler and pruner options are supported. 38 | storage = RDBStorage("sqlite:///:memory:") 39 | sampler = NSGAIISampler() 40 | 41 | # Optuna-distributed just wraps standard Optuna study. The resulting object behaves 42 | # just like regular study, but optimization process is asynchronous. 43 | study = optuna_distributed.from_study( 44 | optuna.create_study(storage=storage, sampler=sampler), client=client 45 | ) 46 | 47 | # And let's continue with original Optuna example from here. 48 | # Let us minimize the objective function above. 49 | print("Running 10 trials...") 50 | study.optimize(objective, n_trials=10) 51 | worker = study.best_trial.user_attrs["worker"] 52 | print(f"Best value: {study.best_value} (params: {study.best_params}) calculated by {worker}\n") 53 | 54 | # We can continue the optimization as follows. 55 | print("Running 20 additional trials...") 56 | study.optimize(objective, n_trials=20) 57 | worker = study.best_trial.user_attrs["worker"] 58 | print(f"Best value: {study.best_value} (params: {study.best_params}) calculated by {worker}\n") 59 | 60 | # We can specify the timeout. 61 | print("Running additional trials in 2 seconds...") 62 | study.optimize(objective, n_trials=100, timeout=2.0) 63 | print("Best value: {} (params: {})\n".format(study.best_value, study.best_params)) 64 | -------------------------------------------------------------------------------- /examples/simple_pruning.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example adds Optuna-distributed semantics on top of 3 | https://github.com/optuna/optuna-examples/blob/main/simple_pruning.py 4 | 5 | Optuna example that demonstrates a pruner. 6 | In this example, we optimize a classifier configuration using scikit-learn. Note that, to enable 7 | the pruning feature, the following 2 methods are invoked after each step of the iterative training. 8 | (1) :func:`optuna.trial.Trial.report` 9 | (2) :func:`optuna.trial.Trial.should_prune` 10 | You can run this example as follows: 11 | $ python simple_prunning.py 12 | """ 13 | 14 | import optuna 15 | from optuna.trial import TrialState 16 | import sklearn.datasets 17 | import sklearn.linear_model 18 | import sklearn.model_selection 19 | 20 | import optuna_distributed 21 | 22 | 23 | def objective(trial): 24 | iris = sklearn.datasets.load_iris() 25 | classes = list(set(iris.target)) 26 | train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split( 27 | iris.data, iris.target, test_size=0.25 28 | ) 29 | 30 | alpha = trial.suggest_float("alpha", 1e-5, 1e-1, log=True) 31 | clf = sklearn.linear_model.SGDClassifier(alpha=alpha) 32 | 33 | for step in range(100): 34 | clf.partial_fit(train_x, train_y, classes=classes) 35 | 36 | # Report intermediate objective value. 37 | intermediate_value = clf.score(valid_x, valid_y) 38 | trial.report(intermediate_value, step) 39 | 40 | # Handle pruning based on the intermediate value. 41 | if trial.should_prune(): 42 | raise optuna.TrialPruned() 43 | 44 | return clf.score(valid_x, valid_y) 45 | 46 | 47 | if __name__ == "__main__": 48 | # Using Dask client, we can easily scale up to multiple machines. 49 | # from dask.distributed import Client 50 | # client = Client() 51 | client = None 52 | 53 | study = optuna_distributed.from_study(optuna.create_study(direction="maximize"), client=client) 54 | study.optimize(objective, n_trials=100) 55 | 56 | pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED]) 57 | complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE]) 58 | 59 | print("Study statistics: ") 60 | print(" Number of finished trials: ", len(study.trials)) 61 | print(" Number of pruned trials: ", len(pruned_trials)) 62 | print(" Number of complete trials: ", len(complete_trials)) 63 | 64 | print("Best trial:") 65 | trial = study.best_trial 66 | 67 | print(" Value: ", trial.value) 68 | 69 | print(" Params: ") 70 | for key, value in trial.params.items(): 71 | print(" {}: {}".format(key, value)) 72 | -------------------------------------------------------------------------------- /optuna_distributed/messages/suggest.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from optuna.distributions import BaseDistribution 6 | from optuna.distributions import CategoricalChoiceType 7 | from optuna.distributions import CategoricalDistribution 8 | from optuna.distributions import FloatDistribution 9 | from optuna.distributions import IntDistribution 10 | from optuna.study import Study 11 | from optuna.trial import Trial 12 | 13 | from optuna_distributed.messages import Message 14 | from optuna_distributed.messages.response import ResponseMessage 15 | 16 | 17 | if TYPE_CHECKING: 18 | from optuna_distributed.managers import OptimizationManager 19 | 20 | 21 | class SuggestMessage(Message): 22 | """A request for value suggestions. 23 | 24 | This message is sent by :class:`~optuna_distributed.trial.DistributedTrial` to 25 | main process asking for value suggestions. Main process provides them by 26 | using regular Optuna suggest APIs and responding via connection provided by worker. 27 | 28 | Args: 29 | trial_id: 30 | Id of a trial to which the message is referring. 31 | name: 32 | A parameter name. 33 | distribution: 34 | A parameter distribution. 35 | """ 36 | 37 | closing = False 38 | 39 | def __init__(self, trial_id: int, name: str, distribution: BaseDistribution) -> None: 40 | self._trial_id = trial_id 41 | self._name = name 42 | self._distribution = distribution 43 | 44 | def process(self, study: Study, manager: "OptimizationManager") -> None: 45 | trial = Trial(study, self._trial_id) 46 | value: float | int | CategoricalChoiceType 47 | if isinstance(self._distribution, FloatDistribution): 48 | value = trial.suggest_float( 49 | name=self._name, 50 | low=self._distribution.low, 51 | high=self._distribution.high, 52 | step=self._distribution.step, 53 | log=self._distribution.log, 54 | ) 55 | elif isinstance(self._distribution, IntDistribution): 56 | value = trial.suggest_int( 57 | name=self._name, 58 | low=self._distribution.low, 59 | high=self._distribution.high, 60 | step=self._distribution.step, 61 | log=self._distribution.log, 62 | ) 63 | elif isinstance(self._distribution, CategoricalDistribution): 64 | value = trial.suggest_categorical(name=self._name, choices=self._distribution.choices) 65 | else: 66 | assert False, "Should not reach." 67 | 68 | conn = manager.get_connection(self._trial_id) 69 | conn.put(ResponseMessage(self._trial_id, value)) 70 | -------------------------------------------------------------------------------- /optuna_distributed/messages/completed.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Sequence 4 | import io 5 | import logging 6 | from typing import TYPE_CHECKING 7 | 8 | from optuna.study import Study 9 | from optuna.trial import FrozenTrial 10 | from optuna.trial import Trial 11 | from optuna.trial import TrialState 12 | 13 | from optuna_distributed.messages import Message 14 | 15 | 16 | if TYPE_CHECKING: 17 | from optuna_distributed.managers import OptimizationManager 18 | 19 | 20 | _logger = logging.getLogger(__name__) 21 | 22 | 23 | class CompletedMessage(Message): 24 | """A completed trial message. 25 | 26 | This message is sent after objective function has been successfully evaluated 27 | and tells study about objective value (in case of single objective optimization) 28 | or sequence of objective values (in case of multi-objective optimization). 29 | 30 | Args: 31 | trial_id: 32 | Id of a trial to which the message is referring. 33 | value_or_values: 34 | Objective value or sequence of objective values. 35 | """ 36 | 37 | closing = True 38 | 39 | def __init__(self, trial_id: int, value_or_values: Sequence[float] | float) -> None: 40 | self._trial_id = trial_id 41 | self._value_or_values = value_or_values 42 | 43 | def process(self, study: Study, manager: "OptimizationManager") -> None: 44 | trial = Trial(study, self._trial_id) 45 | try: 46 | frozen_trial = study.tell(trial, self._value_or_values, skip_if_finished=True) 47 | 48 | except Exception: 49 | frozen_trial = study._storage.get_trial(self._trial_id) 50 | raise 51 | 52 | finally: 53 | manager.register_trial_exit(self._trial_id) 54 | if frozen_trial.state == TrialState.COMPLETE: 55 | self._log_completed_trial(study, trial=frozen_trial) 56 | else: 57 | # Tell failed to postprocess trial, so state has changed. 58 | _logger.warning( 59 | f"Trial {frozen_trial.number} failed because " 60 | "of the following error: STUDY_TELL_WARNING" 61 | ) 62 | 63 | def _log_completed_trial(self, study: Study, trial: FrozenTrial) -> None: 64 | buffer = io.StringIO() 65 | is_multiobjective = len(trial.values) != 1 66 | form = "values" if is_multiobjective else "value" 67 | 68 | buffer.write( 69 | f"Trial {trial.number} finished with {form}: " 70 | f"{self._value_or_values} and parameters: {trial.params}." 71 | ) 72 | if not is_multiobjective: 73 | buffer.write( 74 | f" Best is trial {study.best_trial.number} " 75 | f"with value: {study.best_trial.value}." 76 | ) 77 | 78 | _logger.info(buffer.getvalue()) 79 | buffer.close() 80 | -------------------------------------------------------------------------------- /tests/test_ipc.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from multiprocessing import Process 3 | from multiprocessing.connection import Pipe as MultiprocessingPipe 4 | import time 5 | 6 | from dask.distributed import Client 7 | from dask.distributed import wait 8 | import pytest 9 | 10 | from optuna_distributed.ipc import IPCPrimitive 11 | from optuna_distributed.ipc import Pipe 12 | from optuna_distributed.ipc import Queue 13 | from optuna_distributed.messages import ResponseMessage 14 | 15 | 16 | def _ping_pong(conn: IPCPrimitive) -> None: 17 | msg = conn.get() 18 | assert isinstance(msg, ResponseMessage) 19 | assert msg.data == "ping" 20 | conn.put(ResponseMessage(0, "pong")) 21 | 22 | 23 | def test_pipe_ping_pong() -> None: 24 | a, b = MultiprocessingPipe() 25 | p = Process(target=_ping_pong, args=(Pipe(b),)) 26 | p.start() 27 | 28 | master = Pipe(a) 29 | master.put(ResponseMessage(0, "ping")) 30 | response = master.get() 31 | assert isinstance(response, ResponseMessage) 32 | assert response.data == "pong" 33 | p.join() 34 | assert p.exitcode == 0 35 | 36 | 37 | def test_queue_ping_pong(client: Client) -> None: 38 | public = "public" 39 | private = "private" 40 | future = client.submit(_ping_pong, Queue(public, private)) 41 | master = Queue(private, public) 42 | master.put(ResponseMessage(0, "ping")) 43 | response = master.get() 44 | assert isinstance(response, ResponseMessage) 45 | assert response.data == "pong" 46 | wait(future) 47 | assert future.done() 48 | assert future.status == "finished" 49 | 50 | 51 | def test_queue_publishing_only(client: Client) -> None: 52 | q = Queue("foo") 53 | with pytest.raises(RuntimeError): 54 | q.get() 55 | 56 | 57 | def test_queue_raises_on_timeout_and_backoff(client: Client) -> None: 58 | with pytest.raises(ValueError): 59 | Queue("foo", timeout=1, max_retries=1) 60 | 61 | 62 | def test_queue_raises_after_timeout(client: Client) -> None: 63 | q = Queue("foo", "bar", timeout=1) 64 | with pytest.raises(asyncio.TimeoutError): 65 | q.get() 66 | 67 | 68 | def test_queue_raises_after_retries(client: Client) -> None: 69 | q = Queue("foo", "bar", max_retries=1) 70 | with pytest.raises(asyncio.TimeoutError): 71 | q.get() 72 | 73 | 74 | def test_queue_get_delayed_message(client: Client) -> None: 75 | public = "public" 76 | private = "private" 77 | future = client.submit(_ping_pong, Queue(public, private, max_retries=5)) 78 | master = Queue(private, public) 79 | 80 | # With exponential timeout, attempts are made after 1, 3, 7, 15... seconds. 81 | # To ensure at least one retry, message should be delayed between 1 and 3 seconds. 82 | time.sleep(2.0) 83 | master.put(ResponseMessage(0, "ping")) 84 | response = master.get() 85 | assert isinstance(response, ResponseMessage) 86 | assert response.data == "pong" 87 | wait(future) 88 | assert future.done() 89 | assert future.status == "finished" 90 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "optuna-distributed" 7 | description = "Distributed hyperparameter optimization made easy" 8 | authors = [ 9 | { name = "Adrian Zuber", email = "xadrianzetx@gmail.com" }, 10 | ] 11 | requires-python = ">=3.8" 12 | license = { text = "MIT" } 13 | classifiers = [ 14 | "Development Status :: 4 - Beta", 15 | "Intended Audience :: Science/Research", 16 | "Intended Audience :: Developers", 17 | "License :: OSI Approved :: MIT License", 18 | "Programming Language :: Python :: 3", 19 | "Programming Language :: Python :: 3.8", 20 | "Programming Language :: Python :: 3.9", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: 3.12", 24 | "Programming Language :: Python :: 3 :: Only", 25 | "Topic :: Scientific/Engineering", 26 | "Topic :: Scientific/Engineering :: Mathematics", 27 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 28 | "Topic :: Software Development", 29 | "Topic :: Software Development :: Libraries", 30 | "Topic :: Software Development :: Libraries :: Python Modules", 31 | "Topic :: System :: Distributed Computing", 32 | ] 33 | 34 | dependencies = [ 35 | "optuna>=3.1.0", 36 | "dask[distributed]", 37 | "rich", 38 | ] 39 | dynamic = ["version"] 40 | 41 | [project.readme] 42 | file = "README.md" 43 | content-type = "text/markdown" 44 | 45 | [project.optional-dependencies] 46 | dev = ["black", "isort", "flake8", "mypy", "pandas", "pandas-stubs"] 47 | test = ["pytest"] 48 | 49 | [project.urls] 50 | "Source" = "https://github.com/xadrianzetx/optuna-distributed" 51 | "Bug Tracker" = "https://github.com/xadrianzetx/optuna-distributed/issues" 52 | 53 | [tool.setuptools] 54 | packages = ["optuna_distributed"] 55 | 56 | [tool.setuptools.package-data] 57 | optuna_distributed = ["py.typed"] 58 | 59 | [tool.setuptools.dynamic] 60 | version = { attr = "optuna_distributed.__version__" } 61 | readme = { file = "README.md" } 62 | 63 | [tool.black] 64 | line-length = 99 65 | target-version = ["py310"] 66 | exclude = ''' 67 | /( 68 | \.eggs 69 | | \.git 70 | | \.mypy_cache 71 | | \.vscode 72 | | env 73 | | build 74 | | dist 75 | )/ 76 | ''' 77 | 78 | [tool.isort] 79 | profile = "black" 80 | src_paths = ["optuna_distributed", "tests", "examples"] 81 | line_length = 99 82 | lines_after_imports = 2 83 | force_single_line = true 84 | force_sort_within_sections = true 85 | order_by_type = true 86 | 87 | [tool.pytest.ini_options] 88 | testpaths = ["tests"] 89 | 90 | [tool.mypy] 91 | python_version = "3.10" 92 | warn_unused_configs = true 93 | disallow_untyped_defs = true 94 | disallow_incomplete_defs = true 95 | check_untyped_defs = true 96 | no_implicit_optional = true 97 | warn_redundant_casts = true 98 | strict_equality = true 99 | strict_concatenate = true 100 | exclude = ["env", "build", "examples"] 101 | -------------------------------------------------------------------------------- /optuna_distributed/ipc/queue.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | import pickle 5 | 6 | from dask.distributed import Queue as DaskQueue 7 | 8 | from optuna_distributed.ipc import IPCPrimitive 9 | from optuna_distributed.messages import Message 10 | 11 | 12 | class Queue(IPCPrimitive): 13 | """IPC primitive based on dask distributed queue. 14 | 15 | All messages are pickled before sending and unpickled 16 | after recieving to ensure data is msgpack-encodable. 17 | 18 | Args: 19 | publishing: 20 | A name of the queue used to publish messages to. 21 | recieving: 22 | A name of the queue used to recieve messages from. 23 | timeout: 24 | Time (in seconds) to wait for message to be fetched 25 | before raising an exception. Should not be set if 26 | `max_retries` is used. 27 | max_retries: 28 | Specifies maximum number of attempts to fetch a message. 29 | After each attempt, timeout is extended exponentially. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | publishing: str, 35 | recieving: str | None = None, 36 | timeout: int | None = None, 37 | max_retries: int | None = None, 38 | ) -> None: 39 | self._publishing = publishing 40 | self._recieving = recieving 41 | 42 | if max_retries is not None and timeout is not None: 43 | raise ValueError("Exponentially growing timeout is used when `max_retries` is set.") 44 | 45 | self._timeout = timeout 46 | self._max_retries = max_retries 47 | self._publisher: DaskQueue | None = None 48 | self._subscriber: DaskQueue | None = None 49 | self._initialized = False 50 | 51 | def _initialize(self) -> None: 52 | if not self._initialized: 53 | # Lazy initialization, since we have to make sure 54 | # channels are opened on target machine. 55 | self._publisher = DaskQueue(self._publishing) 56 | if self._recieving is not None: 57 | self._subscriber = DaskQueue(self._recieving) 58 | self._initialized = True 59 | 60 | def get(self) -> Message: 61 | self._initialize() 62 | if self._subscriber is None: 63 | raise RuntimeError("Trying to get message with publish-only connection.") 64 | 65 | attempt = 0 66 | while True: 67 | try: 68 | timeout = self._timeout if self._max_retries is None else 2**attempt 69 | return pickle.loads(self._subscriber.get(timeout)) 70 | 71 | except asyncio.TimeoutError: 72 | attempt += 1 73 | if self._max_retries is None or attempt == self._max_retries: 74 | raise 75 | 76 | def put(self, message: Message) -> None: 77 | self._initialize() 78 | assert self._publisher is not None 79 | self._publisher.put(pickle.dumps(message)) 80 | 81 | def close(self) -> None: 82 | # Cleanup is handled by dask. 83 | # For us it's enough to just drop references to queue objects. 84 | ... 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # optuna-distributed 2 | 3 | An extension to [Optuna](https://github.com/optuna/optuna) which makes distributed hyperparameter optimization easy, and keeps all of the original Optuna semantics. Optuna-distributed can run locally, by default utilising all CPU cores, or can easily scale to many machines in [Dask cluster](https://docs.dask.org/en/stable/deploying.html). 4 | 5 | > **Note** 6 | > 7 | > Optuna-distributed is still in the early stages of development. While core Optuna functionality is supported, few missing APIs (especially around Optuna integrations) might prevent this extension from being entirely plug-and-play for some users. Bug reports, feature requests and PRs are more than welcome. 8 | 9 | ## Features 10 | 11 | * Asynchronous optimization by default. Scales from single machine to many machines in cluster. 12 | * Distributed study walks and quacks just like regular Optuna study, making it plug-and-play. 13 | * Compatible with all standard Optuna storages, samplers and pruners. 14 | * No need to modify existing objective functions. 15 | 16 | ## Installation 17 | 18 | ```sh 19 | pip install optuna-distributed 20 | ``` 21 | Optuna-distributed requires Python 3.8 or newer. 22 | 23 | ## Basic example 24 | Optuna-distributed wraps standard Optuna study. The resulting object behaves just like regular study, but optimization process is asynchronous. Depending on setup of [Dask client](https://docs.dask.org/en/stable/10-minutes-to-dask.html#scheduling), each trial is scheduled to run on available CPU core on local machine, or physical worker in cluster. 25 | 26 | > **Note** 27 | > 28 | > Running distributed optimization requires a Dask cluster with environment closely matching one on the client machine. For more information on cluster setup and configuration, please refer to https://docs.dask.org/en/stable/deploying.html. 29 | 30 | ```python 31 | import random 32 | import time 33 | 34 | import optuna 35 | import optuna_distributed 36 | from dask.distributed import Client 37 | 38 | 39 | def objective(trial): 40 | x = trial.suggest_float("x", -100, 100) 41 | y = trial.suggest_categorical("y", [-1, 0, 1]) 42 | # Some expensive model fit happens here... 43 | time.sleep(random.uniform(1.0, 2.0)) 44 | return x**2 + y 45 | 46 | 47 | if __name__ == "__main__": 48 | # client = Client("") # Enables distributed optimization. 49 | client = None # Enables local asynchronous optimization. 50 | study = optuna_distributed.from_study(optuna.create_study(), client=client) 51 | study.optimize(objective, n_trials=10) 52 | print(study.best_value) 53 | ``` 54 | 55 | But there's more! All of the core Optuna APIs, including [storages, samplers](https://github.com/xadrianzetx/optuna-distributed/blob/main/examples/simple_storages.py) and [pruners](https://github.com/xadrianzetx/optuna-distributed/blob/main/examples/simple_pruning.py) are supported! If you'd like to know how Optuna-distributed works, then check out [this article on Optuna blog](https://medium.com/optuna/running-distributed-hyperparameter-optimization-with-optuna-distributed-17bb2f7d422d). 56 | 57 | ## What's missing? 58 | * Support for callbacks and Optuna integration modules. 59 | * Study APIs such as [`study.stop`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.stop) can't be called from trial at the moment. 60 | * Local asynchronous optimization on Windows machines. Distributed mode is still available. 61 | * Support for [`optuna.terminator`](https://optuna.readthedocs.io/en/stable/reference/terminator.html). -------------------------------------------------------------------------------- /optuna_distributed/eventloop.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | 5 | from optuna.study import Study 6 | from optuna.trial import TrialState 7 | 8 | from optuna_distributed.managers import ObjectiveFuncType 9 | from optuna_distributed.managers import OptimizationManager 10 | from optuna_distributed.terminal import Terminal 11 | 12 | 13 | class EventLoop: 14 | """Collects and acts upon all that is happening in optimization process. 15 | 16 | After trials are dispatched to run across many workers, all communication with 17 | them is held via central point in the event loop. From here we can wait for requests 18 | made by workers (e.g. to suggest a hyperparameter value) and act upon them using local 19 | resources. This ensures sequential access to storages, samplers etc. 20 | 21 | Args: 22 | study: 23 | An instance of Optuna study. 24 | manager: 25 | An instance of :class:`~optuna_distributed.managers.Manager`. 26 | objective: 27 | An objective function to optimize. 28 | interrupt_patience: 29 | Specifies how many seconds to wait for trials to exit ater interrupt has been emitted. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | study: Study, 35 | manager: OptimizationManager, 36 | objective: ObjectiveFuncType, 37 | interrupt_patience: float, 38 | ) -> None: 39 | self.study = study 40 | self.manager = manager 41 | self.objective = objective 42 | self._interrupt_patience = interrupt_patience 43 | 44 | def run( 45 | self, 46 | terminal: Terminal, 47 | timeout: float | None = None, 48 | catch: tuple[type[Exception], ...] = (), 49 | ) -> None: 50 | """Starts the event loop. 51 | 52 | Args: 53 | terminal: 54 | An instance of :obj:`optuna_distributed.terminal.Terminal`. 55 | timeout: 56 | Stops study after the given number of second(s). 57 | catch: 58 | A tuple of exceptions to ignore if any is raised while optimizing a function. 59 | """ 60 | time_start = datetime.now() 61 | self.manager.create_futures(self.study, self.objective) 62 | for message in self.manager.get_message(): 63 | try: 64 | message.process(self.study, self.manager) 65 | self.manager.after_message(self) 66 | 67 | except Exception as e: 68 | if not isinstance(e, catch): 69 | with terminal.spin_while_trials_interrupted(): 70 | self.manager.stop_optimization(patience=self._interrupt_patience) 71 | self._fail_unfinished_trials() 72 | raise 73 | 74 | elapsed = (datetime.now() - time_start).total_seconds() 75 | if timeout is not None and elapsed > timeout: 76 | with terminal.spin_while_trials_interrupted(): 77 | self.manager.stop_optimization(patience=self._interrupt_patience) 78 | break 79 | 80 | if message.closing: 81 | terminal.update_progress_bar() 82 | 83 | # TODO(xadrianzetx): Call callbacks here. 84 | if self.manager.should_end_optimization(): 85 | terminal.close_progress_bar() 86 | break 87 | 88 | def _fail_unfinished_trials(self) -> None: 89 | # TODO(xadrianzetx) Is there a better way to do this in Optuna? 90 | states = (TrialState.RUNNING, TrialState.WAITING) 91 | trials = self.study.get_trials(deepcopy=False, states=states) 92 | for trial in trials: 93 | self.study._storage.set_trial_state_values(trial._trial_id, TrialState.FAIL) 94 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Visual Studio Code 163 | .vscode/ 164 | -------------------------------------------------------------------------------- /optuna_distributed/managers/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | from abc import ABC 5 | from collections.abc import Generator 6 | from typing import Callable 7 | from typing import Sequence 8 | from typing import TYPE_CHECKING 9 | from typing import Union 10 | 11 | from optuna.study import Study 12 | 13 | from optuna_distributed.ipc import IPCPrimitive 14 | from optuna_distributed.messages import Message 15 | from optuna_distributed.trial import DistributedTrial 16 | 17 | 18 | if TYPE_CHECKING: 19 | from optuna_distributed.eventloop import EventLoop 20 | 21 | 22 | ObjectiveFuncType = Callable[[DistributedTrial], Union[float, Sequence[float]]] 23 | DistributableFuncType = Callable[[DistributedTrial], None] 24 | 25 | 26 | class OptimizationManager(ABC): 27 | """Controls and provides context in event loop. 28 | 29 | Managers serve as a layer of abstraction between main process event loop 30 | and distributed workers. They can provide workers with context necessary 31 | to do the job, and pump event loop with messages to process. 32 | """ 33 | 34 | @abc.abstractmethod 35 | def create_futures(self, study: Study, objective: ObjectiveFuncType) -> None: 36 | """Spawns a set of workers to run objective function. 37 | 38 | Args: 39 | study: 40 | An instance of Optuna study. 41 | objective: 42 | User defined callable that implements objective function. Must be 43 | serializable and in distributed mode can only use resources available 44 | to all workers in cluster. 45 | """ 46 | raise NotImplementedError 47 | 48 | @abc.abstractmethod 49 | def get_message(self) -> Generator[Message, None, None]: 50 | """Fetches incoming messages from workers.""" 51 | raise NotImplementedError 52 | 53 | @abc.abstractmethod 54 | def after_message(self, event_loop: "EventLoop") -> None: 55 | """A hook allowing to run additional operations after recieved 56 | message is processed. 57 | 58 | Args: 59 | event_loop: 60 | An instance of :class:`~optuna_distributed.eventloop.EventLoop` 61 | providing context to study and manager. 62 | """ 63 | raise NotImplementedError 64 | 65 | @abc.abstractmethod 66 | def get_connection(self, trial_id: int) -> IPCPrimitive: 67 | """Fetches private connection to worker. 68 | 69 | Args: 70 | trial_id: 71 | A connection to worker running trial with specified 72 | id will be fetched. 73 | """ 74 | raise NotImplementedError 75 | 76 | @abc.abstractmethod 77 | def stop_optimization(self, patience: float) -> None: 78 | """Stops all running trials and sets thier statuses to failed. 79 | 80 | Args: 81 | patience: 82 | Specifies how many seconds to wait for trials to exit 83 | ater interrupt has been emitted. 84 | """ 85 | raise NotImplementedError 86 | 87 | @abc.abstractmethod 88 | def should_end_optimization(self) -> bool: 89 | """Indicates whether optimization process can be finished. 90 | 91 | Returns :obj:`True` when all workers have send one of closing 92 | messages, indicating completed, pruned or failed trials. 93 | """ 94 | raise NotImplementedError 95 | 96 | @abc.abstractmethod 97 | def register_trial_exit(self, trial_id: int) -> None: 98 | """Informs manager about worker finishing a trial. 99 | 100 | This should be called in one of closing messages to indicate 101 | worker finishing with expected state. 102 | 103 | Args: 104 | trial_id: 105 | Id of a trial that was being run on exiting worker. 106 | """ 107 | raise NotImplementedError 108 | -------------------------------------------------------------------------------- /tests/test_trial.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import deque 4 | 5 | from optuna.distributions import CategoricalDistribution 6 | from optuna.distributions import FloatDistribution 7 | from optuna.distributions import IntDistribution 8 | import pytest 9 | 10 | from optuna_distributed.ipc import IPCPrimitive 11 | from optuna_distributed.messages import Message 12 | from optuna_distributed.messages import ReportMessage 13 | from optuna_distributed.messages import ResponseMessage 14 | from optuna_distributed.messages import SetAttributeMessage 15 | from optuna_distributed.messages import ShouldPruneMessage 16 | from optuna_distributed.messages import SuggestMessage 17 | from optuna_distributed.messages import TrialPropertyMessage 18 | from optuna_distributed.trial import DistributedTrial 19 | 20 | 21 | class MockIPC(IPCPrimitive): 22 | def __init__(self) -> None: 23 | self.captured: list[Message] = [] 24 | self.responses: deque[ResponseMessage] = deque() 25 | 26 | def get(self) -> "Message": 27 | return self.responses.popleft() 28 | 29 | def put(self, message: "Message") -> None: 30 | self.captured.append(message) 31 | 32 | def close(self) -> None: ... 33 | 34 | def enqueue_response(self, response: ResponseMessage) -> None: 35 | self.responses.append(response) 36 | 37 | 38 | @pytest.fixture 39 | def connection() -> MockIPC: 40 | return MockIPC() 41 | 42 | 43 | def test_suggest_float(connection: MockIPC) -> None: 44 | connection.enqueue_response(ResponseMessage(0, data=0.0)) 45 | trial = DistributedTrial(0, connection) 46 | x = trial.suggest_float("x", low=0.0, high=1.0) 47 | assert x == 0.0 48 | captured = connection.captured[0] 49 | assert isinstance(captured, SuggestMessage) 50 | assert captured._trial_id == 0 51 | assert captured._name == "x" 52 | 53 | distribution = captured._distribution 54 | assert isinstance(distribution, FloatDistribution) 55 | assert distribution.low == 0.0 56 | assert distribution.high == 1.0 57 | assert not distribution.log 58 | assert distribution.step is None 59 | 60 | 61 | def test_suggest_int(connection: MockIPC) -> None: 62 | connection.enqueue_response(ResponseMessage(0, data=0)) 63 | trial = DistributedTrial(0, connection) 64 | x = trial.suggest_int("x", low=0, high=1) 65 | assert x == 0 66 | captured = connection.captured[0] 67 | assert isinstance(captured, SuggestMessage) 68 | assert captured._trial_id == 0 69 | assert captured._name == "x" 70 | 71 | distribution = captured._distribution 72 | assert isinstance(distribution, IntDistribution) 73 | assert distribution.low == 0 74 | assert distribution.high == 1 75 | assert distribution.step == 1 76 | assert not distribution.log 77 | 78 | 79 | def test_suggest_categorical(connection: MockIPC) -> None: 80 | connection.enqueue_response(ResponseMessage(0, data="foo")) 81 | trial = DistributedTrial(0, connection) 82 | x = trial.suggest_categorical("x", choices=["foo", "bar", "baz"]) 83 | assert x == "foo" 84 | captured = connection.captured[0] 85 | assert isinstance(captured, SuggestMessage) 86 | assert captured._trial_id == 0 87 | assert captured._name == "x" 88 | 89 | distribution = captured._distribution 90 | assert isinstance(distribution, CategoricalDistribution) 91 | assert distribution.choices == ("foo", "bar", "baz") 92 | 93 | 94 | def test_report(connection: MockIPC) -> None: 95 | trial = DistributedTrial(0, connection) 96 | trial.report(value=0.0, step=1) 97 | captured = connection.captured[0] 98 | assert isinstance(captured, ReportMessage) 99 | assert captured._trial_id == 0 100 | assert captured._step == 1 101 | assert captured._value == 0.0 102 | 103 | 104 | def test_should_prune(connection: MockIPC) -> None: 105 | connection.enqueue_response(ResponseMessage(0, data=False)) 106 | trial = DistributedTrial(0, connection) 107 | assert not trial.should_prune() 108 | captured = connection.captured[0] 109 | assert isinstance(captured, ShouldPruneMessage) 110 | assert captured._trial_id == 0 111 | 112 | 113 | def test_set_user_attr(connection: MockIPC) -> None: 114 | trial = DistributedTrial(0, connection) 115 | trial.set_user_attr(key="foo", value="bar") 116 | captured = connection.captured[0] 117 | assert isinstance(captured, SetAttributeMessage) 118 | assert captured._trial_id == 0 119 | assert captured._kind == "user" 120 | assert captured._key == "foo" 121 | assert captured._value == "bar" 122 | 123 | 124 | def test_set_system_attr(connection: MockIPC) -> None: 125 | trial = DistributedTrial(0, connection) 126 | trial.set_system_attr(key="foo", value="bar") 127 | captured = connection.captured[0] 128 | assert isinstance(captured, SetAttributeMessage) 129 | assert captured._trial_id == 0 130 | assert captured._kind == "system" 131 | assert captured._key == "foo" 132 | assert captured._value == "bar" 133 | 134 | 135 | @pytest.mark.parametrize( 136 | "property", 137 | ["params", "distributions", "user_attrs", "system_attrs", "datetime_start", "number"], 138 | ) 139 | def test_get_properties(connection: MockIPC, property: str) -> None: 140 | connection.enqueue_response(ResponseMessage(0, "foo")) 141 | trial = DistributedTrial(0, connection) 142 | assert getattr(trial, property) == "foo" 143 | captured = connection.captured[0] 144 | assert isinstance(captured, TrialPropertyMessage) 145 | assert captured._property == property 146 | -------------------------------------------------------------------------------- /optuna_distributed/managers/local.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Generator 4 | import multiprocessing 5 | from multiprocessing import Pipe as MultiprocessingPipe 6 | from multiprocessing import Process 7 | from multiprocessing.connection import Connection 8 | from multiprocessing.connection import wait 9 | import sys 10 | from typing import TYPE_CHECKING 11 | 12 | from optuna import Study 13 | from optuna.exceptions import TrialPruned 14 | 15 | from optuna_distributed.ipc import IPCPrimitive 16 | from optuna_distributed.ipc import Pipe 17 | from optuna_distributed.managers import ObjectiveFuncType 18 | from optuna_distributed.managers import OptimizationManager 19 | from optuna_distributed.messages import CompletedMessage 20 | from optuna_distributed.messages import FailedMessage 21 | from optuna_distributed.messages import HeartbeatMessage 22 | from optuna_distributed.messages import Message 23 | from optuna_distributed.messages import PrunedMessage 24 | from optuna_distributed.trial import DistributedTrial 25 | 26 | 27 | if TYPE_CHECKING: 28 | from optuna_distributed.eventloop import EventLoop 29 | 30 | 31 | class LocalOptimizationManager(OptimizationManager): 32 | """Controls optimization process on local machine. 33 | 34 | In contrast to Optuna, this implementation uses process based parallelism. 35 | 36 | Args: 37 | n_trials: 38 | Number of trials to run. 39 | n_jobs: 40 | Maximum number of processes allowed to run trials at the same time. 41 | If less or equal to 0, then this argument is overridden with CPU count. 42 | """ 43 | 44 | def __init__(self, n_trials: int, n_jobs: int) -> None: 45 | if n_jobs <= 0 or n_jobs > multiprocessing.cpu_count(): 46 | self._n_jobs = multiprocessing.cpu_count() 47 | else: 48 | self._n_jobs = n_jobs 49 | 50 | self._workers_to_spawn = min(self._n_jobs, n_trials) 51 | self._trials_remaining = n_trials - self._workers_to_spawn 52 | 53 | self._connections: dict[int, Connection] = {} 54 | self._processes: dict[int, Process] = {} 55 | 56 | def create_futures(self, study: Study, objective: ObjectiveFuncType) -> None: 57 | trial_ids = [study.ask()._trial_id for _ in range(self._workers_to_spawn)] 58 | for trial_id in trial_ids: 59 | master, worker = MultiprocessingPipe() 60 | trial = DistributedTrial(trial_id, Pipe(worker)) 61 | p = Process(target=_trial_runtime, args=(objective, trial), daemon=True) 62 | p.start() 63 | worker.close() 64 | 65 | self._processes[trial_id] = p 66 | self._connections[trial_id] = master 67 | 68 | def get_message(self) -> Generator[Message, None, None]: 69 | while True: 70 | messages: list[Message] = [] 71 | for incoming in wait(self._connections.values(), timeout=10): 72 | # FIXME: This assertion is true only for Unix systems. 73 | # Some refactoring is needed to support Windows as well. 74 | # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.connection.wait 75 | assert isinstance(incoming, Connection) 76 | try: 77 | message = incoming.recv() 78 | messages.append(message) 79 | 80 | except EOFError: 81 | self._close_connection(incoming) 82 | 83 | self._set_workers_to_spawn() 84 | if messages: 85 | yield from messages 86 | else: 87 | yield HeartbeatMessage() 88 | 89 | def after_message(self, event_loop: "EventLoop") -> None: 90 | if self._workers_to_spawn > 0: 91 | self._join_finished_processes() 92 | self.create_futures(event_loop.study, event_loop.objective) 93 | 94 | self._trials_remaining -= self._workers_to_spawn 95 | self._workers_to_spawn = 0 96 | 97 | def get_connection(self, trial_id: int) -> IPCPrimitive: 98 | return Pipe(self._connections[trial_id]) 99 | 100 | def stop_optimization(self, patience: float) -> None: 101 | for process in self._processes.values(): 102 | if process.is_alive(): 103 | process.kill() 104 | process.join(timeout=patience) 105 | 106 | def should_end_optimization(self) -> bool: 107 | return len(self._connections) == 0 and self._trials_remaining == 0 108 | 109 | def register_trial_exit(self, trial_id: int) -> None: 110 | # Noop, as worker informs us about exit by closing connection. 111 | ... 112 | 113 | def _close_connection(self, connection: Connection) -> None: 114 | for trial_id, open_connection in self._connections.items(): 115 | if connection == open_connection: 116 | break 117 | 118 | self._connections.pop(trial_id).close() 119 | 120 | def _set_workers_to_spawn(self) -> None: 121 | self._workers_to_spawn = min(self._n_jobs - len(self._connections), self._trials_remaining) 122 | 123 | def _join_finished_processes(self) -> None: 124 | for trial_id in [tid for tid, p in self._processes.items() if p.exitcode is not None]: 125 | self._processes.pop(trial_id).join() 126 | 127 | 128 | def _trial_runtime(func: ObjectiveFuncType, trial: DistributedTrial) -> None: 129 | message: Message 130 | try: 131 | value_or_values = func(trial) 132 | message = CompletedMessage(trial.trial_id, value_or_values) 133 | trial.connection.put(message) 134 | 135 | except TrialPruned as e: 136 | message = PrunedMessage(trial.trial_id, e) 137 | trial.connection.put(message) 138 | 139 | except Exception as e: 140 | exc_info = sys.exc_info() 141 | message = FailedMessage(trial.trial_id, e, exc_info) 142 | trial.connection.put(message) 143 | 144 | finally: 145 | trial.connection.close() 146 | -------------------------------------------------------------------------------- /tests/test_messages.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import logging 3 | from typing import Any 4 | from typing import Generator 5 | from unittest.mock import MagicMock 6 | 7 | from optuna.distributions import BaseDistribution 8 | from optuna.distributions import CategoricalDistribution 9 | from optuna.distributions import FloatDistribution 10 | from optuna.distributions import IntDistribution 11 | from optuna.exceptions import TrialPruned 12 | from optuna.study import Study 13 | from optuna.trial import TrialState 14 | import pytest 15 | 16 | import optuna_distributed 17 | from optuna_distributed.eventloop import EventLoop 18 | from optuna_distributed.ipc import IPCPrimitive 19 | from optuna_distributed.managers import ObjectiveFuncType 20 | from optuna_distributed.managers import OptimizationManager 21 | from optuna_distributed.messages import CompletedMessage 22 | from optuna_distributed.messages import FailedMessage 23 | from optuna_distributed.messages import HeartbeatMessage 24 | from optuna_distributed.messages import Message 25 | from optuna_distributed.messages import PrunedMessage 26 | from optuna_distributed.messages import ReportMessage 27 | from optuna_distributed.messages import ResponseMessage 28 | from optuna_distributed.messages import SetAttributeMessage 29 | from optuna_distributed.messages import ShouldPruneMessage 30 | from optuna_distributed.messages import SuggestMessage 31 | from optuna_distributed.messages import TrialProperty 32 | from optuna_distributed.messages import TrialPropertyMessage 33 | 34 | 35 | class MockConnection(IPCPrimitive): 36 | def __init__(self, manager: "MockOptimizationManager") -> None: 37 | self._manager = manager 38 | 39 | def get(self) -> "Message": 40 | return HeartbeatMessage() 41 | 42 | def put(self, message: Message) -> None: 43 | assert isinstance(message, ResponseMessage) 44 | self._manager.message_response = message.data 45 | 46 | def close(self) -> None: ... 47 | 48 | 49 | class MockOptimizationManager(OptimizationManager): 50 | def __init__(self) -> None: 51 | self.trial_exit_called = False 52 | self.message_response = None 53 | 54 | def create_futures(self, study: "Study", objective: ObjectiveFuncType) -> None: ... 55 | 56 | def before_message(self, event_loop: "EventLoop") -> None: ... 57 | 58 | def get_message(self) -> Generator["Message", None, None]: 59 | yield HeartbeatMessage() 60 | 61 | def after_message(self, event_loop: "EventLoop") -> None: ... 62 | 63 | def get_connection(self, trial_id: int) -> "IPCPrimitive": 64 | return MockConnection(self) 65 | 66 | def stop_optimization(self, patience: float) -> None: ... 67 | 68 | def should_end_optimization(self) -> bool: 69 | return True 70 | 71 | def register_trial_exit(self, trial_id: int) -> None: 72 | self.trial_exit_called = True 73 | 74 | 75 | @pytest.fixture 76 | def manager() -> MockOptimizationManager: 77 | return MockOptimizationManager() 78 | 79 | 80 | @contextmanager 81 | def _forced_log_propagation(logger_name: str) -> Generator[None, None, None]: 82 | try: 83 | # Local fix for https://github.com/pytest-dev/pytest/issues/3697 84 | logging.getLogger(logger_name).propagate = True 85 | yield 86 | finally: 87 | logging.getLogger(logger_name).propagate = False 88 | 89 | 90 | def _message_responds_with(value: Any, manager: MockOptimizationManager) -> bool: 91 | return manager.message_response == value 92 | 93 | 94 | def test_completed_with_correct_value( 95 | study: Study, manager: MockOptimizationManager, caplog: pytest.LogCaptureFixture 96 | ) -> None: 97 | msg = CompletedMessage(0, 0.0) 98 | assert msg.closing 99 | with _forced_log_propagation(logger_name=optuna_distributed.__name__): 100 | msg.process(study, manager) 101 | assert manager.trial_exit_called 102 | assert len(caplog.records) == 1 103 | assert caplog.records[0].levelno == logging.INFO 104 | trial = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,)) 105 | assert len(trial) == 1 106 | assert trial[0].value == 0.0 107 | 108 | 109 | def test_completed_with_incorrect_values(study: Study, manager: MockOptimizationManager) -> None: 110 | msg = CompletedMessage(0, "foo") # type: ignore 111 | assert msg.closing 112 | with pytest.warns(): 113 | msg.process(study, manager) 114 | assert manager.trial_exit_called 115 | 116 | 117 | def test_pruned( 118 | study: Study, manager: MockOptimizationManager, caplog: pytest.LogCaptureFixture 119 | ) -> None: 120 | msg = PrunedMessage(0, TrialPruned()) 121 | assert msg.closing 122 | with _forced_log_propagation(logger_name=optuna_distributed.__name__): 123 | msg.process(study, manager) 124 | assert manager.trial_exit_called 125 | assert len(caplog.records) == 1 126 | assert caplog.records[0].levelno == logging.INFO 127 | trial = study.get_trials(deepcopy=False, states=(TrialState.PRUNED,)) 128 | assert len(trial) == 1 129 | 130 | 131 | def test_failed( 132 | study: Study, manager: MockOptimizationManager, caplog: pytest.LogCaptureFixture 133 | ) -> None: 134 | exc = ValueError("foo") 135 | msg = FailedMessage(0, exc, exc_info=MagicMock()) 136 | assert msg.closing 137 | logger_name = optuna_distributed.__name__ 138 | with pytest.raises(ValueError), _forced_log_propagation(logger_name): 139 | msg.process(study, manager) 140 | assert manager.trial_exit_called 141 | assert len(caplog.records) == 1 142 | assert caplog.records[0].levelno == logging.WARNING 143 | trial = study.get_trials(deepcopy=False, states=(TrialState.FAIL,)) 144 | assert len(trial) == 1 145 | 146 | 147 | def test_heartbeat() -> None: 148 | msg = HeartbeatMessage() 149 | assert not msg.closing 150 | 151 | 152 | def test_response() -> None: 153 | msg = ResponseMessage(0, data="foo") 154 | assert not msg.closing 155 | assert msg.data == "foo" 156 | 157 | 158 | @pytest.mark.parametrize( 159 | "property", 160 | [ 161 | "params", 162 | "distributions", 163 | "user_attrs", 164 | "system_attrs", 165 | "datetime_start", 166 | "number", 167 | ], 168 | ) 169 | def test_trial_property( 170 | study: Study, manager: MockOptimizationManager, property: TrialProperty 171 | ) -> None: 172 | msg = TrialPropertyMessage(0, property) 173 | assert not msg.closing 174 | msg.process(study, manager) 175 | expected = getattr(study.get_trials(deepcopy=False)[0], property) 176 | assert _message_responds_with(expected, manager=manager) 177 | 178 | 179 | def test_should_prune(study: Study, manager: MockOptimizationManager) -> None: 180 | msg = ShouldPruneMessage(0) 181 | assert not msg.closing 182 | msg.process(study, manager) 183 | 184 | assert _message_responds_with(False, manager=manager) 185 | trial = study.get_trials(deepcopy=False, states=(TrialState.RUNNING,)) 186 | assert len(trial) == 1 187 | assert trial[0]._trial_id == 0 188 | 189 | 190 | def test_report_intermediate(study: Study, manager: MockOptimizationManager) -> None: 191 | msg = ReportMessage(0, value=0.0, step=1) 192 | assert not msg.closing 193 | 194 | msg.process(study, manager) 195 | trial = study.get_trials(deepcopy=False)[0] 196 | assert trial.intermediate_values[1] == 0.0 197 | 198 | 199 | def test_set_user_attributes(study: Study, manager: MockOptimizationManager) -> None: 200 | msg = SetAttributeMessage(0, key="foo", value=0, kind="user") 201 | assert not msg.closing 202 | 203 | msg.process(study, manager) 204 | trial = study.get_trials(deepcopy=False)[0] 205 | assert trial.user_attrs["foo"] == 0 206 | 207 | 208 | def test_set_system_attributes(study: Study, manager: MockOptimizationManager) -> None: 209 | msg = SetAttributeMessage(0, value=0, key="foo", kind="system") 210 | assert not msg.closing 211 | 212 | msg.process(study, manager) 213 | trial = study.get_trials(deepcopy=False)[0] 214 | assert trial.system_attrs["foo"] == 0 215 | 216 | 217 | @pytest.mark.parametrize( 218 | "distribution", 219 | [ 220 | FloatDistribution(low=0.0, high=1.0), 221 | IntDistribution(low=0, high=1), 222 | CategoricalDistribution(choices=["foo", "bar"]), 223 | ], 224 | ) 225 | def test_suggest( 226 | study: Study, manager: MockOptimizationManager, distribution: BaseDistribution 227 | ) -> None: 228 | msg = SuggestMessage(0, name="x", distribution=distribution) 229 | assert not msg.closing 230 | 231 | msg.process(study, manager) 232 | trial = study.get_trials(deepcopy=False)[0] 233 | assert "x" in trial.distributions 234 | assert trial.distributions["x"] == distribution 235 | assert "x" in trial.params 236 | assert _message_responds_with(trial.params["x"], manager=manager) 237 | -------------------------------------------------------------------------------- /optuna_distributed/managers/distributed.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | from collections.abc import Generator 5 | import ctypes 6 | from dataclasses import dataclass 7 | from enum import IntEnum 8 | import sys 9 | import threading 10 | from threading import Thread 11 | import time 12 | from typing import Callable 13 | from typing import TYPE_CHECKING 14 | import uuid 15 | 16 | from dask.distributed import Client 17 | from dask.distributed import Future 18 | from dask.distributed import Variable 19 | from optuna.exceptions import TrialPruned 20 | from optuna.study import Study 21 | 22 | from optuna_distributed.ipc import IPCPrimitive 23 | from optuna_distributed.ipc import Queue 24 | from optuna_distributed.managers import ObjectiveFuncType 25 | from optuna_distributed.managers import OptimizationManager 26 | from optuna_distributed.messages import CompletedMessage 27 | from optuna_distributed.messages import FailedMessage 28 | from optuna_distributed.messages import HeartbeatMessage 29 | from optuna_distributed.messages import Message 30 | from optuna_distributed.messages import PrunedMessage 31 | from optuna_distributed.trial import DistributedTrial 32 | 33 | 34 | if TYPE_CHECKING: 35 | from optuna_distributed.eventloop import EventLoop 36 | 37 | 38 | DistributableWithContext = Callable[["_TaskContext"], None] 39 | 40 | 41 | class WorkerInterrupted(Exception): ... 42 | 43 | 44 | class _TaskState(IntEnum): 45 | WAITING = 0 46 | RUNNING = 1 47 | FINISHED = 2 48 | 49 | 50 | @dataclass 51 | class _TaskContext: 52 | trial: DistributedTrial 53 | stop_flag: str 54 | state_id: str 55 | 56 | 57 | class _StateSynchronizer: 58 | def __init__(self) -> None: 59 | self._optimization_enabled = Variable() 60 | self._optimization_enabled.set(True) 61 | self._task_states: list[Variable] = [] 62 | 63 | @property 64 | def stop_flag(self) -> str: 65 | return self._optimization_enabled.name 66 | 67 | def set_initial_state(self) -> str: 68 | task_state = Variable() 69 | task_state.set(_TaskState.WAITING) 70 | self._task_states.append(task_state) 71 | return task_state.name 72 | 73 | def emit_stop_and_wait(self, patience: float) -> None: 74 | self._optimization_enabled.set(False) 75 | disabled_at = time.time() 76 | while any(_TaskState(state.get()) is _TaskState.RUNNING for state in self._task_states): 77 | if time.time() - disabled_at > patience: 78 | raise TimeoutError("Timed out while trying to interrupt running tasks.") 79 | time.sleep(0.1) 80 | 81 | 82 | class DistributedOptimizationManager(OptimizationManager): 83 | """Controls optimization process spanning multiple physical machines. 84 | 85 | This implementation uses dask as parallel computing backend. 86 | 87 | Args: 88 | client: 89 | An instance of dask client. 90 | n_trials: 91 | Number of trials to run. 92 | heartbeat_interval: 93 | Delay (in seconds) before 94 | :func:`optuna_distributed.managers.DistributedOptimizationManager.get_message` 95 | produces a heartbeat message if no other message is sent by the worker. 96 | """ 97 | 98 | def __init__(self, client: Client, n_trials: int, heartbeat_interval: int = 60) -> None: 99 | self._client = client 100 | self._n_trials = n_trials 101 | self._completed_trials = 0 102 | self._public_channel = str(uuid.uuid4()) 103 | self._synchronizer = _StateSynchronizer() 104 | 105 | # Manager has write access to its own message queue as a sort of health check. 106 | # Basically that means we can pump event loop from callbacks running in 107 | # main process with e.g. HeartbeatMessage. 108 | self._message_queue = Queue( 109 | publishing=self._public_channel, 110 | recieving=self._public_channel, 111 | timeout=heartbeat_interval, 112 | ) 113 | self._private_channels: dict[int, str] = {} 114 | self._futures: list[Future] = [] 115 | 116 | def _ensure_safe_exit(self, future: Future) -> None: 117 | if future.status in ["error", "cancelled"]: 118 | # FIXME: I'm not sure if there is a way to get 119 | # id of a trial that failed this way. 120 | self.register_trial_exit(-1) 121 | self._message_queue.put(HeartbeatMessage()) 122 | 123 | def _assign_private_channel(self, trial_id: int) -> "Queue": 124 | private_channel = str(uuid.uuid4()) 125 | self._private_channels[trial_id] = private_channel 126 | return Queue(self._public_channel, private_channel, max_retries=5) 127 | 128 | def _create_trials(self, study: Study) -> list[DistributedTrial]: 129 | # HACK: It's kinda naughty to access _trial_id, but this is gonna make 130 | # our lifes much easier in messaging system. 131 | trial_ids = [study.ask()._trial_id for _ in range(self._n_trials)] 132 | return [DistributedTrial(tid, self._assign_private_channel(tid)) for tid in trial_ids] 133 | 134 | def _add_task_context(self, trials: list[DistributedTrial]) -> list[_TaskContext]: 135 | trials_with_context: list[_TaskContext] = [] 136 | for trial in trials: 137 | trials_with_context.append( 138 | _TaskContext( 139 | trial, 140 | stop_flag=self._synchronizer.stop_flag, 141 | state_id=self._synchronizer.set_initial_state(), 142 | ) 143 | ) 144 | 145 | return trials_with_context 146 | 147 | def create_futures(self, study: Study, objective: ObjectiveFuncType) -> None: 148 | distributable = _distributable(objective) 149 | trials = self._add_task_context(self._create_trials(study)) 150 | self._futures = self._client.map(distributable, trials, pure=False) 151 | for future in self._futures: 152 | future.add_done_callback(self._ensure_safe_exit) 153 | 154 | def get_message(self) -> Generator[Message, None, None]: 155 | while True: 156 | try: 157 | # TODO(xadrianzetx) At some point we might need a mechanism 158 | # that allows workers to repeat messages to master. 159 | # A deduplication algorithm would go here then. 160 | yield self._message_queue.get() 161 | 162 | except asyncio.TimeoutError: 163 | # Pumping event loop with heartbeat messages on timeout 164 | # allows us to handle potential problems gracefully 165 | # e.g. in `after_message`. 166 | yield HeartbeatMessage() 167 | 168 | def after_message(self, event_loop: "EventLoop") -> None: ... 169 | 170 | def get_connection(self, trial_id: int) -> IPCPrimitive: 171 | return Queue(self._private_channels[trial_id]) 172 | 173 | def stop_optimization(self, patience: float) -> None: 174 | self._client.cancel(self._futures) 175 | self._synchronizer.emit_stop_and_wait(patience) 176 | 177 | def should_end_optimization(self) -> bool: 178 | return self._completed_trials == self._n_trials 179 | 180 | def register_trial_exit(self, trial_id: int) -> None: 181 | self._completed_trials += 1 182 | 183 | 184 | def _distributable(func: ObjectiveFuncType) -> DistributableWithContext: 185 | def _wrapper(context: _TaskContext) -> None: 186 | task_state = Variable(context.state_id) 187 | if _TaskState(task_state.get()) is not _TaskState.WAITING: 188 | return 189 | 190 | task_state.set(_TaskState.RUNNING) 191 | message: Message 192 | 193 | try: 194 | args = (threading.get_ident(), context) 195 | Thread(target=_task_supervisor, args=args, daemon=True).start() 196 | value_or_values = func(context.trial) 197 | message = CompletedMessage(context.trial.trial_id, value_or_values) 198 | context.trial.connection.put(message) 199 | 200 | except TrialPruned as e: 201 | message = PrunedMessage(context.trial.trial_id, e) 202 | context.trial.connection.put(message) 203 | 204 | except WorkerInterrupted: 205 | ... 206 | 207 | except Exception as e: 208 | exc_info = sys.exc_info() 209 | message = FailedMessage(context.trial.trial_id, e, exc_info) 210 | context.trial.connection.put(message) 211 | 212 | finally: 213 | context.trial.connection.close() 214 | task_state.set(_TaskState.FINISHED) 215 | 216 | return _wrapper 217 | 218 | 219 | def _task_supervisor(thread_id: int, context: _TaskContext) -> None: 220 | optimization_enabled = Variable(context.stop_flag) 221 | task_state = Variable(context.state_id) 222 | while True: 223 | time.sleep(0.1) 224 | if _TaskState(task_state.get()) is _TaskState.FINISHED: 225 | break 226 | 227 | if not optimization_enabled.get(): 228 | # https://gist.github.com/liuw/2407154 229 | # https://distributed.dask.org/en/stable/worker-state.html#task-cancellation 230 | ctypes.pythonapi.PyThreadState_SetAsyncExc( 231 | ctypes.c_long(thread_id), ctypes.py_object(WorkerInterrupted) 232 | ) 233 | break 234 | -------------------------------------------------------------------------------- /optuna_distributed/trial.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Sequence 4 | import datetime 5 | from typing import Any 6 | from typing import TypeVar 7 | 8 | from optuna.distributions import BaseDistribution 9 | from optuna.distributions import CategoricalChoiceType 10 | from optuna.distributions import CategoricalDistribution 11 | from optuna.distributions import FloatDistribution 12 | from optuna.distributions import IntDistribution 13 | 14 | from optuna_distributed.ipc import IPCPrimitive 15 | from optuna_distributed.messages import ReportMessage 16 | from optuna_distributed.messages import ResponseMessage 17 | from optuna_distributed.messages import SetAttributeMessage 18 | from optuna_distributed.messages import ShouldPruneMessage 19 | from optuna_distributed.messages import SuggestMessage 20 | from optuna_distributed.messages import TrialProperty 21 | from optuna_distributed.messages import TrialPropertyMessage 22 | from optuna_distributed.messages.base import Message 23 | 24 | 25 | T = TypeVar("T", bound=CategoricalChoiceType) 26 | 27 | 28 | class DistributedTrial: 29 | """A trial is a process of evaluating an objective function. 30 | 31 | This is a version of Optuna trial designed to run in process or machine separate 32 | to the study and its resources. Communication with study is held via messaging 33 | system, allowing remote workers to use standard Optuna trial APIs. 34 | 35 | For complete documentation, please refer to: 36 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna-trial-trial 37 | 38 | Args: 39 | trial_id: 40 | A trial ID that is automatically generated. 41 | connection: 42 | An instance of :class:`~optuna_distributed.ipc.IPCPrimitive`. 43 | """ 44 | 45 | def __init__(self, trial_id: int, connection: IPCPrimitive) -> None: 46 | self.trial_id = trial_id 47 | self.connection = connection 48 | 49 | def _suggest(self, name: str, distribution: BaseDistribution) -> Any: 50 | message = SuggestMessage(self.trial_id, name, distribution) 51 | return self._send_message_and_wait_response(message) 52 | 53 | def _get_property(self, property: TrialProperty) -> Any: 54 | message = TrialPropertyMessage(self.trial_id, property) 55 | return self._send_message_and_wait_response(message) 56 | 57 | def _send_message_and_wait_response(self, message: Message) -> Any: 58 | self.connection.put(message) 59 | response = self.connection.get() 60 | assert isinstance(response, ResponseMessage) 61 | return response.data 62 | 63 | def suggest_float( 64 | self, 65 | name: str, 66 | low: float, 67 | high: float, 68 | *, 69 | step: float | None = None, 70 | log: bool = False, 71 | ) -> float: 72 | """Suggest a value for the floating point parameter. 73 | 74 | For complete documentation, please refer to: 75 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.suggest_float 76 | 77 | Args: 78 | name: 79 | A parameter name. 80 | low: 81 | Lower endpoint of the range of suggested values. ``low`` is included in the range. 82 | ``low`` must be less than or equal to ``high``. If ``log`` is :obj:`True`, 83 | ``low`` must be larger than 0. 84 | high: 85 | Upper endpoint of the range of suggested values. ``high`` is included in the range. 86 | ``high`` must be greater than or equal to ``low``. 87 | step: 88 | A step of discretization. 89 | log: 90 | A flag to sample the value from the log domain or not. 91 | If ``log`` is true, the value is sampled from the range in the log domain. 92 | Otherwise, the value is sampled from the range in the linear domain. 93 | """ 94 | distribution = FloatDistribution(low, high, step=step, log=log) 95 | return self._suggest(name, distribution) 96 | 97 | def suggest_uniform(self, name: str, low: float, high: float) -> float: 98 | """Suggest a value for the continuous parameter. 99 | 100 | For complete documentation, please refer to: 101 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.suggest_uniform 102 | 103 | Args: 104 | name: 105 | A parameter name. 106 | low: 107 | Lower endpoint of the range of suggested values. ``low`` is included in the range. 108 | high: 109 | Upper endpoint of the range of suggested values. ``high`` is included in the range. 110 | """ 111 | return self.suggest_float(name, low, high) 112 | 113 | def suggest_loguniform(self, name: str, low: float, high: float) -> float: 114 | """Suggest a value for the continuous parameter. 115 | 116 | For complete documentation, please refer to: 117 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.suggest_loguniform 118 | 119 | Args: 120 | name: 121 | A parameter name. 122 | low: 123 | Lower endpoint of the range of suggested values. ``low`` is included in the range. 124 | high: 125 | Upper endpoint of the range of suggested values. ``high`` is included in the range. 126 | """ 127 | return self.suggest_float(name, low, high, log=True) 128 | 129 | def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float) -> float: 130 | """Suggest a value for the discrete parameter. 131 | 132 | For complete documentation, please refer to: 133 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.suggest_discrete_uniform 134 | 135 | Args: 136 | name: 137 | A parameter name. 138 | low: 139 | Lower endpoint of the range of suggested values. ``low`` is included in the range. 140 | high: 141 | Upper endpoint of the range of suggested values. ``high`` is included in the range. 142 | q: 143 | A step of discretization. 144 | """ 145 | return self.suggest_float(name, low, high, step=q) 146 | 147 | def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int: 148 | """Suggest a value for the integer parameter. 149 | 150 | For complete documentation, please refer to: 151 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.suggest_int 152 | 153 | Args: 154 | name: 155 | A parameter name. 156 | low: 157 | Lower endpoint of the range of suggested values. ``low`` is included in the range. 158 | ``low`` must be less than or equal to ``high``. If ``log`` is :obj:`True`, 159 | ``low`` must be larger than 0. 160 | high: 161 | Upper endpoint of the range of suggested values. ``high`` is included in the range. 162 | ``high`` must be greater than or equal to ``low``. 163 | step: 164 | A step of discretization. 165 | log: 166 | A flag to sample the value from the log domain or not. 167 | """ 168 | distribution = IntDistribution(low, high, log=log, step=step) 169 | return self._suggest(name, distribution) 170 | 171 | def suggest_categorical(self, name: str, choices: Sequence[T]) -> T: 172 | """Suggest a value for the categorical parameter. 173 | 174 | For complete documentation, please refer to: 175 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.suggest_categorical 176 | 177 | Args: 178 | name: 179 | A parameter name. 180 | choices: 181 | Parameter value candidates. 182 | """ 183 | distribution = CategoricalDistribution(choices) 184 | return self._suggest(name, distribution) 185 | 186 | def report(self, value: float, step: int) -> None: 187 | """Report an objective function value for a given step. 188 | 189 | For complete documentation, please refer to: 190 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.report 191 | 192 | Args: 193 | value: 194 | An intermediate value returned from the objective function. 195 | step: 196 | Step of the trial (e.g., Epoch of neural network training). 197 | """ 198 | message = ReportMessage(self.trial_id, value, step) 199 | self.connection.put(message) 200 | 201 | def should_prune(self) -> bool: 202 | """Suggest whether the trial should be pruned or not. 203 | 204 | For complete documentation, please refer to: 205 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.should_prune 206 | """ 207 | message = ShouldPruneMessage(self.trial_id) 208 | return self._send_message_and_wait_response(message) 209 | 210 | def set_user_attr(self, key: str, value: Any) -> None: 211 | """Set user attributes to the trial. 212 | 213 | For complete documentation, please refer to: 214 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.set_user_attr 215 | 216 | Args: 217 | key: 218 | A key string of the attribute. 219 | value: 220 | A value of the attribute. The value should be able to serialize with pickle. 221 | """ 222 | message = SetAttributeMessage(self.trial_id, key, value, kind="user") 223 | self.connection.put(message) 224 | 225 | def set_system_attr(self, key: str, value: Any) -> None: 226 | """set system attributes to the trial. 227 | 228 | For complete documentation, please refer to: 229 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.set_system_attr 230 | 231 | Args: 232 | key: 233 | A key string of the attribute. 234 | value: 235 | A value of the attribute. The value should be able to serialize with pickle. 236 | """ 237 | message = SetAttributeMessage(self.trial_id, key, value, kind="system") 238 | self.connection.put(message) 239 | 240 | @property 241 | def params(self) -> dict[str, Any]: 242 | """Return parameters to be optimized.""" 243 | return self._get_property("params") 244 | 245 | @property 246 | def distributions(self) -> dict[str, BaseDistribution]: 247 | """Return distributions of parameters to be optimized.""" 248 | return self._get_property("distributions") 249 | 250 | @property 251 | def user_attrs(self) -> dict[str, Any]: 252 | """Return user attributes.""" 253 | return self._get_property("user_attrs") 254 | 255 | @property 256 | def system_attrs(self) -> dict[str, Any]: 257 | """Return system attributes.""" 258 | return self._get_property("system_attrs") 259 | 260 | @property 261 | def datetime_start(self) -> datetime.datetime | None: 262 | """Return start datetime.""" 263 | return self._get_property("datetime_start") 264 | 265 | @property 266 | def number(self) -> int: 267 | """Return trial's number which is consecutive and unique in a study.""" 268 | return self._get_property("number") 269 | -------------------------------------------------------------------------------- /tests/test_managers.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from dataclasses import dataclass 3 | import multiprocessing 4 | import sys 5 | import time 6 | from typing import Generator 7 | from unittest.mock import Mock 8 | import uuid 9 | 10 | from dask.distributed import Client 11 | from dask.distributed import Variable 12 | from dask.distributed import wait 13 | import optuna 14 | import pytest 15 | 16 | from optuna_distributed.managers import DistributedOptimizationManager 17 | from optuna_distributed.managers import LocalOptimizationManager 18 | from optuna_distributed.managers import ObjectiveFuncType 19 | from optuna_distributed.managers.distributed import _StateSynchronizer 20 | from optuna_distributed.managers.distributed import _TaskContext 21 | from optuna_distributed.managers.distributed import _TaskState 22 | from optuna_distributed.managers.distributed import _distributable 23 | from optuna_distributed.messages import CompletedMessage 24 | from optuna_distributed.messages import HeartbeatMessage 25 | from optuna_distributed.messages import ResponseMessage 26 | from optuna_distributed.trial import DistributedTrial 27 | 28 | 29 | def test_distributed_get_message(client: Client) -> None: 30 | n_trials = 5 31 | study = optuna.create_study() 32 | manager = DistributedOptimizationManager(client, n_trials) 33 | manager.create_futures(study, lambda trial: 0.0) 34 | completed = 0 35 | for message in manager.get_message(): 36 | assert isinstance(message, CompletedMessage) 37 | completed += 1 38 | if completed == n_trials: 39 | break 40 | 41 | 42 | def test_distributed_heartbeat_on_timeout(client: Client) -> None: 43 | def _objective(trial: DistributedTrial) -> float: 44 | time.sleep(2.0) 45 | return 0.0 46 | 47 | study = optuna.create_study() 48 | manager = DistributedOptimizationManager(client, n_trials=1, heartbeat_interval=1) 49 | manager.create_futures(study, _objective) 50 | start = time.time() 51 | for message in manager.get_message(): 52 | assert isinstance(message, HeartbeatMessage) 53 | assert 0.8 < time.time() - start < 1.2 54 | break 55 | 56 | wait(manager._futures) 57 | 58 | 59 | def test_distributed_should_end_optimization(client: Client) -> None: 60 | n_trials = 5 61 | study = optuna.create_study() 62 | manager = DistributedOptimizationManager(client, n_trials) 63 | manager.create_futures(study, lambda trial: 0.0) 64 | closing_messages_recieved = 0 65 | for message in manager.get_message(): 66 | assert not isinstance(message, HeartbeatMessage) 67 | if message.closing: 68 | closing_messages_recieved += 1 69 | manager.register_trial_exit(message._trial_id) # type: ignore 70 | 71 | if manager.should_end_optimization(): 72 | break 73 | 74 | assert closing_messages_recieved == n_trials 75 | 76 | 77 | def test_distributed_stops_optimization(client: Client) -> None: 78 | uninterrupted_execution_time = 100 79 | 80 | def _objective(trial: DistributedTrial) -> float: 81 | # Sleep needs to be fragemnted to read error indicator. 82 | for _ in range(uninterrupted_execution_time): 83 | time.sleep(1.0) 84 | return 0.0 85 | 86 | study = optuna.create_study() 87 | manager = DistributedOptimizationManager(client, n_trials=5) 88 | manager.create_futures(study, _objective) 89 | stopped_at = time.time() 90 | manager.stop_optimization(patience=10.0) 91 | interrupted_execution_time = time.time() - stopped_at 92 | assert interrupted_execution_time < uninterrupted_execution_time 93 | for future in manager._futures: 94 | assert future.cancelled() 95 | 96 | 97 | def test_distributed_connection_management(client: Client) -> None: 98 | def _objective(trial: DistributedTrial) -> float: 99 | requested = trial.connection.get() 100 | assert isinstance(requested, ResponseMessage) 101 | data = {"requested": requested.data, "actual": trial.trial_id} 102 | trial.connection.put(ResponseMessage(trial.trial_id, data)) 103 | return 0.0 104 | 105 | n_trials = 5 106 | study = optuna.create_study() 107 | manager = DistributedOptimizationManager(client, n_trials) 108 | manager.create_futures(study, _objective) 109 | for trial in study.get_trials(deepcopy=False): 110 | connection = manager.get_connection(trial._trial_id) 111 | connection.put(ResponseMessage(0, data=trial._trial_id)) 112 | 113 | for message in manager.get_message(): 114 | if message.closing: 115 | manager.register_trial_exit(message._trial_id) # type: ignore 116 | if isinstance(message, ResponseMessage): 117 | assert message.data["requested"] == message.data["actual"] 118 | if manager.should_end_optimization(): 119 | break 120 | 121 | 122 | def test_distributed_task_deduped(client: Client) -> None: 123 | def _objective(trial: DistributedTrial) -> float: 124 | run_count = Variable("run_count") 125 | run_count.set(run_count.get() + 1) 126 | return 0.0 127 | 128 | run_count = Variable("run_count") 129 | run_count.set(0) 130 | state_id = uuid.uuid4().hex 131 | Variable(state_id).set(_TaskState.WAITING) 132 | 133 | # Simulate scenario where task run was repeated. 134 | # https://stackoverflow.com/a/41965766 135 | func = _distributable(_objective) 136 | context = _TaskContext(DistributedTrial(0, Mock()), stop_flag="foo", state_id=state_id) 137 | for _ in range(5): 138 | client.submit(func, context).result() 139 | 140 | assert run_count.get() == 1 141 | 142 | 143 | def test_synchronizer_optimization_enabled() -> None: 144 | synchronizer = _StateSynchronizer() 145 | optimization_enabled = Variable(synchronizer.stop_flag) 146 | assert optimization_enabled.get() 147 | 148 | 149 | def test_synchronizer_emits_stop() -> None: 150 | synchronizer = _StateSynchronizer() 151 | synchronizer.emit_stop_and_wait(1) 152 | optimization_enabled = Variable(synchronizer.stop_flag) 153 | assert not optimization_enabled.get() 154 | 155 | 156 | def test_synchronizer_states_created() -> None: 157 | synchronizer = _StateSynchronizer() 158 | states = [Variable(synchronizer.set_initial_state()) for _ in range(10)] 159 | assert all(_TaskState(state.get()) is _TaskState.WAITING for state in states) 160 | 161 | 162 | def test_synchronizer_timeout() -> None: 163 | synchronizer = _StateSynchronizer() 164 | task_state = Variable(synchronizer.set_initial_state()) 165 | task_state.set(_TaskState.RUNNING) 166 | with pytest.raises(TimeoutError): 167 | synchronizer.emit_stop_and_wait(0) 168 | 169 | 170 | def _objective_local_get_message(trial: DistributedTrial) -> float: 171 | trial.connection.put(ResponseMessage(0, data=None)) 172 | return 0.0 173 | 174 | 175 | @pytest.mark.skipif(sys.platform == "win32", reason="Local optimization not supported on Windows.") 176 | def test_local_get_message() -> None: 177 | n_trials = 1 178 | study = optuna.create_study() 179 | manager = LocalOptimizationManager(n_trials, n_jobs=1) 180 | manager.create_futures(study, _objective_local_get_message) 181 | completed = 0 182 | for message in manager.get_message(): 183 | assert isinstance(message, ResponseMessage) 184 | completed += 1 185 | if completed == n_trials: 186 | break 187 | 188 | 189 | def _objective_local_should_end_optimization(trial: DistributedTrial) -> float: 190 | return 0.0 191 | 192 | 193 | @pytest.mark.skipif(sys.platform == "win32", reason="Local optimization not supported on Windows.") 194 | def test_local_should_end_optimization() -> None: 195 | n_trials = 1 196 | study = optuna.create_study() 197 | manager = LocalOptimizationManager(n_trials, n_jobs=1) 198 | manager.create_futures(study, _objective_local_should_end_optimization) 199 | closing_messages_recieved = 0 200 | for message in manager.get_message(): 201 | if message.closing: 202 | closing_messages_recieved += 1 203 | manager.register_trial_exit(message._trial_id) # type: ignore 204 | 205 | if manager.should_end_optimization(): 206 | break 207 | 208 | assert closing_messages_recieved == n_trials 209 | 210 | 211 | def _objective_local_stops_optimziation(trial: DistributedTrial) -> float: 212 | uninterrupted_execution_time = 5.0 213 | time.sleep(uninterrupted_execution_time) 214 | return 0.0 215 | 216 | 217 | @pytest.mark.skipif(sys.platform == "win32", reason="Local optimization not supported on Windows.") 218 | def test_local_stops_optimziation() -> None: 219 | uninterrupted_execution_time = 5.0 220 | study = optuna.create_study() 221 | manager = LocalOptimizationManager(n_trials=10, n_jobs=1) 222 | manager.create_futures(study, _objective_local_stops_optimziation) 223 | stopped_at = time.time() 224 | manager.stop_optimization(patience=10.0) 225 | interrupted_execution_time = time.time() - stopped_at 226 | assert interrupted_execution_time < uninterrupted_execution_time 227 | for process in manager._processes.values(): 228 | assert not process.is_alive() 229 | 230 | 231 | def _objective_local_connection_management(trial: DistributedTrial) -> float: 232 | requested = trial.connection.get() 233 | assert isinstance(requested, ResponseMessage) 234 | data = {"requested": requested.data, "actual": trial.trial_id} 235 | trial.connection.put(ResponseMessage(trial.trial_id, data)) 236 | return 0.0 237 | 238 | 239 | @pytest.mark.skipif(sys.platform == "win32", reason="Local optimization not supported on Windows.") 240 | def test_local_connection_management() -> None: 241 | n_trials = 1 242 | study = optuna.create_study() 243 | manager = LocalOptimizationManager(n_trials, n_jobs=1) 244 | manager.create_futures(study, _objective_local_connection_management) 245 | for trial in study.get_trials(deepcopy=False): 246 | connection = manager.get_connection(trial._trial_id) 247 | connection.put(ResponseMessage(0, data=trial._trial_id)) 248 | 249 | recieved = 0 250 | for message in manager.get_message(): 251 | assert isinstance(message, ResponseMessage) 252 | assert message.data["requested"] == message.data["actual"] 253 | recieved += 1 254 | if recieved == n_trials: 255 | break 256 | 257 | 258 | def _objective_local_worker_pool_management(trial: DistributedTrial) -> float: 259 | return 0.0 260 | 261 | 262 | @pytest.mark.skipif(sys.platform == "win32", reason="Local optimization not supported on Windows.") 263 | def test_local_worker_pool_management() -> None: 264 | @dataclass 265 | class _MockEventLoop: 266 | study: optuna.Study 267 | objective: ObjectiveFuncType 268 | 269 | study = optuna.create_study() 270 | manager = LocalOptimizationManager(n_trials=10, n_jobs=-1) 271 | eventloop = _MockEventLoop(study, _objective_local_worker_pool_management) 272 | 273 | manager.create_futures(study, _objective_local_worker_pool_management) 274 | for message in manager.get_message(): 275 | message.process(study, manager) 276 | manager.after_message(eventloop) # type: ignore 277 | if not manager.should_end_optimization(): 278 | assert 0 < len(manager._connections) <= multiprocessing.cpu_count() 279 | else: 280 | break 281 | 282 | 283 | @pytest.mark.skipif(sys.platform == "win32", reason="No file descriptor limits on Windows.") 284 | def test_local_free_resources() -> None: 285 | @contextmanager 286 | def _limited_nofile(limit: int) -> Generator[None, None, None]: 287 | import resource 288 | 289 | soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) 290 | resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard)) 291 | yield 292 | 293 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard)) 294 | 295 | @dataclass 296 | class _MockEventLoop: 297 | study: optuna.Study 298 | objective: ObjectiveFuncType 299 | 300 | study = optuna.create_study() 301 | eventloop = _MockEventLoop(study, _objective_local_worker_pool_management) 302 | 303 | # Try to run more trials than there are available file descriptors. Incorrectly managed 304 | # optimization will fail by exceeding this limit. 305 | n_trials = 1024 306 | with _limited_nofile(n_trials): 307 | manager = LocalOptimizationManager(n_trials=n_trials + 1, n_jobs=5) 308 | manager.create_futures(study, objective=_objective_local_worker_pool_management) 309 | 310 | try: 311 | for message in manager.get_message(): 312 | message.process(study, manager) 313 | manager.after_message(eventloop) # type: ignore 314 | if manager.should_end_optimization(): 315 | break 316 | 317 | except OSError: 318 | pytest.fail("File descriptor limit reached") 319 | -------------------------------------------------------------------------------- /optuna_distributed/study.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | from collections.abc import Container 5 | from collections.abc import Iterable 6 | from collections.abc import Sequence 7 | import sys 8 | from typing import Any 9 | from typing import TYPE_CHECKING 10 | 11 | from dask.distributed import Client 12 | from dask.distributed import LocalCluster 13 | from optuna.distributions import BaseDistribution 14 | from optuna.study import Study 15 | from optuna.study import StudyDirection 16 | from optuna.trial import FrozenTrial 17 | from optuna.trial import Trial 18 | from optuna.trial import TrialState 19 | 20 | from optuna_distributed.eventloop import EventLoop 21 | from optuna_distributed.managers import DistributedOptimizationManager 22 | from optuna_distributed.managers import LocalOptimizationManager 23 | from optuna_distributed.managers import ObjectiveFuncType 24 | from optuna_distributed.terminal import Terminal 25 | 26 | 27 | if TYPE_CHECKING: 28 | import pandas as pd 29 | 30 | 31 | class DistributedStudy: 32 | """Extends regular Optuna study by distributing trials across multiple workers. 33 | 34 | This object behaves like regular Optuna study, except trials will be evaluated in parallel 35 | after :func:`optuna_distributed.DistributedStudy.optimize` is called. When :obj:`client` 36 | is :obj:`None`, work is distributed among available CPU cores by using multiprocessing. 37 | If Dask client is specified, `optuna_distributed` can use it to distribute trials across 38 | many physical workers in the cluster. 39 | 40 | .. note:: 41 | Using `optuna_distributed` in distributed mode requires a Dask cluster with matching 42 | environment. To read more about the deployment and usage of Dask clusters, please refer 43 | to https://docs.dask.org/en/stable/deploying.html. 44 | 45 | .. note:: 46 | Any APIs besides :func:`optuna_distributed.DistributedStudy.optimize` are just 47 | passthrough to regular Optuna study and can be used in standard ways. 48 | 49 | .. note:: 50 | There are no known compatibility issues at the moment. All Optuna storages, samplers 51 | and pruners can be used. 52 | 53 | Args: 54 | study: 55 | An isntance of Optuna study. 56 | client: 57 | A Dask client. When specified, all trials will be passed to 58 | Dask scheduler to distribute across available workers. 59 | If :obj:`None`, multiprocessing backend is used for 60 | process based parallelism. 61 | """ 62 | 63 | def __init__(self, study: Study, client: Client | None = None) -> None: 64 | self._study = study 65 | self._client = client 66 | 67 | @property 68 | def best_params(self) -> dict[str, Any]: 69 | """Return parameters of the best trial in the study.""" 70 | return self._study.best_params 71 | 72 | @property 73 | def best_value(self) -> float: 74 | """Return the best objective value in the study.""" 75 | return self._study.best_value 76 | 77 | @property 78 | def best_trial(self) -> FrozenTrial: 79 | """Return the best trial in the study.""" 80 | return self._study.best_trial 81 | 82 | @property 83 | def best_trials(self) -> list[FrozenTrial]: 84 | """Return trials located at the Pareto front in the study.""" 85 | return self._study.best_trials 86 | 87 | @property 88 | def direction(self) -> StudyDirection: 89 | """Return the direction of the study.""" 90 | return self._study.direction 91 | 92 | @property 93 | def directions(self) -> list[StudyDirection]: 94 | """Return the directions of the study.""" 95 | return self._study.directions 96 | 97 | @property 98 | def trials(self) -> list[FrozenTrial]: 99 | """Return all trials in the study.""" 100 | return self._study.trials 101 | 102 | @property 103 | def user_attrs(self) -> dict[str, Any]: 104 | """Return user attributes.""" 105 | return self._study.user_attrs 106 | 107 | @property 108 | def system_attrs(self) -> dict[str, Any]: 109 | """Return system attributes.""" 110 | return self._study.system_attrs 111 | 112 | def into_study(self) -> Study: 113 | """Returns regular Optuna study.""" 114 | return self._study 115 | 116 | def get_trials( 117 | self, deepcopy: bool = True, states: Container[TrialState] | None = None 118 | ) -> list[FrozenTrial]: 119 | """Return all trials in the study. 120 | 121 | For complete documentation, please refer to: 122 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.get_trials 123 | 124 | Args: 125 | deepcopy: 126 | Flag to control whether to apply ``copy.deepcopy()`` to the trials. 127 | states: 128 | Trial states to filter on. If :obj:`None`, include all states. 129 | """ 130 | return self._study.get_trials(deepcopy, states) 131 | 132 | def optimize( 133 | self, 134 | func: ObjectiveFuncType, 135 | n_trials: int | None = None, 136 | timeout: float | None = None, 137 | n_jobs: int = -1, 138 | catch: Iterable[type[Exception]] | type[Exception] = (), 139 | callbacks: list[Callable[["Study", FrozenTrial], None]] | None = None, 140 | show_progress_bar: bool = False, 141 | *args: Any, 142 | **kwargs: Any, 143 | ) -> None: 144 | """Optimize an objective function. 145 | 146 | Optimization is done by choosing a suitable set of hyperparameter values from a given 147 | range. If Dask client has been specified, evaluations of objective function (trials) 148 | will be distributed among available workers, otherwise parallelism is process based. 149 | 150 | For additional notes on some args, please refer to: 151 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.optimize 152 | 153 | Args: 154 | func: 155 | A callable that implements objective function. 156 | n_trials: 157 | The number of trials to run in total. 158 | timeout: 159 | Stop study after the given number of second(s). 160 | n_jobs: 161 | The number of parallel jobs when using multiprocessing backend. Values less than 162 | one or greater than :obj:`multiprocessing.cpu_count()` will default to number of 163 | logical CPU cores available. 164 | catch: 165 | A study continues to run even when a trial raises one of the exceptions specified 166 | in this argument. 167 | callbacks: 168 | List of callback functions that are invoked at the end of each trial. Currently 169 | not supported. 170 | show_progress_bar: 171 | Flag to show progress bars or not. To disable progress bar, set this :obj:`False`. 172 | """ 173 | if n_trials is None: 174 | raise ValueError("Only finite number of trials supported at the moment.") 175 | 176 | terminal = Terminal(show_progress_bar, n_trials, timeout) 177 | catch = tuple(catch) if isinstance(catch, Iterable) else (catch,) 178 | manager = ( 179 | DistributedOptimizationManager(self._client, n_trials) 180 | if self._client is not None and not isinstance(self._client.cluster, LocalCluster) 181 | else LocalOptimizationManager(n_trials, n_jobs) 182 | ) 183 | 184 | if isinstance(manager, LocalOptimizationManager) and sys.platform == "win32": 185 | raise ValueError( 186 | "Local asynchronous optimization is currently not supported on Windows. " 187 | "Please specify Dask client to continue in distributed mode." 188 | ) 189 | 190 | try: 191 | event_loop = EventLoop(self._study, manager, objective=func, interrupt_patience=10.0) 192 | event_loop.run(terminal, timeout, catch) 193 | 194 | except KeyboardInterrupt: 195 | with terminal.spin_while_trials_interrupted(): 196 | manager.stop_optimization(patience=10.0) 197 | 198 | states = (TrialState.RUNNING, TrialState.WAITING) 199 | trials = self._study.get_trials(deepcopy=False, states=states) 200 | for trial in trials: 201 | self._study._storage.set_trial_state_values(trial._trial_id, TrialState.FAIL) 202 | raise 203 | 204 | finally: 205 | self._study._storage.remove_session() 206 | 207 | def ask(self, fixed_distributions: dict[str, BaseDistribution] | None = None) -> Trial: 208 | """Create a new trial from which hyperparameters can be suggested. 209 | 210 | For complete documentation, please refer to: 211 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.ask 212 | 213 | Args: 214 | fixed_distributions: 215 | A dictionary containing the parameter names and parameter's distributions. 216 | """ 217 | return self._study.ask(fixed_distributions) 218 | 219 | def tell( 220 | self, 221 | trial: Trial | int, 222 | values: float | Sequence[float] | None = None, 223 | state: TrialState | None = None, 224 | skip_if_finished: bool = False, 225 | ) -> FrozenTrial: 226 | """Finish a trial created with :func:`~optuna_distributed.study.DistributedStudy.ask`. 227 | 228 | For complete documentation, please refer to: 229 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.tell 230 | 231 | Args: 232 | trial: 233 | A :obj:`optuna.trial.Trial` object or a trial number. 234 | values: 235 | Optional objective value or a sequence of such values in case the study is used 236 | for multi-objective optimization. 237 | state: 238 | State to be reported. 239 | skip_if_finished: 240 | Flag to control whether exception should be raised when values for already 241 | finished trial are told. 242 | """ 243 | return self._study.tell(trial, values, state, skip_if_finished) 244 | 245 | def set_user_attr(self, key: str, value: Any) -> None: 246 | """Set a user attribute to the study. 247 | 248 | For complete documentation, please refer to: 249 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.set_user_attr 250 | 251 | Args: 252 | key: A key string of the attribute. 253 | value: A value of the attribute. The value should be JSON serializable. 254 | """ 255 | self._study.set_user_attr(key, value) 256 | 257 | def set_system_attr(self, key: str, value: Any) -> None: 258 | """Set a system attribute to the study. 259 | 260 | For complete documentation, please refer to: 261 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.set_system_attr 262 | 263 | Args: 264 | key: A key string of the attribute. 265 | value: A value of the attribute. The value should be JSON serializable 266 | """ 267 | self._study.set_system_attr(key, value) 268 | 269 | def trials_dataframe( 270 | self, 271 | attrs: tuple[str, ...] = ( 272 | "number", 273 | "value", 274 | "datetime_start", 275 | "datetime_complete", 276 | "duration", 277 | "params", 278 | "user_attrs", 279 | "system_attrs", 280 | "state", 281 | ), 282 | multi_index: bool = False, 283 | ) -> "pd.DataFrame": 284 | """Export trials as a pandas DataFrame. 285 | 286 | For complete documentation, please refer to: 287 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.trials_dataframe 288 | 289 | Args: 290 | attrs: 291 | Specifies field names of :obj:`optuna.trial.FrozenTrial` to include them to a 292 | DataFrame of trials. 293 | multi_index: 294 | Specifies whether the returned DataFrame employs MultiIndex or not. 295 | """ 296 | return self._study.trials_dataframe(attrs, multi_index) 297 | 298 | def stop(self) -> None: 299 | """Exit from the current optimization loop after the running trials finish. 300 | 301 | This method is effectively a noop, sice there is no way to reach study from the 302 | objective function at the moment. TODO(xadrianzetx) Implement this. 303 | """ 304 | self._study.stop() 305 | 306 | def enqueue_trial( 307 | self, 308 | params: dict[str, Any], 309 | user_attrs: dict[str, Any] | None = None, 310 | skip_if_exists: bool = False, 311 | ) -> None: 312 | """Enqueue a trial with given parameter values. 313 | 314 | For complete documentation, please refer to: 315 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.enqueue_trial 316 | 317 | Args: 318 | params: 319 | Parameter values to pass your objective function. 320 | user_attrs: 321 | A dictionary of user-specific attributes other than :obj:`params`. 322 | skip_if_exists: 323 | When :obj:`True`, prevents duplicate trials from being enqueued again. 324 | """ 325 | self._study.enqueue_trial(params, user_attrs, skip_if_exists) 326 | 327 | def add_trial(self, trial: FrozenTrial) -> None: 328 | """Add trial to study. 329 | 330 | For complete documentation, please refer to: 331 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.add_trial 332 | 333 | Args: 334 | trial: Trial to add. 335 | """ 336 | self._study.add_trial(trial) 337 | 338 | def add_trials(self, trials: Iterable[FrozenTrial]) -> None: 339 | """Add trials to study. 340 | 341 | For complete documentation, please refer to: 342 | https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.add_trials 343 | 344 | Args: 345 | trial: Trials to add. 346 | """ 347 | self._study.add_trials(trials) 348 | 349 | 350 | def from_study(study: Study, client: Client | None = None) -> DistributedStudy: 351 | """Takes regular Optuna study and extends it to :class:`~optuna_distributed.DistributedStudy`. 352 | 353 | This creates an object which behaves like regular Optuna study, except trials 354 | will be evaluated in parallel after :func:`optuna_distributed.DistributedStudy.optimize` 355 | is called. When :obj:`client` is :obj:`None`, work is distributed among available CPU cores 356 | by using multiprocessing. If Dask client is specified, `optuna_distributed` can use it to 357 | distribute trials across many physical workers in the cluster. 358 | 359 | .. note:: 360 | Using `optuna_distributed` in distributed mode requires a Dask cluster with matching 361 | environment. To read more about the deployment and usage of Dask clusters, please refer 362 | to https://docs.dask.org/en/stable/deploying.html. 363 | 364 | .. note:: 365 | Any APIs besides :func:`optuna_distributed.DistributedStudy.optimize` are just 366 | passthrough to regular Optuna study and can be used in standard ways. 367 | 368 | .. note:: 369 | There are no known compatibility issues at the moment. All Optuna storages, samplers 370 | and pruners can be used. 371 | 372 | Args: 373 | study: 374 | A regular Optuna study isntance. 375 | client: 376 | Dask client, as described in https://distributed.dask.org/en/stable/client.html#client 377 | """ 378 | return DistributedStudy(study, client) 379 | --------------------------------------------------------------------------------