├── submitit ├── py.typed ├── auto │ ├── __init__.py │ ├── test_auto.py │ └── auto.py ├── core │ ├── __init__.py │ ├── _submit.py │ ├── logger.py │ ├── test_async.py │ ├── submission.py │ ├── test_utils.py │ ├── plugins.py │ ├── test_plugins.py │ ├── test_core.py │ ├── job_environment.py │ └── utils.py ├── local │ ├── __init__.py │ ├── _local.py │ ├── test_debug.py │ ├── debug.py │ ├── test_local.py │ └── local.py ├── slurm │ ├── __init__.py │ └── _sbatch_test_record.txt ├── __init__.py ├── conftest.py ├── test_pickle.py ├── test_documentation.py ├── test_helpers.py └── helpers.py ├── .pre-commit-config.yaml ├── docs ├── plugins.md ├── examples │ └── torch_distributed.py ├── tips.md ├── checkpointing.md ├── nevergrad.md ├── structure.md ├── mnist.py └── examples.md ├── LICENSE ├── .github ├── CONTRIBUTING.md ├── workflows │ └── ci.yaml └── CODE_OF_CONDUCT.md ├── .gitignore ├── pyproject.toml ├── Makefile ├── integration └── preemption.py └── README.md /submitit/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /submitit/auto/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | -------------------------------------------------------------------------------- /submitit/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | -------------------------------------------------------------------------------- /submitit/local/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | -------------------------------------------------------------------------------- /submitit/slurm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.2.3 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-added-large-files 8 | -------------------------------------------------------------------------------- /submitit/core/_submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | from submitit.core.submission import submitit_main 8 | 9 | if __name__ == "__main__": 10 | # This script is called by Executor.submit 11 | submitit_main() 12 | -------------------------------------------------------------------------------- /submitit/local/_local.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import sys 8 | from pathlib import Path 9 | 10 | from .local import Controller 11 | 12 | if __name__ == "__main__": 13 | assert len(sys.argv) == 2, "Usage: _local.py " 14 | # most arguments are read from environment variables. 15 | controller = Controller(Path(sys.argv[1])) 16 | controller.run() 17 | -------------------------------------------------------------------------------- /submitit/slurm/_sbatch_test_record.txt: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Parameters 4 | #SBATCH --blublu=12 5 | #SBATCH --error=/tmp/%j_0_log.err 6 | #SBATCH --exclusive 7 | #SBATCH --job-name=submitit 8 | #SBATCH --nodes=1 9 | #SBATCH --open-mode=append 10 | #SBATCH --output=/tmp/%j_0_log.out 11 | #SBATCH --partition=learnfair 12 | #SBATCH --signal=USR2@90 13 | #SBATCH --time=5 14 | #SBATCH --wckey=submitit 15 | 16 | # command 17 | export SUBMITIT_EXECUTOR=slurm 18 | srun --unbuffered --output /tmp/%j_%t_log.out --error /tmp/%j_%t_log.err -vv --cpu-bind none blublu bar 19 | -------------------------------------------------------------------------------- /submitit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """"Python 3.8+ toolbox for submitting jobs to Slurm""" 7 | 8 | # allow explicit reimports (mypy) by renaming all imports 9 | from . import helpers as helpers 10 | from .auto.auto import AutoExecutor as AutoExecutor 11 | from .core.core import Executor as Executor 12 | from .core.core import Job as Job 13 | from .core.job_environment import JobEnvironment as JobEnvironment 14 | from .local.debug import DebugExecutor as DebugExecutor 15 | from .local.debug import DebugJob as DebugJob 16 | from .local.local import LocalExecutor as LocalExecutor 17 | from .local.local import LocalJob as LocalJob 18 | from .slurm.slurm import SlurmExecutor as SlurmExecutor 19 | from .slurm.slurm import SlurmJob as SlurmJob 20 | 21 | __version__ = "1.5.4" 22 | -------------------------------------------------------------------------------- /submitit/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import time 8 | from pathlib import Path 9 | 10 | import pytest 11 | 12 | from .local.local import LocalExecutor 13 | 14 | 15 | @pytest.fixture() 16 | def executor(tmp_path: Path) -> LocalExecutor: 17 | return LocalExecutor(tmp_path) 18 | 19 | 20 | @pytest.fixture(params=["a_0", "a 0", 'a"=0"', "a'; echo foo", r"a\=0", r"a\=", "a\n0"]) 21 | def weird_tmp_path(request, tmp_path: Path) -> Path: 22 | return tmp_path / request.param 23 | 24 | 25 | @pytest.fixture() 26 | def fast_forward_clock(monkeypatch): 27 | """Allows to go in the future.""" 28 | clock_time = [time.time()] 29 | 30 | monkeypatch.setattr(time, "time", lambda: clock_time[0]) 31 | 32 | def _fast_forward(minutes: float): 33 | clock_time[0] += minutes * 60 34 | 35 | return _fast_forward 36 | -------------------------------------------------------------------------------- /docs/plugins.md: -------------------------------------------------------------------------------- 1 | # Plugins 2 | 3 | In order to switch between executing on Slurm and another cluster, 4 | `submitit` provides a plugin API. 5 | Each plugin must implement an `Executor`, a `Job`, an `InfoWatcher` and a `JobEnvironment` class. 6 | Look at [structure.md](./structure.md) for more details on those classes. 7 | 8 | Main functions to implement: 9 | - `Executor.submit`: from a function create a `Job` using the correct log files and python executable 10 | - `Executor._convert_parameters`: convert standardized parameters to cluster specific ones 11 | - `InfoWatcher.get_info`: given a job id, get the state of the job (pending, running, ...) 12 | - `JobEnviroment`: setup signal handlers and requeuing to behave nicely in the cluster 13 | 14 | Look for `@plugin-dev` mention in comments for more details. 15 | 16 | Plugins must have an `entry_points.txt` file with the following keys: 17 | 18 | ``` 19 | [submitit] 20 | executor=my_plugin:MyExecutor 21 | job_environment=my_plugin:MyJobEnvironment 22 | ``` 23 | 24 | See [packaging](https://packaging.python.org/guides/creating-and-discovering-plugins/#using-package-metadata) documentation for more details. 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /submitit/core/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import logging.config 8 | import os 9 | from typing import Union 10 | 11 | # provide a way to change level through SUBMITIT_LOG_LEVEL environment variable: 12 | # level "CRITICAL" (50) or more (eg.: "100") will deactivate submitit logger 13 | # "NOCONFIG" will avoid configuration 14 | LOG_VARNAME = "SUBMITIT_LOG_LEVEL" 15 | level_str = os.environ.get(LOG_VARNAME, "INFO").upper() 16 | level: Union[int, str] = level_str if not level_str.isdigit() else int(level_str) 17 | 18 | 19 | CONFIG = { 20 | "version": 1, 21 | "disable_existing_loggers": False, 22 | "formatters": {"submitit_basic": {"format": "%(name)s %(levelname)s (%(asctime)s) - %(message)s"}}, 23 | "handlers": { 24 | "submitit_out": { 25 | "class": "logging.StreamHandler", 26 | "level": "DEBUG", 27 | "formatter": "submitit_basic", 28 | "stream": "ext://sys.stdout", 29 | }, 30 | "submitit_err": { 31 | "class": "logging.StreamHandler", 32 | "level": "WARNING", 33 | "formatter": "submitit_basic", 34 | "stream": "ext://sys.stderr", 35 | }, 36 | }, 37 | "loggers": {"submitit": {"handlers": ["submitit_err", "submitit_out"], "level": level}}, 38 | } 39 | 40 | 41 | if level != "NOCONFIG": 42 | logging.config.dictConfig(CONFIG) 43 | 44 | 45 | def get_logger() -> logging.Logger: 46 | return logging.getLogger("submitit") 47 | 48 | 49 | def exception(*args: str) -> None: 50 | get_logger().exception(*args) 51 | 52 | 53 | def warning(*args: str) -> None: 54 | get_logger().warning(*args) 55 | -------------------------------------------------------------------------------- /submitit/core/test_async.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import asyncio 8 | from pathlib import Path 9 | 10 | import pytest 11 | 12 | from . import submission, utils 13 | from .test_core import FakeExecutor, _three_time 14 | 15 | 16 | @pytest.mark.asyncio 17 | async def test_result(tmp_path: Path) -> None: 18 | event_loop = asyncio.get_running_loop() 19 | executor = FakeExecutor(folder=tmp_path) 20 | job = executor.submit(_three_time, 8) 21 | result_task = event_loop.create_task(job.awaitable().result()) 22 | with utils.environment_variables(_TEST_CLUSTER_="slurm", SLURM_JOB_ID=str(job.job_id)): 23 | submission.process_job(folder=job.paths.folder) 24 | result = await result_task 25 | assert result == 24 26 | 27 | 28 | @pytest.mark.asyncio 29 | async def test_results_single(tmp_path: Path) -> None: 30 | event_loop = asyncio.get_running_loop() 31 | executor = FakeExecutor(folder=tmp_path) 32 | job = executor.submit(_three_time, 8) 33 | result_task = event_loop.create_task(job.awaitable().results()) 34 | with utils.environment_variables(_TEST_CLUSTER_="slurm", SLURM_JOB_ID=str(job.job_id)): 35 | submission.process_job(folder=job.paths.folder) 36 | result = await result_task 37 | assert result == [24] 38 | 39 | 40 | @pytest.mark.asyncio 41 | async def test_results_ascompleted_single(tmp_path: Path) -> None: 42 | executor = FakeExecutor(folder=tmp_path) 43 | job = executor.submit(_three_time, 8) 44 | with utils.environment_variables(_TEST_CLUSTER_="slurm", SLURM_JOB_ID=str(job.job_id)): 45 | submission.process_job(folder=job.paths.folder) 46 | count = 0 47 | for aws in job.awaitable().results_as_completed(): 48 | result = await aws 49 | count += 1 50 | assert result == 24 51 | assert count == 1 52 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to _submitit_ 2 | We want to make contributing to this project as easy and transparent as possible. 3 | 4 | ## Our Development Process 5 | _submitit_ is actively used by FAIR researcher and engineers. 6 | All bugs tracking and feature plannings are public. 7 | _submitit_ will be updated to keep up with Slurm versions and to fix bug, 8 | but we don't have any major feature planned ahead. 9 | 10 | 11 | ## Pull Requests 12 | We actively welcome your pull requests. 13 | 14 | 1. Fork the repo and create your branch from `main`. 15 | 2. Create a virtual environment and activate it: `make venv && . venv/bin/activate` 16 | 3. If you've added code please add tests. 17 | 4. If you've changed APIs, please update the documentation. 18 | 5. Ensure the test suite passes: `make test` 19 | 6. Make sure your code lints: `make pre_commit`. You can run this automatically on commit by making it a hook: `make register_pre_commit` 20 | 7. When ready you can run the full test suits ran on CI with `make -k integration` 21 | 8. If you haven't already, complete the Contributor License Agreement ("CLA"). 22 | 23 | ## Contributor License Agreement ("CLA") 24 | In order to accept your pull request, we need you to submit a CLA. You only need 25 | to do this once to work on any of Facebook's open source projects. 26 | 27 | Complete your CLA here: 28 | 29 | ## Issues 30 | We use GitHub issues to track public bugs. Please ensure your description is 31 | clear and has sufficient instructions to be able to reproduce the issue. 32 | 33 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 34 | disclosure of security bugs. In those cases, please go through the process 35 | outlined on that page and do not file a public issue. 36 | 37 | ## Coding Style 38 | We use black coding style with a generous 110 line length. 39 | 40 | ## License 41 | By contributing to _submitit_, you agree that your contributions will be licensed 42 | under the LICENSE file in the root directory of this source tree. 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project specific folders 2 | /docs/mnist_logs/ 3 | /integration/logs/ 4 | 5 | # Byte-compiled / optimized / DLL files / tmp 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | *.swp 10 | 11 | # C extensions 12 | *.so 13 | 14 | # OS specific files 15 | .DS_Store 16 | 17 | # Distribution / packaging / data storage 18 | data/ 19 | outputs/ 20 | nevergrad_repository/ 21 | .Python 22 | env/ 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | /pip-wheel-metadata/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | test_results/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # dotenv 96 | .env 97 | 98 | # virtualenv 99 | .venv 100 | venv/ 101 | ENV/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | 116 | # vscode 117 | .vscode/ 118 | 119 | # pytest 120 | .pytest_cache/ 121 | 122 | # pycharm 123 | .idea 124 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: Build, and run tests lint and format 2 | env: 3 | IN_GITHUB_ACTION: 1 4 | 5 | on: [push] 6 | 7 | jobs: 8 | build-linux: 9 | # require 8-core machines (Github Actions Larger Runners) to have more than 14GB disk space 10 | runs-on: 8-core-ubuntu # ubuntu-latest 11 | strategy: 12 | max-parallel: 5 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python 3.8 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.8' 20 | 21 | # Building/caching the environment 22 | 23 | - name: Add conda to system path 24 | run: | 25 | # $CONDA is an environment variable pointing to the root of the miniconda directory 26 | echo $CONDA 27 | echo $CONDA/bin >> $GITHUB_PATH 28 | echo $CONDA_PREFIX 29 | 30 | - name: Cache conda env 31 | id: cache-conda 32 | uses: actions/cache@v4 33 | env: 34 | # change name here (only) to invalidate cache 35 | cache-name: cache-conda-env-v0 36 | with: 37 | key: ${{ env.cache-name }}-${{ hashFiles('pyproject.toml') }} 38 | path: ./ci_env 39 | 40 | - name: Create conda env & Install dependencies 41 | run: | 42 | sudo apt-get update 43 | sudo apt-get install rsync 44 | if [ ! -d "./ci_env" ]; then \ 45 | # creates the env if it does not exist (not loaded from cache) 46 | conda create -p ./ci_env python=3.10 ipython -y 47 | fi 48 | source activate ./ci_env 49 | pip install --progress-bar off --upgrade pip 50 | pip install --progress-bar off -U -e .[dev] 51 | 52 | - name: Print installed packages 53 | run: | 54 | source activate ./ci_env 55 | pip freeze 56 | 57 | # start checks 58 | 59 | - name: Test lint 60 | run: | 61 | source activate ./ci_env 62 | pip show mypy 63 | make use_venv=0 lint 64 | 65 | - name: Test coverage 66 | run: | 67 | source activate ./ci_env 68 | make use_venv=0 test_coverage 69 | 70 | - name: Test format 71 | run: | 72 | source activate ./ci_env 73 | make use_venv=0 format 74 | 75 | 76 | -------------------------------------------------------------------------------- /submitit/test_pickle.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import pickle 8 | from weakref import ref 9 | 10 | import pytest 11 | 12 | from .local.debug import DebugExecutor 13 | from .local.local import LocalExecutor 14 | 15 | 16 | def job_with_weakref(ex): 17 | class MyObject: 18 | hello = "world" 19 | 20 | a = MyObject() 21 | a_ref = ref(a) 22 | assert a_ref() is a 23 | 24 | def f(a_ref): 25 | a = a_ref() 26 | assert a is not None 27 | return a_ref().hello 28 | 29 | return ex.submit(f, ref(a)) 30 | 31 | 32 | @pytest.mark.xfail(reason="'a' is GC-ed before we call the function") 33 | def test_weakref_no_pickle(tmp_path): 34 | ex = DebugExecutor(tmp_path) 35 | assert job_with_weakref(ex).result() == "world" 36 | 37 | 38 | @pytest.mark.xfail(reason="'ref(a)' can't be pickled") 39 | def test_weakref_with_pickle(tmp_path): 40 | ex = LocalExecutor(tmp_path) 41 | assert job_with_weakref(ex).result() == "world" 42 | 43 | 44 | def hello_fn() -> None: 45 | print("hello world") 46 | 47 | 48 | def test_nested_pickling(tmp_path): 49 | def make_pickle() -> bytes: 50 | return pickle.dumps(hello_fn) 51 | 52 | pkl = make_pickle() 53 | assert bytes(__name__, "ascii") in pkl 54 | assert b"hello_fn" in pkl 55 | ex = LocalExecutor(tmp_path) 56 | j = ex.submit(make_pickle) 57 | assert j.result() == pkl 58 | 59 | 60 | @pytest.mark.xfail(reason="Submitit changes __main__") 61 | def test_submitit_respects_main(tmp_path): 62 | # TODO: I think this is the root cause of issue #11 63 | # https://github.com/facebookincubator/submitit/issues/11 64 | # Some programs like pytorch-lightning are dependent on the value of __main__ 65 | # See how `pdb` manage to restore the correct __main__: 66 | # https://sourcegraph.com/github.com/python/cpython/-/blob/Lib/pdb.py#L1549 67 | # But maybe we could fix #11 by just using 68 | # `from submitit.core.submission import submitit_main` 69 | # as in https://github.com/facebookincubator/submitit/issues/11#issuecomment-713148952 70 | 71 | def get_main() -> str: 72 | # pylint: disable=import-outside-toplevel 73 | import __main__ # type: ignore 74 | 75 | return getattr(__main__, "__file__", "") 76 | 77 | main = get_main() 78 | ex = LocalExecutor(tmp_path) 79 | j_main = ex.submit(get_main).result() 80 | assert main == j_main 81 | -------------------------------------------------------------------------------- /docs/examples/torch_distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | import os 10 | import sys 11 | import time 12 | 13 | import torch 14 | 15 | import submitit 16 | 17 | NUM_NODES = 2 18 | NUM_TASKS_PER_NODE = 8 19 | 20 | 21 | NUM_CPUS_PER_TASK = 1 22 | PARTITION = "devlab" 23 | LOGS_DIR = "logs" 24 | 25 | 26 | def print_env(): 27 | for key in sorted(os.environ.keys()): 28 | if not ( 29 | key.startswith(("SLURM_", "SUBMITIT_")) 30 | or key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK", "LOCAL_WORLD_SIZE") 31 | ): 32 | continue 33 | value = os.environ[key] 34 | print(f"{key}={value}") 35 | 36 | 37 | class Task: 38 | def __call__(self): 39 | # print_env() 40 | print("exporting PyTorch distributed environment variables") 41 | dist_env = submitit.helpers.TorchDistributedEnvironment().export() 42 | print(f"master: {dist_env.master_addr}:{dist_env.master_port}") 43 | print(f"rank: {dist_env.rank}") 44 | print(f"world size: {dist_env.world_size}") 45 | print(f"local rank: {dist_env.local_rank}") 46 | print(f"local world size: {dist_env.local_world_size}") 47 | # print_env() 48 | 49 | # Using the (default) env:// initialization method 50 | torch.distributed.init_process_group(backend="nccl") 51 | assert dist_env.rank == torch.distributed.get_rank() 52 | assert dist_env.world_size == torch.distributed.get_world_size() 53 | 54 | # Actual task / computation 55 | tensor = dist_env.rank * torch.ones(1).cuda() 56 | 57 | time.sleep(120) 58 | 59 | torch.distributed.all_reduce(tensor) 60 | if dist_env.rank == 0: 61 | result = list(tensor) 62 | print(result) 63 | return result 64 | 65 | def checkpoint(self): 66 | print("checkpointing") 67 | return submitit.helpers.DelayedSubmission(self) 68 | 69 | 70 | def main(): 71 | executor = submitit.AutoExecutor(folder=LOGS_DIR) 72 | executor.update_parameters( 73 | nodes=NUM_NODES, 74 | gpus_per_node=NUM_TASKS_PER_NODE, 75 | tasks_per_node=NUM_TASKS_PER_NODE, 76 | cpus_per_task=NUM_CPUS_PER_TASK, 77 | slurm_partition=PARTITION, 78 | ) 79 | task = Task() 80 | job = executor.submit(task) 81 | submitit.helpers.monitor_jobs([job]) 82 | print(job.results()[0]) 83 | return 0 84 | 85 | 86 | if __name__ == "__main__": 87 | sys.exit(main()) 88 | -------------------------------------------------------------------------------- /docs/tips.md: -------------------------------------------------------------------------------- 1 | # Tips and caveats 2 | 3 | - It is always preferable to submit functions defined in a module, the traceback will be more explicit in case of error during the execution. 4 | - Modules imported after added their paths through `sys.path.append` cannot be correctly pickled. If you can, restructure your code to avoid appending paths this way. 5 | If you cannot, then an ugly hack consists in using `sys.path.append` lazily, i.e. *within* a function, and then importe the required module. 6 | - Imports order may not be respected. If this causes issue, you can import lazily as above. Another option is to wrap your import into another module in which the order can be respected. 7 | - On SLURM, use the flush option of print to avoid logs to be buffered `print(text, flush=True)`. 8 | - Contributors are much welcome! You'll probably find some weird behaviors, you can open an issue if so, and solve it if you can ;) 9 | - the API may still evolve, in particular regarding the locations of pickled data and logs and how they are managed. It is preferable to use a fixed version if you do not want to have any compatibility issue at some point. Use the lastest "release" version that suits you. 10 | - since the pickled function are references to the module function, if the module changes between the submission and the start 11 | of the computation, then the computation may not be the one you expect. Similarly, if at the start of the computation, one file cannot 12 | be run (if you are currently editing it for instance), then the computation will fail. Joblib implements some kind of version check it seems, 13 | which could be handy. 14 | - Do not hesitate to create your own `Executor` class: this can be useful to control more precisely how your job are submitted. 15 | In about 10 lines of code, you can for instance have an executor which creates a new logging folder for each submitted jobs etc... 16 | - Some non-picklable objects like locks cannot be submited. This may cause issue if they are used as default arguments of a function. 17 | 18 | ## Specific to Slurm 19 | - While all jobs are requeued after preemption, only Checkpointable classes are requeued after a timeout (since a stateless function is expected to timeout again if it is requeued) 20 | - Timeouts are requeued a limited number of time (default: 3, see `Executor`) in order to avoid endless jobs. 21 | - the log/output folder must be chosen carefully: it must be a directory shared between instances, which rules out /tmp. If you 22 | are not careful with this, you may get job that fails without any log to debug them. Also, it will fill up pretty fast with batch 23 | files and pickled objects, but no cleaning mechanism is currently implemented. 24 | 25 | ## Debugging 26 | 27 | If you want to add breakpoints to a function run by `submitit`, the easiest way is to `AutoExecutor(cluster="debug")`. 28 | This will execute all the "submitted" jobs inside the main process, and your breakpoints will be hit normally. 29 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.2,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "submitit" 7 | readme = "README.md" 8 | authors = [{name = "Facebook AI Research"}] 9 | requires-python = ">=3.8" 10 | dynamic = ["version", "description"] 11 | 12 | dependencies = [ 13 | "cloudpickle>=1.2.1", 14 | "typing_extensions>=3.7.4.2" 15 | ] 16 | # zip_safe = false 17 | classifiers=[ 18 | "License :: OSI Approved :: MIT License", 19 | "Topic :: System :: Distributed Computing", 20 | "Development Status :: 5 - Production/Stable", 21 | ] 22 | 23 | [project.urls] 24 | Source = "https://github.com/facebookincubator/submitit" 25 | Tracker = "https://github.com/facebookincubator/submitit/issues" 26 | 27 | [project.optional-dependencies] 28 | dev = [ 29 | # Test 30 | "pytest>=7.4.2", 31 | "pytest-asyncio>=0.15.0", 32 | "pytest-cov>=4.1.0", 33 | "coverage[toml]>=5.1", 34 | # Format 35 | "black==23.3.0", 36 | "isort==5.11.5", 37 | "pre-commit>=1.15.2", 38 | # Linters 39 | "mypy>=1.4.1", 40 | "pylint>=3.0.0", 41 | # Release 42 | "flit>=3.5.1" 43 | ] 44 | 45 | [tool.black] 46 | line-length = 110 47 | exclude = ''' 48 | /( 49 | | \.git 50 | | \.mypy_cache 51 | | venv 52 | )/ 53 | ''' 54 | 55 | [tool.isort] 56 | profile = "black" 57 | line_length = 110 58 | skip_gitignore = true 59 | 60 | 61 | [tool.pylint] 62 | [tool.pylint."MESSAGES CONTROL"] 63 | # disabled messages 64 | # * no-member has a lot of false positive, mypy does it better 65 | disable = """ 66 | broad-except, 67 | fixme, 68 | invalid-name, 69 | logging-fstring-interpolation, 70 | missing-docstring, 71 | no-else-return, 72 | no-member, 73 | protected-access, 74 | too-few-public-methods, 75 | useless-import-alias, 76 | unspecified-encoding, 77 | too-many-positional-arguments, 78 | """ 79 | [tool.pylint.DESIGN] 80 | max-args = 6 81 | 82 | [tool.pylint.FORMAT] 83 | max-line-length = "140" 84 | 85 | [tool.pylint.SIMILARITIES] 86 | ignore-imports = "yes" 87 | 88 | 89 | [tool.coverage] 90 | [tool.coverage.run] 91 | omit = ["*/test_*.py", "/tmp/pytest*/*"] 92 | data_file = "test_results/coverage/coverage.bin" 93 | 94 | [tool.coverage.html] 95 | directory = "test_results/coverage_html" 96 | 97 | [tool.coverage.xml] 98 | output = "test_results/coverage/coverage.xml" 99 | 100 | [tool.coverage.report] 101 | fail_under = 90 102 | exclude_lines = [ 103 | "pragma: no cover", # Re-enable the standard pragma 104 | "raise NotImplementedError", 105 | "^\\s+\\.\\.\\.$", 106 | ] 107 | 108 | [tool.mypy] 109 | show_error_codes = true 110 | 111 | [[tool.mypy.overrides]] 112 | module = ['cloudpickle', 'ipdb', 'pytest', 'setuptools'] 113 | ignore_missing_imports = true 114 | -------------------------------------------------------------------------------- /submitit/core/submission.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import argparse 8 | import os 9 | import time 10 | import traceback 11 | from pathlib import Path 12 | from typing import Union 13 | 14 | try: # loading numpy before loading the pickle, to avoid unexpected interactions 15 | # pylint: disable=unused-import 16 | import numpy # type: ignore # noqa 17 | except ImportError: 18 | pass 19 | 20 | from . import job_environment, utils 21 | from .logger import get_logger 22 | 23 | 24 | def process_job(folder: Union[Path, str]) -> None: 25 | """Loads a pickled job, runs it and pickles the output 26 | 27 | Parameter 28 | --------- 29 | folder: Path/str 30 | path of the folder where the job pickle will be stored (with a name containing its uuid) 31 | 32 | Side-effect 33 | ----------- 34 | Creates a picked output file next to the job file. 35 | """ 36 | os.environ["SUBMITIT_FOLDER"] = str(folder) 37 | env = job_environment.JobEnvironment() 38 | paths = env.paths 39 | logger = get_logger() 40 | logger.info(f"Starting with {env}") 41 | logger.info(f"Loading pickle: {paths.submitted_pickle}") 42 | wait_time = 60 43 | for _ in range(wait_time): 44 | if paths.submitted_pickle.exists(): 45 | break 46 | time.sleep(1) 47 | if not paths.submitted_pickle.exists(): 48 | raise RuntimeError( 49 | f"Waited for {wait_time} seconds but could not find submitted jobs in path:\n{paths.submitted_pickle}" 50 | ) 51 | try: 52 | delayed = utils.DelayedSubmission.load(paths.submitted_pickle) 53 | env = job_environment.JobEnvironment() 54 | env._handle_signals(paths, delayed) 55 | result = delayed.result() 56 | logger.info("Job completed successfully") 57 | del delayed # if it blocks here, you have a race condition that must be solved! 58 | with utils.temporary_save_path(paths.result_pickle) as tmppath: # save somewhere else, and move 59 | utils.cloudpickle_dump(("success", result), tmppath) 60 | del result 61 | logger.info("Exiting after successful completion") 62 | except Exception as error: # TODO: check pickle methods for capturing traceback; pickling and raising 63 | try: 64 | with utils.temporary_save_path(paths.result_pickle) as tmppath: 65 | utils.cloudpickle_dump(("error", traceback.format_exc()), tmppath) 66 | except Exception as dumperror: 67 | logger.error(f"Could not dump error:\n{error}\n\nbecause of {dumperror}") 68 | logger.error("Submitted job triggered an exception") 69 | raise error 70 | 71 | 72 | def submitit_main() -> None: 73 | parser = argparse.ArgumentParser(description="Run a job") 74 | parser.add_argument("folder", type=str, help="Folder where the jobs are stored (in subfolder)") 75 | args = parser.parse_args() 76 | process_job(args.folder) 77 | -------------------------------------------------------------------------------- /submitit/test_documentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import re 8 | import typing as tp 9 | from pathlib import Path 10 | 11 | import submitit 12 | 13 | 14 | class MarkdownLink: 15 | """Handle to a markdown link, for easy existence test and printing 16 | (external links are not tested) 17 | """ 18 | 19 | regex = re.compile(r"\[(?P.+?)\]\((?P\S+?)\)") 20 | 21 | def __init__(self, root: Path, file: Path, name: str, link: str) -> None: 22 | self.root = root 23 | self.file = file 24 | self.name = name 25 | self.link = link 26 | 27 | def exists(self) -> bool: 28 | if self.link.startswith("http"): 29 | # We don't check external urls. 30 | return True 31 | link = self.link.split("#")[0] 32 | if not link: 33 | return False 34 | fullpath = self.root / self.file.parent / link 35 | return fullpath.exists() 36 | 37 | def __repr__(self) -> str: 38 | return f"[{self.link}]({self.name}) in file {self.file}" 39 | 40 | 41 | def _get_root() -> Path: 42 | root = Path(__file__).parent.parent.absolute() 43 | assert (root / "pyproject.toml").exists(), f"Wrong root folder: {root}" 44 | return root 45 | 46 | 47 | def _get_markdown_files(root: Path) -> tp.List[Path]: 48 | return [md for pattern in ("*.md", "submitit/**/*.md", "docs/**/*.md") for md in root.glob(pattern)] 49 | 50 | 51 | def _get_all_markdown_links(root: Path, files: tp.List[Path]) -> tp.List[MarkdownLink]: 52 | """Returns a list of all existing markdown links""" 53 | pattern = MarkdownLink.regex 54 | links = [] 55 | for file in files: 56 | for match in pattern.finditer(file.read_text()): 57 | links.append(MarkdownLink(root, file, match.group("name"), match.group("link"))) 58 | return links 59 | 60 | 61 | def test_assert_markdown_links_not_broken() -> None: 62 | root = _get_root() 63 | files = _get_markdown_files(root) 64 | assert len(files) > 3 65 | 66 | links = _get_all_markdown_links(root, files) 67 | assert len(links) > 5, "There should be several hyperlinks!" 68 | broken_links = [l for l in links if not l.exists()] 69 | assert not broken_links 70 | 71 | 72 | def _replace_relative_links(regex: tp.Match[str]) -> str: 73 | """Converts relative links into links to master 74 | so that links on Pypi long description are correct 75 | """ 76 | string: str = regex.group() 77 | link = regex.group("link") 78 | name = regex.group("name") 79 | version = submitit.__version__ 80 | if not link.startswith("http") and Path(link).exists(): 81 | github_url = f"github.com/facebookincubator/submitit/blob/{version}" 82 | string = f"[{name}](https://{github_url}/{link})" 83 | return string 84 | 85 | 86 | def expand_links(): 87 | readme = _get_root() / "README.md" 88 | assert readme.exists() 89 | 90 | desc = readme.read_text(encoding="utf-8") 91 | desc = re.sub(MarkdownLink.regex, _replace_relative_links, desc) 92 | readme.write_text(desc) 93 | 94 | 95 | if __name__ == "__main__": 96 | expand_links() 97 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Run the linter and tests suite of this repo. 2 | 3 | # `make use_venv=0` will use the default `python` otherwise uses the one in `.venv/` 4 | use_venv=1 5 | ifeq ($(use_venv),0) 6 | BIN?= 7 | else 8 | BIN=venv/bin/ 9 | endif 10 | 11 | CODE=submitit 12 | CODE_AND_DOCS=$(CODE) docs/ integration/ 13 | 14 | all: integration 15 | 16 | which: 17 | which $(BIN)python 18 | $(BIN)python --version 19 | 20 | test: 21 | $(BIN)pytest $(CODE) 22 | 23 | test_coverage: 24 | $(BIN)pytest \ 25 | -v \ 26 | --cov=submitit --cov-report=html --cov-report=term \ 27 | --durations=10 \ 28 | --junitxml=test_results/pytest/results.xml \ 29 | $(CODE) 30 | 31 | format: 32 | $(BIN)python -m pre_commit 33 | $(BIN)isort $(CODE_AND_DOCS) 34 | $(BIN)black $(CODE_AND_DOCS) 35 | 36 | check_format: 37 | # also formats docs 38 | $(BIN)isort --check --diff $(CODE_AND_DOCS) 39 | $(BIN)black --check --diff $(CODE_AND_DOCS) 40 | 41 | mypy: 42 | $(BIN)mypy --version 43 | $(BIN)mypy --junit-xml=test_results/pytest/results.xml $(CODE) 44 | 45 | pylint: 46 | $(BIN)pylint --version 47 | $(BIN)pylint $(CODE) 48 | 49 | 50 | lint: mypy pylint 51 | 52 | venv: venv/pyproject.toml 53 | 54 | venv/pyproject.toml: pyproject.toml 55 | python3 -m venv venv 56 | venv/bin/pip install --progress-bar off --upgrade pip 57 | venv/bin/pip install --progress-bar off -U -e .[dev] 58 | cp $^ $@ 59 | 60 | installable: installable_local installable_wheel 61 | 62 | installable_local: venv 63 | (. ./venv/bin/activate ; cd /tmp ; python -c "import submitit") 64 | 65 | BUILD=dev0$(CIRCLE_BUILD_NUM) 66 | USER_VENV=/tmp/submitit_user_venv/ 67 | CURRENT_VERSION=`grep -e '__version__' ./submitit/__init__.py | sed 's/__version__ = //' | sed 's/"//g'` 68 | TEST_PYPI=--index-url 'https://test.pypi.org/simple/' --no-cache-dir --no-deps --progress-bar off 69 | 70 | installable_wheel: 71 | [ ! -d dist ] || rm -r dist 72 | # Append .$(BUILD) to the current version 73 | sed -i -e 's/__version__ = "[0-9].[0-9].[0-9]/&.$(BUILD)/' ./submitit/__init__.py 74 | grep -e '__version__' ./submitit/__init__.py | sed 's/__version__ = //' | sed 's/"//g' 75 | $(BIN)python -m flit build --setup-py 76 | git checkout HEAD -- ./submitit/__init__.py 77 | 78 | [ ! -d $(USER_VENV) ] || rm -r $(USER_VENV) 79 | python3 -m venv $(USER_VENV) 80 | $(USER_VENV)/bin/pip install --progress-bar off dist/submitit-*any.whl 81 | # Check that importing works 82 | $(USER_VENV)/bin/python -c "import submitit" 83 | 84 | clean: 85 | rm -r venv 86 | 87 | clean_cache: 88 | # Invalidates `make venv` and therefore trigger a new `pip install --upgrade`. 89 | rm venv/pyproject.toml 90 | 91 | pre_commit: format lint 92 | 93 | register_pre_commit: venv 94 | (grep -e "^make pre_commit$$" .git/hooks/pre-commit) || (echo "make pre_commit" >> .git/hooks/pre-commit) 95 | chmod +x .git/hooks/pre-commit 96 | 97 | integration: clean_cache venv check_format lint installable test_coverage 98 | # Runs the same tests than on CI. 99 | # clean_cache will make sure we download the last versions of linters. 100 | # Use `make -k integration` to run all checks even if previous fails. 101 | 102 | release: integration 103 | echo "Releasing submitit $(CURRENT_VERSION)" 104 | [ ! -d dist ] || rm -r dist 105 | # Make sure the repo is in a clean state 106 | git diff --exit-code 107 | $(BIN)python submitit/test_documentation.py 108 | # --setup-py generates a setup.py file to allow user with old 109 | # versions of pip to install it without flit. 110 | git tag $(CURRENT_VERSION) 111 | # To have a reproducible build we use the timestamp of the last commit: 112 | # https://flit.pypa.io/en/latest/reproducible.html 113 | SOURCE_DATE_EPOCH=`git log -n1 --format=%cd --date=unix` $(BIN)python -m flit publish --setup-py 114 | git checkout HEAD -- README.md 115 | git push origin $(CURRENT_VERSION) 116 | -------------------------------------------------------------------------------- /submitit/local/test_debug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import functools 8 | import os 9 | from pathlib import Path 10 | from typing import Any, Tuple 11 | 12 | import pytest 13 | 14 | from ..core import utils 15 | from ..core.core import Job 16 | from ..core.job_environment import JobEnvironment 17 | from . import debug 18 | 19 | 20 | class CheckFunction: 21 | """Function used for checking that computations are correct""" 22 | 23 | def __init__(self, n: int) -> None: 24 | self.data1 = list(range(n)) 25 | self.data2 = list(range(10, 10 + n)) 26 | 27 | def __call__(self, x: float, y: float) -> float: 28 | assert x in self.data1 29 | assert y in self.data2 30 | return x + y 31 | 32 | 33 | def test_debug_job(tmp_path: Path) -> None: 34 | def func(p: int) -> int: 35 | return 2 * p 36 | 37 | executor = debug.DebugExecutor(tmp_path) 38 | job = executor.submit(func, 4) 39 | assert job.result() == 8 40 | with executor.batch(): 41 | job2 = executor.submit(func, 5) 42 | assert job2.result() == 10 43 | # Check that job results are cached. 44 | job2.submission().function = None # type: ignore 45 | assert job2.result() == 10 46 | 47 | 48 | def test_debug_map_array(tmp_path: Path) -> None: 49 | g = CheckFunction(5) 50 | executor = debug.DebugExecutor(tmp_path) 51 | jobs = executor.map_array(g, g.data1, g.data2) 52 | print(type(jobs[0])) 53 | print(jobs) 54 | assert list(map(g, g.data1, g.data2)) == [j.result() for j in jobs] 55 | 56 | 57 | def test_debug_submit_array(tmp_path: Path) -> None: 58 | g = CheckFunction(5) 59 | executor = debug.DebugExecutor(tmp_path) 60 | fns = [functools.partial(g, x, y) for x, y in zip(g.data1, g.data2)] 61 | jobs = executor.submit_array(fns) 62 | assert list(map(g, g.data1, g.data2)) == [j.result() for j in jobs] 63 | 64 | 65 | def test_debug_error(tmp_path: Path) -> None: 66 | def failing_job() -> None: 67 | raise RuntimeError("Failed on purpose") 68 | 69 | executor = debug.DebugExecutor(tmp_path) 70 | job = executor.submit(failing_job) 71 | exception = job.exception() 72 | assert isinstance(exception, RuntimeError) 73 | message = exception.args[0] 74 | assert "Failed on purpose" in message 75 | 76 | 77 | def f_42() -> int: 78 | return 42 79 | 80 | 81 | def test_debug_triggered(tmp_path: Path) -> None: 82 | def get_result(job: Job) -> Tuple[bool, Any]: 83 | assert isinstance(job, debug.DebugJob) 84 | return (job._submission._done, job._submission._result) 85 | 86 | executor = debug.DebugExecutor(tmp_path) 87 | for trigger in ("wait", "done", "exception", "results"): 88 | job = executor.submit(f_42) 89 | assert job.state == "QUEUED" 90 | assert get_result(job) == (False, None) 91 | getattr(job, trigger)() 92 | assert job.state == "DONE" 93 | assert get_result(job) == (True, 42) 94 | 95 | 96 | def test_cancel(tmp_path: Path) -> None: 97 | executor = debug.DebugExecutor(tmp_path) 98 | job = executor.submit(f_42) 99 | assert job.state == "QUEUED" 100 | job.cancel() 101 | assert job.state == "CANCELLED" 102 | with pytest.raises(utils.UncompletedJobError, match="was cancelled"): 103 | job.result() 104 | 105 | 106 | def test_job_environment(tmp_path: Path) -> None: 107 | executor = debug.DebugExecutor(tmp_path) 108 | 109 | def use_env(): 110 | env = JobEnvironment() 111 | assert env.num_nodes == 1 112 | assert env.num_tasks == 1 113 | assert env.node == 0 114 | assert env.global_rank == 0 115 | assert env.local_rank == 0 116 | assert "DEBUG" in env.job_id 117 | 118 | job = executor.submit(use_env) 119 | job.result() 120 | # Check that we clean up the env after us. 121 | assert "SUBMITIT_DEBUG_JOB_ID" not in os.environ 122 | -------------------------------------------------------------------------------- /submitit/core/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import os 8 | import shutil 9 | import sys 10 | from pathlib import Path 11 | from typing import Optional 12 | 13 | import pytest 14 | 15 | from . import utils 16 | 17 | 18 | @pytest.mark.parametrize("existing_content", [None, "blublu"]) # type: ignore 19 | def test_temporary_save_path(tmp_path: Path, existing_content: Optional[str]) -> None: 20 | filepath = tmp_path / "save_and_move_test.txt" 21 | if existing_content: 22 | filepath.write_text(existing_content) 23 | with utils.temporary_save_path(filepath) as tmp: 24 | assert str(tmp).endswith(".txt.save_tmp") 25 | tmp.write_text("12") 26 | if existing_content: 27 | assert filepath.read_text() == existing_content 28 | assert filepath.read_text() == "12" 29 | 30 | 31 | def test_temporary_save_path_error() -> None: 32 | with pytest.raises(FileNotFoundError): 33 | with utils.temporary_save_path("save_and_move_test"): 34 | pass 35 | 36 | 37 | def _three_time(x: int) -> int: 38 | return 3 * x 39 | 40 | 41 | def test_delayed(tmp_path: Path) -> None: 42 | delayed = utils.DelayedSubmission(_three_time, 4) 43 | assert not delayed.done() 44 | assert delayed.result() == 12 45 | assert delayed.done() 46 | delayed_pkl = tmp_path / "test_delayed.pkl" 47 | delayed.dump(delayed_pkl) 48 | delayed2 = utils.DelayedSubmission.load(delayed_pkl) 49 | assert delayed2.done() 50 | 51 | 52 | def test_environment_variable_context() -> None: 53 | name = "ENV_VAR_TEST" 54 | assert name not in os.environ 55 | with utils.environment_variables(ENV_VAR_TEST="blublu"): 56 | assert os.environ[name] == "blublu" 57 | with utils.environment_variables(ENV_VAR_TEST="blublu2"): 58 | assert os.environ[name] == "blublu2" 59 | assert os.environ[name] == "blublu" 60 | assert name not in os.environ 61 | 62 | 63 | def test_slurmpaths_id_independent() -> None: 64 | path = "test/truc/machin_%j/name" 65 | output = utils.JobPaths.get_first_id_independent_folder(path) 66 | assert output.name == "truc" 67 | 68 | 69 | def test_archive_dev_folders(tmp_path: Path) -> None: 70 | utils.archive_dev_folders([Path(__file__).parent], outfile=tmp_path.with_suffix(".tar.gz")) 71 | shutil.unpack_archive(str(tmp_path.with_suffix(".tar.gz")), extract_dir=tmp_path) 72 | assert (tmp_path / "core").exists() 73 | 74 | 75 | def test_command_function() -> None: 76 | # This will call `submitit.core.test_core.do_nothing` 77 | command = [sys.executable, "-m", "submitit.core.test_core"] 78 | word = "testblublu12" 79 | output = utils.CommandFunction(command)(word) 80 | assert output is not None 81 | assert word in output 82 | with pytest.raises(utils.FailedJobError, match="Too bad"): 83 | # error=True will make `do_nothing` fail 84 | utils.CommandFunction(command, verbose=True)(error=True) 85 | 86 | 87 | def test_command_function_deadlock(executor) -> None: 88 | code = """ 89 | import sys; 90 | print(sys.__stderr__) 91 | # The goal here is to fill up the stderr pipe buffer. 92 | for i in range({n}): 93 | print("-" * 1024, file=sys.stdout) 94 | print("printed {n} lines to stderr") 95 | """ 96 | fn1 = utils.CommandFunction([sys.executable, "-c", code.format(n=10)]) 97 | executor.update_parameters(timeout_min=2 / 60) 98 | j1 = executor.submit(fn1) 99 | assert "10 lines" in j1.result() 100 | 101 | fn2 = utils.CommandFunction(["python", "-c", code.format(n=1000)]) 102 | j2 = executor.submit(fn2) 103 | assert "1000 lines" in j2.result() 104 | 105 | 106 | def test_jobpaths(tmp_path: Path) -> None: 107 | assert utils.JobPaths(tmp_path, "123").stdout == tmp_path / "123_0_log.out" 108 | assert utils.JobPaths(tmp_path, "123", 1).stdout == tmp_path / "123_1_log.out" 109 | assert ( 110 | utils.JobPaths(tmp_path / "array-%A-index-%a", "456_3").stdout 111 | == tmp_path / "array-456-index-3" / "456_3_0_log.out" 112 | ) 113 | -------------------------------------------------------------------------------- /integration/preemption.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | """Preemption tests, need to be run on a an actual cluster""" 8 | import getpass 9 | import logging 10 | import shutil 11 | import subprocess 12 | import time 13 | from datetime import datetime 14 | from pathlib import Path 15 | 16 | import submitit 17 | from submitit import AutoExecutor, Job 18 | from submitit.core import test_core 19 | 20 | FILE = Path(__file__) 21 | LOGS = FILE.parent / "logs" / f"{FILE.stem}_log" 22 | 23 | log = logging.getLogger("preemption_main") 24 | formatter = logging.Formatter("%(name)s %(levelname)s (%(asctime)s) - %(message)s") 25 | handler = logging.StreamHandler() 26 | handler.setFormatter(formatter) 27 | log.setLevel(logging.INFO) 28 | log.addHandler(handler) 29 | 30 | 31 | def clock(partition: str, duration: int): 32 | log = logging.getLogger(f"preemption_{partition}") 33 | tick_tack = ["tick", "tack"] 34 | try: 35 | for minute in range(duration - 5): 36 | log.info(tick_tack[minute % 2]) 37 | time.sleep(60) 38 | logging.warning("*** Exited peacefully ***") 39 | return duration 40 | except: 41 | logging.warning(f"!!! Interrupted on: {datetime.now().isoformat()}") 42 | raise 43 | 44 | 45 | def pascal_job(partition: str, timeout_min: int, node: str = "") -> Job: 46 | """Submit a job with specific constraint that we can preempt deterministically.""" 47 | ex = submitit.AutoExecutor(folder=LOGS, slurm_max_num_timeout=1) 48 | ex.update_parameters( 49 | name=f"submitit_preemption_{partition}", 50 | timeout_min=timeout_min, 51 | mem_gb=7, 52 | slurm_constraint="pascal", 53 | slurm_comment="submitit integration test", 54 | slurm_partition=partition, 55 | slurm_mail_type="REQUEUE,BEGIN", 56 | slurm_mail_user=f"{getpass.getuser()}+slurm@meta.com", 57 | # pascal nodes have 80 cpus. 58 | # By requesting 50 we now that their can be only one such job with this property. 59 | cpus_per_task=50, 60 | slurm_additional_parameters={}, 61 | ) 62 | if node: 63 | ex.update_parameters(nodelist=node) 64 | 65 | return ex.submit(clock, partition, timeout_min) 66 | 67 | 68 | def wait_job_is_running(job: Job) -> None: 69 | while job.state in ("UNKNOWN", "PENDING"): 70 | log.info(f"{job} is not RUNNING") 71 | time.sleep(60) 72 | 73 | 74 | def preemption(): 75 | job = pascal_job("learnlab", timeout_min=2 * 60) 76 | log.info(f"Scheduled {job}, {job.paths.stdout}") 77 | # log.info(job.paths.submission_file.read_text()) 78 | 79 | wait_job_is_running(job) 80 | node = job.get_info()["NodeList"] 81 | log.info(f"{job} ({job.state}) is runnning on {node} !") 82 | # Schedule another pascal job on the same node, whith high priority 83 | priority_job = pascal_job("devlab", timeout_min=15, node=node) 84 | log.info(f"Schedule {priority_job} ({job.state}) on {node} with high priority.") 85 | wait_job_is_running(priority_job) 86 | 87 | # if priority_job is running, then job should have been preempted 88 | learfair_stderr = job.stderr() 89 | assert learfair_stderr is not None, job.paths.stderr 90 | 91 | log.info( 92 | f"Job {priority_job} ({priority_job.state}) started, " 93 | f"job {job} ({job.state}) should have been preempted: {learfair_stderr}" 94 | ) 95 | interruptions = [l for l in learfair_stderr.splitlines() if "Interrupted" in l] 96 | assert len(interruptions) == 1, interruptions 97 | assert job.state in ("PENDING"), job.state 98 | 99 | interrupted_ts = interruptions[0].split("!!! Interrupted on: ")[-1] 100 | interrupted = datetime.fromisoformat(interrupted_ts) 101 | 102 | priority_job.result() 103 | print("Preemption test succeeded ✅") 104 | 105 | 106 | def main(): 107 | log.info("Hello !") 108 | if LOGS.exists(): 109 | log.info(f"Cleaning up log folder: {LOGS}") 110 | shutil.rmtree(str(LOGS)) 111 | 112 | preemption() 113 | 114 | 115 | if __name__ == "__main__": 116 | main() 117 | -------------------------------------------------------------------------------- /submitit/core/plugins.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import functools 8 | import os 9 | from importlib import metadata 10 | from typing import TYPE_CHECKING, List, Mapping, Tuple, Type 11 | 12 | from ..core import logger 13 | 14 | if TYPE_CHECKING: 15 | # Breaks the import cycle 16 | from ..core.core import Executor 17 | from ..core.job_environment import JobEnvironment 18 | 19 | 20 | def _iter_submitit_entrypoints(): 21 | """Return an iterable of EntryPoint objects in the 'submitit' group 22 | compatible with Python 3.8+ and the backport.""" 23 | 24 | # 3.10+ API: EntryPoints with .select 25 | eps = metadata.entry_points() 26 | if hasattr(eps, "select"): 27 | return eps.select(group="submitit") 28 | 29 | # importlib_metadata backport newer signature: entry_points("submitit") 30 | try: 31 | return metadata.entry_points()["submitit"] 32 | except TypeError: 33 | pass # older API; fall through 34 | 35 | # 3.8/3.9 legacy: mapping {group: [EntryPoint, ...]} 36 | if hasattr(eps, "get"): 37 | return eps.get("submitit", []) 38 | 39 | # old style (should in theory never get here if 3.8+): flat iterable; filter by .group 40 | return [ep for ep in eps if getattr(ep, "group", None) == "submitit"] 41 | 42 | 43 | @functools.lru_cache() 44 | def _get_plugins() -> Tuple[List[Type["Executor"]], List["JobEnvironment"]]: 45 | # pylint: disable=cyclic-import,import-outside-toplevel 46 | from ..local import debug, local 47 | from ..slurm import slurm 48 | 49 | executors: List[Type["Executor"]] = [slurm.SlurmExecutor, local.LocalExecutor, debug.DebugExecutor] 50 | job_envs = [slurm.SlurmJobEnvironment(), local.LocalJobEnvironment(), debug.DebugJobEnvironment()] 51 | for entry_point in _iter_submitit_entrypoints(): 52 | if entry_point.name not in ("executor", "job_environment"): 53 | logger.warning(f"{entry_point.name} = {entry_point.value}") 54 | continue 55 | 56 | module_name = entry_point.value.split(":", 1)[0] 57 | try: 58 | # call `load` rather than `resolve`. 59 | # `load` also checks the module and its dependencies are correctly installed. 60 | obj = entry_point.load() 61 | except Exception as e: 62 | # This may happen if the plugin haven't been correctly installed 63 | logger.exception(f"Failed to load submitit plugin '{module_name}': {e}") 64 | continue 65 | 66 | if entry_point.name == "executor": 67 | executors.append(obj) 68 | else: 69 | try: 70 | job_env = obj() 71 | except Exception as e: 72 | name = getattr(obj, "name", getattr(obj, "__name__", str(obj))) 73 | logger.exception( 74 | f"Failed to init JobEnvironment '{name}' ({obj}) from submitit plugin '{module_name}': {e}" 75 | ) 76 | continue 77 | job_envs.append(job_env) 78 | 79 | return (executors, job_envs) 80 | 81 | 82 | @functools.lru_cache() 83 | def get_executors() -> Mapping[str, Type["Executor"]]: 84 | # TODO: check collisions between executor names 85 | return {ex.name(): ex for ex in _get_plugins()[0]} 86 | 87 | 88 | def get_job_environment() -> "JobEnvironment": 89 | # Don't cache this function. It makes testing harder. 90 | # The slow part is the plugin discovery anyway. 91 | envs = get_job_environments() 92 | # bypassing can be helful for testing 93 | if "_TEST_CLUSTER_" in os.environ: 94 | c = os.environ["_TEST_CLUSTER_"] 95 | assert c in envs, f"Unknown $_TEST_CLUSTER_='{c}', available: {envs.keys()}." 96 | return envs[c] 97 | for env in envs.values(): 98 | # TODO? handle the case where several envs are valid 99 | if env.activated(): 100 | return env 101 | raise RuntimeError( 102 | f"Could not figure out which environment the job is runnning in. Known environments: {', '.join(envs.keys())}." 103 | ) 104 | 105 | 106 | @functools.lru_cache() 107 | def get_job_environments() -> Mapping[str, "JobEnvironment"]: 108 | return {env.name(): env for env in _get_plugins()[1]} 109 | -------------------------------------------------------------------------------- /submitit/auto/test_auto.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import sys 8 | from pathlib import Path 9 | 10 | import pytest 11 | 12 | from ..local import debug 13 | from ..slurm import test_slurm 14 | from . import auto 15 | 16 | 17 | def test_slurm_executor(tmp_path: Path, monkeypatch) -> None: 18 | monkeypatch.setattr(debug.DebugExecutor, "_valid_parameters", lambda: {"blabla"}) 19 | with test_slurm.mocked_slurm(): 20 | executor = auto.AutoExecutor(folder=tmp_path) 21 | assert executor.cluster == "slurm" 22 | 23 | # local_xxx parameter is ignored 24 | executor.update_parameters(mem_gb=2, name="machin", debug_blabla="blublu") 25 | params = executor._executor.parameters 26 | assert params == {"mem": "2GB", "job_name": "machin"} 27 | 28 | # shared parameter with wrong type 29 | with pytest.raises(AssertionError): 30 | executor.update_parameters(mem_gb="2.0GB") # should be int 31 | # unknown shared parameter 32 | with pytest.raises(NameError): 33 | executor.update_parameters(blublu=2.0) 34 | # unknown slurm parameter 35 | with pytest.raises(NameError): 36 | executor.update_parameters(slurm_host_filter="blublu") 37 | # check that error message contains all 38 | with pytest.raises(NameError, match=r"debug_blublu.*\n.*local_num_threads"): 39 | executor.update_parameters(debug_blublu=2.0, local_num_threads=4) 40 | 41 | 42 | def test_local_executor(tmp_path: Path) -> None: 43 | with test_slurm.mocked_slurm(): 44 | executor = auto.AutoExecutor(folder=tmp_path, cluster="local") 45 | assert executor.cluster == "local" 46 | 47 | 48 | def test_max_pickle_size_gb_in_auto(tmp_path: Path) -> None: 49 | ex = auto.AutoExecutor(folder=tmp_path, cluster="local", local_max_pickle_size_gb=0.12) 50 | assert ex._executor.max_pickle_size_gb == 0.12 # type: ignore 51 | 52 | 53 | def test_python_executor(tmp_path: Path) -> None: 54 | executor = auto.AutoExecutor(folder=tmp_path, cluster="local", local_python=sys.executable) 55 | job = executor.submit(lambda: 12) 56 | assert job.result() == 12 57 | 58 | 59 | def test_executor_argument(tmp_path: Path) -> None: 60 | with test_slurm.mocked_slurm(): 61 | executor = auto.AutoExecutor(folder=tmp_path, slurm_max_num_timeout=22) 62 | assert getattr(executor._executor, "max_num_timeout", None) == 22 63 | 64 | # Local executor 65 | executor = auto.AutoExecutor(folder=tmp_path, cluster="local", slurm_max_num_timeout=22) 66 | assert getattr(executor._executor, "max_num_timeout", None) != 22 67 | 68 | 69 | def test_executor_unknown_argument(tmp_path: Path) -> None: 70 | with test_slurm.mocked_slurm(): 71 | with pytest.raises(TypeError): 72 | auto.AutoExecutor(folder=tmp_path, slurm_foobar=22) 73 | 74 | 75 | def test_executor_deprecated_arguments(tmp_path: Path) -> None: 76 | with test_slurm.mocked_slurm(): 77 | with pytest.warns(UserWarning, match="slurm_max_num_timeout"): 78 | auto.AutoExecutor(folder=tmp_path, max_num_timeout=22) 79 | 80 | 81 | def test_deprecated_argument(tmp_path: Path, monkeypatch) -> None: 82 | monkeypatch.setattr(debug.DebugExecutor, "_valid_parameters", lambda: {"blabla"}) 83 | with test_slurm.mocked_slurm(): 84 | executor = auto.AutoExecutor(folder=tmp_path) 85 | assert executor.cluster == "slurm" 86 | 87 | # debug 'blabla' parameter is ignored 88 | with pytest.warns(UserWarning, match=r"blabla.*debug_blabla"): 89 | executor.update_parameters(mem_gb=2, blabla="blublu") 90 | 91 | 92 | def test_overriden_arguments(tmp_path: Path) -> None: 93 | with test_slurm.mocked_slurm(): 94 | slurm_ex = auto.AutoExecutor(folder=tmp_path, cluster="slurm") 95 | 96 | slurm_ex.update_parameters( 97 | timeout_min=60, slurm_timeout_min=120, tasks_per_node=2, slurm_ntasks_per_node=3 98 | ) 99 | slurm_params = slurm_ex._executor.parameters 100 | # slurm use time 101 | assert slurm_params == {"time": 120, "ntasks_per_node": 3} 102 | 103 | # others use timeout_min 104 | local_ex = auto.AutoExecutor(folder=tmp_path, cluster="local") 105 | local_ex.update_parameters(timeout_min=60, slurm_time=120) 106 | 107 | 108 | def test_auto_batch_watcher(tmp_path: Path) -> None: 109 | with test_slurm.mocked_slurm(): 110 | executor = auto.AutoExecutor(folder=tmp_path) 111 | with executor.batch(): 112 | job = executor.submit(print, "hi") 113 | assert not job.done() 114 | 115 | 116 | def test_redirect_stdout_stderr(executor) -> None: 117 | def log_to_stderr_and_stdout(): 118 | print("hello") 119 | print("world", file=sys.stderr) 120 | 121 | executor.update_parameters(stderr_to_stdout=True) 122 | job = executor.submit(log_to_stderr_and_stdout) 123 | job.wait() 124 | assert job.stderr() is None 125 | stdout = job.stdout() 126 | assert "hello" in stdout 127 | assert "world" in stdout 128 | 129 | executor.update_parameters(stderr_to_stdout=False) 130 | job = executor.submit(log_to_stderr_and_stdout) 131 | job.wait() 132 | assert "world" in job.stderr() 133 | assert "hello" in job.stdout() 134 | -------------------------------------------------------------------------------- /docs/checkpointing.md: -------------------------------------------------------------------------------- 1 | # Checkpointing 2 | 3 | ## The basics of checkpointing with `submitit` 4 | 5 | Checkpointing is trickier and requires a precise understanding of the inner working of the job pickling. 6 | 7 | At the time we need to requeue a job (after preemption or timeout), we can edit the submitted task according to the current state of the computation. Since in a standard function, the state cannot be accessed, we need to submit a callable (an instance of a class with a `__call__` method) instead of a mere function. 8 | 9 | In practice, when requeuing, `submitit` will check if the callable has a `__submitit_checkpoint__` or `checkpoint` method. 10 | If so, it will send it the initial arguments and the `checkpoint` method takes care of preparing the new submission. 11 | The `checkpoint` method must therefore have a signature able to receive all parameters from the `__call__` function of your callable. 12 | It must return a `DelayedSubmission` which acts exactly as `executor.submit`: you can provide any function and arguments. Alternatively, it can return `None` if for some reason it does not want to be requeued. 13 | 14 | **Important note**: for preemptions to be recognized as such, the cluster needs to be configured by an admin with the parameter `SlurmctldParameters=preempt_send_user_signal`. 15 | 16 | ## Minimal example 17 | 18 | Typically, in most cases you would just resubmit the current callable at its current state with the same initial arguments, so adding the 19 | following generic `checkpoint` method to your callable may work just fine: 20 | ```python 21 | def checkpoint(self, *args: Any, **kwargs: Any) -> submitit.helpers.DelayedSubmission: 22 | return submitit.helpers.DelayedSubmission(self, *args, **kwargs) # submits to requeuing 23 | ``` 24 | If this kind of checkpoint is sufficient for you, you can derive your callable from `submitit.helpers.Checkpointable` which implements this very function. 25 | 26 | Generally checkpointing requires a modification of your code to skip the parts that have been done before being rescheduled. 27 | You can look at [the MNIST example](./mnist.py) to see what it looks like in practice. 28 | 29 | You may however submit something completely different if you wish. This can happen for instance if: 30 | - you want to restart with different parameters. 31 | - you do not want all the attributes to be pickled. Typically, you may want to dump a neural network in a separate file 32 | in a standard format and set the corresponding argument to None in order to avoid relying on pickle for saving the 33 | network. 34 | 35 | For a basic example of a checkpointable callable, checkout the code from `submitit.helpers.FunctionSequence` in [helpers.py](../submitit/helpers.py). 36 | 37 | 38 | ## Example - Training and checkpointing a model 39 | 40 | The following example provides a recipe for checkpointing the training of a model. It is more complex because we do not want to rely on `submitit` to pickle the model. This recipe has not been fool proofed yet, I am happy to help if you encounter any issue ;) 41 | 42 | ```python 43 | from pathlib import Path 44 | import submitit 45 | 46 | class NetworkTraining: 47 | 48 | def __init__(self): 49 | # this is the "state" which we will be able to access when checkpointing: 50 | self.model = None 51 | 52 | def __call__(self, checkpointpath: str): 53 | if not Path(checkpointpath).exists(): 54 | self.model = ... # initialize your model 55 | else: 56 | self.model = ... # load your model 57 | # train your model 58 | ... 59 | 60 | def checkpoint(self, checkpointpath: str) -> submitit.helpers.DelayedSubmission: 61 | # the checkpoint method is called asynchroneously when the slurm manager 62 | # sends a preemption signal, with the same arguments as the __call__ method 63 | # "self" is your callable, at its current state. 64 | # "self" therefore holds the current version of the model: 65 | # do whatever you need to do to dump it properly 66 | self.model.dump(checkpointpath) # this is an example that probably does not work 67 | ... 68 | # create a new, clean (= no loaded model) NetworkTraining instance which 69 | # will be loaded when the job resumes, and will fetch the dumped model 70 | # (creating a new instance is not necessary but can avoid weird interactions 71 | # with the current instance) 72 | training_callable = NetworkTraining() 73 | # Resubmission to the queue is performed through the DelayedSubmission object 74 | return submitit.helpers.DelayedSubmission(training_callable, checkpointpath) 75 | ``` 76 | 77 | When you want to train your model, you just have to run the following code, and it will be 78 | submitted to a slurm job, which will be checkpointed and requeued at most `slurm_max_num_timeout=3` times if timed out 79 | (and any number of time if preempted): 80 | ```python 81 | import submitit 82 | from .network import NetworkTraining # must be defined in an importable module! 83 | executor = submitit.AutoExecutor(folder="logs_training", slurm_max_num_timeout=3) 84 | executor.update_parameters(timeout_min=30, slurm_partition="your_partition", 85 | gpus_per_node=1, cpus_per_task=2) 86 | training_callable = NetworkTraining() 87 | job = executor.submit(training_callable, "some/path/for/checkpointing/your/network") 88 | ``` 89 | 90 | On Slurm cluster, you can trigger a fake preemption or timeout in order to test your checkpointing by using the job method: `job._interrupt(timeout=)`. 91 | -------------------------------------------------------------------------------- /submitit/test_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import os 8 | import time 9 | import typing as tp 10 | from pathlib import Path 11 | 12 | import pytest 13 | 14 | from . import helpers 15 | from .core import core, utils 16 | 17 | 18 | def _three_time(x: int) -> int: 19 | return 3 * x 20 | 21 | 22 | requires_rsync = pytest.mark.skipif( 23 | not helpers.RsyncSnapshot.available(), reason="Rsync is required for snapshotting" 24 | ) 25 | 26 | 27 | def test_function_sequence_checkpoint(tmp_path: Path) -> None: 28 | file = tmp_path / "test_funcseq.pkl" 29 | fs0 = helpers.FunctionSequence(verbose=True) 30 | fs0.add(_three_time, 4) 31 | fs0.add(_three_time, 5) 32 | assert len(fs0) == 2 33 | assert sum(x.done() for x in fs0) == 0 34 | utils.cloudpickle_dump(fs0, file) 35 | fs1 = utils.pickle_load(file) 36 | assert sum(x.done() for x in fs1) == 0 37 | assert fs1() == [12, 15] 38 | assert sum(x.done() for x in fs1) == 2 39 | 40 | 41 | def test_as_completed(executor) -> None: 42 | def f(x: float) -> float: 43 | time.sleep(x) 44 | return x 45 | 46 | # slow need to be > 1.5s otherwise it might finish before we start polling. 47 | slow, fast = 1.5, 0.1 48 | # One slow job and two fast jobs. 49 | jobs = executor.map_array(f, [slow, fast, fast]) 50 | start = time.time() 51 | finished_jobs = [] 52 | for n, j in enumerate(helpers.as_completed(jobs, poll_frequency=0.1)): 53 | elapsed = time.time() - start 54 | if n < 2: 55 | # we start getting result before the slow job finished. 56 | assert elapsed < slow 57 | finished_jobs.append(j) 58 | # We get fast job results first, then result of the slow job. 59 | assert [fast, fast, slow] == [j.result() for j in finished_jobs] 60 | assert jobs[0] is finished_jobs[-1] 61 | 62 | 63 | @requires_rsync 64 | def test_snapshot(tmp_path: Path) -> None: 65 | cwd = Path.cwd() 66 | with helpers.RsyncSnapshot(tmp_path): 67 | assert Path.cwd() == tmp_path 68 | assert (tmp_path / "submitit/test_helpers.py").exists() 69 | assert Path.cwd() == cwd 70 | 71 | 72 | @requires_rsync 73 | def test_snapshot_excludes(tmp_path: Path) -> None: 74 | exclude = ["submitit/test_*"] 75 | with helpers.RsyncSnapshot(snapshot_dir=tmp_path, exclude=exclude): 76 | assert (tmp_path / "submitit/helpers.py").exists() 77 | assert not (tmp_path / "submitit/test_helpers.py").exists() 78 | 79 | 80 | @requires_rsync 81 | def test_job_use_snapshot_cwd(executor, tmp_path: Path) -> None: 82 | with helpers.RsyncSnapshot(snapshot_dir=tmp_path): 83 | job = executor.submit(os.getcwd) 84 | assert Path(job.result()) == tmp_path 85 | 86 | 87 | @requires_rsync 88 | def test_job_use_snapshot_modules(executor, tmp_path: Path) -> None: 89 | with helpers.RsyncSnapshot(snapshot_dir=tmp_path): 90 | 91 | def submitit_file() -> Path: 92 | # pylint: disable=import-outside-toplevel 93 | import submitit 94 | 95 | return Path(submitit.__file__) 96 | 97 | job = executor.submit(submitit_file) 98 | # Here we load the normal submitit 99 | assert submitit_file() == Path(__file__).parent / "__init__.py" 100 | # In the job we should import submitit from the snapshot dir 101 | assert job.result() == tmp_path / "submitit/__init__.py" 102 | 103 | 104 | class FakeInfoWatcherWithTimer(core.InfoWatcher): 105 | # pylint: disable=abstract-method 106 | def __init__(self, delay_s: int = 60, time_change: float = 0.02): 107 | super().__init__(delay_s) 108 | self.start_timer = time.time() 109 | self.time_change = time_change 110 | 111 | def get_state(self, job_id: str, mode: str = "standard") -> str: 112 | duration = time.time() - self.start_timer 113 | if duration < self.time_change: 114 | return "pending" 115 | elif 2 * self.time_change > duration > self.time_change: 116 | return "running" 117 | if job_id == "failed": 118 | return "failed" 119 | return "done" 120 | 121 | 122 | class FakeJobWithTimer(core.Job[core.R]): 123 | watcher = FakeInfoWatcherWithTimer() 124 | 125 | 126 | def test_monitor_jobs(tmp_path: Path) -> None: 127 | job: FakeJobWithTimer[int] = FakeJobWithTimer(job_id="failed", folder=tmp_path) 128 | job2: FakeJobWithTimer[int] = FakeJobWithTimer(job_id="succeeded", folder=tmp_path) 129 | jobs = [job, job2] 130 | helpers.monitor_jobs(jobs, 0.02, test_mode=True) 131 | assert all(j for j in jobs if j.done()) 132 | assert set(j for j in jobs if j.state.upper() == "FAILED") == {job} 133 | 134 | 135 | def _get_env() -> tp.Dict[str, str]: 136 | return {x: y for x, y in os.environ.items() if x.startswith(("SLURM_", "SUBMITIT_"))} 137 | 138 | 139 | def test_clean_env() -> None: 140 | base = _get_env() 141 | with utils.environment_variables(SLURM_BLUBLU=12, SUBMITIT_BLUBLU=12): 142 | assert len(_get_env()) == len(base) + 2 143 | with helpers.clean_env(): 144 | assert not _get_env() 145 | assert len(_get_env()) == len(base) + 2 146 | assert _get_env() == base 147 | 148 | with utils.environment_variables(MASTER_PORT=42, BLABLA=314): 149 | with helpers.clean_env(extra_names=("BLABLA",)): 150 | assert "MASTER_PORT" not in os.environ 151 | assert "BLABLA" not in os.environ 152 | -------------------------------------------------------------------------------- /docs/nevergrad.md: -------------------------------------------------------------------------------- 1 | # Using `nevergrad` with `submitit` 2 | 3 | `nevergrad` is a derivative-free optimization toolbox developed at FAIR which can be used to tune network hyperparameters. 4 | These algorithms can be competitive over random search if you have around 10 parameters or more. 5 | 6 | ## Basics of `nevergrad` 7 | 8 | The following is a simplified version of the tutorial in [`nevergrad`'s repository](https://github.com/facebookresearch/nevergrad/blob/main/README.md), you can find more details there. 9 | 10 | ### Example of optimization 11 | 12 | For the sake of this example, we'll define a function to optimize: 13 | ```python 14 | def myfunction(x, y=12): 15 | return sum((x - .5)**2) + abs(y) 16 | ``` 17 | 18 | Before we can perform optimization, we must define how this function is instrumented, i.e. the values that the parameters 19 | of the function can take. This is done through the `Instrumentation` class. 20 | The following states the first parameter of the function is an array of size 2, and the second a float: 21 | 22 | ```python 23 | import nevergrad as ng 24 | instrum = ng.p.Instrumentation( 25 | ng.p.Array(shape=(2,)), # first parameter 26 | y=ng.p.Scalar()) # second (named) parameter 27 | ``` 28 | 29 | 30 | Then you can initialize an algorithm (here `TwoPointsDE`, a full list can be obtained with `list(ng.optimizers.registry.keys())` with 31 | this instrumentation and the budget (number of iterations) it can spend, and run the optimization: 32 | ```python 33 | import nevergrad as ng 34 | optimizer = ng.optimizers.TwoPointsDE(parametrization=instrum, budget=100) 35 | recommendation = optimizer.minimize(square) 36 | print(recommendation.value) 37 | >>> (array([0.500, 0.499]),), {y: -0.012} 38 | ``` 39 | `recommendation` holds the optimal attributes `args` and `kwargs` found by the optimizer for the provided function. 40 | The optimal value is obtained through `recommendation.value`. 41 | 42 | ### Instrumentation 43 | 44 | 45 | 5 base types of variables are currently provided for instrumentation: 46 | - `Choice`: for unordered categorical values. 47 | - `TransitionChoice`: for ordered categorical values. 48 | - `Array`: for array parameters, possibly bounded 49 | - `Scalar`: for standard scalar parameters 50 | - `Log`: for log-distributed scalar parameters 51 | 52 | Here is a basic example: 53 | ```python 54 | import nevergrad as ng 55 | 56 | arg1 = ng.p.TransitionChoice(["a", "b"]) # either "a" or "b" (ordered) 57 | arg2 = ng.p.Choice(["a", "c", "e"]) # "a", "c" or "e" (unordered) 58 | arg4 = ng.p.Array(shape=(4, 3)).set_bounds(lower=0, upper=1) # an array of size (4, 3) with values between 0 and 1 59 | 60 | # the following instrumentation uses these variables (and keeps the 3rd argument constant) 61 | instrum = ng.p.Instrumentation(arg1, arg2, "constant_argument", arg4) 62 | ``` 63 | 64 | And this is a more realistic instrumentation example for a neural network training: 65 | ```python 66 | instru = ng.p.Instrumentation( 67 | dataset=ng.p.Choice(['wikilarge', 'allnli']), 68 | # Arch 69 | architecture=ng.p.Choice(['fconv', 'lightconv', 'fconv_self_att', 'lstm', 'transformer']), 70 | dropout=ng.p.Scalar(lower=0, upper=1) 71 | # Training 72 | max_tokens=ng.p.TransitionChoice(np.arange(500, 20001, 100).tolist()), 73 | max_epoch=ng.p.Scalar(lower=1, upper=50).set_integer_casting(), 74 | # Optimization 75 | lr=ng.p.Log(lower=0.001, upper=1.0), 76 | ) 77 | ``` 78 | 79 | 80 | ## Working asynchronously with submitit 81 | 82 | To speed up optimization, you probably want to run several function evaluations concurrently. 83 | To do this, you need to notify `nevergrad` at the optimizer initialization that you will have several workers (example: 32 here): 84 | 85 | ```python 86 | optimizer = ng.optimizers.TwoPointsDE(parametrization=instru, budget=8192, num_workers=32) 87 | ``` 88 | This, way, the optimizer will be prepared to provide several points to evaluate at once. You have then 2 ways to handle these evaluations: 89 | 90 | 91 | ### Through the optimize method 92 | 93 | The `minimize` method takes an executor-like object which is compatible with `submitit` (and `concurrent.futures` and `dask`). 94 | With the following, nevergrad will take care of submitting a job per function evaluation, with at most 32 jobs in parallel: 95 | ```python 96 | executor = AutoExecutor(folder=my_folder) 97 | executor.update_parameters(timeout_min=60, gpus_per_node=1, cpus_per_task=2) 98 | recommendation = optimizer.minimize(my_function, executor=executor, verbosity=2) 99 | ``` 100 | 101 | ### Using the ask and tell interface 102 | 103 | `nevergrad` also provides an ask and tell interface: 104 | - `ask()` provides a candidate point to evaluate. 105 | - `tell(x, value)` is used to feed the function evaluation back to `nevergrad`. 106 | 107 | In this case you are the one responsible for properly running the optimization procedure. A naive loop with batch evaluation could like like this: 108 | ```python 109 | remaining = optimizer.budget - optimizer.num_ask 110 | while remaining: 111 | candidates = [optimizer.ask() for k in range(remaining)] 112 | jobs = [executor.submit(my_function, *c.args, **c.kwargs)) for c in candidates] 113 | for candidate, job in zip(candidates, jobs): 114 | optimizer.tell(candidate, job.result()) 115 | remaining = optimizer.budget - optimizer.num_ask 116 | recommendation = optimizer.provide_recommendation() 117 | ``` 118 | 119 | ### Gotcha 120 | 121 | Since a job will be submitted for each function evaluation, using `submitit` executor in `nevergrad` is suitable for evaluations which take at least tens of minutes, not for small evaluations which will overload the cluster and spend more time pending than running. 122 | -------------------------------------------------------------------------------- /submitit/local/debug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import logging 8 | import os 9 | import typing as tp 10 | from pathlib import Path 11 | from typing import Dict, List, Optional, Union 12 | 13 | from ..core.core import Executor, InfoWatcher, Job, R 14 | from ..core.job_environment import JobEnvironment 15 | from ..core.utils import DelayedSubmission, UncompletedJobError 16 | 17 | 18 | class DebugInfoWatcher(InfoWatcher): 19 | # pylint: disable=abstract-method 20 | def register_job(self, job_id: str) -> None: 21 | pass 22 | 23 | 24 | class DebugJobEnvironment(JobEnvironment): 25 | _env = { 26 | "job_id": "SUBMITIT_DEBUG_JOB_ID", 27 | # We don't set those, and rely on the default values from JobEnvironment 28 | "num_nodes": "SUBMITIT_DEBUG_NOT_SET", 29 | "num_tasks": "SUBMITIT_DEBUG_NOT_SET", 30 | "node": "SUBMITIT_DEBUG_NOT_SET", 31 | "global_rank": "SUBMITIT_DEBUG_NOT_SET", 32 | "local_rank": "SUBMITIT_DEBUG_NOT_SET", 33 | } 34 | 35 | def activated(self) -> bool: 36 | return "SUBMITIT_DEBUG_JOB_ID" in os.environ 37 | 38 | def _requeue(self, countdown: int) -> None: 39 | pass 40 | 41 | 42 | class DebugJob(Job[R]): 43 | watcher = DebugInfoWatcher() 44 | 45 | def __init__(self, folder: Path, submission: DelayedSubmission) -> None: 46 | job_id = f"DEBUG_{id(submission)}" 47 | super().__init__(folder=folder, job_id=job_id) 48 | self._submission = submission 49 | self.cancelled = False 50 | self.environ = dict(os.environ) 51 | self.environ["SUBMITIT_DEBUG_JOB_ID"] = self.job_id 52 | 53 | def submission(self) -> DelayedSubmission: 54 | return self._submission 55 | 56 | @property 57 | def num_tasks(self) -> int: 58 | return 1 59 | 60 | def cancel(self, check: bool = True) -> None: # pylint: disable=unused-argument 61 | self.cancelled = True 62 | 63 | def _check_not_cancelled(self) -> None: 64 | if self.cancelled: 65 | raise UncompletedJobError(f"Job {self} was cancelled.") 66 | 67 | def results(self) -> List[R]: 68 | self._check_not_cancelled() 69 | if self._submission.done(): 70 | return [self._submission._result] 71 | 72 | environ_backup = dict(os.environ) 73 | # Restore os.environ from job creation time. 74 | os.environ.clear() 75 | os.environ.update(self.environ) 76 | 77 | root_logger = logging.getLogger("") 78 | self.paths.stdout.parent.mkdir(exist_ok=True, parents=True) 79 | stdout_handler = logging.FileHandler(self.paths.stdout) 80 | stdout_handler.setLevel(logging.DEBUG) 81 | stderr_handler = logging.FileHandler(self.paths.stderr) 82 | stderr_handler.setLevel(logging.WARNING) 83 | root_logger.addHandler(stdout_handler) 84 | root_logger.addHandler(stderr_handler) 85 | root_logger.warning( 86 | f"Logging is written both to stderr/stdout and to {self.paths.stdout}/err. " 87 | "But call to print will only appear in the console." 88 | ) 89 | try: 90 | return [self._submission.result()] 91 | except Exception as e: 92 | print(e) 93 | # Try to mimic `breakpoint()` behavior 94 | # pylint: disable=import-outside-toplevel 95 | if os.environ.get("PYTHONBREAKPOINT", "").startswith("ipdb"): 96 | import ipdb # pylint: disable=import-error 97 | 98 | ipdb.post_mortem() 99 | else: 100 | import pdb 101 | 102 | pdb.post_mortem() 103 | raise 104 | finally: 105 | os.environ.clear() 106 | os.environ.update(environ_backup) 107 | root_logger.removeHandler(stdout_handler) 108 | root_logger.removeHandler(stderr_handler) 109 | 110 | def exception(self) -> Optional[BaseException]: # type: ignore 111 | self._check_not_cancelled() 112 | try: 113 | self._submission.result() 114 | return None 115 | except Exception as e: 116 | # Note that we aren't wrapping the error contrary to what is done in 117 | # other Executors. It makes the stacktrace smaller and debugging easier. 118 | return e 119 | 120 | def wait(self) -> None: 121 | # forces execution. 122 | self.results() 123 | 124 | def done(self, force_check: bool = False) -> bool: # pylint: disable=unused-argument 125 | # forces execution, in case the client is waiting on it to become True. 126 | self.results() 127 | return self._submission.done() 128 | 129 | @property 130 | def state(self) -> str: 131 | if self._submission.done(): 132 | return "DONE" 133 | if self.cancelled: 134 | return "CANCELLED" 135 | return "QUEUED" 136 | 137 | def get_info(self, mode: str = "force") -> Dict[str, str]: # pylint: disable=unused-argument 138 | return {"STATE": self.state} 139 | 140 | def __del__(self) -> None: 141 | # Skip parent code 142 | return 143 | 144 | 145 | class DebugExecutor(Executor): 146 | job_class = DebugJob 147 | 148 | def __init__(self, folder: Union[str, Path]): 149 | super().__init__(folder) 150 | 151 | def _internal_process_submissions( 152 | self, delayed_submissions: tp.List[DelayedSubmission] 153 | ) -> tp.List[Job[tp.Any]]: 154 | return [DebugJob(self.folder, ds) for ds in delayed_submissions] 155 | -------------------------------------------------------------------------------- /submitit/core/test_plugins.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import importlib 8 | import logging 9 | import re 10 | import typing as tp 11 | from pathlib import Path 12 | 13 | import pytest 14 | 15 | from . import core, plugins 16 | from .job_environment import JobEnvironment 17 | 18 | 19 | @pytest.mark.parametrize("env", plugins.get_job_environments().values()) 20 | def test_env(env: JobEnvironment) -> None: 21 | assert isinstance(env, JobEnvironment) 22 | # We are not inside a submitit job 23 | assert not env.activated() 24 | assert type(env)._requeue is not JobEnvironment._requeue, "_requeue need to be overridden" 25 | 26 | 27 | @pytest.mark.parametrize("ex", plugins.get_executors().values()) 28 | def test_executors(ex: tp.Type[core.Executor]) -> None: 29 | assert isinstance(ex, type) 30 | assert issubclass(ex, core.Executor) 31 | assert ex.affinity() >= -1 32 | 33 | 34 | def test_finds_default_environments() -> None: 35 | envs = plugins.get_job_environments() 36 | assert len(envs) >= 3 37 | assert "slurm" in envs 38 | assert "local" in envs 39 | assert "debug" in envs 40 | 41 | 42 | def test_finds_default_executors() -> None: 43 | ex = plugins.get_executors() 44 | assert len(ex) >= 3 45 | assert "slurm" in ex 46 | assert "local" in ex 47 | assert "debug" in ex 48 | 49 | 50 | def test_job_environment_works(monkeypatch): 51 | monkeypatch.setenv("_TEST_CLUSTER_", "slurm") 52 | env = plugins.get_job_environment() 53 | assert env.cluster == "slurm" 54 | assert type(env).__name__ == "SlurmJobEnvironment" 55 | 56 | env2 = JobEnvironment() 57 | assert env2.cluster == "slurm" 58 | assert type(env2).__name__ == "SlurmJobEnvironment" 59 | 60 | 61 | def test_job_environment_raises_outside_of_job() -> None: 62 | with pytest.raises(RuntimeError, match=r"which environment.*slurm.*local.*debug"): 63 | plugins.get_job_environment() 64 | 65 | 66 | class PluginCreator: 67 | def __init__(self, tmp_path: Path, monkeypatch): 68 | self.tmp_path = tmp_path 69 | self.monkeypatch = monkeypatch 70 | 71 | def add_plugin(self, name: str, entry_points: str, init: str): 72 | # Extract version from init string if available 73 | version = "0.0.0" # default fallback - this bit doesn't matter for testing 74 | version_match = re.search(r'__version__\s*=\s*["\']([^"\']+)["\']', init) 75 | if version_match: 76 | version = version_match.group(1) 77 | 78 | pkg_dir = self.tmp_path / name 79 | pkg_dir.mkdir(mode=0o777) 80 | (pkg_dir / "__init__.py").write_text(init) 81 | 82 | dist = self.tmp_path / f"{name}-{version}.dist-info" 83 | dist.mkdir(mode=0o777) 84 | (dist / "METADATA").write_text(f"Name: {name}\nVersion: {version}\n") 85 | (dist / "entry_points.txt").write_text(entry_points) 86 | 87 | # Make sure Python and metadata see the new files 88 | importlib.invalidate_caches() 89 | 90 | def __enter__(self) -> None: 91 | _clear_plugin_cache() 92 | self.monkeypatch.syspath_prepend(self.tmp_path) 93 | 94 | def __exit__(self, *exception: tp.Any) -> None: 95 | _clear_plugin_cache() 96 | 97 | 98 | def _clear_plugin_cache() -> None: 99 | plugins._get_plugins.cache_clear() 100 | plugins.get_executors.cache_clear() 101 | 102 | 103 | @pytest.fixture(name="plugin_creator") 104 | def _plugin_creator(tmp_path: Path, monkeypatch) -> tp.Iterator[PluginCreator]: 105 | creator = PluginCreator(tmp_path, monkeypatch) 106 | with creator: 107 | yield creator 108 | 109 | 110 | def test_find_good_plugin(plugin_creator: PluginCreator) -> None: 111 | plugin_creator.add_plugin( 112 | "submitit_good", 113 | entry_points="""[submitit] 114 | executor = submitit_good:GoodExecutor 115 | job_environment = submitit_good:GoodJobEnvironment 116 | unsupported_key = submitit_good:SomethingElse 117 | """, 118 | init=""" 119 | import submitit 120 | 121 | class GoodExecutor(submitit.Executor): 122 | pass 123 | 124 | class GoodJobEnvironment: 125 | pass 126 | """, 127 | ) 128 | 129 | executors = plugins.get_executors().keys() 130 | # Only the plugins declared with plugin_creator are visible. 131 | assert set(executors) == {"good", "slurm", "local", "debug"} 132 | 133 | 134 | def test_skip_bad_plugin(caplog, plugin_creator: PluginCreator) -> None: 135 | caplog.set_level(logging.WARNING, logger="submitit") 136 | plugin_creator.add_plugin( 137 | "submitit_bad", 138 | entry_points="""[submitit] 139 | executor = submitit_bad:NonExisitingExecutor 140 | job_environment = submitit_bad:BadEnvironment 141 | unsupported_key = submitit_bad:SomethingElse 142 | """, 143 | init=""" 144 | import submitit 145 | 146 | class BadEnvironment: 147 | name = "bad" 148 | 149 | def __init__(self): 150 | raise Exception("this is a bad environment") 151 | """, 152 | ) 153 | 154 | executors = plugins.get_executors().keys() 155 | assert {"slurm", "local", "debug"} == set(executors) 156 | assert "bad" not in executors 157 | expected = [ 158 | (logging.ERROR, r"'submitit_bad'.*no attribute 'NonExisitingExecutor'"), 159 | (logging.ERROR, r"'submitit_bad'.*this is a bad environment"), 160 | (logging.WARNING, "unsupported_key = submitit_bad:SomethingElse"), 161 | ] 162 | assert len(caplog.records) == len(expected) 163 | for record, ex_record in zip(caplog.records, expected): 164 | assert record.name == "submitit" 165 | assert record.levelno == ex_record[0] 166 | assert re.search(ex_record[1], record.getMessage()) 167 | -------------------------------------------------------------------------------- /docs/structure.md: -------------------------------------------------------------------------------- 1 | # Structure 2 | 3 | ## Under the hood 4 | 5 | When you submit a function and its arguments, it will return a Job instance to you. Under the hood, the function and arguments are 6 | pickled, and the submission functions run a batch file which will load the pickled object, compute the function 7 | with the provided arguments, and pickle the output of the function into a new file. Whenever this file becomes available, your Job instance 8 | will be able to recover it. The computation in the cluster will use the current conda environment, so make sure everything you need is installed (including `submitit`). 9 | 10 | If the computation failed and we are able to catch the error, then the trace will be dumped and available to you through the Job instance as well. 11 | However, if it could not be catched, you will be notified as well, but will probably need to look into the logs to understand what happened. 12 | 13 | For each job, you will therefore usually end up with a task file `_submitted.pkl`, an output file `_result.pkl`, 14 | a batch file `batchfile_.sh`, a stdout log file `__log.out` and a stderr log file `__log.err`, where the uuid 15 | is created by `submitit`, and the id is the job id from slurm. The Job instance helps you link all of this together (see `job.job_id`). 16 | 17 | ## Main objects 18 | 19 | Here are some information about the main objects defined in `submitit`, but you can always refer to the docstrings if you need details. 20 | 21 | ### Executor 22 | 23 | The executor is your interface for submitting jobs. Its role is to save/dump the job you want to submit, 24 | then, in Slurm case for instance, to create the sbatch file for running submitting the job. 25 | Its main methods are: 26 | - `submit(function, *args, **kwargs)`: submits a function for run on the cluster with given parameters, and returns 27 | a `Job` instance. 28 | - `map_array(function, *iterables)`: submits a function several times for run on the cluster with different parameters 29 | (pulled from the iterables), and returns a list of `Job` instances. On Slurm, this uses [job arrays](https://slurm.schedmd.com/job_array.html), 30 | which are the preferred options for submitting large number of jobs in parallel, since they are better handled by the scheduler. 31 | The `slurm_array_parallelism` parameter of `executor.update_parameters` controls how many jobs will be able to run in parallel on Slurm cluster. 32 | - `update_parameters(**kwargs)`: sets or updates parameters for the cluster. Only a subset is implemented but 33 | it can be easily improved with more parameters. We have homogenized some parameter names, to use the same 34 | parameters for slurm and other clusters (eg, use `gpus_per_node=2`, that historically corresponds to `--gres=gpu:2` for slurm, but is now `--gpus-per-node` as well). 35 | If you misspell a name, the function will raise an exception with all allowed parameters (this can be useful if you are looking for 36 | an argument ;) ) 37 | 38 | `submitit` has a plugin system so that several executor implementations can be provided. There are currently several implementations: 39 | - `AutoExecutor` which **we advise to always use** for submititting to clusters. This executor chooses the best available plugin to use depending on your environment. The aim is to be able to use the same code an several clusters. 40 | - `SlurmExecutor` which only works for slurm, and should be used through `AutoExecutor`. 41 | - `LocalExecutor` which provides a way to test job submission locally through multiprocessing. 42 | - `DebugExecutor` which mocks job submission and does all the computation in the same process. 43 | 44 | 45 | ### Job 46 | 47 | Jobs are processes running on the cluster. The `Job` class implements methods for checking the state and the results, as well as 48 | raised exceptions within the job. This class tends to replicate the main element of the `concurrent.Future` API and adds some more. 49 | Its main methods and attributes are: 50 | - `job_id`: ID of the job in slurm (`str`). 51 | - `state`: the current state of the job (Eg: `RUNNING`, `PENDING`). 52 | - `done`: whether your job is finished. 53 | - `result()`: waits for completion and returns the result of the computation, or raises an exception if it failed. 54 | - `cancel()`: cancels the job 55 | - `()`: returns the text contained in stdout/stderr logs. 56 | - `submission()`: returns the content of your submission (see `DelayedSubmission` object below) 57 | 58 | ### Job environment 59 | 60 | `submitit.JobEnvironment` is a handle to access information relevant to the current job such as its id. It therefore has the following attributes: 61 | `job_id`, `num_tasks`, `num_nodes`, `node`, `global_rank`, `local_rank`. 62 | 63 | ### helpers 64 | 65 | This module implements convenient functions/classes for use with `submitit`: 66 | - `CommandFunction`: a class transforming a shell command into a function, so as to be able to submit it as well (see examples below). 67 | - `Checkpointable`: base class implementing a very basic example of checkpointing (`checkpoint` method). More on this on the [Checkpointing section](https://github.com/facebookincubator/submitit/blob/main/docs/checkpointing.md). 68 | - `FunctionSequence`: A function that computes sequentially the output of other functions. This can be used 69 | to compute several independent results sequentially on a unique job, and it implements checkpointing for free. 70 | - `RsyncSnapshot`: A context manager that creates a snapshot of the git repository that the script lives in 71 | when creating the snapshot. This is useful for ensuring that remote jobs that get launched don't accidentally 72 | pick up unintended local changes. 73 | 74 | 75 | ### DelayedSubmission 76 | 77 | This is the class which contains all information about the job. You will only have to deal with it if you do 78 | custom checkpointing (see below). Its main attributes are: 79 | - `function`: the function (or callable) to be called 80 | - `args`: the positional arguments 81 | - `kwargs`: the keyword arguments 82 | It is basically used exactly as the `submit` method of an `executor`: `DelayedSubmission(func, arg1, arg2, kwarg1=12)` 83 | -------------------------------------------------------------------------------- /docs/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Arthur Mensch 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the BSD 3-clauses license. 5 | # Original at https://scikit-learn.org/stable/auto_examples/linear_model/plot_sparse_logistic_regression_mnist.html 6 | # 7 | 8 | import functools 9 | import pickle 10 | import time 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | from sklearn.datasets import fetch_openml 15 | from sklearn.linear_model import LogisticRegression 16 | from sklearn.model_selection import train_test_split 17 | from sklearn.preprocessing import StandardScaler 18 | from sklearn.utils import check_random_state 19 | 20 | import submitit 21 | 22 | 23 | class MnistTrainer(submitit.helpers.Checkpointable): 24 | """ 25 | This shows how to rewrite a monolith function so that it can handle preemption nicely, 26 | and not restart from scratch everytime it's preempted. 27 | """ 28 | 29 | def __init__(self, clf): 30 | # This is the state that will be saved by `checkpoint` 31 | self.train_test = None 32 | self.scaler = None 33 | self.clf = clf 34 | self.trained_clf = False 35 | self.stage = "0" 36 | 37 | def __call__(self, train_samples: int, model_path: Path = None): 38 | # `train_samples` and `model_path` will also be saved 39 | log = functools.partial(print, flush=True) 40 | log(f"*** Starting from stage '{self.stage}' ***") 41 | 42 | if self.train_test is None: 43 | self.stage = "Data Loading" 44 | t0 = time.time() 45 | log(f"*** Entering stage '{self.stage}' ***") 46 | # Load data from https://www.openml.org/d/554 47 | X, y = fetch_openml("mnist_784", version=1, return_X_y=True) 48 | X, y = X.numpy(), y.numpy() 49 | 50 | random_state = check_random_state(0) 51 | permutation = random_state.permutation(X.shape[0]) 52 | X = X[permutation] 53 | y = y[permutation] 54 | X = X.reshape((X.shape[0], -1)) 55 | 56 | # Checkpoint 1: save the train/test splits 57 | X_train, X_test, y_train, y_test = train_test_split( 58 | X, y, train_size=train_samples, test_size=10000 59 | ) 60 | self.train_test = X_train, X_test, y_train, y_test 61 | log(f"Loaded data, shuffle and split in {time.time() - t0:.1f}s") 62 | 63 | X_train, X_test, y_train, y_test = self.train_test 64 | if self.scaler is None: 65 | self.stage = "Data Cleaning" 66 | t0 = time.time() 67 | log(f"*** Entering stage '{self.stage}' ***") 68 | scaler = StandardScaler() 69 | X_train = scaler.fit_transform(X_train) 70 | X_test = scaler.transform(X_test) 71 | # Scaling is actual pretty fast, make it a bit slower to allow preemption to happen here 72 | time.sleep(10) 73 | # Checkpoint 2: save the scaler and the preprocessed data 74 | self.scaler = scaler 75 | self.train_test = X_train, X_test, y_train, y_test 76 | log(f"Scaled the data took {time.time() - t0:.0f}s") 77 | 78 | if not self.trained_clf: 79 | self.stage = "Model Training" 80 | t0 = time.time() 81 | log(f"*** Entering stage '{self.stage}' ***") 82 | self.clf.C = 50 / train_samples 83 | self.clf.fit(X_train, y_train) 84 | # Checkpoint 3: mark the classifier as trained 85 | self.trained_clf = True 86 | log(f"Training took {time.time() - t0:.0f}s") 87 | 88 | sparsity = np.mean(self.clf.coef_ == 0) * 100 89 | score = self.clf.score(X_test, y_test) 90 | log(f"Sparsity with L1 penalty: {sparsity / 100:.2%}") 91 | log(f"Test score with L1 penalty: {score:.4f}") 92 | 93 | if model_path: 94 | self.save(model_path) 95 | return score 96 | 97 | def checkpoint(self, *args, **kwargs): 98 | print(f"Checkpointing at stage '{self.stage}'") 99 | return super().checkpoint(*args, **kwargs) 100 | 101 | def save(self, model_path: Path): 102 | with open(model_path, "wb") as o: 103 | pickle.dump((self.scaler, self.clf), o, pickle.HIGHEST_PROTOCOL) 104 | 105 | 106 | def main(): 107 | t0 = time.time() 108 | # Cleanup log folder. 109 | # This folder may grow rapidly especially if you have large checkpoints, 110 | # or submit lot of jobs. You should think about an automated way of cleaning it. 111 | folder = Path(__file__).parent / "mnist_logs" 112 | if folder.exists(): 113 | for file in folder.iterdir(): 114 | file.unlink() 115 | 116 | ex = submitit.AutoExecutor(folder) 117 | if ex.cluster == "slurm": 118 | print("Executor will schedule jobs on Slurm.") 119 | else: 120 | print(f"!!! Slurm executable `srun` not found. Will execute jobs on '{ex.cluster}'") 121 | 122 | model_path = folder / "model.pkl" 123 | trainer = MnistTrainer(LogisticRegression(penalty="l1", solver="saga", tol=0.1, multi_class="auto")) 124 | 125 | # Specify the job requirements. 126 | # Reserving only as much resource as you need ensure the cluster resource are 127 | # efficiently allocated. 128 | ex.update_parameters(mem_gb=1, cpus_per_task=4, timeout_min=5) 129 | job = ex.submit(trainer, 5000, model_path=model_path) 130 | 131 | print(f"Scheduled {job}.") 132 | 133 | # Wait for the job to be running. 134 | while job.state != "RUNNING": 135 | time.sleep(1) 136 | 137 | print("Run the following command to see what's happening") 138 | print(f" less +F {job.paths.stdout}") 139 | 140 | # Simulate preemption. 141 | # Tries to stop the job after the first stage. 142 | # If the job is preempted before the end of the first stage, try to increase it. 143 | # If the job is not preempted, try to decrease it. 144 | time.sleep(25) 145 | print(f"preempting {job} after {time.time() - t0:.0f}s") 146 | job._interrupt() 147 | 148 | score = job.result() 149 | print(f"Finished training. Final score: {score}.") 150 | print(f"---------------- Job output ---------------------") 151 | print(job.stdout()) 152 | print(f"-------------------------------------------------") 153 | 154 | assert model_path.exists() 155 | with open(model_path, "rb") as f: 156 | (scaler, clf) = pickle.load(f) 157 | sparsity = np.mean(clf.coef_ == 0) * 100 158 | print(f"Sparsity with L1 penalty: {sparsity / 100:.2%}") 159 | 160 | 161 | if __name__ == "__main__": 162 | main() 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![CircleCI](https://circleci.com/gh/facebookincubator/submitit.svg?style=svg)](https://circleci.com/gh/facebookincubator/workflows/submitit) 2 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 3 | [![Pypi](https://img.shields.io/pypi/v/submitit)](https://pypi.org/project/submitit/) 4 | [![conda-forge](https://img.shields.io/conda/vn/conda-forge/submitit)](https://anaconda.org/conda-forge/submitit) 5 | # Submit it! 6 | 7 | ## What is submitit? 8 | 9 | Submitit is a lightweight tool for submitting Python functions for computation within a Slurm cluster. 10 | It basically wraps submission and provide access to results, logs and more. 11 | [Slurm](https://slurm.schedmd.com/quickstart.html) is an open source, fault-tolerant, and highly scalable cluster management and job scheduling system for large and small Linux clusters. 12 | Submitit allows to switch seamlessly between executing on Slurm or locally. 13 | 14 | ### An example is worth a thousand words: performing an addition 15 | 16 | From inside an environment with `submitit` installed: 17 | 18 | ```python 19 | import submitit 20 | 21 | def add(a, b): 22 | return a + b 23 | 24 | # executor is the submission interface (logs are dumped in the folder) 25 | executor = submitit.AutoExecutor(folder="log_test") 26 | # set timeout in min, and partition for running the job 27 | executor.update_parameters(timeout_min=1, slurm_partition="dev") 28 | job = executor.submit(add, 5, 7) # will compute add(5, 7) 29 | print(job.job_id) # ID of your job 30 | 31 | output = job.result() # waits for completion and returns output 32 | assert output == 12 # 5 + 7 = 12... your addition was computed in the cluster 33 | ``` 34 | 35 | The `Job` class also provides tools for reading the log files (`job.stdout()` and `job.stderr()`). 36 | 37 | If what you want to run is a command, turn it into a Python function using `submitit.helpers.CommandFunction`, then submit it. 38 | By default stdout is silenced in `CommandFunction`, but it can be unsilenced with `verbose=True`. 39 | 40 | **Find more examples [here](docs/examples.md)!!!** 41 | 42 | Submitit is a Python 3.8+ toolbox for submitting jobs to Slurm. 43 | It aims at running python function from python code. 44 | 45 | 46 | ## Install 47 | 48 | Quick install, in a virtualenv/conda environment where `pip` is installed (check `which pip`): 49 | - stable release: 50 | ``` 51 | pip install submitit 52 | ``` 53 | - stable release using __conda__: 54 | ``` 55 | conda install -c conda-forge submitit 56 | ``` 57 | - main branch: 58 | ``` 59 | pip install git+https://github.com/facebookincubator/submitit@main#egg=submitit 60 | ``` 61 | 62 | You can try running the [MNIST example](docs/mnist.py) to check that everything is working as expected (requires sklearn). 63 | 64 | 65 | ## Documentation 66 | 67 | See the following pages for more detailled information: 68 | 69 | - [Examples](docs/examples.md): for a bunch of examples dealing with errors, concurrency, multi-tasking etc... 70 | - [Structure and main objects](docs/structure.md): to get a better understanding of how `submitit` works, which files are created for each job, and the main objects you will interact with. 71 | - [Checkpointing](docs/checkpointing.md): to understand how you can configure your job to get checkpointed when preempted and/or timed-out. 72 | - [Tips and caveats](docs/tips.md): for a bunch of information that can be handy when working with `submitit`. 73 | - [Hyperparameter search with nevergrad](docs/nevergrad.md): basic example of `nevergrad` usage and how it interfaces with `submitit`. 74 | 75 | 76 | ### Goals 77 | 78 | The aim of this Python3 package is to be able to launch jobs on Slurm painlessly from *inside Python*, using the same submission and job patterns than the standard library package `concurrent.futures`: 79 | 80 | Here are a few benefits of using this lightweight package: 81 | - submit any function, even lambda and script-defined functions. 82 | - raises an error with stack trace if the job failed. 83 | - requeue preempted jobs (Slurm only) 84 | - swap between `submitit` executor and one of `concurrent.futures` executors in a line, so that it is easy to run your code either on slurm, or locally with multithreading for instance. 85 | - checkpoints stateful callables when preempted or timed-out and requeue from current state (advanced feature). 86 | - easy access to task local/global rank for multi-nodes/tasks jobs. 87 | - same code can work for different clusters thanks to a plugin system. 88 | 89 | Submitit is used by FAIR researchers on the FAIR cluster. 90 | The defaults are chosen to make their life easier, and might not be ideal for every cluster. 91 | 92 | ### Non-goals 93 | 94 | - a commandline tool for running slurm jobs. Here, everything happens inside Python. To this end, you can however use [Hydra](https://hydra.cc/)'s [submitit plugin](https://hydra.cc/docs/next/plugins/submitit_launcher) (version >= 1.0.0). 95 | - a task queue, this only implements the ability to launch tasks, but does not schedule them in any way. 96 | - being used in Python2! This is a Python3.8+ only package :) 97 | 98 | 99 | ### Comparison with dask.distributed 100 | 101 | [`dask`](https://distributed.dask.org/en/latest/) is a nice framework for distributed computing. `dask.distributed` provides the same `concurrent.futures` executor API as `submitit`: 102 | 103 | ```python 104 | from distributed import Client 105 | from dask_jobqueue import SLURMCluster 106 | cluster = SLURMCluster(processes=1, cores=2, memory="2GB") 107 | cluster.scale(2) # this may take a few seconds to launch 108 | executor = Client(cluster) 109 | executor.submit(...) 110 | ``` 111 | 112 | The key difference with `submitit` is that `dask.distributed` distributes the jobs to a pool of workers (see the `cluster` variable above) while `submitit` jobs are directly jobs on the cluster. In that sense `submitit` is a lower level interface than `dask.distributed` and you get more direct control over your jobs, including individual `stdout` and `stderr`, and possibly checkpointing in case of preemption and timeout. On the other hand, you should avoid submitting multiple small tasks with `submitit`, which would create many independent jobs and possibly overload the cluster, while you can do it without any problem through `dask.distributed`. 113 | 114 | 115 | ## Contributors 116 | 117 | By chronological order: Jérémy Rapin, Louis Martin, Lowik Chanussot, Lucas Hosseini, Fabio Petroni, Francisco Massa, Guillaume Wenzek, Thibaut Lavril, Vinayak Tantia, Andrea Vedaldi, Max Nickel, Quentin Duval, Rushil Patel (feel free to [contribute](.github/CONTRIBUTING.md) and add your name ;) ) 118 | 119 | ## License 120 | 121 | Submitit is released under the [MIT License](LICENSE). 122 | -------------------------------------------------------------------------------- /submitit/local/test_local.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import functools 8 | import os 9 | import pickle 10 | import re 11 | import signal 12 | import sys 13 | import time 14 | from pathlib import Path 15 | 16 | import pytest 17 | 18 | from submitit import AutoExecutor 19 | 20 | from .. import helpers 21 | from ..core import job_environment, test_core, utils 22 | from . import local, test_debug 23 | 24 | 25 | def test_local_job(tmp_path: Path) -> None: 26 | def func(p: int) -> int: 27 | job_env = job_environment.JobEnvironment() 28 | return p * job_env.local_rank 29 | 30 | executor = local.LocalExecutor(tmp_path) 31 | executor.update_parameters(tasks_per_node=3, nodes=1) 32 | job1 = executor.submit(func, 1) 33 | 34 | executor.update_parameters(tasks_per_node=1) 35 | 36 | with executor.batch(): 37 | with pytest.raises(RuntimeError, match="with executor.batch"): 38 | executor.update_parameters(tasks_per_node=1) 39 | job2 = executor.submit(func, 2) 40 | assert job1.results() == [0, 1, 2] 41 | assert job1.task(1).result() == 1 42 | assert job1.task(2).result() == 2 43 | assert job1.task(2).result() == 2 44 | assert job1.exception() is None 45 | assert job1.done() 46 | 47 | with pytest.raises(ValueError, match="must be between 0 and 2"): 48 | job1.task(4).result() 49 | 50 | assert job2.results() == [0] 51 | assert job2.task(0).result() == 0 52 | # single task job is a regular job 53 | assert job2.task(0) is job2 54 | assert job2.done() 55 | # picklability 56 | b = pickle.dumps(job2) 57 | job3 = pickle.loads(b) 58 | assert job3.results() == [0] 59 | assert job3._process is not None 60 | del job2 61 | job3 = pickle.loads(b) 62 | assert job3._process is None, "garbage collection should I removed finished job" 63 | 64 | 65 | def test_local_map_array(tmp_path: Path) -> None: 66 | g = test_debug.CheckFunction(5) 67 | executor = local.LocalExecutor(tmp_path) 68 | jobs = executor.map_array(g, g.data1, g.data2) 69 | assert list(map(g, g.data1, g.data2)) == [j.result() for j in jobs] 70 | 71 | 72 | def test_local_submit_array(tmp_path: Path) -> None: 73 | g = test_debug.CheckFunction(5) 74 | fns = [functools.partial(g, x, y) for x, y in zip(g.data1, g.data2)] 75 | executor = local.LocalExecutor(tmp_path) 76 | jobs = executor.submit_array(fns) 77 | assert list(map(g, g.data1, g.data2)) == [j.result() for j in jobs] 78 | 79 | 80 | def test_local_error(tmp_path: Path) -> None: 81 | def failing_job() -> None: 82 | raise RuntimeError("Failed on purpose") 83 | 84 | executor = local.LocalExecutor(tmp_path) 85 | job = executor.submit(failing_job) 86 | exception = job.exception() 87 | assert isinstance(exception, utils.FailedJobError) 88 | traceback = exception.args[0] 89 | assert "Traceback" in traceback 90 | assert "Failed on purpose" in traceback 91 | 92 | 93 | def test_pickle_output_from_main(tmp_path: Path) -> None: 94 | class MyClass: 95 | pass 96 | 97 | executor = local.LocalExecutor(tmp_path) 98 | job = executor.submit(MyClass.__call__) 99 | assert isinstance(job.result(), MyClass) 100 | 101 | 102 | def test_get_first_task_error(tmp_path: Path) -> None: 103 | def flaky() -> None: 104 | job_env = job_environment.JobEnvironment() 105 | if job_env.local_rank > 0: 106 | raise RuntimeError(f"Failed on purpose: {job_env.local_rank}") 107 | 108 | executor = local.LocalExecutor(tmp_path) 109 | executor.update_parameters(tasks_per_node=3, nodes=1) 110 | job = executor.submit(flaky) 111 | exception = job.exception() 112 | assert isinstance(exception, utils.FailedJobError) 113 | traceback = exception.args[0] 114 | assert "Traceback" in traceback 115 | assert "Failed on purpose: 1" in traceback 116 | 117 | 118 | def test_stdout(tmp_path: Path) -> None: 119 | def hello() -> None: 120 | job_env = job_environment.JobEnvironment() 121 | print("hello from", job_env.local_rank) 122 | print("bye from", job_env.local_rank, file=sys.stderr) 123 | 124 | executor = local.LocalExecutor(tmp_path) 125 | executor.update_parameters(tasks_per_node=2, nodes=1) 126 | job = executor.submit(hello) 127 | 128 | job.wait() 129 | stdout = job.stdout() 130 | assert stdout is not None 131 | assert "hello from 0\n" in stdout 132 | assert "hello from 1\n" in stdout 133 | 134 | stderr = job.stderr() 135 | assert stderr is not None 136 | assert "bye from 0\n" in stderr 137 | assert "bye from 1\n" in stderr 138 | 139 | 140 | def test_killed(tmp_path: Path) -> None: 141 | def failing_job() -> None: 142 | time.sleep(120) 143 | raise RuntimeError("Failed on purpose") 144 | 145 | executor = local.LocalExecutor(tmp_path) 146 | job = executor.submit(failing_job) 147 | assert job.state == "RUNNING" 148 | job._process.send_signal(signal.SIGKILL) # type: ignore 149 | time.sleep(1) 150 | assert job.state == "INTERRUPTED" 151 | 152 | 153 | @pytest.mark.skipif(not os.environ.get("SUBMITIT_SLOW_TESTS", False), reason="slow") # type: ignore 154 | def test_long_running_job(tmp_path: Path) -> None: 155 | def f(x: int, y: int, sleep: int = 120) -> int: 156 | time.sleep(sleep) 157 | return x + y 158 | 159 | executor = local.LocalExecutor(tmp_path) 160 | executor.update_parameters(timeout_min=5) 161 | job = executor.submit(f, 40, 2) 162 | assert job.result() == 42 163 | 164 | 165 | def test_requeuing(tmp_path: Path) -> None: 166 | func = helpers.FunctionSequence(verbose=True) 167 | for x in range(20): 168 | func.add(test_core.do_nothing, x=x, sleep=1) 169 | executor = local.LocalExecutor(tmp_path, max_num_timeout=1) 170 | executor.update_parameters(timeout_min=3 / 60, signal_delay_s=1) 171 | job = executor.submit(func) 172 | job.wait() 173 | stdout = job.stdout() 174 | assert stdout is not None 175 | match = re.search(r"Starting from [123]/20", stdout) 176 | assert match, f"Should have resumed from a checkpoint:\n{stdout}" 177 | assert "timed-out too many times" in stdout, f"Unexpected stdout:\n{stdout}" 178 | assert "(0 remaining timeouts)" in stdout, f"Unexpected stdout:\n{stdout}" 179 | 180 | 181 | def test_custom_checkpoint(tmp_path: Path) -> None: 182 | class Slacker(helpers.Checkpointable): 183 | def __call__(self, slack: bool = True): 184 | if slack: 185 | print("Slacking", flush=True) 186 | time.sleep(10) 187 | raise RuntimeError("I really don't want to work") 188 | print("Working hard", flush=True) 189 | return "worked hard" 190 | 191 | def __submitit_checkpoint__(self, slack: bool = True): 192 | if slack: 193 | print("Interrupted while slacking. I won't slack next time.", flush=True) 194 | return utils.DelayedSubmission(self, slack=False) 195 | 196 | executor = local.LocalExecutor(tmp_path, max_num_timeout=1) 197 | executor.update_parameters(timeout_min=2 / 60, signal_delay_s=1) 198 | job = executor.submit(Slacker(True)) 199 | job.wait() 200 | stdout = job.stdout() 201 | assert stdout 202 | assert "I won't slack next time." in stdout 203 | 204 | 205 | def test_make_subprocess(tmp_path: Path) -> None: 206 | process = local.start_controller( 207 | tmp_path, "python -c 'import os;print(os.environ[\"SUBMITIT_LOCAL_JOB_ID\"])'", timeout_min=1 208 | ) 209 | paths = utils.JobPaths(tmp_path, str(process.pid), 0) 210 | pg = process.pid 211 | process.wait() 212 | stdout = paths.stdout.read_text() 213 | stderr = paths.stderr.read_text() 214 | assert stdout and int(stdout.strip()) == pg, f"PID link is broken (stderr: {stderr})" 215 | 216 | 217 | def test_cancel(tmp_path: Path) -> None: 218 | executor = local.LocalExecutor(tmp_path) 219 | job = executor.submit(time.sleep, 10) 220 | assert job.state == "RUNNING" 221 | job.cancel() 222 | time.sleep(0.1) 223 | # Note: with local job we don't have a precise status. 224 | assert job.state == "INTERRUPTED" 225 | 226 | job = executor.submit(time.sleep, 10) 227 | process = job._process # type: ignore 228 | job.cancel_at_deletion() 229 | assert job.state == "RUNNING" 230 | assert process.poll() is None 231 | del job 232 | time.sleep(0.1) 233 | assert process.poll() == -2 234 | 235 | 236 | def f66(x: int, y: int = 0) -> int: # pylint: disable=unused-argument 237 | return 66 238 | 239 | 240 | def test_setup(tmp_path: Path) -> None: 241 | executor = AutoExecutor(tmp_path, cluster="local") 242 | setup_file = tmp_path / "setup_done" 243 | executor.update_parameters(local_setup=[f"touch {setup_file}"]) 244 | job = executor.submit(f66, 12) 245 | time.sleep(1) 246 | assert job.result() == 66 247 | assert setup_file.exists() 248 | 249 | 250 | def test_load_submission(tmp_path: Path) -> None: 251 | """Check we can load submission just from a path and job id.""" 252 | executor = local.LocalExecutor(tmp_path) 253 | job = executor.submit(f66, 67, y=68) 254 | 255 | submission = local.LocalJob(tmp_path, job.job_id).submission() 256 | # It's important that f66 isn't a local function for the equality to work 257 | assert submission.function is f66 258 | assert submission.args == (67,) 259 | assert submission.kwargs == {"y": 68} 260 | # Loading submission doesn't evaluate them. 261 | assert submission._result is None 262 | 263 | 264 | def test_weird_dir(weird_tmp_path: Path) -> None: 265 | executor = local.LocalExecutor(weird_tmp_path / "%j") 266 | executor.submit(f66, 67, 68).result() 267 | -------------------------------------------------------------------------------- /submitit/auto/auto.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import typing as tp 8 | import warnings 9 | from pathlib import Path 10 | from typing import Any, List, Optional, Type, Union 11 | 12 | from ..core import plugins 13 | from ..core.core import Executor, Job 14 | from ..core.utils import DelayedSubmission 15 | 16 | 17 | def _convert_deprecated_args(kwargs: tp.Dict[str, Any], deprecated_args: tp.Mapping[str, str]) -> None: 18 | for arg in list(kwargs): 19 | new_arg = deprecated_args.get(arg) 20 | if not new_arg: 21 | continue 22 | kwargs[new_arg] = kwargs.pop(arg) 23 | warnings.warn(f"Setting '{arg}' is deprecated. Use '{new_arg}' instead.") 24 | 25 | 26 | class AutoExecutor(Executor): 27 | """Automatic job executor 28 | This class is used to hold the parameters to run a job on the cluster 29 | corresponding to the environment. 30 | It can also be used to run job locally or in debug mode. 31 | In practice, it will create a bash file in the specified directory for each job, 32 | and pickle the task function and parameters. At completion, the job will also pickle 33 | the output. Logs are also dumped in the same directory. 34 | 35 | Executor specific parameters must be specified by prefixing them with the name 36 | of the executor they refer to. eg: 37 | - 'chronos_conda_file' (internal) 38 | - 'slurm_max_num_timeout' 39 | See each executor documentation for the list of available parameters. 40 | 41 | Parameters 42 | ---------- 43 | folder: Path/str 44 | folder for storing job submission/output and logs. 45 | warn_ignored: bool 46 | prints a warning each time a parameter is provided but ignored because it is only 47 | useful for the other cluster. 48 | cluster: str 49 | Forces AutoExecutor to use the given environment. Use "local" to run jobs locally, 50 | "debug" to run jobs in process. 51 | kwargs: other arguments must be prefixed by the name of the executor they refer to. 52 | {exname}_{argname}: see {argname} documentation in {Exname}Executor documentation. 53 | 54 | Note 55 | ---- 56 | - be aware that the log/output folder will be full of logs and pickled objects very fast, 57 | it may need cleaning. 58 | - use update_parameters to specify custom parameters (gpus_per_node etc...). If you 59 | input erroneous parameters, an error will print all parameters available for you. 60 | """ 61 | 62 | _ctor_deprecated_args = {"max_num_timeout": "slurm_max_num_timeout", "conda_file": "chronos_conda_file"} 63 | 64 | def __init__(self, folder: Union[str, Path], cluster: Optional[str] = None, **kwargs: Any) -> None: 65 | self.cluster = cluster or self.which() 66 | 67 | executors = plugins.get_executors() 68 | if self.cluster not in executors: 69 | raise ValueError(f"AutoExecutor doesn't know any executor named {self.cluster}") 70 | 71 | _convert_deprecated_args(kwargs, self._ctor_deprecated_args) 72 | err = "Extra arguments must be prefixed by executor named, received unknown arg" 73 | err_ex_list = f"Known executors: {', '.join(executors)}." 74 | for name in kwargs: 75 | assert "_" in name, f"{err} '{name}'. {err_ex_list}" 76 | prefix = name.split("_")[0] 77 | assert ( 78 | prefix in executors 79 | ), f"{err} '{name}', and '{prefix}' executor is also unknown. {err_ex_list}" 80 | self._executor = flexible_init(executors[self.cluster], folder, **kwargs) 81 | 82 | valid = self._valid_parameters() 83 | self._deprecated_args = { 84 | arg: f"{ex_name}_{arg}" 85 | for ex_name, ex in executors.items() 86 | for arg in ex._valid_parameters() 87 | if arg not in valid 88 | } 89 | super().__init__(self._executor.folder, self._executor.parameters) 90 | 91 | @staticmethod 92 | def which() -> str: 93 | """Returns what is the detected cluster.""" 94 | executors = plugins.get_executors() 95 | best_ex = max(executors, key=lambda ex: executors[ex].affinity()) 96 | 97 | if executors[best_ex].affinity() <= 0: 98 | raise RuntimeError(f"Did not found an available executor among {executors.keys()}.") 99 | 100 | return best_ex 101 | 102 | def register_dev_folders(self, folders: List[Union[str, Path]]) -> None: 103 | """Archive a list of folders to be untarred in the job working directory. 104 | This is only implemented for internal cluster, for running job on non-installed packages. 105 | This is not useful on slurm since the working directory of jobs is identical to 106 | your work station working directory. 107 | 108 | folders: list of paths 109 | The list of folders to archive and untar in the job working directory 110 | """ 111 | register = getattr(self._executor, "register_dev_folders", None) 112 | if register is not None: 113 | register(folders) 114 | else: 115 | # TODO this should be done through update parameters 116 | warnings.warn( 117 | "Ignoring dev folder registration as it is only supported (and needed) for internal cluster" 118 | ) 119 | 120 | @classmethod 121 | def _typed_parameters(cls) -> tp.Dict[str, Type]: 122 | return { 123 | "name": str, 124 | "timeout_min": int, 125 | "mem_gb": float, 126 | "nodes": int, 127 | "cpus_per_task": int, 128 | "gpus_per_node": int, 129 | "tasks_per_node": int, 130 | "stderr_to_stdout": bool, 131 | } 132 | 133 | @classmethod 134 | def _valid_parameters(cls) -> tp.Set[str]: 135 | return set(cls._typed_parameters().keys()) 136 | 137 | def _internal_update_parameters(self, **kwargs: Any) -> None: 138 | """Updates submission parameters to srun/crun. 139 | 140 | Parameters 141 | ---------- 142 | AutoExecutors provides shared parameters that are translated for each specific cluster. 143 | Those are: timeout_min (int), mem_gb (int), gpus_per_node (int), cpus_per_task (int), 144 | nodes (int), tasks_per_node (int) and name (str). 145 | Cluster specific parameters can be specified by prefixing them with the cluster name. 146 | 147 | Notes 148 | ----- 149 | - Cluster specific parameters win over shared parameters. 150 | eg: if both `slurm_time` and `timeout_min` are provided, then: 151 | - `slurm_time` is used on the slurm cluster 152 | - `timeout_min` is used on other clusters 153 | """ 154 | # We handle None as not set. 155 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 156 | # check type of replaced variables 157 | generics = AutoExecutor._typed_parameters() 158 | for name, expected_type in generics.items(): 159 | if expected_type == float: 160 | expected_type = (int, float) # type: ignore 161 | if name in kwargs: 162 | assert isinstance(kwargs[name], expected_type), ( 163 | f'Parameter "{name}" expected type {expected_type} ' f'(but value: "{kwargs[name]}")' 164 | ) 165 | 166 | _convert_deprecated_args(kwargs, self._deprecated_args) 167 | specific = [x.split("_", 1) for x in kwargs if x not in generics] 168 | 169 | invalid = [] 170 | executors = plugins.get_executors() 171 | for ex_arg in specific: 172 | if len(ex_arg) != 2: 173 | invalid.append(f"Parameter '{ex_arg[0]}' need to be prefixed by an executor name.") 174 | continue 175 | ex, arg = ex_arg 176 | 177 | if ex not in executors: 178 | invalid.append(f"Unknown executor '{ex}' in parameter '{ex}_{arg}'.") 179 | continue 180 | 181 | valid = executors[ex]._valid_parameters() 182 | if arg not in valid and arg not in generics: 183 | invalid.append( 184 | f"Unknown argument '{arg}' for executor '{ex}' in parameter '{ex}_{arg}'." 185 | + " Valid arguments: " 186 | + ", ".join(valid) 187 | ) 188 | continue 189 | if invalid: 190 | invalid.extend( 191 | [ 192 | f"Known executors: {', '.join(executors.keys())}", 193 | f"As a reminder, shared/generic (non-prefixed) parameters are: {generics}.", 194 | ] 195 | ) 196 | raise NameError("\n".join(invalid)) 197 | 198 | # add cluster specific generic overrides 199 | kwargs.update( 200 | **{ 201 | arg: kwargs.pop(f"{ex}_{arg}") 202 | for ex, arg in specific 203 | if ex == self.cluster and arg in generics 204 | } 205 | ) 206 | parameters = self._executor._convert_parameters({k: kwargs[k] for k in kwargs if k in generics}) 207 | # update parameters in the core executor 208 | for ex, arg in specific: 209 | # update cluster specific non-generic arguments 210 | if arg not in generics and ex == self.cluster: 211 | parameters[arg] = kwargs[f"{ex}_{arg}"] 212 | 213 | self._executor._internal_update_parameters(**parameters) 214 | 215 | def _internal_process_submissions( 216 | self, delayed_submissions: tp.List[DelayedSubmission] 217 | ) -> tp.List[Job[tp.Any]]: 218 | return self._executor._internal_process_submissions(delayed_submissions) 219 | 220 | 221 | def flexible_init(cls: Type[Executor], folder: Union[str, Path], **kwargs: Any) -> Executor: 222 | prefix = cls.name() + "_" 223 | return cls(folder, **{k[len(prefix) :]: val for k, val in kwargs.items() if k.startswith(prefix)}) 224 | -------------------------------------------------------------------------------- /submitit/core/test_core.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | # pylint: disable=redefined-outer-name 8 | import contextlib 9 | import pickle 10 | import subprocess 11 | import sys 12 | import time 13 | import typing as tp 14 | from pathlib import Path 15 | from unittest.mock import patch 16 | 17 | import pytest 18 | 19 | from . import core, submission, utils 20 | 21 | 22 | class MockedSubprocess: 23 | """Helper for mocking subprocess calls""" 24 | 25 | SACCT_HEADER = "JobID|State" 26 | SACCT_JOB = "{j}|{state}\n{j}.ext+|{state}\n{j}.0|{state}" 27 | 28 | def __init__(self, known_cmds: tp.Optional[tp.Sequence[str]] = None) -> None: 29 | self.job_sacct: tp.Dict[str, str] = {} 30 | self.last_job: str = "" 31 | self._subprocess_check_output = subprocess.check_output 32 | self.known_cmds = known_cmds or [] 33 | self.job_count = 12 34 | 35 | def __call__(self, command: tp.Sequence[str], **kwargs: tp.Any) -> bytes: 36 | program = command[0] 37 | if program in ["sacct", "sbatch", "scancel"]: 38 | return getattr(self, program)(command[1:]).encode() 39 | elif program == "tail": 40 | return self._subprocess_check_output(command, **kwargs) 41 | else: 42 | raise ValueError(f'Unknown command to mock "{command}".') 43 | 44 | def sacct(self, _: tp.Sequence[str]) -> str: 45 | return "\n".join(self.job_sacct.values()) 46 | 47 | def sbatch(self, args: tp.Sequence[str]) -> str: 48 | """Create a "RUNNING" job.""" 49 | job_id = str(self.job_count) 50 | self.job_count += 1 51 | sbatch_file = Path(args[0]) 52 | array = 0 53 | if sbatch_file.exists(): 54 | array_lines = [l for l in sbatch_file.read_text().splitlines() if "--array" in l] 55 | if array_lines: 56 | # SBATCH --array=0-4%3 57 | array = int(array_lines[0].split("=0-")[-1].split("%")[0]) 58 | array += 1 59 | self.set_job_state(job_id, "RUNNING", array) 60 | return f"Running job {job_id}\n" 61 | 62 | def scancel(self, _: tp.Sequence[str]) -> str: 63 | # TODO:should we call set_job_state ? 64 | return "" 65 | 66 | def set_job_state(self, job_id: str, state: str, array: int = 0) -> None: 67 | self.job_sacct[job_id] = self._sacct(state, job_id, array) 68 | self.last_job = job_id 69 | 70 | def _sacct(self, state: str, job_id: str, array: int) -> str: 71 | if array == 0: 72 | lines = self.SACCT_JOB.format(j=job_id, state=state) 73 | else: 74 | lines = "\n".join(self.SACCT_JOB.format(j=f"{job_id}_{i}", state=state) for i in range(array)) 75 | return "\n".join((self.SACCT_HEADER, lines)) 76 | 77 | def which(self, name: str) -> tp.Optional[str]: 78 | return "here" if name in self.known_cmds else None 79 | 80 | def mock_cmd_fn(self, *args, **_): 81 | # CommandFunction(cmd)() ~= subprocess.check_output(cmd) 82 | return lambda: self(*args) 83 | 84 | @contextlib.contextmanager 85 | def context(self) -> tp.Iterator[None]: 86 | with patch("submitit.core.utils.CommandFunction", new=self.mock_cmd_fn): 87 | with patch("subprocess.check_output", new=self): 88 | with patch("shutil.which", new=self.which): 89 | with patch("subprocess.check_call", new=self): 90 | yield None 91 | 92 | @contextlib.contextmanager 93 | def job_context(self, job_id: str) -> tp.Iterator[None]: 94 | with utils.environment_variables( 95 | _USELESS_TEST_ENV_VAR_="1", SUBMITIT_EXECUTOR="slurm", SLURM_JOB_ID=str(job_id) 96 | ): 97 | yield None 98 | 99 | 100 | class FakeInfoWatcher(core.InfoWatcher): 101 | # pylint: disable=abstract-method 102 | def get_state(self, job_id: str, mode: str = "standard") -> str: 103 | return "running" 104 | 105 | 106 | class FakeJob(core.Job[core.R]): 107 | watcher = FakeInfoWatcher() 108 | _cancel_at_deletion = False 109 | 110 | 111 | class FakeExecutor(core.PicklingExecutor): 112 | job_class = FakeJob 113 | 114 | @property 115 | def _submitit_command_str(self) -> str: 116 | return "echo 1" 117 | 118 | def _num_tasks(self) -> int: 119 | return 1 120 | 121 | def _make_submission_file_text(self, command: str, uid: str) -> str: # pylint: disable=unused-argument 122 | """Creates the text of a file which will be created and run 123 | for the submission (for slurm, this is sbatch file). 124 | """ 125 | return command + "2" # this makes "echo 12" 126 | 127 | def _make_submission_command(self, submission_file_path: Path) -> tp.List[str]: 128 | """Create the submission command.""" 129 | with submission_file_path.open("r") as f: 130 | text: str = f.read() 131 | return text.split() # this makes ["echo", "12"] 132 | 133 | @staticmethod 134 | def _get_job_id_from_submission_command(string: tp.Union[bytes, str]) -> str: 135 | return string if isinstance(string, str) else string.decode() # this returns "12" 136 | 137 | 138 | def _three_time(x: int) -> int: 139 | return 3 * x 140 | 141 | 142 | def do_nothing(*args: tp.Any, **kwargs: tp.Any) -> int: 143 | print("my args", args, flush=True) 144 | print("my kwargs", kwargs, flush=True) 145 | if "sleep" in kwargs: 146 | print("Waiting", flush=True) 147 | time.sleep(int(kwargs["sleep"])) 148 | if kwargs.get("error", False): 149 | print("Raising", flush=True) 150 | raise ValueError("Too bad") 151 | print("Finishing", flush=True) 152 | return 12 153 | 154 | 155 | def test_fake_job(tmp_path: Path) -> None: 156 | job: FakeJob[int] = FakeJob(job_id="12", folder=tmp_path) 157 | repr(job) 158 | assert not job.done(force_check=True) 159 | # logs 160 | assert job.stdout() is None 161 | assert job.stderr() is None 162 | with job.paths.stderr.open("w") as f: 163 | f.write("blublu") 164 | assert job.stderr() == "blublu" 165 | # result 166 | utils.cloudpickle_dump(("success", 12), job.paths.result_pickle) 167 | assert job.result() == 12 168 | # exception 169 | assert job.exception() is None 170 | utils.cloudpickle_dump(("error", "blublu"), job.paths.result_pickle) 171 | assert isinstance(job.exception(), Exception) 172 | with pytest.raises(core.utils.FailedJobError): 173 | job.result() 174 | 175 | 176 | def test_fake_job_cancel_at_deletion(tmp_path: Path) -> None: 177 | job: FakeJob[tp.Any] = FakeJob(job_id="12", folder=tmp_path).cancel_at_deletion() # type: ignore 178 | with patch("subprocess.call", return_value=None) as mock: 179 | assert mock.call_count == 0 180 | del job 181 | assert mock.call_count == 1 182 | 183 | 184 | def test_fake_executor(tmp_path: Path) -> None: 185 | executor = FakeExecutor(folder=tmp_path) 186 | job = executor.submit(_three_time, 8) 187 | assert job.job_id == "12" 188 | assert job.paths.submission_file.exists() 189 | with utils.environment_variables(_TEST_CLUSTER_="slurm", SLURM_JOB_ID=str(job.job_id)): 190 | submission.process_job(folder=job.paths.folder) 191 | assert job.result() == 24 192 | 193 | 194 | def test_fake_executor_batch(tmp_path: Path) -> None: 195 | executor = FakeExecutor(folder=tmp_path) 196 | with executor.batch(): 197 | job = executor.submit(_three_time, 8) 198 | assert isinstance(job, core.DelayedJob) 199 | assert isinstance(job, FakeJob) 200 | with executor.batch(): # make sure we can send a new batch 201 | job = executor.submit(_three_time, 8) 202 | assert isinstance(job, core.DelayedJob) 203 | assert isinstance(job, FakeJob) 204 | # bad update 205 | with pytest.raises(RuntimeError): 206 | with executor.batch(): 207 | executor.update_parameters(blublu=12) 208 | # bad access 209 | with pytest.raises(AttributeError): 210 | with executor.batch(): 211 | job = executor.submit(_three_time, 8) 212 | assert isinstance(job, core.DelayedJob) 213 | job.job_id # pylint: disable=pointless-statement 214 | assert isinstance(job, core.DelayedJob) 215 | 216 | with executor.batch(allow_implicit_submissions=True): 217 | job = executor.submit(_three_time, 8) 218 | assert isinstance(job, core.DelayedJob) 219 | job.job_id # pylint: disable=pointless-statement 220 | assert isinstance(job, FakeJob) 221 | assert not executor._delayed_batch 222 | 223 | # empty context 224 | with pytest.warns(RuntimeWarning): 225 | with executor.batch(): 226 | pass 227 | # multi context 228 | with pytest.raises(RuntimeError): 229 | with executor.batch(): 230 | with executor.batch(): 231 | job = executor.submit(_three_time, 8) 232 | assert isinstance(job, core.DelayedJob) 233 | assert isinstance(job, FakeJob) 234 | 235 | 236 | def test_unpickling_watcher_registration(tmp_path: Path) -> None: 237 | executor = FakeExecutor(folder=tmp_path) 238 | job = executor.submit(_three_time, 4) 239 | original_job_id = job._job_id 240 | job._job_id = "007" # pylint: disable=attribute-defined-outside-init 241 | assert job.watcher._registered == {original_job_id} # still holds the old job id 242 | pkl = pickle.dumps(job) 243 | newjob = pickle.loads(pkl) 244 | assert newjob.job_id == "007" 245 | assert newjob.watcher._registered == {original_job_id, "007"} 246 | 247 | 248 | def test_max_pickle_size_gb(tmp_path: Path) -> None: 249 | executor = FakeExecutor(folder=tmp_path, max_pickle_size_gb=0) 250 | with pytest.raises(RuntimeError): 251 | _ = executor.submit(_three_time, 4) 252 | 253 | 254 | if __name__ == "__main__": 255 | args, kwargs = [], {} # oversimplisitic parser 256 | for argv in sys.argv[1:]: 257 | if "=" in argv: 258 | key, val = argv.split("=") 259 | kwargs[key.strip("-")] = val 260 | else: 261 | args.append(argv) 262 | do_nothing(*args, **kwargs) 263 | -------------------------------------------------------------------------------- /submitit/core/job_environment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import os 8 | import signal 9 | import socket 10 | import sys 11 | import time 12 | import types 13 | import typing as tp 14 | from pathlib import Path 15 | 16 | from . import logger, utils 17 | from .utils import DelayedSubmission, JobPaths 18 | 19 | _PREEMPT_SIG_ENV = "SUBMITIT_PREEMPT_SIGNAL" 20 | 21 | 22 | class JobEnvironment: 23 | """Describe the environment inside which the job is running. 24 | This includes job id, number of GPUs available, ... 25 | 26 | This class can only be instantiated from a running submitit job. 27 | 28 | @plugin-dev: default implementation look for information into environment variables. 29 | Override _env to map environment variable to each property. 30 | """ 31 | 32 | # preemption signal uses USR2 as default, but this behavior 33 | # can be overiden (eg: export SUBMITIT_PREEMPT_SIGNAL=USR2) 34 | # CAUTION: NCCL may catch USR1 so it should be avoided 35 | USR_SIG = os.environ.get(_PREEMPT_SIG_ENV, "USR2") 36 | _env: tp.ClassVar[tp.Dict[str, str]] = {} 37 | 38 | def __new__(cls, *args: tp.Any) -> "JobEnvironment": 39 | if cls is not JobEnvironment: 40 | return super().__new__(cls, *args) # type: ignore 41 | 42 | from . import plugins # pylint: disable=cyclic-import,import-outside-toplevel 43 | 44 | return plugins.get_job_environment() 45 | 46 | def __init__(self) -> None: 47 | self.cluster = self.name() 48 | 49 | @classmethod 50 | def name(cls) -> str: 51 | n = cls.__name__ 52 | if n.endswith("JobEnvironment"): 53 | n = n[: -len("JobEnvironment")] 54 | return n.lower() 55 | 56 | @property 57 | def paths(self) -> JobPaths: 58 | """Provides the paths used by submitit, including 59 | stdout, stderr, submitted_pickle and folder. 60 | """ 61 | folder = os.environ["SUBMITIT_FOLDER"] 62 | return JobPaths(folder, job_id=self.job_id, task_id=self.global_rank) 63 | 64 | def activated(self) -> bool: 65 | """Tests if we are running inside this environment. 66 | 67 | @plugin-dev: assumes that the SUBMITIT_EXECUTOR variable has been 68 | set to the executor name 69 | """ 70 | return os.environ.get("SUBMITIT_EXECUTOR", "") == self.name() 71 | 72 | @property 73 | def hostname(self) -> str: 74 | return socket.gethostname() 75 | 76 | @property 77 | def hostnames(self) -> tp.Sequence[str]: 78 | return [self.hostname] 79 | 80 | @property 81 | def job_id(self) -> str: 82 | if self.array_job_id: 83 | return f"{self.array_job_id}_{self.array_task_id}" 84 | else: 85 | return self.raw_job_id 86 | 87 | @property 88 | def raw_job_id(self) -> str: 89 | return os.environ[self._env["job_id"]] 90 | 91 | @property 92 | def array_job_id(self) -> tp.Optional[str]: 93 | n = "array_job_id" 94 | return None if n not in self._env else os.environ.get(self._env[n], None) 95 | 96 | @property 97 | def array_task_id(self) -> tp.Optional[str]: 98 | n = "array_task_id" 99 | return None if n not in self._env else os.environ.get(self._env[n], None) 100 | 101 | @property 102 | def num_tasks(self) -> int: 103 | """Total number of tasks for the job""" 104 | return int(os.environ.get(self._env["num_tasks"], 1)) 105 | 106 | @property 107 | def num_nodes(self) -> int: 108 | """Total number of nodes for the job""" 109 | return int(os.environ.get(self._env["num_nodes"], 1)) 110 | 111 | @property 112 | def node(self) -> int: 113 | """Id of the current node""" 114 | return int(os.environ.get(self._env["node"], 0)) 115 | 116 | @property 117 | def global_rank(self) -> int: 118 | """Global rank of the task""" 119 | return int(os.environ.get(self._env["global_rank"], 0)) 120 | 121 | @property 122 | def local_rank(self) -> int: 123 | """Local rank of the task, ie on the current node.""" 124 | return int(os.environ.get(self._env["local_rank"], 0)) 125 | 126 | def __repr__(self) -> str: 127 | # should look like this: 128 | # JobEnvironment(job_id=17015819, hostname=learnfair0218, local_rank=2(3), node=1(2), global_rank=5(6)) 129 | info = [f"{n}={getattr(self, n)}" for n in ("job_id", "hostname")] 130 | names = ("local_rank", "node", "global_rank") 131 | totals = [self.num_tasks // self.num_nodes, self.num_nodes, self.num_tasks] 132 | info += [f"{n}={getattr(self, n)}({t})" for n, t in zip(names, totals)] 133 | info_str = ", ".join(info) 134 | return f"JobEnvironment({info_str})" 135 | 136 | @classmethod 137 | def _usr_sig(cls) -> tp.Any: 138 | name = "SIG" + cls.USR_SIG 139 | out = getattr(signal, name, None) 140 | if out is None: 141 | raise RuntimeError( 142 | f"Unknown signal {name}, you may need to unset or update env var {_PREEMPT_SIG_ENV} (Eg: USR2)" 143 | ) 144 | return out 145 | 146 | def _handle_signals(self, paths: JobPaths, submission: DelayedSubmission) -> None: 147 | """Set up signals handler for the current executable. 148 | 149 | The default implementation checkpoint the given submission and requeues it. 150 | @plugin-dev: Should be adapted to the signals used in this cluster. 151 | """ 152 | handler = SignalHandler(self, paths, submission) 153 | # A priori we don't need other signals anymore, 154 | # but still log them to make it easier to debug. 155 | signal.signal(signal.SIGTERM, handler.bypass) 156 | signal.signal(signal.SIGCONT, handler.bypass) 157 | # register user signal last just in case it overlaps with SIGTERM/SIGCONT 158 | signal.signal(self._usr_sig(), handler.checkpoint_and_try_requeue) 159 | 160 | # pylint: disable=unused-argument 161 | def _requeue(self, countdown: int) -> None: 162 | """Requeue the current job. 163 | 164 | @plugin-dev:Must be overridden by JobEnvironment implementations. 165 | Use self.job_id to find what need to be requeued. 166 | """ 167 | 168 | 169 | class SignalHandler: 170 | def __init__(self, env: JobEnvironment, job_paths: JobPaths, delayed: DelayedSubmission) -> None: 171 | self.env = env 172 | self._job_paths = job_paths 173 | self._delayed = delayed 174 | self._logger = logger.get_logger() 175 | self._start_time = time.time() 176 | 177 | def has_timed_out(self) -> bool: 178 | # SignalHandler is created by submitit as soon as the process start, 179 | # so _start_time is an accurate measure of the global runtime of the job. 180 | walltime = time.time() - self._start_time 181 | max_walltime = self._delayed._timeout_min * 60 182 | guaranteed_walltime = min(max_walltime * 0.8, max_walltime - 10 * 60) 183 | 184 | timed_out = walltime >= guaranteed_walltime 185 | if timed_out: 186 | self._logger.info( 187 | f"Job has timed out. Ran {walltime / 60:.0f} minutes out of requested {max_walltime / 60:.0f} minutes." 188 | ) 189 | else: 190 | self._logger.info( 191 | f"Job has not timed out. Ran {walltime / 60:.0f} minutes out of requested {max_walltime / 60:.0f} minutes." 192 | ) 193 | return timed_out 194 | 195 | # pylint:disable=unused-argument 196 | def bypass(self, signum: int, frame: tp.Optional[types.FrameType] = None) -> None: 197 | self._logger.warning(f"Bypassing signal {signal.Signals(signum).name}") 198 | 199 | # pylint:disable=unused-argument 200 | def checkpoint_and_try_requeue(self, signum: int, frame: tp.Optional[types.FrameType] = None) -> None: 201 | timed_out = self.has_timed_out() 202 | case = "timed-out" if timed_out else "preempted" 203 | self._logger.warning( 204 | f"Caught signal {signal.Signals(signum).name} on {socket.gethostname()}: this job is {case}." 205 | ) 206 | 207 | procid = self.env.global_rank 208 | if procid != 0: 209 | self._logger.info(f"Not checkpointing nor requeuing since I am a slave (procid={procid}).") 210 | # do not sys.exit, because it might kill the master task 211 | return 212 | 213 | delayed = self._delayed 214 | countdown = delayed._timeout_countdown - timed_out 215 | no_requeue_reason = "" 216 | if hasattr(delayed.function, "checkpoint"): 217 | no_requeue_reason = _checkpoint(delayed, self._job_paths.submitted_pickle, countdown) 218 | elif timed_out: 219 | no_requeue_reason = "timed-out and not checkpointable" 220 | if countdown < 0: # this is the end 221 | no_requeue_reason = "timed-out too many times" 222 | if no_requeue_reason: 223 | # raise an error so as to create "result_pickle" file which notifies the job is over 224 | # this is caught by the try/except in "process_job" 225 | message = f"Job not requeued because: {no_requeue_reason}." 226 | self._logger.info(message) 227 | raise utils.UncompletedJobError(message) 228 | # if everything went well, requeue! 229 | self.env._requeue(countdown) 230 | self._exit() 231 | 232 | # pylint:disable=unused-argument 233 | def checkpoint_and_exit(self, signum: int, frame: tp.Optional[types.FrameType] = None) -> None: 234 | # Note: no signal is actually bound to `checkpoint_and_exit` but this is used by plugins. 235 | self._logger.info(f"Caught signal {signal.Signals(signum).name} on {socket.gethostname()}") 236 | 237 | procid = self.env.global_rank 238 | if procid: 239 | self._logger.info(f"Not checkpointing since I am a slave (procid={procid}).") 240 | # do not sys.exit, because it might kill the master task 241 | return 242 | 243 | delayed = self._delayed 244 | if hasattr(delayed.function, "checkpoint"): 245 | _checkpoint(self._delayed, self._job_paths.submitted_pickle, self._delayed._timeout_countdown) 246 | self._exit() 247 | 248 | def _exit(self) -> None: 249 | # extracted for mocking 250 | self._logger.info("Exiting gracefully after preemption/timeout.") 251 | sys.exit(-1) 252 | 253 | 254 | def _checkpoint(delayed: DelayedSubmission, filepath: Path, countdown: int) -> str: 255 | """Call the checkpoint method and dump the updated delayed. 256 | 257 | Returns: 258 | -------- 259 | no_requeue_reason: str 260 | a string explaining while there was no requeuing, else empty string if requeuing works 261 | """ 262 | logger.get_logger().info("Calling checkpoint method.") 263 | ckpt_delayed = delayed._checkpoint_function() 264 | if ckpt_delayed is None: 265 | return "checkpoint function returned None" 266 | ckpt_delayed.set_timeout(delayed._timeout_min, countdown) 267 | with utils.temporary_save_path(filepath) as tmp: 268 | ckpt_delayed.dump(tmp) 269 | return "" # requeues 270 | -------------------------------------------------------------------------------- /docs/examples.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ## Explained example - Initial "add" exemple with a few more comments: 4 | ```python 5 | import submitit 6 | 7 | def add(a, b): 8 | return a + b 9 | 10 | # the AutoExecutor class is your interface for submitting function to a cluster or run them locally. 11 | # The specified folder is used to dump job information, logs and result when finished 12 | # %j is replaced by the job id at runtime 13 | log_folder = "log_test/%j" 14 | executor = submitit.AutoExecutor(folder=log_folder) 15 | # The AutoExecutor provides a simple abstraction over SLURM to simplify switching between local and slurm jobs (or other clusters if plugins are available). 16 | # specify sbatch parameters (here it will timeout after 4min, and run on dev) 17 | # This is where you would specify `gpus_per_node=1` for instance 18 | # Cluster specific options must be appended by the cluster name: 19 | # Eg.: slurm partition can be specified using `slurm_partition` argument. It 20 | # will be ignored on other clusters: 21 | executor.update_parameters(timeout_min=4, slurm_partition="dev") 22 | # The submission interface is identical to concurrent.futures.Executor 23 | job = executor.submit(add, 5, 7) # will compute add(5, 7) 24 | print(job.job_id) # ID of your job 25 | 26 | output = job.result() # waits for the submitted function to complete and returns its output 27 | # if ever the job failed, job.result() will raise an error with the corresponding trace 28 | assert output == 12 # 5 + 7 = 12... your addition was computed in the cluster 29 | ``` 30 | 31 | ## Job arrays 32 | 33 | `submitit` supports the submission of [Slurm job arrays](https://slurm.schedmd.com/job_array.html) through the `executor.map_array` method. 34 | 35 | If you want to submit many jobs at once, this is the **preferred way to go** because: 36 | - it can submit all jobs in only 1 call to slurm (avoids flooding it). 37 | - it is faster than submitting all jobs independently. 38 | - it lets you define a cap on how many jobs can run in parallel at any given time, so you can send thousands of jobs without breaking the scheduler, as long as you leave a reasonable value for this parallelism. 39 | 40 | Here is an example on how to submit 4 additions at once, with at most 2 jobs running in parallel at any given time: 41 | ```python 42 | a = [1, 2, 3, 4] 43 | b = [10, 20, 30, 40] 44 | executor = submitit.AutoExecutor(folder=log_folder) 45 | # the following line tells the scheduler to only run\ 46 | # at most 2 jobs at once. By default, this is several hundreds 47 | executor.update_parameters(slurm_array_parallelism=2) 48 | jobs = executor.map_array(add, a, b) # just a list of jobs 49 | ``` 50 | 51 | In comparison to standard jobs, job arrays have IDs like formatted as `_` (Eg: `17390420_15`) where the job id is 52 | common to all the submitted jobs, and the task id goes from 0 to the `N - 1` where `N` is the number of submitted jobs. 53 | 54 | **Note**: `map_array` has no equivalent in `concurent.futures` (`map` is similar but has a different return type) 55 | 56 | **Warning**: when running `map_array`, `submitit` will create one pickle per job. 57 | If you have big object in your functions (like a full pytorch model) you should serialize it once 58 | and only pass its path to the submitted function. 59 | 60 | ### Job arrays through a context manager 61 | 62 | If you submit multiple jobs through a `for` loop like this one: 63 | ```python 64 | jobs = [] 65 | for arg in whatever: 66 | job = executor.submit(myfunc, arg) 67 | jobs.append(job) 68 | ``` 69 | You can easily update it to batch the jobs into one array with exactly one extra line, by adding a batch context manager: 70 | ```python 71 | jobs = [] 72 | with executor.batch(): 73 | for arg in whatever: 74 | job = executor.submit(myfunc, arg) 75 | jobs.append(job) 76 | ``` 77 | This way, adding the `with` context to any existing code will convert it to an array submission, 78 | the submission being triggered when leaving the context. 79 | 80 | This allows to submit job arrays when the functions need many arguments and keywords arguments. 81 | 82 | **Disclaimers**: 83 | - within the context, you won't be allowed to interact with the jobs methods and attributes (nor even print it)! This is because the jobs are only submitted when leaving the context: inside the context, the jobs are like empty shells. You can however store the jobs in a list for instance, and access their attributes and methods after leaving the batch context. 84 | - within the context, you can't update the executor parameters either (since all jobs must be submitted with the same settings) 85 | - any error within the context will just cancel the whole submission. 86 | - this option is still experimental and may undergo some changes in the future. 87 | 88 | 89 | ## Concurrent jobs 90 | 91 | You can submit several jobs in parallel, and check their completion with the `done` method: 92 | ```python 93 | import submitit 94 | import time 95 | 96 | executor = submitit.AutoExecutor(folder="log_test") 97 | executor.update_parameters(timeout_min=1, slurm_partition="dev") 98 | jobs = [executor.submit(time.sleep, k) for k in range(1, 11)] 99 | 100 | # wait and check how many have finished 101 | time.sleep(5) 102 | num_finished = sum(job.done() for job in jobs) 103 | print(num_finished) # probably around 2 have finished, given the overhead 104 | 105 | # then you may want to wait until all jobs are completed: 106 | outputs = [job.result() for job in jobs] 107 | ``` 108 | 109 | Notice that this is straightforward to convert to multi-threading: 110 | ```python 111 | import time 112 | from concurrent import futures 113 | with futures.ThreadPoolExecutor(max_workers=10) as executor: # This is the only real difference 114 | jobs = [executor.submit(time.sleep, k) for k in range(1, 11)] 115 | time.sleep(5) 116 | print(sum(job.done() for job in jobs)) # around 4 or 5 should be over 117 | [job.result() for job in jobs] 118 | assert sum(job.done() for job in jobs) == 10 # all done 119 | ``` 120 | 121 | ## Asyncio 122 | 123 | You can also use the asyncio coroutines if you want 124 | 125 | ```python 126 | import asyncio 127 | import random 128 | import submitit 129 | import time 130 | 131 | def slow_multiplication(x, y): 132 | time.sleep(x*y) 133 | return x*y 134 | 135 | executor = submitit.AutoExecutor(folder="log_test") 136 | executor.update_parameters(timeout_min=1, slurm_partition="dev") 137 | 138 | # await a single result 139 | job = executor.submit(slow_multiplication, 10, 2) 140 | await job.awaitable().result() 141 | 142 | # print results as they become available 143 | jobs = [executor.submit(slow_multiplication, k, random.randint(1, 4)) for k in range(1, 5)] 144 | for aws in asyncio.as_completed([j.awaitable().result() for j in jobs]): 145 | result = await aws 146 | print(result) 147 | ``` 148 | 149 | Note that you can also use `submitit.helpers.as_completed` if you don't want to use coroutines 150 | 151 | ## Errors 152 | 153 | Errors are caught and their stacktrace is recorded. When calling `job.result()`, a `FailedJobError` is raised with the available information: 154 | ```python 155 | import submitit 156 | from operator import truediv 157 | 158 | executor = submitit.AutoExecutor(folder="log_test") 159 | executor.update_parameters(timeout_min=1, slurm_partition="dev") 160 | job = executor.submit(truediv, 1, 0) 161 | 162 | job.result() # will raise a FailedJobError stating the ZeroDivisionError with its stacktrace 163 | full_stderr = job.stderr() # recover the full stack trace if need be 164 | # the stderr log is written in file job.get_logs_path("stderr") 165 | ``` 166 | 167 | 168 | ## Working with commands 169 | 170 | You should preferably submit pure Python function whenever you can. This would probably save you a lot of hassle. 171 | Still, this is not always feasible. The class `submitit.helpers.CommandFunction` can help you in this case. It runs a 172 | command in a subprocess and returns its stdout. It's main benefit is to be able to deal with errors and provide explicit errors. 173 | (Note: `CommandFunction` runs locally, so you still need to submit it with an `Executor` 174 | if you want to run it on slurm, see "Understanding the environment" below). 175 | Note however that, because we use `subprocess` with `shell=False` under the hood, the command must be provided as a list and not a string. 176 | 177 | 178 | By default, the function hides stdout and returns it at the end: 179 | ```python 180 | import submitit 181 | function = submitit.helpers.CommandFunction(["which", "python"]) # commands must be provided as a list! 182 | print(function()) # This returns your python path (which you be inside your virtualenv) 183 | # for me: /private/home/jrapin/.conda/envs/dfconda/bin/python 184 | ``` 185 | 186 | Some useful parameters of the `CommandFunction` class: 187 | - `cwd`: to choose from which directory the command is run. 188 | - `env`: to provide specific environment variables. 189 | - `verbose`: set to `False` if you do not want any logging. 190 | 191 | As an experimental feature, you can also provide arguments when calling the instance: 192 | ```python 193 | print(submitit.helpers.CommandFunction(["which"])("pip")) # will run "which pip" 194 | ``` 195 | 196 | 197 | **Understanding the environment** - Make sure you have everything you need installed in your conda environment. Indeed, for its computation, Slurm uses 198 | the active conda environment to submit your job: 199 | ```python 200 | import submitit 201 | function = submitit.helpers.CommandFunction(["which", "python"]) 202 | executor = submitit.AutoExecutor(folder="log_test") 203 | executor.update_parameters(timeout_min=1, slurm_partition="dev") 204 | job = executor.submit(function) 205 | 206 | # The returned python path is the one used in slurm. 207 | # It should be the same as when running out of slurm! 208 | # This means that everything that is installed in your 209 | # conda environment should work just as well in the cluster 210 | print(job.result()) 211 | ``` 212 | 213 | 214 | ## Multi-tasks jobs 215 | 216 | `submitit` support multi-tasks jobs (on one or several nodes). 217 | You just need to use the `tasks_per_node` and `nodes` parameters. 218 | 219 | ```python 220 | import submitit 221 | from operator import add 222 | executor = submitit.AutoExecutor(folder="log_test") 223 | # 3 * 2 = 6 tasks 224 | executor.update_parameters(tasks_per_node=3, nodes=2, timeout_min=1, slurm_partition="dev") 225 | job = executor.submit(add, 5, 7) # will compute add(5, 7) 226 | print(job.result()) # return [12, 12, 12, 12, 12, 12] 227 | ``` 228 | 229 | The same method will be executed in each task. 230 | The typical usage is to use the task rank inside your submitted Callable to chunk the inputs, and attribute one chunk to each task. 231 | 232 | We provide a `JobEnvironment` class, that gives access to this information (in a cluster-agnostic way). 233 | ```python 234 | import submitit 235 | from math import ceil 236 | 237 | def my_func(inputs): 238 | job_env = submitit.JobEnvironment() 239 | print(f"There are {job_env.num_tasks} in this job") 240 | print(f"I'm the task #{job_env.local_rank} on the node {job_env.node}") 241 | print(f"I'm the task #{job_env.global_rank} in the job") 242 | num_items_per_task = int(ceil(len(inputs) / job_env.num_tasks)) 243 | r = job_env.local_rank 244 | task_chunk = inputs[r * num_items_per_task: (r + 1) * num_items_per_task] 245 | return process(task_chunk) # process only this chunk. 246 | ``` 247 | 248 | You can use the `task` method of a `Job` instance to access task specific information. A task is also a Job, so the Job's methods are available. 249 | 250 | ```python 251 | import submitit 252 | 253 | from operator import add 254 | executor = submitit.AutoExecutor(folder="log_test") 255 | # 3 * 2 = 6 tasks 256 | executor.update_parameters(tasks_per_node=3, nodes=2, timeout_min=1, slurm_partition="dev") 257 | job = executor.submit(add, 5, 7) # will compute add(5, 7) 258 | print(job.task(2).result()) # Wait for task #2 result 259 | print(job.task(2).stdout()) # Show task # stdout 260 | print(job.result()) # Wait for all tasks and returns a list of results. 261 | print(job.stdout()) # Concatenated stdout of all tasks 262 | ``` 263 | 264 | ## PyTorch distributed initialization 265 | 266 | Call the `export()` method of the `submitit.helpers.TorchDistributedEnvironment` class to setup all the required environment variables for PyTorch distributed with the `env://` initialization method. See [this code example](examples/torch_distributed.py). 267 | 268 | ## Even more examples 269 | 270 | TODO: share more examples, eg grid search over CIFAR-10 271 | -------------------------------------------------------------------------------- /submitit/core/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import contextlib 8 | import io 9 | import itertools 10 | import os 11 | import pickle 12 | import select 13 | import shutil 14 | import subprocess 15 | import sys 16 | import tarfile 17 | import typing as tp 18 | from pathlib import Path 19 | 20 | import cloudpickle 21 | 22 | 23 | @contextlib.contextmanager 24 | def environment_variables(**kwargs: tp.Any) -> tp.Iterator[None]: 25 | backup = {x: os.environ[x] for x in kwargs if x in os.environ} 26 | os.environ.update({x: str(y) for x, y in kwargs.items()}) 27 | try: 28 | yield 29 | finally: 30 | for x in kwargs: 31 | del os.environ[x] 32 | os.environ.update(backup) 33 | 34 | 35 | class UncompletedJobError(RuntimeError): 36 | """Job is uncomplete: either unfinished or failed""" 37 | 38 | 39 | class FailedJobError(UncompletedJobError): 40 | """Job failed during processing""" 41 | 42 | 43 | class FailedSubmissionError(RuntimeError): 44 | """Job Submission failed""" 45 | 46 | 47 | class JobPaths: 48 | """Creates paths related to the slurm job and its submission""" 49 | 50 | def __init__( 51 | self, folder: tp.Union[Path, str], job_id: tp.Optional[str] = None, task_id: tp.Optional[int] = None 52 | ) -> None: 53 | self._folder = Path(folder).expanduser().absolute() 54 | self.job_id = job_id 55 | self.task_id = task_id or 0 56 | 57 | @property 58 | def folder(self) -> Path: 59 | return self._format_id(self._folder) 60 | 61 | @property 62 | def submission_file(self) -> Path: 63 | if self.job_id and "_" in self.job_id: 64 | # We only have one submission file per job array 65 | return self._format_id(self.folder / "%A_submission.sh") 66 | return self._format_id(self.folder / "%j_submission.sh") 67 | 68 | @property 69 | def submitted_pickle(self) -> Path: 70 | return self._format_id(self.folder / "%j_submitted.pkl") 71 | 72 | @property 73 | def result_pickle(self) -> Path: 74 | return self._format_id(self.folder / "%j_%t_result.pkl") 75 | 76 | @property 77 | def stderr(self) -> Path: 78 | return self._format_id(self.folder / "%j_%t_log.err") 79 | 80 | @property 81 | def stdout(self) -> Path: 82 | return self._format_id(self.folder / "%j_%t_log.out") 83 | 84 | def _format_id(self, path: tp.Union[Path, str]) -> Path: 85 | """Replace id tag by actual id if available""" 86 | if self.job_id is None: 87 | return Path(path) 88 | replaced_path = str(path).replace("%j", str(self.job_id)).replace("%t", str(self.task_id)) 89 | array_id, *array_index = str(self.job_id).split("_", 1) 90 | if "%a" in replaced_path: 91 | if len(array_index) != 1: 92 | raise ValueError("%a is in the folder path but this is not a job array") 93 | replaced_path = replaced_path.replace("%a", array_index[0]) 94 | return Path(replaced_path.replace("%A", array_id)) 95 | 96 | def move_temporary_file( 97 | self, tmp_path: tp.Union[Path, str], name: str, keep_as_symlink: bool = False 98 | ) -> None: 99 | self.folder.mkdir(parents=True, exist_ok=True) 100 | Path(tmp_path).rename(getattr(self, name)) 101 | if keep_as_symlink: 102 | Path(tmp_path).symlink_to(getattr(self, name)) 103 | 104 | @staticmethod 105 | def get_first_id_independent_folder(folder: tp.Union[Path, str]) -> Path: 106 | """Returns the closest folder which is id independent""" 107 | parts = Path(folder).expanduser().absolute().parts 108 | tags = ["%j", "%t", "%A", "%a"] 109 | indep_parts = itertools.takewhile(lambda x: not any(tag in x for tag in tags), parts) 110 | return Path(*indep_parts) 111 | 112 | def __repr__(self) -> str: 113 | return f"{self.__class__.__name__}({self.folder})" 114 | 115 | 116 | class DelayedSubmission: 117 | """Object for specifying the function/callable call to submit and process later. 118 | This is only syntactic sugar to make sure everything is well formatted: 119 | If what you want to compute later is func(*args, **kwargs), just instanciate: 120 | DelayedSubmission(func, *args, **kwargs). 121 | It also provides convenient tools for dumping and loading. 122 | """ 123 | 124 | def __init__(self, function: tp.Callable[..., tp.Any], *args: tp.Any, **kwargs: tp.Any) -> None: 125 | self.function = function 126 | self.args = args 127 | self.kwargs = kwargs 128 | self._result: tp.Any = None 129 | self._done = False 130 | self._timeout_min: int = 0 131 | self._timeout_countdown: int = 0 # controlled in submission and execution 132 | 133 | def result(self) -> tp.Any: 134 | if self._done: 135 | return self._result 136 | 137 | self._result = self.function(*self.args, **self.kwargs) 138 | self._done = True 139 | return self._result 140 | 141 | def done(self) -> bool: 142 | return self._done 143 | 144 | def dump(self, filepath: tp.Union[str, Path]) -> None: 145 | cloudpickle_dump(self, filepath) 146 | 147 | def set_timeout(self, timeout_min: int, max_num_timeout: int) -> None: 148 | self._timeout_min = timeout_min 149 | self._timeout_countdown = max_num_timeout 150 | 151 | @classmethod 152 | def load(cls: tp.Type["DelayedSubmission"], filepath: tp.Union[str, Path]) -> "DelayedSubmission": 153 | obj = pickle_load(filepath) 154 | # following assertion is relaxed compared to isinstance, to allow flexibility 155 | # (Eg: copying this class in a project to be able to have checkpointable jobs without adding submitit as dependency) 156 | assert obj.__class__.__name__ == cls.__name__, f"Loaded object is {type(obj)} but should be {cls}." 157 | return obj # type: ignore 158 | 159 | def _checkpoint_function(self) -> tp.Optional["DelayedSubmission"]: 160 | checkpoint = getattr(self.function, "__submitit_checkpoint__", None) 161 | if checkpoint is None: 162 | checkpoint = getattr(self.function, "checkpoint", None) 163 | if checkpoint is None: 164 | return None 165 | return checkpoint(*self.args, **self.kwargs) # type: ignore 166 | 167 | 168 | @contextlib.contextmanager 169 | def temporary_save_path(filepath: tp.Union[Path, str]) -> tp.Iterator[Path]: 170 | """Yields a path where to save a file and moves it 171 | afterward to the provided location (and replaces any 172 | existing file) 173 | This is useful to avoid processes monitoring the filepath 174 | to break if trying to read when the file is being written. 175 | 176 | Note 177 | ---- 178 | The temporary path is the provided path appended with .save_tmp 179 | """ 180 | filepath = Path(filepath) 181 | tmppath = filepath.with_suffix(filepath.suffix + ".save_tmp") 182 | assert not tmppath.exists(), "A temporary saved file already exists." 183 | yield tmppath 184 | if not tmppath.exists(): 185 | raise FileNotFoundError("No file was saved at the temporary path.") 186 | if filepath.exists(): 187 | os.remove(filepath) 188 | os.rename(tmppath, filepath) 189 | 190 | 191 | def archive_dev_folders( 192 | folders: tp.List[tp.Union[str, Path]], outfile: tp.Optional[tp.Union[str, Path]] = None 193 | ) -> Path: 194 | """Creates a tar.gz file with all provided folders""" 195 | assert isinstance(folders, (list, tuple)), "Only lists and tuples of folders are allowed" 196 | if outfile is None: 197 | outfile = "_dev_folders_.tar.gz" 198 | outfile = Path(outfile) 199 | assert str(outfile).endswith(".tar.gz"), "Archive file must have extension .tar.gz" 200 | with tarfile.TarFile(outfile, mode="w") as tf: 201 | for folder in folders: 202 | tf.add(str(folder), arcname=Path(folder).name) 203 | return outfile 204 | 205 | 206 | def copy_par_file(par_file: tp.Union[str, Path], folder: tp.Union[str, Path]) -> Path: 207 | """Copy the par (or xar) file in the folder 208 | 209 | Parameter 210 | --------- 211 | par_file: str/Path 212 | Par file generated by buck 213 | folder: str/Path 214 | folder where the par file must be copied 215 | 216 | Returns 217 | ------- 218 | Path 219 | Path of the copied .par file 220 | """ 221 | par_file = Path(par_file).expanduser().absolute() 222 | folder = Path(folder).expanduser().absolute() 223 | folder.mkdir(parents=True, exist_ok=True) 224 | dst_name = folder / par_file.name 225 | shutil.copy2(par_file, dst_name) 226 | return dst_name 227 | 228 | 229 | def pickle_load(filename: tp.Union[str, Path]) -> tp.Any: 230 | # this is used by cloudpickle as well 231 | with open(filename, "rb") as ifile: 232 | return pickle.load(ifile) 233 | 234 | 235 | def cloudpickle_dump(obj: tp.Any, filename: tp.Union[str, Path]) -> None: 236 | with open(filename, "wb") as ofile: 237 | cloudpickle.dump(obj, ofile, pickle.HIGHEST_PROTOCOL) 238 | 239 | 240 | # pylint: disable=too-many-locals 241 | def copy_process_streams( 242 | process: subprocess.Popen, stdout: io.StringIO, stderr: io.StringIO, verbose: bool = False 243 | ): 244 | """ 245 | Reads the given process stdout/stderr and write them to StringIO objects. 246 | Make sure that there is no deadlock because of pipe congestion. 247 | If `verbose` the process stdout/stderr are also copying to the interpreter stdout/stderr. 248 | """ 249 | 250 | def raw(stream: tp.Optional[tp.IO[bytes]]) -> tp.IO[bytes]: 251 | if stream is None: 252 | raise RuntimeError("Stream should not be None") 253 | if isinstance(stream, io.BufferedIOBase): 254 | stream = stream.raw # type: ignore 255 | return stream # type: ignore 256 | 257 | p_stdout, p_stderr = raw(process.stdout), raw(process.stderr) 258 | stream_by_fd: tp.Dict[int, tp.Tuple[tp.IO[bytes], io.StringIO, tp.IO[str]]] = { 259 | p_stdout.fileno(): (p_stdout, stdout, sys.stdout), 260 | p_stderr.fileno(): (p_stderr, stderr, sys.stderr), 261 | } 262 | fds = list(stream_by_fd.keys()) 263 | poller = select.poll() 264 | for fd in stream_by_fd: 265 | poller.register(fd, select.POLLIN | select.POLLPRI) 266 | while fds: 267 | # `poll` syscall will wait until one of the registered file descriptors has content. 268 | ready = poller.poll() 269 | for fd, _ in ready: 270 | p_stream, string, std = stream_by_fd[fd] 271 | raw_buf = p_stream.read(2**16) 272 | if not raw_buf: 273 | fds.remove(fd) 274 | poller.unregister(fd) 275 | continue 276 | buf = raw_buf.decode() 277 | string.write(buf) 278 | string.flush() 279 | if verbose: 280 | std.write(buf) 281 | std.flush() 282 | 283 | 284 | # used in "_core", so cannot be in "helpers" 285 | class CommandFunction: 286 | """Wraps a command as a function in order to make sure it goes through the 287 | pipeline and notify when it is finished. 288 | The output is a string containing everything that has been sent to stdout. 289 | WARNING: use CommandFunction only if you know the output won't be too big ! 290 | Otherwise use subprocess.run() that also streams the outputto stdout/stderr. 291 | 292 | Parameters 293 | ---------- 294 | command: list 295 | command to run, as a list 296 | verbose: bool 297 | prints the command and stdout at runtime 298 | cwd: Path/str 299 | path to the location where the command must run from 300 | 301 | Returns 302 | ------- 303 | str 304 | Everything that has been sent to stdout 305 | """ 306 | 307 | def __init__( 308 | self, 309 | command: tp.List[str], 310 | verbose: bool = True, 311 | cwd: tp.Optional[tp.Union[str, Path]] = None, 312 | env: tp.Optional[tp.Dict[str, str]] = None, 313 | ) -> None: 314 | if not isinstance(command, list): 315 | raise TypeError("The command must be provided as a list") 316 | self.command = command 317 | self.verbose = verbose 318 | self.cwd = None if cwd is None else str(cwd) 319 | self.env = env 320 | 321 | def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> str: 322 | """Call the cammand line with addidional arguments 323 | The keyword arguments will be sent as --{key}={val} 324 | The logs bufferized. They will be printed if the job fails, or sent as output of the function 325 | Errors are provided with the internal stderr. 326 | """ 327 | full_command = ( 328 | self.command + [str(x) for x in args] + [f"--{x}={y}" for x, y in kwargs.items()] 329 | ) # TODO bad parsing 330 | if self.verbose: 331 | print(f"The following command is sent: \"{' '.join(full_command)}\"") 332 | with subprocess.Popen( 333 | full_command, 334 | stdout=subprocess.PIPE, 335 | stderr=subprocess.PIPE, 336 | shell=False, 337 | cwd=self.cwd, 338 | env=self.env, 339 | ) as process: 340 | stdout_buffer = io.StringIO() 341 | stderr_buffer = io.StringIO() 342 | 343 | try: 344 | copy_process_streams(process, stdout_buffer, stderr_buffer, self.verbose) 345 | except Exception as e: 346 | process.kill() 347 | process.wait() 348 | raise FailedJobError("Job got killed for an unknown reason.") from e 349 | stdout = stdout_buffer.getvalue().strip() 350 | stderr = stderr_buffer.getvalue().strip() 351 | retcode = process.wait() 352 | if stderr and (retcode and not self.verbose): 353 | # We don't print is self.verbose, as it already happened before. 354 | print(stderr, file=sys.stderr) 355 | if retcode: 356 | subprocess_error = subprocess.CalledProcessError( 357 | retcode, process.args, output=stdout, stderr=stderr 358 | ) 359 | raise FailedJobError(stderr) from subprocess_error 360 | return stdout 361 | -------------------------------------------------------------------------------- /submitit/local/local.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import os 8 | import shlex 9 | import signal 10 | import subprocess 11 | import sys 12 | import time 13 | import typing as tp 14 | from pathlib import Path 15 | 16 | from ..core import core, job_environment, logger, utils 17 | from ..core.core import R 18 | 19 | # pylint: disable-msg=too-many-arguments 20 | # VALID_KEYS = {"timeout_min", "gpus_per_node", "tasks_per_node", "signal_delay_s", "visible_gpus", "setup"} 21 | 22 | LOCAL_REQUEUE_RETURN_CODE = 144 23 | 24 | # global variable storing unfinished processes of pickled jobs 25 | # in case we need to reload them later 26 | _PROCESSES: tp.Dict[str, "subprocess.Popen['bytes']"] = {} 27 | 28 | 29 | class LocalJob(core.Job[R]): 30 | def __init__( 31 | self, 32 | folder: tp.Union[Path, str], 33 | job_id: str, 34 | tasks: tp.Sequence[int] = (0,), 35 | process: tp.Optional["subprocess.Popen['bytes']"] = None, 36 | ) -> None: 37 | super().__init__(folder, job_id, tasks) 38 | self._cancel_at_deletion = False 39 | # downcast sub-jobs to get proper typing 40 | self._sub_jobs: tp.Sequence["LocalJob[R]"] = self._sub_jobs 41 | # set process (to self and subjobs) 42 | self._process = process 43 | for sjob in self._sub_jobs: 44 | sjob._process = process 45 | 46 | def done(self, force_check: bool = False) -> bool: # pylint: disable=unused-argument 47 | """Override to avoid using the watcher""" 48 | state = self.get_info()["jobState"] 49 | return state != "RUNNING" 50 | 51 | @property 52 | def state(self) -> str: 53 | """State of the job""" 54 | try: 55 | return self.get_info().get("jobState", "unknown") 56 | # I don't what is the exception returned and it's hard to reproduce 57 | except Exception: # pylint: disable=broad-except 58 | return "UNKNOWN" 59 | 60 | def get_info(self, mode: str = "force") -> tp.Dict[str, str]: # pylint: disable=unused-argument 61 | """Returns information about the job as a dict.""" 62 | if self._process is None: 63 | state = "NO PROCESS AND NO RESULT" 64 | if self.paths.result_pickle.exists(): 65 | state = "FINISHED" 66 | return {"jobState": state} 67 | poll = self._process.poll() 68 | if poll is None: 69 | state = "RUNNING" 70 | elif poll < 0: 71 | state = "INTERRUPTED" 72 | else: 73 | state = "FINISHED" 74 | return {"jobState": state} 75 | 76 | def cancel(self, check: bool = True) -> None: # pylint: disable=unused-argument 77 | if self._process is not None: 78 | self._process.send_signal(signal.SIGINT) 79 | 80 | def _interrupt(self) -> None: 81 | """Sends preemption / timeout signal to the job (for testing purpose)""" 82 | if self._process is not None: 83 | self._process.send_signal(LocalJobEnvironment._usr_sig()) 84 | 85 | def __del__(self) -> None: 86 | if self._cancel_at_deletion: 87 | if not self.get_info().get("jobState") == "FINISHED": 88 | self.cancel(check=False) 89 | # let's clear the process dict if we know it's finished 90 | if self.paths.result_pickle.exists(): 91 | _PROCESSES.pop(self.job_id, None) 92 | 93 | # # # # # pickling below # # # # # 94 | 95 | def __getstate__(self) -> tp.Any: 96 | out = dict(self.__dict__) 97 | out["_process"] = None 98 | if self._process is not None: 99 | _PROCESSES[self.job_id] = self._process 100 | return out 101 | 102 | def __setstate__(self, state: tp.Any) -> None: 103 | # Restore instance attributes 104 | self.__dict__.update(state) 105 | # recover process if it still exists 106 | self._process = _PROCESSES.get(self.job_id, None) 107 | 108 | 109 | class LocalJobEnvironment(job_environment.JobEnvironment): 110 | _env = { 111 | "job_id": "SUBMITIT_LOCAL_JOB_ID", 112 | "num_tasks": "SUBMITIT_LOCAL_NTASKS", 113 | "num_nodes": "SUBMITIT_LOCAL_JOB_NUM_NODES", 114 | "node": "SUBMITIT_LOCAL_NODEID", 115 | "global_rank": "SUBMITIT_LOCAL_GLOBALID", 116 | "local_rank": "SUBMITIT_LOCAL_LOCALID", 117 | } 118 | 119 | def _requeue(self, countdown: int) -> None: 120 | jid = self.job_id 121 | logger.get_logger().info(f"Requeued job {jid} ({countdown} remaining timeouts)") 122 | sys.exit(LOCAL_REQUEUE_RETURN_CODE) # should help noticing if need requeuing 123 | 124 | 125 | class LocalExecutor(core.PicklingExecutor): 126 | """Local job executor 127 | This class is used to hold the parameters to run a job locally. 128 | In practice, it will create a bash file in the specified directory for each job, 129 | and pickle the task function and parameters. At completion, the job will also pickle 130 | the output. Logs are also dumped in the same directory. 131 | 132 | The submission file spawn several processes (one per task), with a timeout. 133 | 134 | 135 | Parameters 136 | ---------- 137 | folder: Path/str 138 | folder for storing job submission/output and logs. 139 | 140 | Note 141 | ---- 142 | - be aware that the log/output folder will be full of logs and pickled objects very fast, 143 | it may need cleaning. 144 | - use update_parameters to specify custom parameters (n_gpus etc...). 145 | """ 146 | 147 | job_class = LocalJob 148 | 149 | def __init__( 150 | self, 151 | folder: tp.Union[str, Path], 152 | max_num_timeout: int = 3, 153 | max_pickle_size_gb: float = 1.0, 154 | python: tp.Optional[str] = None, 155 | ) -> None: 156 | super().__init__( 157 | folder, 158 | max_pickle_size_gb=max_pickle_size_gb, 159 | max_num_timeout=max_num_timeout, 160 | ) 161 | self.python = shlex.quote(sys.executable) if python is None else python 162 | # preliminary check 163 | indep_folder = utils.JobPaths.get_first_id_independent_folder(self.folder) 164 | indep_folder.mkdir(parents=True, exist_ok=True) 165 | 166 | @classmethod 167 | def _valid_parameters(cls) -> tp.Set[str]: 168 | """Parameters that can be set through update_parameters""" 169 | return {"setup"} 170 | 171 | def _internal_update_parameters(self, **kwargs: tp.Any) -> None: 172 | """Update the parameters of the Executor. 173 | 174 | Valid parameters are: 175 | - timeout_min (float) 176 | - gpus_per_node (int) 177 | - visible_gpus (Sequence[int]) 178 | - tasks_per_node (int) 179 | - nodes (int). Must be 1 if specified 180 | - signal_delay_s (int): signal (lately: USR2) delay before timeout 181 | 182 | Other parameters are ignored 183 | """ 184 | if kwargs.get("nodes", 0) > 1: 185 | raise ValueError("LocalExecutor can use only one node. Use nodes=1") 186 | gpus_requested = kwargs.get("gpus_per_node", 0) 187 | visible_gpus = kwargs.get("visible_gpus", ()) 188 | if not isinstance(visible_gpus, tp.Sequence): 189 | raise ValueError(f"Provided visible_gpus={visible_gpus} is not an instance of Sequence.") 190 | if not all(isinstance(x, int) for x in visible_gpus): 191 | raise ValueError(f"Provided visible_gpus={visible_gpus} contains an element that is not an int.") 192 | if len(visible_gpus) > 0 and gpus_requested > len(visible_gpus): 193 | raise ValueError( 194 | f"{gpus_requested} gpus requested, but only {visible_gpus} were specified visible." 195 | ) 196 | super()._internal_update_parameters(**kwargs) 197 | 198 | def _submit_command(self, command: str) -> LocalJob[R]: 199 | # Override this, because the implementation is simpler than for clusters like Slurm 200 | # Only one node is supported for local executor. 201 | ntasks = self.parameters.get("tasks_per_node", 1) 202 | n_gpus = self.parameters.get("gpus_per_node", 0) 203 | visible_gpus = self.parameters.get("visible_gpus", ()) 204 | gpus = range(n_gpus) if visible_gpus == () else visible_gpus[:n_gpus] 205 | process = start_controller( 206 | folder=self.folder, 207 | command=command, 208 | tasks_per_node=ntasks, 209 | cuda_devices=",".join(str(k) for k in gpus), 210 | timeout_min=self.parameters.get("timeout_min", 2.0), 211 | signal_delay_s=self.parameters.get("signal_delay_s", 30), 212 | stderr_to_stdout=self.parameters.get("stderr_to_stdout", False), 213 | setup=self.parameters.get("setup", ()), 214 | ) 215 | job: LocalJob[R] = LocalJob( 216 | folder=self.folder, job_id=str(process.pid), process=process, tasks=list(range(ntasks)) 217 | ) 218 | return job 219 | 220 | @property 221 | def _submitit_command_str(self) -> str: 222 | return " ".join([self.python, "-u -m submitit.core._submit", shlex.quote(str(self.folder))]) 223 | 224 | def _num_tasks(self) -> int: 225 | nodes: int = 1 226 | tasks_per_node: int = self.parameters.get("tasks_per_node", 1) 227 | return nodes * tasks_per_node 228 | 229 | def _make_submission_file_text(self, command: str, uid: str) -> str: 230 | return "" 231 | 232 | @staticmethod 233 | def _get_job_id_from_submission_command(string: tp.Union[bytes, str]) -> str: 234 | # Not used, but need an implementation 235 | return "0" 236 | 237 | def _make_submission_command(self, submission_file_path: Path) -> tp.List[str]: 238 | # Not used, but need an implementation 239 | return [] 240 | 241 | 242 | def start_controller( 243 | folder: Path, 244 | command: str, 245 | tasks_per_node: int = 1, 246 | cuda_devices: str = "", 247 | timeout_min: float = 5.0, 248 | signal_delay_s: int = 30, 249 | stderr_to_stdout: bool = False, 250 | setup: tp.Sequence[str] = (), 251 | ) -> "subprocess.Popen['bytes']": 252 | """Starts a job controller, which is expected to survive the end of the python session.""" 253 | env = dict(os.environ) 254 | env.update( 255 | SUBMITIT_LOCAL_NTASKS=str(tasks_per_node), 256 | SUBMITIT_LOCAL_COMMAND=command, 257 | SUBMITIT_LOCAL_TIMEOUT_S=str(int(60 * timeout_min)), 258 | SUBMITIT_LOCAL_SIGNAL_DELAY_S=str(int(signal_delay_s)), 259 | SUBMITIT_LOCAL_NODEID="0", 260 | SUBMITIT_LOCAL_JOB_NUM_NODES="1", 261 | SUBMITIT_STDERR_TO_STDOUT="1" if stderr_to_stdout else "", 262 | SUBMITIT_EXECUTOR="local", 263 | CUDA_VISIBLE_DEVICES=cuda_devices, 264 | SUBMITIT_LOCAL_WITH_SHELL="1" if setup else "", 265 | ) 266 | # The LocalJob will be responsible to polling and ending this process. 267 | # pylint: disable=consider-using-with 268 | proc_cmd: tp.Any = [sys.executable, "-m", "submitit.local._local", str(folder)] 269 | need_shell = bool(setup) 270 | if need_shell: 271 | proc_cmd = " && ".join(list(setup) + [shlex.join(proc_cmd)]) 272 | process = subprocess.Popen(proc_cmd, shell=need_shell, env=env) 273 | return process 274 | 275 | 276 | class Controller: # pragma: no cover 277 | """This controls a job: 278 | - instantiate each of the tasks 279 | - sends timeout signal 280 | - stops all tasks if one of them finishes 281 | - cleans up the tasks/closes log files when deleted 282 | """ 283 | 284 | # pylint: disable=too-many-instance-attributes 285 | 286 | def __init__(self, folder: Path): 287 | self.ntasks = int(os.environ["SUBMITIT_LOCAL_NTASKS"]) 288 | self.command = shlex.split(os.environ["SUBMITIT_LOCAL_COMMAND"]) 289 | self.timeout_s = int(os.environ["SUBMITIT_LOCAL_TIMEOUT_S"]) 290 | self.signal_delay_s = int(os.environ["SUBMITIT_LOCAL_SIGNAL_DELAY_S"]) 291 | self.stderr_to_stdout = bool(os.environ["SUBMITIT_STDERR_TO_STDOUT"]) 292 | self.tasks: tp.List[subprocess.Popen] = [] # type: ignore 293 | self.stdouts: tp.List[tp.IO[tp.Any]] = [] 294 | self.stderrs: tp.List[tp.IO[tp.Any]] = [] 295 | with_shell = bool(os.environ["SUBMITIT_LOCAL_WITH_SHELL"]) 296 | self.pid = str(os.getppid() if with_shell else os.getpid()) 297 | self.folder = Path(folder) 298 | signal.signal(signal.SIGTERM, self._forward_signal) # type: ignore 299 | 300 | # pylint:disable=unused-argument 301 | def _forward_signal(self, signum: signal.Signals, *args: tp.Any) -> None: 302 | for task in self.tasks: 303 | try: 304 | task.send_signal(signum) # sending kill signal to make sure everything finishes 305 | except Exception: 306 | pass 307 | 308 | def start_tasks(self) -> None: 309 | self.folder.mkdir(exist_ok=True) 310 | paths = [utils.JobPaths(self.folder, self.pid, k) for k in range(self.ntasks)] 311 | self.stdouts = [p.stdout.open("a") for p in paths] 312 | self.stderrs = self.stdouts if self.stderr_to_stdout else [p.stderr.open("a") for p in paths] 313 | for k in range(self.ntasks): 314 | env = dict(os.environ) 315 | env.update( 316 | SUBMITIT_LOCAL_LOCALID=str(k), SUBMITIT_LOCAL_GLOBALID=str(k), SUBMITIT_LOCAL_JOB_ID=self.pid 317 | ) 318 | self.tasks.append( 319 | subprocess.Popen( # pylint: disable=consider-using-with 320 | self.command, 321 | shell=False, 322 | env=env, 323 | stderr=self.stderrs[k], 324 | stdout=self.stdouts[k], 325 | encoding="utf-8", 326 | ) 327 | ) 328 | 329 | def kill_tasks(self) -> None: 330 | # try and be progressive in deletion... 331 | for sig in [signal.SIGINT, signal.SIGKILL]: 332 | self._forward_signal(sig) 333 | # if one is still alive after sigterm and sigint, try sigkill after 1s 334 | if sig == signal.SIGINT and any(t.poll() is None for t in self.tasks): 335 | time.sleep(0.001) 336 | if any(t.poll() is None for t in self.tasks): 337 | time.sleep(1.0) # wait a bit more 338 | self.tasks = [] 339 | files = self.stdouts + self.stderrs 340 | self.stdouts, self.stderrs = [], [] # remove all instance references 341 | for f in files: 342 | f.close() 343 | 344 | def wait(self, freq: int = 24) -> tp.Sequence[tp.Optional[int]]: 345 | """Waits for all tasks to finish or to time-out. 346 | 347 | Returns 348 | ------- 349 | Sequence[Optional[int]]: 350 | Exit codes of each task. 351 | Some tasks might still have not exited, but they will have received the "timed-out" signal. 352 | """ 353 | assert self.tasks, "Nothing to do!" 354 | timeout = freq * self.timeout_s 355 | almost_timeout = freq * (self.timeout_s - self.signal_delay_s) 356 | 357 | # safer to keep a for loop :) 358 | for step in range(timeout): 359 | exit_codes = [t.poll() for t in self.tasks] 360 | if all(e is not None for e in exit_codes): 361 | return exit_codes 362 | 363 | if step == almost_timeout: 364 | self._forward_signal(LocalJobEnvironment._usr_sig()) 365 | 366 | time.sleep(1.0 / freq) 367 | return [t.poll() for t in self.tasks] 368 | 369 | def run(self, max_retry: int = 6) -> None: 370 | # max_retry is a safety measure, the submission also have a timeout_countdown, 371 | # and will fail if it times out too many times. 372 | for _ in range(max_retry): 373 | try: 374 | self.start_tasks() 375 | exit_codes = self.wait() 376 | requeue = any(e == LOCAL_REQUEUE_RETURN_CODE for e in exit_codes) 377 | if not requeue: 378 | break 379 | finally: 380 | self.kill_tasks() 381 | -------------------------------------------------------------------------------- /submitit/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import collections 8 | import contextlib 9 | import datetime 10 | import itertools 11 | import os 12 | import random 13 | import shutil 14 | import subprocess 15 | import tempfile 16 | import time 17 | import typing as tp 18 | from pathlib import Path 19 | 20 | # pylint: disable=unused-import 21 | # import DelayedSubmission and CommandFunction to populate helpers namespace 22 | from .core import core 23 | from .core.job_environment import JobEnvironment 24 | from .core.utils import CommandFunction as CommandFunction # noqa 25 | from .core.utils import DelayedSubmission as DelayedSubmission # noqa 26 | from .core.utils import environment_variables as environment_variables # noqa 27 | 28 | 29 | class Checkpointable: 30 | """Derived callable classes are requeued after timeout with their current 31 | state dumped at checkpoint. 32 | 33 | __call__ method must be implemented to make your class a callable. 34 | 35 | Note 36 | ---- 37 | The following implementation of the checkpoint method resubmits the full current 38 | state of the callable (self) with the initial argument. You may want to replace the method to 39 | curate the state (dump a neural network to a standard format and remove it from 40 | the state so that not to pickle it) and change/remove the initial parameters. 41 | """ 42 | 43 | # pylint: disable=unused-argument 44 | def __new__(cls, *args, **kwargs): 45 | instance = super().__new__(cls) 46 | assert callable( 47 | instance 48 | ), f"Class {cls.__name__} is marked as Checkpointable but doesn't have a __call__ method. Please add a __call__ method." 49 | return instance 50 | 51 | def checkpoint(self, *args: tp.Any, **kwargs: tp.Any) -> DelayedSubmission: 52 | """Resubmits the same callable with the same arguments""" 53 | # The DelayedSubmission class goal is only to register and format 54 | # the arguments of the call "self(*args, **kwargs)" for submission to slurm 55 | return DelayedSubmission(self, *args, **kwargs) # type: ignore 56 | 57 | 58 | class FunctionSequence(Checkpointable): 59 | """This is for gathering several estimations into one function, which 60 | will return the sequence of outputs. 61 | Also this "function" is stateful, hence it can be stopped, and recovered, 62 | which is useful when job can be preempted. 63 | 64 | Usage 65 | ----- 66 | func = FunctionSequence() 67 | func.add(my_function1, arg1, kwarg1=value_kwarg1) 68 | func.add(my_function2, arg1, arg2) 69 | result1, result2 = func() 70 | 71 | Note 72 | ---- 73 | This function is checkpointable because: 74 | - it derives from Checkpointable 75 | - it keeps DelayedSubmission objects as attribute, which in turn store the 76 | results of the computation in memory once they are computed. So at checkpoint 77 | time, those results will be saved, and only the non-computed results 78 | will be computed once the job restarts. 79 | """ 80 | 81 | def __init__(self, verbose: bool = False) -> None: 82 | self.verbose = verbose 83 | self.delayed_functions: tp.List[DelayedSubmission] = [] 84 | 85 | def add(self, func: tp.Callable[..., tp.Any], *args: tp.Any, **kwargs: tp.Any) -> None: 86 | self.delayed_functions.append(DelayedSubmission(func, *args, **kwargs)) 87 | 88 | def __len__(self) -> int: 89 | return len(self.delayed_functions) 90 | 91 | def __iter__(self) -> tp.Iterator[DelayedSubmission]: 92 | return iter(self.delayed_functions) 93 | 94 | def __call__(self) -> tp.List[tp.Any]: # pylint: disable=arguments-differ 95 | if self.verbose: 96 | done = sum(f.done() for f in self) # those were computed before checkpoint 97 | print(f"Starting from {done}/{len(self.delayed_functions)}", flush=True) 98 | return [ 99 | f.result() for f in self.delayed_functions 100 | ] # results all results one by one (by running the functions if not already done) 101 | 102 | 103 | def as_completed( 104 | jobs: tp.Sequence[core.Job[core.R]], 105 | timeout: tp.Optional[tp.Union[int, float]] = None, 106 | poll_frequency: float = 10, 107 | ) -> tp.Iterator[core.Job[core.R]]: 108 | """ 109 | Yields jobs as they complete (finished, failed or were cancelled). 110 | Raises a TimeoutError if the result isn’t available after timeout seconds. 111 | timeout can be an int or float. If timeout is not specified or None, there is no 112 | limit to the wait time. 113 | 114 | Parameters 115 | ---------- 116 | jobs: list 117 | Jobs instances 118 | 119 | timeout: int/float 120 | Maximum time (in sec) to wait for jobs completion 121 | 122 | poll_frequency: float 123 | Frequency in second at which we check job status. 124 | 125 | Yields 126 | ------ 127 | Job 128 | The next completed job 129 | """ 130 | start = time.time() 131 | jobs_done: tp.Set[int] = set() 132 | while True: 133 | if timeout is not None and time.time() - start > timeout: 134 | raise TimeoutError 135 | for i, job in enumerate(jobs): 136 | if i in jobs_done: 137 | continue 138 | if job.done(): 139 | jobs_done.add(i) 140 | yield job 141 | if len(jobs_done) == len(jobs): 142 | break 143 | time.sleep(poll_frequency) 144 | 145 | 146 | def run_cmd(str_args, **kwargs): 147 | return subprocess.check_output(str_args, **kwargs).decode("utf-8").strip() 148 | 149 | 150 | class RsyncSnapshot: 151 | """Takes a snapshot of the git repository that the script lives in. 152 | 153 | This ensures that remote jobs always use the code from when they are scheduled 154 | and not the code from when they are launched / re-started. 155 | 156 | 157 | Parameters 158 | ---------- 159 | snapshot_dir: Path 160 | A path to where the snapshot should be created 161 | with_submodules: bool 162 | Whether or not submodules should be included in the snapshot 163 | exclude: Sequence[str] 164 | An optional list of patterns to exclude from the snapshot 165 | include: Sequence[str] 166 | A list of relative file names to include from the snapshot. 167 | Useful for .so or other build artifacts that are genarally not tracked by git. 168 | 169 | Note 170 | ---- 171 | - Only files that are checked in to the repository are included in the snapshot. 172 | If you have experimental code that you would like to include in the snapshot, 173 | you'll need to `git add` the file first for it to be included, or use `include` arg. 174 | """ 175 | 176 | def __init__( 177 | self, 178 | snapshot_dir: Path, 179 | root_dir: tp.Optional[Path] = None, 180 | with_submodules: bool = False, 181 | exclude: tp.Sequence[str] = (), 182 | include: tp.Sequence[str] = (), 183 | ): 184 | self.available(throw=True) 185 | self.snapshot_dir = Path(snapshot_dir) 186 | self.root_dir = root_dir or run_cmd(["git", "rev-parse", "--show-toplevel"]) 187 | self.original_dir = Path.cwd() 188 | self.with_submodules = with_submodules 189 | self.exclude = exclude 190 | self.include = include 191 | 192 | @staticmethod 193 | def available(throw: bool = False) -> bool: 194 | if not shutil.which("rsync"): 195 | if throw: 196 | raise RuntimeError("RsyncSnapshot requires rsync to be installed.") 197 | return False 198 | return True 199 | 200 | def __enter__(self) -> None: 201 | self.original_dir = Path.cwd() 202 | # Get the repository root 203 | root_dir = str(self.root_dir) 204 | sub = "--recurse-submodules" if self.with_submodules else "-s" 205 | # Make a shallow git clone 206 | if not self.snapshot_dir.exists(): 207 | self.snapshot_dir.parent.mkdir(parents=True, exist_ok=True) 208 | subprocess.check_call(["git", "clone", "--depth=2", f"file://{root_dir}", str(self.snapshot_dir)]) 209 | 210 | # Get a list of all the checked in files that we can pass to rsync 211 | # Is Rsync faster than a `git pull` ? 212 | with tempfile.NamedTemporaryFile() as tfile: 213 | # https://stackoverflow.com/a/51689219/4876946 214 | run_cmd(f"git ls-files {sub} | grep -v ^16 | cut -f2- > {tfile.name}", cwd=root_dir, shell=True) 215 | exclude = list(itertools.chain.from_iterable(("--exclude", pat) for pat in self.exclude)) 216 | with open(tfile.name, "a", encoding="utf8") as o: 217 | for inc in self.include: 218 | print(inc, file=o) 219 | run_cmd(["rsync", "-a", "--files-from", tfile.name, root_dir, str(self.snapshot_dir)] + exclude) 220 | os.chdir(self.snapshot_dir) 221 | 222 | def __exit__(self, *args): 223 | os.chdir(self.original_dir) 224 | 225 | 226 | def _default_custom_logging(monitoring_start_time: float, n_jobs: int, state_jobs: tp.Dict[str, tp.Set[int]]): 227 | run_time = time.time() - monitoring_start_time 228 | date_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 229 | failed_job_indices = sorted(state_jobs["FAILED"]) 230 | n_chars = len(str(n_jobs)) 231 | 232 | print( 233 | f"[{date_time}] Launched {int(run_time / 60)} minutes ago,", 234 | f"{len(state_jobs['RUNNING']):{n_chars}}/{n_jobs} jobs running,", 235 | f"{len(failed_job_indices):{n_chars}}/{n_jobs} jobs failed,", 236 | f"{len(state_jobs['DONE']) - len(failed_job_indices):{n_chars}}/{n_jobs} jobs done", 237 | flush=True, 238 | ) 239 | 240 | if len(failed_job_indices) > 0: 241 | print(f"[{date_time}] Failed jobs, indices {failed_job_indices}", flush=True) 242 | 243 | 244 | def monitor_jobs( 245 | jobs: tp.Sequence[core.Job[core.R]], 246 | poll_frequency: float = 30, 247 | test_mode: bool = False, 248 | custom_logging: tp.Callable = _default_custom_logging, 249 | ) -> None: 250 | """Continuously monitors given jobs until they are all done or failed. 251 | 252 | Parameters 253 | ---------- 254 | jobs: List[Jobs] 255 | A list of jobs to monitor 256 | poll_frequency: int 257 | The time (in seconds) between two refreshes of the monitoring. 258 | Can't be inferior to 30s. 259 | test_mode: bool 260 | If in test mode, we do not check the length of poll_frequency 261 | """ 262 | 263 | if not test_mode: 264 | assert poll_frequency >= 30, "You can't refresh too often (>= 30s) to avoid overloading squeue" 265 | 266 | n_jobs = len(jobs) 267 | if n_jobs == 0: 268 | print("There are no jobs to monitor") 269 | return 270 | 271 | job_arrays = ", ".join(sorted(set(str(job.job_id).split("_", 1)[0] for job in jobs))) 272 | print(f"Monitoring {n_jobs} jobs from job arrays {job_arrays} \n") 273 | 274 | monitoring_start_time = time.time() 275 | while True: 276 | if not test_mode: 277 | jobs[0].get_info(mode="force") # Force update once to sync the state 278 | state_jobs = collections.defaultdict(set) 279 | for i, job in enumerate(jobs): 280 | state_jobs[job.state.upper()].add(i) 281 | if job.done(): 282 | state_jobs["DONE"].add(i) 283 | 284 | failed_job_indices = sorted(state_jobs["FAILED"]) 285 | if len(state_jobs["DONE"]) == len(jobs): 286 | print(f"All jobs finished, jobs with indices {failed_job_indices} failed", flush=True) 287 | break 288 | 289 | custom_logging(monitoring_start_time, n_jobs, state_jobs) 290 | time.sleep(poll_frequency) 291 | 292 | print(f"Whole process is finished, took {int((time.time() - monitoring_start_time) / 60)} minutes") 293 | 294 | 295 | @contextlib.contextmanager 296 | def clean_env(extra_names: tp.Sequence[str] = ()) -> tp.Iterator[None]: 297 | """Removes slurm and submitit related environment variables so as to avoid interferences 298 | when submiting a new job from a job. 299 | 300 | Parameters 301 | ---------- 302 | extra_names: Sequence[str] 303 | Additional environment variables to hide inside the context, 304 | e.g. TRITON_CACHE_DIR and TORCHINDUCTOR_CACHE_DIR when using torch.compile. 305 | 306 | Note 307 | ---- 308 | A slurm job submitted from within a slurm job inherits some of its attributes, which may 309 | be confusing a cause weird gres errors (or pytorch distributed). 310 | Submitting within this context should prevent this. 311 | 312 | Usage 313 | ----- 314 | with submitit.helpers.clean_env(): 315 | executor.submit(...) 316 | """ 317 | distrib_names = ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK", "LOCAL_WORLD_SIZE") 318 | cluster_env = { 319 | x: os.environ.pop(x) 320 | for x in os.environ 321 | if ( 322 | x.startswith(("SLURM_", "SLURMD_", "SRUN_", "SBATCH_", "SUBMITIT_")) 323 | or x in distrib_names 324 | or x in extra_names 325 | ) 326 | } 327 | try: 328 | yield 329 | finally: 330 | os.environ.update(cluster_env) 331 | 332 | 333 | class TorchDistributedEnvironment: # pragma: no cover 334 | def __init__(self) -> None: 335 | """Construct a class holding the parameters required to properly setup 336 | PyTorch distributed (with the default env:// initialization method). 337 | 338 | Examples 339 | -------- 340 | >>> dist_env = TorchDistributedEnvironment().export() 341 | >>> torch.distributed.init_process_group(backend="nccl") 342 | >>> print(f"master: {dist_env.master_addr}:{dist_env.master_port}") 343 | """ 344 | self._job_env = JobEnvironment() 345 | self.master_addr = self._job_env.hostnames[0] 346 | self.master_port = self._get_master_port() 347 | self.rank = self._job_env.global_rank 348 | self.world_size = self._job_env.num_tasks 349 | self.local_rank = self._job_env.local_rank 350 | self.local_world_size = self._job_env.num_tasks // self._job_env.num_nodes 351 | 352 | def _get_master_port(self) -> int: 353 | # MIN_MASTER_PORT, MAX_MASTER_PORT = (1023, 65535) 354 | MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000) 355 | 356 | master_port_str = os.environ.get("MASTER_PORT") 357 | if master_port_str is None: 358 | rng = random.Random(self._job_env.job_id) 359 | return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) 360 | 361 | master_port = int(master_port_str) 362 | # assert MIN_MASTER_PORT <= master_port <= MIN_MASTER_PORT 363 | return master_port 364 | 365 | def export( 366 | self, 367 | set_cuda_visible_devices: bool = True, 368 | overwrite: bool = False, 369 | ) -> "TorchDistributedEnvironment": 370 | """Export all the environment variables required to properly setup 371 | PyTorch distributed (with the default env:// initialization method) i.e. 372 | MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE (to which LOCAL_RANK and 373 | LOCAL_WORLD_SIZE are added). 374 | 375 | Parameter 376 | ---------- 377 | set_cuda_visible_device: bool 378 | if True, updates CUDA_VISIBLE_DEVICES to use only the device 379 | matching the local rank. 380 | overwrite: bool 381 | if True, overwrites the environment variables if they exist; 382 | this can be useful when launching a job from another job. 383 | 384 | Returns 385 | -------- 386 | TorchDistributedEnvironment 387 | the current instance 388 | """ 389 | # See the "Environment variable initialization" section from 390 | # https://pytorch.org/docs/stable/distributed.html for the complete list of 391 | # environment variables required for the env:// initialization method. 392 | env_vars = { 393 | "MASTER_ADDR": self.master_addr, 394 | "MASTER_PORT": str(self.master_port), 395 | "RANK": str(self.rank), 396 | "WORLD_SIZE": str(self.world_size), 397 | "LOCAL_RANK": str(self.local_rank), # Not required 398 | "LOCAL_WORLD_SIZE": str(self.local_world_size), # Not required 399 | } 400 | if not overwrite: 401 | for key in env_vars: 402 | if key in os.environ: 403 | raise RuntimeError(f"Cannot export environment variables as {key} is already set") 404 | # Note: CUDA_VISIBLE_DEVICES may already be set with all available GPUs 405 | if set_cuda_visible_devices: 406 | env_vars["CUDA_VISIBLE_DEVICES"] = str(self.local_rank) 407 | os.environ.update(env_vars) 408 | return self 409 | --------------------------------------------------------------------------------