├── examples ├── train.py ├── pi.py ├── simple.py └── README.md ├── LICENSE ├── pyproject.toml ├── dawgz ├── __main__.py ├── utils.py ├── __init__.py ├── workflow.py └── schedulers.py ├── .gitignore └── README.md /examples/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from dawgz import after, job, schedule 4 | 5 | 6 | @job 7 | def preprocessing(): 8 | print("data preprocessing") 9 | 10 | 11 | evals = [] 12 | previous = preprocessing 13 | 14 | for i in range(1, 4): 15 | 16 | @after(previous) 17 | @job(name=f"train_{i}") 18 | def train(): 19 | print(f"training step {i}") 20 | 21 | @after(train) 22 | @job(name=f"eval_{i}") 23 | def evaluate(): 24 | print(f"evaluation step {i}") 25 | 26 | evals.append(evaluate) 27 | previous = train 28 | 29 | if __name__ == "__main__": 30 | schedule(*evals, name="train.py", backend="async") 31 | -------------------------------------------------------------------------------- /examples/pi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import numpy as np 5 | import os 6 | 7 | from dawgz import after, ensure, job, schedule 8 | 9 | samples = 10000 10 | tasks = 5 11 | 12 | 13 | @ensure(lambda i: os.path.exists(f"pi_{i}.npy")) 14 | @job(array=tasks, cpus=1, ram="2GB", time="5:00") 15 | def generate(i: int): 16 | print(f"Task {i + 1} / {tasks}") 17 | 18 | x = np.random.random(samples) 19 | y = np.random.random(samples) 20 | within_circle = x**2 + y**2 <= 1 21 | 22 | np.save(f"pi_{i}.npy", within_circle) 23 | 24 | 25 | @after(generate) 26 | @job(cpus=2, ram="4GB", time="15:00") 27 | def estimate(): 28 | files = glob.glob("pi_*.npy") 29 | stack = np.vstack([np.load(f) for f in files]) 30 | pi_estimate = stack.mean() * 4 31 | 32 | print(f"π ≈ {pi_estimate}") 33 | 34 | 35 | if __name__ == "__main__": 36 | schedule(estimate, name="pi.py", backend="async") 37 | -------------------------------------------------------------------------------- /examples/simple.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import time 5 | 6 | from dawgz import after, ensure, job, schedule, waitfor 7 | 8 | 9 | @job 10 | def a(): 11 | print("a") 12 | time.sleep(3) 13 | print("a") 14 | raise Exception() 15 | 16 | 17 | @job 18 | def b(): 19 | time.sleep(1) 20 | print("b") 21 | time.sleep(1) 22 | print("b") 23 | 24 | 25 | @after(a, status="success") 26 | @ensure(lambda: 2 + 2 == 2 * 2) 27 | @ensure(lambda: 1 + 2 + 3 == 1 * 2 * 3) 28 | @job 29 | def c(): 30 | print("c") 31 | 32 | 33 | @after(b) 34 | @ensure(lambda i: i != 42 or os.path.exists(f"{i}.log")) 35 | @job(array=100) 36 | def d(i: int): 37 | print(f"d{i}") 38 | 39 | with open(f"{i}.log", "w") as file: 40 | file.write("done") 41 | 42 | 43 | @after(a, d) 44 | @waitfor("any") 45 | @job 46 | def e(): 47 | print("e") 48 | 49 | 50 | if __name__ == "__main__": 51 | schedule(c, e, name="simple.py", backend="async", prune=True) 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 François Rozet 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "dawgz" 7 | description = "Directed Acyclic Workflow Graph Scheduling" 8 | authors = [ 9 | {name = "François Rozet", email = "francois.rozet@outlook.com"}, 10 | {name = "Joeri Hermans"}, 11 | ] 12 | classifiers = [ 13 | "Intended Audience :: Developers", 14 | "Intended Audience :: Science/Research", 15 | "License :: OSI Approved :: MIT License", 16 | "Natural Language :: English", 17 | "Operating System :: OS Independent", 18 | "Programming Language :: Python :: 3", 19 | ] 20 | dependencies = [ 21 | "cloudpickle>=2.0.0", 22 | "tabulate>=0.8.0", 23 | ] 24 | dynamic = ["version"] 25 | keywords = ["workflow", "scheduling", "slurm", "hpc"] 26 | readme = "README.md" 27 | requires-python = ">=3.8" 28 | 29 | [project.optional-dependencies] 30 | lint = [ 31 | "ruff==0.9.9", 32 | ] 33 | 34 | [project.scripts] 35 | dawgz = "dawgz.__main__:main" 36 | 37 | [project.urls] 38 | documentation = "https://github.com/francois-rozet/dawgz" 39 | source = "https://github.com/francois-rozet/dawgz" 40 | tracker = "https://github.com/francois-rozet/dawgz/issues" 41 | 42 | [tool.ruff] 43 | line-length = 99 44 | 45 | [tool.ruff.lint] 46 | extend-select = ["B", "I", "W"] 47 | ignore = ["B023", "E731"] 48 | preview = true 49 | 50 | [tool.ruff.lint.isort] 51 | lines-between-types = 1 52 | relative-imports-order = "closest-to-furthest" 53 | section-order = ["future", "third-party", "first-party", "local-folder"] 54 | 55 | [tool.ruff.format] 56 | preview = true 57 | 58 | [tool.setuptools.dynamic] 59 | version = {attr = "dawgz.__version__"} 60 | 61 | [tool.setuptools.packages.find] 62 | include = ["dawgz*"] 63 | -------------------------------------------------------------------------------- /dawgz/__main__.py: -------------------------------------------------------------------------------- 1 | r"""Module's main""" 2 | 3 | import argparse 4 | import csv 5 | 6 | from tabulate import tabulate 7 | from typing import List 8 | 9 | from .schedulers import DIR, Scheduler 10 | 11 | 12 | def table(workflows: List[List[str]], workflow: int = None, job: int = None): 13 | if workflow is None: 14 | headers = ("Name", "ID", "Date", "Backend", "Jobs", "Errors") 15 | rows = [(w[0], w[1][:8], *w[2:]) for w in workflows] 16 | 17 | table = tabulate(rows, headers, showindex=True) 18 | else: 19 | row = workflows[workflow] 20 | uuid = row[1] 21 | scheduler = Scheduler.load(DIR / uuid) 22 | 23 | if job is None: 24 | table = scheduler.report() 25 | else: 26 | jobs = list(scheduler.order) 27 | job = jobs[job] 28 | 29 | table = scheduler.report(job) 30 | 31 | print(table) 32 | 33 | 34 | def cancel(workflows: List[List[str]], workflow: int, job: int = None): 35 | row = workflows[workflow] 36 | uuid = row[1] 37 | scheduler = Scheduler.load(DIR / uuid) 38 | 39 | if job is None: 40 | message = scheduler.cancel() 41 | else: 42 | jobs = list(scheduler.order) 43 | job = jobs[job] 44 | 45 | message = scheduler.cancel(job) 46 | 47 | if message: 48 | print(message) 49 | 50 | 51 | def main(): 52 | # Parser 53 | parser = argparse.ArgumentParser(description="DAWGZ's CLI") 54 | 55 | parser.add_argument("workflow", default=None, nargs="?", type=int, help="workflow index") 56 | parser.add_argument("job", default=None, nargs="?", type=int, help="job index") 57 | 58 | parser.add_argument("-c", "--cancel", default=False, action="store_true") 59 | 60 | args = parser.parse_args() 61 | 62 | # Workflows 63 | record = DIR / "workflows.csv" 64 | 65 | if record.exists(): 66 | with open(record) as f: 67 | workflows = list(csv.reader(f)) 68 | else: 69 | workflows = [] 70 | 71 | # Action 72 | if args.cancel: 73 | cancel(workflows, args.workflow, args.job) 74 | else: 75 | table(workflows, args.workflow, args.job) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ######### 2 | # DAWGZ # 3 | ######### 4 | 5 | .dawgz 6 | 7 | 8 | ########## 9 | # Others # 10 | ########## 11 | 12 | # macOS 13 | .DS_Store 14 | 15 | 16 | ########## 17 | # Python # 18 | ########## 19 | 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | share/python-wheels/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | MANIFEST 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .nox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | *.py,cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | cover/ 72 | 73 | # Translations 74 | *.mo 75 | *.pot 76 | 77 | # Django stuff: 78 | *.log 79 | local_settings.py 80 | db.sqlite3 81 | db.sqlite3-journal 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | .pybuilder/ 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # IPython 101 | profile_default/ 102 | ipython_config.py 103 | 104 | # pyenv 105 | # For a library or package, you might want to ignore these files since the code is 106 | # intended to run in multiple environments; otherwise, check them in: 107 | # .python-version 108 | 109 | # pipenv 110 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 111 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 112 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 113 | # install all needed dependencies. 114 | # Pipfile.lock 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # NumPy 160 | *.npy 161 | 162 | # Pickle 163 | *.pkl 164 | -------------------------------------------------------------------------------- /dawgz/utils.py: -------------------------------------------------------------------------------- 1 | r"""Miscellaneous helpers""" 2 | 3 | import asyncio 4 | import cloudpickle as pickle 5 | import inspect 6 | import sys 7 | import traceback 8 | 9 | from typing import Any, Callable, Iterable 10 | 11 | 12 | def accepts(f: Callable, /, *args, **kwargs) -> bool: 13 | r"""Checks whether a function `f` accepts arguments without errors.""" 14 | 15 | try: 16 | inspect.signature(f).bind(*args, **kwargs) 17 | except TypeError: 18 | return False 19 | else: 20 | return True 21 | 22 | 23 | def cat(text: str, width: int) -> str: 24 | r"""Formats text as it would be displayed in a terminal.""" 25 | 26 | lines = [] 27 | 28 | for line in text.split("\n"): 29 | s = "" 30 | 31 | for carriage in reversed(line.split("\r")): 32 | if len(carriage) > len(s): 33 | s = s + carriage[len(s) :] 34 | 35 | line = s 36 | 37 | if line: 38 | for i in range(0, len(line), width): 39 | lines.append(line[i : i + width]) 40 | else: 41 | lines.append(line) # keep empty lines 42 | 43 | return "\n".join(lines) 44 | 45 | 46 | def comma_separated(integers: Iterable[int]) -> str: 47 | r"""Formats integers as a comma separated list of intervals.""" 48 | 49 | integers = sorted(list(integers)) 50 | intervals = [] 51 | 52 | i = j = integers[0] 53 | 54 | for k in integers[1:]: 55 | if k > j + 1: 56 | intervals.append((i, j)) 57 | i = j = k 58 | else: 59 | j = k 60 | else: 61 | intervals.append((i, j)) 62 | 63 | fmt = lambda i, j: f"{i}" if i == j else f"{i}-{j}" 64 | 65 | return ",".join(map(fmt, *zip(*intervals))) 66 | 67 | 68 | def eprint(*args, **kwargs): 69 | r"""Prints to the standard error stream.""" 70 | 71 | print(*args, file=sys.stderr, **kwargs) 72 | 73 | 74 | def every(conditions: Iterable[Callable]) -> Callable: 75 | r"""Combines a list of conditions into a single condition.""" 76 | 77 | return lambda *args: all(c(*args) for c in conditions) 78 | 79 | 80 | def future(obj: Any, return_exceptions: bool = False) -> asyncio.Future: 81 | r"""Transforms any object to an awaitable future.""" 82 | 83 | if inspect.isawaitable(obj): 84 | if return_exceptions: 85 | fut = asyncio.Future() 86 | 87 | def callback(self): 88 | result = self.exception() 89 | if result is None: 90 | result = self.result() 91 | 92 | fut.set_result(result) 93 | 94 | asyncio.ensure_future(obj).add_done_callback(callback) 95 | else: 96 | fut = asyncio.ensure_future(obj) 97 | else: 98 | fut = asyncio.Future() 99 | fut.set_result(obj) 100 | 101 | return fut 102 | 103 | 104 | def runpickle(f: bytes, /, *args, **kwargs) -> Any: 105 | r"""Runs a pickled function `f` with arguments.""" 106 | 107 | return pickle.loads(f)(*args, **kwargs) 108 | 109 | 110 | def slugify(text: str) -> str: 111 | r"""Slugifies text.""" 112 | 113 | return "".join(char if char.isalnum() else "_" for char in text) 114 | 115 | 116 | def trace(error: Exception) -> str: 117 | r"""Returns the trace of an error.""" 118 | 119 | lines = traceback.format_exception( 120 | type(error), 121 | error, 122 | error.__traceback__, 123 | ) 124 | 125 | return "".join(lines).strip("\n") 126 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ## Simple example 4 | 5 | In [`simple.py`](simple.py) we define a workflow composed of 5 jobs. In summary, 6 | 7 | * `a` and `b` are concurrent. 8 | * `c` waits for `a` to complete with success. 9 | * `c` ensures that `2 + 2 == 2 * 2` and `1 + 2 + 3 == 1 * 2 * 3`. 10 | * `d` waits for `b` to complete with `'success'`. 11 | * `d` ensures that `i != 42` or `{i}.log` exists after successful completion. 12 | * `e` waits for `'any'` of its dependencies (either `a` or `d`) to complete. 13 | 14 | In `schedule`, the dependency graph of `c` and `e` is pruned with respect to the postconditions. 15 | 16 | * `c`'s postconditions are both always `True`, resulting in `c` being pruned out from the graph, even though its dependency `a` fails. 17 | * `d`'s postcondition returns `False` for `i = 42`. Therefore, all other indices are pruned out. 18 | 19 | Then, the jobs in the workflow graph are submitted, which results in the following output 20 | 21 | ``` 22 | a 23 | b 24 | b 25 | d42 26 | e 27 | a 28 | ``` 29 | 30 | as well as a table caused by the failure of `a`. 31 | 32 | ``` 33 | Job Error 34 | -- ----- --------------------------------------------------------------------------------------------------- 35 | 0 a Traceback (most recent call last): 36 | File "/home/username/env/lib/python3.8/site-packages/dawgz/schedulers.py", line 241, in exec 37 | return await call() 38 | File "/home/username/env/lib/python3.8/site-packages/dawgz/schedulers.py", line 254, in remote 39 | return await asyncio.get_running_loop().run_in_executor( 40 | File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57, in run 41 | result = self.fn(*self.args, **self.kwargs) 42 | File "/home/username/env/lib/python3.8/site-packages/dawgz/utils.py", line 90, in runpickle 43 | return pickle.loads(f)(*args, **kwargs) 44 | File "/home/username/env/lib/python3.8/site-packages/dawgz/workflow.py", line 90, in call 45 | result = f(*args) 46 | File "simple.py", line 12, in a 47 | raise Exception() 48 | Exception 49 | 50 | The above exception was the direct cause of the following exception: 51 | 52 | Traceback (most recent call last): 53 | File "/home/username/env/lib/python3.8/site-packages/dawgz/schedulers.py", line 136, in _submit 54 | return await self.exec(job) 55 | File "/home/username/env/lib/python3.8/site-packages/dawgz/schedulers.py", line 251, in exec 56 | raise JobFailedError(str(job)) from e 57 | dawgz.schedulers.JobFailedError: a 58 | ``` 59 | 60 | However, because `schedule` is called with `prune=True`, running the script again leads to the following output 61 | 62 | ``` 63 | e 64 | ``` 65 | 66 | as the file `42.log` now exists and `e` only requires `a` or `d` to complete. 67 | 68 | ## Train example 69 | 70 | In [`train.py`](train.py) we define a workflow that alternates between training and evaluation steps. The training steps are consecutive, meaning that the `i`th is always executed after the `i-1`th and before the `i+1`th. However, the evaluation steps can be executed directly after their respective training step, even though preceding evaluation steps have not completed yet. The workflow graph looks like 71 | 72 | ``` 73 | preprocessing → train_1 → train_2 → train_3 74 | ↓ ↓ ↓ 75 | eval_1 eval_2 eval_3 76 | ``` 77 | 78 | and scheduling the dependency graph results in the following output 79 | 80 | ``` 81 | data preprocessing 82 | training step 1 83 | evaluation step 1 84 | training step 2 85 | evaluation step 2 86 | training step 3 87 | evaluation step 3 88 | ``` 89 | 90 | If we change the backend to `'dummy'`, we observe that the evaluation steps are not necessarily consecutive. 91 | 92 | ``` 93 | START preprocessing 94 | END preprocessing 95 | START train_1 96 | END train_1 97 | START eval_1 98 | START train_2 99 | END eval_1 100 | END train_2 101 | START eval_2 102 | START train_3 103 | END train_3 104 | START eval_3 105 | END eval_3 106 | END eval_2 107 | ``` 108 | -------------------------------------------------------------------------------- /dawgz/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Directed Acyclic Workflow Graph Scheduling""" 2 | 3 | __version__ = "1.0.4" 4 | 5 | from functools import partial 6 | from tabulate import tabulate 7 | from typing import Any, Callable, Dict, Iterable, Optional, Union 8 | 9 | from .schedulers import ( 10 | AsyncScheduler, 11 | DummyScheduler, 12 | Scheduler, 13 | SlurmScheduler, 14 | ) 15 | from .utils import eprint 16 | from .workflow import Job 17 | 18 | 19 | def job( 20 | f: Callable = None, 21 | *, 22 | name: Optional[str] = None, 23 | array: Optional[Union[int, Iterable[int]]] = None, 24 | array_throttle: Optional[int] = None, 25 | settings: Dict[str, Any] = {}, # noqa: B006 26 | **kwargs, 27 | ) -> Union[Callable, Job]: 28 | r"""Transforms a function into a job. 29 | 30 | Arguments: 31 | f: A function. 32 | name: The job name. 33 | array: An array size or set of indices. A job array is a group of jobs that can 34 | be launched concurrently. They are described by the same function, but 35 | differ by their index. 36 | array_throttle: The maximum number of simultaneously running jobs in an array. 37 | Only affects the Slurm backend. 38 | settings: The settings of the job, interpreted by the scheduler. Settings include 39 | the allocated resources (e.g. `cpus=4`, `ram="16GB"`), the estimated runtime 40 | (e.g. `time="03:14:15"`), the partition (e.g. `partition="gpu"`) and much 41 | more. 42 | kwargs: Additional keyword arguments added to `settings`. 43 | """ 44 | 45 | kwargs.update( 46 | name=name, 47 | array=array, 48 | array_throttle=array_throttle, 49 | settings=settings, 50 | ) 51 | 52 | if f is None: 53 | return partial(job, **kwargs) 54 | else: 55 | return Job(f, **kwargs) 56 | 57 | 58 | def after(*deps: Job, status: str = "success") -> Callable: 59 | r"""Adds dependencies to a job. 60 | 61 | Arguments: 62 | deps: A set of job dependencies. 63 | status: The desired dependency status. Options are `"success"`, `"failure"` or `"any"`. 64 | """ 65 | 66 | def decorator(self: Job) -> Job: 67 | self.after(*deps, status=status) 68 | return self 69 | 70 | return decorator 71 | 72 | 73 | def waitfor(mode: str) -> Callable: 74 | r"""Modifies the waiting mode of a job. 75 | 76 | Arguments: 77 | mode: The dependency waiting mode. Options are `"all"` (default) or `"any"`. 78 | """ 79 | 80 | def decorator(self: Job) -> Job: 81 | self.waitfor = mode 82 | return self 83 | 84 | return decorator 85 | 86 | 87 | def ensure(condition: Callable) -> Callable: 88 | r"""Adds a postcondition to a job. 89 | 90 | Arguments: 91 | condition: A predicate that should be `True` after the execution of the job. 92 | """ 93 | 94 | def decorator(self: Job) -> Job: 95 | self.ensure(condition) 96 | return self 97 | 98 | return decorator 99 | 100 | 101 | def schedule( 102 | *jobs: Job, 103 | backend: str, 104 | prune: bool = False, 105 | quiet: bool = False, 106 | **kwargs, 107 | ) -> Scheduler: 108 | r"""Schedules a group of jobs. 109 | 110 | Arguments: 111 | jobs: A group of jobs describing a workflow. 112 | backend: The scheduling backend. Options are `"async"`, `"dummy"` or `"slurm"`. 113 | prune: Whether to prune jobs that have already been executed or not, 114 | as determined by their postconditions. 115 | quiet: Whether to display eventual job errors or not. 116 | kwargs: Keyword arguments passed to the scheduler's constructor. 117 | 118 | Returns: 119 | The workflow scheduler. 120 | """ 121 | 122 | backends = { 123 | s.backend: s 124 | for s in ( 125 | AsyncScheduler, 126 | DummyScheduler, 127 | SlurmScheduler, 128 | ) 129 | } 130 | 131 | scheduler = backends.get(backend)(**kwargs) 132 | scheduler(*jobs, prune=prune) 133 | scheduler.dump() 134 | 135 | if scheduler.traces and not quiet: 136 | eprint(tabulate(scheduler.traces.items(), ("Job", "Error"), showindex=True)) 137 | 138 | return scheduler 139 | -------------------------------------------------------------------------------- /dawgz/workflow.py: -------------------------------------------------------------------------------- 1 | r"""Workflow graph components""" 2 | 3 | from __future__ import annotations 4 | 5 | from functools import cached_property 6 | from typing import Any, Callable, Dict, Iterable, Iterator, List, Set, Union 7 | 8 | from .utils import accepts, comma_separated, every, pickle 9 | 10 | 11 | class Node(object): 12 | r"""Abstract graph node""" 13 | 14 | def __init__(self): 15 | super().__init__() 16 | 17 | self.children = {} 18 | self.parents = {} 19 | 20 | def add_child(self, node: Node, edge: Any = None): 21 | self.children[node] = edge 22 | node.parents[self] = edge 23 | 24 | def add_parent(self, node: Node, edge: Any = None): 25 | node.add_child(self, edge) 26 | 27 | def rm_child(self, node: Node): 28 | del self.children[node] 29 | del node.parents[self] 30 | 31 | def rm_parent(self, node: Node): 32 | node.rm_child(self) 33 | 34 | 35 | class Job(Node): 36 | r"""Job node""" 37 | 38 | def __init__( 39 | self, 40 | f: Callable, 41 | *, 42 | name: str = None, 43 | array: Union[int, Iterable[int]] = None, 44 | array_throttle: int = None, 45 | interpreter: str = None, 46 | settings: Dict[str, Any] = {}, # noqa: B006 47 | **kwargs, 48 | ): 49 | super().__init__() 50 | 51 | assert callable(f), "job should be callable" 52 | 53 | if array is None: 54 | assert accepts(f), "job should not expect arguments" 55 | else: 56 | if type(array) is int: 57 | array = range(array) 58 | array = set(array) 59 | 60 | assert len(array) > 0, "array should not be empty" 61 | assert accepts(f, 0), "job array should expect an argument" 62 | 63 | self._f = pickle.dumps(f) 64 | self.name = f.__name__ if name is None else name 65 | self.array = array 66 | self.array_throttle = array_throttle 67 | self.interpreter = interpreter 68 | 69 | # Settings 70 | self.settings = settings.copy() 71 | self.settings.update(kwargs) 72 | 73 | # Dependencies 74 | self._waitfor = "all" 75 | self.unsatisfied = set() 76 | 77 | # Conditions 78 | self._postconditions = [] 79 | 80 | def __getstate__(self) -> Dict: 81 | state = self.__dict__.copy() 82 | 83 | for key in ["_f", "_postconditions"]: 84 | state.pop(key, None) 85 | 86 | return state 87 | 88 | @property 89 | def f(self) -> Callable: 90 | return pickle.loads(self._f) 91 | 92 | @property 93 | def run(self) -> Callable: 94 | name = self.name 95 | f = self.f 96 | cond = every(self.postconditions) 97 | 98 | def fun(*args) -> Any: 99 | result = f(*args) 100 | 101 | if not cond(*args): 102 | raise PostconditionNotSatisfiedError(f"{name}{list(args) if args else ''}") 103 | 104 | return result 105 | 106 | return fun 107 | 108 | def __call__(self, *args) -> Any: 109 | return self.run(*args) 110 | 111 | def __str__(self) -> str: 112 | if self.array is None: 113 | return self.name 114 | else: 115 | return self.name + "[" + comma_separated(self.array) + "]" 116 | 117 | @property 118 | def dependencies(self) -> Dict[Job, str]: 119 | return self.parents 120 | 121 | def after(self, *deps: Job, status: str = "success"): 122 | assert status in ["success", "failure", "any"] 123 | 124 | for dep in deps: 125 | self.add_parent(dep, status) 126 | 127 | def detach(self, *deps: Job): 128 | for dep in deps: 129 | self.rm_parent(dep) 130 | 131 | @property 132 | def waitfor(self) -> str: 133 | return self._waitfor 134 | 135 | @waitfor.setter 136 | def waitfor(self, mode: str = "all"): 137 | assert mode in ["all", "any"] 138 | 139 | self._waitfor = mode 140 | 141 | def ensure(self, condition: Callable): 142 | if self.array is None: 143 | assert accepts(condition), "postcondition should not expect arguments" 144 | else: 145 | assert accepts(condition, 0), "postcondition should expect an argument" 146 | 147 | self._postconditions.append(pickle.dumps(condition)) 148 | 149 | @property 150 | def postconditions(self) -> List[Callable]: 151 | return list(map(pickle.loads, self._postconditions)) 152 | 153 | @cached_property 154 | def done(self) -> bool: 155 | if not self.postconditions: 156 | return False 157 | 158 | condition = every(self.postconditions) 159 | 160 | if self.array is None: 161 | return condition() 162 | else: 163 | return all(map(condition, self.array)) 164 | 165 | @property 166 | def satisfiable(self) -> bool: 167 | if self.unsatisfied: 168 | if self.waitfor == "all": 169 | return False 170 | elif self.waitfor == "any" and not self.dependencies: 171 | return False 172 | 173 | return True 174 | 175 | 176 | def dfs(*nodes: Node, backward: bool = False) -> Iterator[Node]: 177 | queue = list(nodes) 178 | visited = set() 179 | 180 | while queue: 181 | node = queue.pop() 182 | 183 | if node in visited: 184 | continue 185 | else: 186 | yield node 187 | 188 | queue.extend(node.parents if backward else node.children) 189 | visited.add(node) 190 | 191 | 192 | def leafs(*nodes: Node) -> Set[Node]: 193 | return {node for node in dfs(*nodes, backward=False) if not node.children} 194 | 195 | 196 | def roots(*nodes: Node) -> Set[Node]: 197 | return {node for node in dfs(*nodes, backward=True) if not node.parents} 198 | 199 | 200 | def cycles(*nodes: Node, backward: bool = False) -> Iterator[List[Node]]: 201 | queue = [list(nodes)] 202 | path = [] 203 | pathset = set() 204 | visited = set() 205 | 206 | while queue: 207 | branch = queue[-1] 208 | 209 | if not branch: 210 | if not path: 211 | break 212 | 213 | queue.pop() 214 | pathset.remove(path.pop()) 215 | continue 216 | 217 | node = branch.pop() 218 | 219 | if node in visited: 220 | if node in pathset: 221 | yield path + [node] 222 | continue 223 | 224 | queue.append(list(node.parents if backward else node.children)) 225 | path.append(node) 226 | pathset.add(node) 227 | visited.add(node) 228 | 229 | 230 | def prune(*jobs: Job) -> Set[Job]: 231 | for job in dfs(*jobs, backward=True): 232 | if job.done: 233 | job.detach(*job.dependencies) 234 | elif job.array is not None and job.postconditions: 235 | condition = every(job.postconditions) 236 | job.array = {i for i in job.array if not condition(i)} 237 | 238 | satisfied, unsatisfied, pending = [], [], [] 239 | 240 | for dep, status in job.dependencies.items(): 241 | if dep.done: 242 | if status == "failure": # first-order unsatisfiability 243 | unsatisfied.append(dep) 244 | else: 245 | satisfied.append(dep) 246 | else: 247 | pending.append(dep) 248 | 249 | job.detach(*satisfied, *unsatisfied) 250 | 251 | if job.waitfor == "any" and satisfied: 252 | job.detach(*pending) 253 | job.unsatisfied.clear() 254 | else: 255 | job.unsatisfied.update(unsatisfied) 256 | 257 | return {job for job in jobs if not job.done} 258 | 259 | 260 | class PostconditionNotSatisfiedError(Exception): 261 | pass 262 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Directed Acyclic Workflow Graph Scheduling 2 | 3 | Would you like fully reproducible and reusable experiments that run on HPC clusters as seamlessly as on your machine? Do you have to comment out large parts of your pipelines whenever something failed? Tired of writing and submitting [Slurm](https://wikipedia.org/wiki/Slurm_Workload_Manager) scripts? Then `dawgz` is made for you! 4 | 5 | The `dawgz` package provides a lightweight and intuitive Python interface to declare jobs along with their dependencies, requirements, settings, etc. A single line of code is then needed to execute automatically all or part of the workflow, while complying to the dependencies. Importantly, `dawgz` can also hand over the execution to resource management backends like [Slurm](https://wikipedia.org/wiki/Slurm_Workload_Manager), which enables to execute the same workflow on your machine and HPC clusters. 6 | 7 | ## Installation 8 | 9 | The `dawgz` package is available on [PyPi](https://pypi.org/project/dawgz/), which means it is installable via `pip`. 10 | 11 | ``` 12 | pip install dawgz 13 | ``` 14 | 15 | Alternatively, if you need the latest features, you can install it using 16 | 17 | ``` 18 | pip install git+https://github.com/francois-rozet/dawgz 19 | ``` 20 | 21 | ## Getting started 22 | 23 | In `dawgz`, a job is a Python function decorated by `@dawgz.job`. This decorator allows to define the job's parameters, like its name, whether it is a job array, the resources it needs, etc. The job's dependencies are declared with the `@dawgz.after` decorator. At last, the `dawgz.schedule` function takes care of scheduling the jobs and their dependencies, with a selected backend. For more information, check out the [interface](#Interface) and the [examples](examples/). 24 | 25 | Follows a small example demonstrating how one could use `dawgz` to calculate `π` (very roughly) using the [Monte Carlo method](https://en.wikipedia.org/wiki/Monte_Carlo_method). We define two jobs: `generate` and `estimate`. The former is a *job array*, meaning that it is executed concurrently for all values of `i = 0` up to `tasks - 1`. It also defines a [postcondition](https://en.wikipedia.org/wiki/Postconditions) ensuring that the file `pi_{i}.npy` exists after the job's completion. The job `estimate` has `generate` as dependency, meaning it should only start after `generate` succeeded. 26 | 27 | ```python 28 | import glob 29 | import numpy as np 30 | import os 31 | 32 | from dawgz import job, after, ensure, schedule 33 | 34 | samples = 10000 35 | tasks = 5 36 | 37 | @ensure(lambda i: os.path.exists(f"pi_{i}.npy")) 38 | @job(array=tasks, cpus=1, ram="2GB", time="5:00") 39 | def generate(i: int): 40 | print(f"Task {i + 1} / {tasks}") 41 | 42 | x = np.random.random(samples) 43 | y = np.random.random(samples) 44 | within_circle = x**2 + y**2 <= 1 45 | 46 | np.save(f"pi_{i}.npy", within_circle) 47 | 48 | @after(generate) 49 | @job(cpus=2, ram="4GB", time="15:00") 50 | def estimate(): 51 | files = glob.glob("pi_*.npy") 52 | stack = np.vstack([np.load(f) for f in files]) 53 | pi_estimate = stack.mean() * 4 54 | 55 | print(f"π ≈ {pi_estimate}") 56 | 57 | schedule(estimate, name="pi.py", backend="async") 58 | ``` 59 | 60 | Running this script with the `"async"` backend displays 61 | 62 | ``` 63 | $ python examples/pi.py 64 | Task 1 / 5 65 | Task 2 / 5 66 | Task 3 / 5 67 | Task 4 / 5 68 | Task 5 / 5 69 | π ≈ 3.141865 70 | ``` 71 | 72 | Alternatively, on a Slurm HPC cluster, changing the backend to `"slurm"` results in the following job queue. 73 | 74 | ``` 75 | $ squeue -u username 76 | JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON) 77 | 1868832 all estimate username PD 0:00 1 (Dependency) 78 | 1868831_[2-4] all generate username PD 0:00 1 (Resources) 79 | 1868831_0 all generate username R 0:01 1 node-x 80 | 1868831_1 all generate username R 0:01 1 node-y 81 | ``` 82 | 83 | In addition to the Python interface, `dawgz` provides a simple command-line interface (CLI) to list the scheduled workflows, the jobs of a workflow or the output(s) of a job. 84 | 85 | ``` 86 | $ dawgz 87 | Name ID Date Backend Jobs Errors 88 | -- ------ -------- ------------------- --------- ------ -------- 89 | 0 pi.py 8094aa20 2022-02-28 16:37:58 async 2 0 90 | 1 pi.py 9cc409fd 2022-02-28 16:38:33 slurm 2 0 91 | $ dawgz 1 92 | Name ID State 93 | -- ------------- ------- ------------------------- 94 | 0 generate[0-4] 1868831 COMPLETED,PENDING,RUNNING 95 | 1 estimate 1868832 PENDING 96 | $ dawgz 1 0 97 | Name State Output 98 | -- ----------- --------- ---------- 99 | 0 generate[0] COMPLETED Task 1 / 5 100 | 1 generate[1] COMPLETED Task 2 / 5 101 | 2 generate[2] RUNNING 102 | 3 generate[3] RUNNING 103 | 4 generate[4] PENDING 104 | ``` 105 | 106 | ## Interface 107 | 108 | ### Decorators 109 | 110 | The package provides four decorators: 111 | 112 | * `@dawgz.job` registers a function as a job, with its settings (name, array, resources, ...). It should always be the first (lowest) decorator. In the following example, `a` is a job with the name `"A"` and a time limit of one hour. 113 | 114 | ```python 115 | @job(name="A", time="01:00:00") 116 | def a(): 117 | ``` 118 | 119 | All keyword arguments other than `name`, `array` and `array_throttle` are passed as settings to the scheduler. For example, with the `slurm` backend, the following would lead to a job array of 64 tasks, with a maximum of 3 simultaneous tasks running exclusively on `tesla` or `quadro` partitions. 120 | 121 | ```python 122 | @job(array=64, array_throttle=3, partition="tesla,quadro") 123 | ``` 124 | 125 | Importantly, a job is **shipped with its context**, meaning that modifying global variables after it has been created does not affect its execution. 126 | 127 | However, a job is **not** shipped with its dependencies. This means that updating or modifying a dependency (i.e. a module) after a job has been submitted can affect its execution. If this is an issue for you, you can register your module such that it is pickled by value rather than by reference. 128 | 129 | ```python 130 | import cloudpickle 131 | import my_module 132 | 133 | cloudpickle.register_pickle_by_value(my_module) 134 | 135 | @job 136 | def a(): 137 | my_module.my_function() 138 | ``` 139 | 140 | * `@dawgz.after` adds one or more dependencies to a job. By default, the job waits for its dependencies to complete with success. The desired status can be set to `"success"` (default), `"failure"` or `"any"`. In the following example, `b` waits for `a` to complete with `"failure"`. 141 | 142 | ```python 143 | @after(a, status="failure") 144 | @job 145 | def b(): 146 | ``` 147 | 148 | * `@dawgz.waitfor` declares whether the job has to wait for `"all"` (default) or `"any"` of its dependencies to be satisfied before starting. In the following example, `c` waits for either `a` or `b` to complete (with success). 149 | 150 | ```python 151 | @after(a, b) 152 | @waitfor("any") 153 | @job 154 | def c(): 155 | ``` 156 | 157 | * `@dawgz.ensure` adds a [postcondition](https://wikipedia.org/wiki/Postconditions) to a job, i.e. a condition that must be `True` after the execution of the job. Not satisfying all postconditions after execution results in an `AssertionError` at runtime. In the following example, `d` ensures that the file `log.txt` exists. 158 | 159 | ```python 160 | @ensure(lambda: os.path.exists("log.txt")) 161 | @job 162 | def d(): 163 | ``` 164 | 165 | Traditionally, postconditions are only **necessary** indicators that a task completed with success. In `dawgz`, they are considered both necessary and **sufficient** indicators. Therefore, postconditions can be used to detect jobs that have already been executed and prune them out of the workflow. To do so, set `prune=True` in `dawgz.schedule`. 166 | 167 | ### Backends 168 | 169 | Currently, `dawgz.schedule` supports three backends: `async`, `dummy` and `slurm`. 170 | 171 | * `async` waits asynchronously for dependencies to complete before executing each job. The jobs are executed by the current Python interpreter. 172 | * `dummy` is equivalent to `async`, but instead of executing the jobs, prints their name before and after a short (random) sleep time. The main use of `dummy` is debugging. 173 | * `slurm` submits the jobs to the Slurm workload manager by automatically generating `sbatch` submission scripts. 174 | -------------------------------------------------------------------------------- /dawgz/schedulers.py: -------------------------------------------------------------------------------- 1 | r"""Scheduling backends""" 2 | 3 | from __future__ import annotations 4 | 5 | import asyncio 6 | import concurrent.futures as cf 7 | import csv 8 | import os 9 | import shutil 10 | import subprocess 11 | import uuid 12 | 13 | from abc import ABC, abstractmethod 14 | from contextlib import contextmanager 15 | from datetime import datetime 16 | from functools import lru_cache, partial 17 | from inspect import isawaitable 18 | from pathlib import Path 19 | from random import random 20 | from tabulate import tabulate 21 | from typing import Any, Callable, Dict, Sequence 22 | 23 | from .utils import cat, comma_separated, future, pickle, runpickle, slugify, trace 24 | from .workflow import Job, cycles 25 | from .workflow import prune as _prune 26 | 27 | DIR = os.environ.get("DAWGZ_DIR", ".dawgz") 28 | DIR = Path(DIR).resolve() 29 | 30 | 31 | class Scheduler(ABC): 32 | r"""Abstract workflow scheduler.""" 33 | 34 | backend: str = None 35 | 36 | def __init__( 37 | self, 38 | name: str = None, 39 | settings: Dict[str, Any] = {}, # noqa: B006 40 | **kwargs, 41 | ): 42 | r""" 43 | Arguments: 44 | name: The name of the workflow. 45 | settings: A dictionnary of settings. 46 | kwargs: Keyword arguments added to `settings`. 47 | """ 48 | 49 | super().__init__() 50 | 51 | self.name = name 52 | self.date = datetime.now().replace(microsecond=0) 53 | self.uuid = uuid.uuid4().hex 54 | 55 | self.path = DIR / self.uuid 56 | self.path.mkdir(parents=True) 57 | 58 | # Settings 59 | self.settings = settings.copy() 60 | self.settings.update(kwargs) 61 | 62 | # Jobs 63 | self.order = {} 64 | self.results = {} 65 | self.traces = {} 66 | 67 | def dump(self): 68 | with open(self.path / "dump.pkl", "wb") as f: 69 | pickle.dump(self, f) 70 | 71 | with open(self.path.parent / "workflows.csv", mode="a", newline="") as f: 72 | csv.writer(f).writerow(( 73 | self.name, 74 | self.uuid, 75 | self.date, 76 | self.backend, 77 | len(self.order), 78 | len(self.traces), 79 | )) 80 | 81 | @staticmethod 82 | def load(path: Path) -> Scheduler: 83 | with open(path / "dump.pkl", "rb") as f: 84 | return pickle.load(f) 85 | 86 | def tag(self, job: Job) -> str: 87 | if job in self.order: 88 | i = self.order[job] 89 | else: 90 | i = self.order[job] = len(self.order) 91 | 92 | return f"{i:04d}_{slugify(job.name)}" 93 | 94 | def state(self, job: Job, i: int = None) -> str: 95 | if job in self.traces: 96 | return "FAILED" 97 | else: 98 | return "COMPLETED" 99 | 100 | def output(self, job: Job, i: int = None) -> Any: 101 | if job.array is None: 102 | return self.results[job] 103 | else: 104 | return self.results[job].get(i) 105 | 106 | def report(self, job: Job = None) -> str: 107 | if job is None: 108 | headers = ("Name", "State") 109 | rows = [(str(job), self.state(job)) for job in self.order] 110 | 111 | return tabulate(rows, headers, showindex=True) 112 | else: 113 | headers = ("Name", "State", "Output") 114 | array = [None] 115 | 116 | if job in self.traces: 117 | rows = [(str(job), self.state(job), self.traces[job])] 118 | elif job.array is None: 119 | rows = [(str(job), self.state(job), self.output(job))] 120 | else: 121 | array = sorted(job.array) 122 | rows = [ 123 | (f"{job.name}[{i}]", self.state(job, i), self.output(job, i)) for i in array 124 | ] 125 | 126 | rows = [ 127 | ( 128 | name, 129 | state, 130 | None if output is None else cat(output, width=120), 131 | ) 132 | for name, state, output in rows 133 | ] 134 | 135 | return tabulate(rows, headers, showindex=array) 136 | 137 | def cancel(self, job: Job = None) -> str: 138 | raise NotImplementedError(f"'cancel' is not implemented for the {self.backend} backend.") 139 | 140 | @contextmanager 141 | def context(self): 142 | try: 143 | yield None 144 | finally: 145 | pass 146 | 147 | def __call__(self, *jobs: Job, prune: bool = False): 148 | for cycle in cycles(*jobs, backward=True): 149 | raise CyclicDependencyGraphError(" <- ".join(map(str, cycle))) 150 | 151 | if prune: 152 | jobs = _prune(*jobs) 153 | 154 | with self.context(): 155 | asyncio.run(self.wait(*jobs)) 156 | 157 | async def wait(self, *jobs: Job): 158 | if jobs: 159 | await asyncio.wait(map(asyncio.create_task, map(self.submit, jobs))) 160 | await asyncio.wait(map(asyncio.create_task, map(self.submit, self.order))) 161 | 162 | async def submit(self, job: Job) -> Any: 163 | if job in self.results: 164 | result = self.results[job] 165 | else: 166 | result = self.results[job] = future(self._submit(job), return_exceptions=True) 167 | 168 | if isawaitable(result): 169 | result = self.results[job] = await result 170 | 171 | if isinstance(result, Exception): 172 | self.traces[job] = trace(result) 173 | 174 | return result 175 | 176 | async def _submit(self, job: Job) -> Any: 177 | try: 178 | if job.satisfiable: 179 | await self.satisfy(job) 180 | else: 181 | raise DependencyNeverSatisfiedError(str(job)) 182 | finally: 183 | self.tag(job) 184 | 185 | return await self.exec(job) 186 | 187 | @abstractmethod 188 | async def satisfy(self, job: Job): 189 | pass 190 | 191 | @abstractmethod 192 | async def exec(self, job: Job) -> Any: 193 | pass 194 | 195 | 196 | class AsyncScheduler(Scheduler): 197 | r"""Asynchronous scheduler. 198 | 199 | Jobs are executed asynchronously. A job is launched as soon as its dependencies are 200 | satisfied. 201 | """ 202 | 203 | backend: str = "async" 204 | 205 | def __init__(self, name: str = None, pools: int = None, **kwargs): 206 | r""" 207 | Arguments: 208 | name: The name of the workflow. 209 | pools: The number of processing pools. If `None`, use threads instead. 210 | kwargs: Keyword arguments passed to :class:`Scheduler`. 211 | """ 212 | 213 | super().__init__(name=name, **kwargs) 214 | 215 | self.pools = pools 216 | 217 | @contextmanager 218 | def context(self): 219 | if self.pools is None: 220 | self.executor = cf.ThreadPoolExecutor() 221 | else: 222 | self.executor = cf.ProcessPoolExecutor(self.pools) 223 | 224 | try: 225 | yield None 226 | finally: 227 | del self.executor 228 | 229 | async def satisfy(self, job: Job): 230 | pending = [ 231 | asyncio.gather(self.submit(dep), future(status)) 232 | for dep, status in job.dependencies.items() 233 | ] 234 | 235 | while pending: 236 | done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) 237 | 238 | for task in done: 239 | result, status = task.result() 240 | 241 | if isinstance(result, JobFailedError) and status != "success": 242 | result = None 243 | elif not isinstance(result, Exception) and status == "failure": 244 | result = JobNotFailedError(f"{job}") 245 | 246 | if isinstance(result, Exception): 247 | if job.waitfor == "all": 248 | raise DependencyNeverSatisfiedError(str(job)) from result 249 | elif job.waitfor == "any": 250 | break 251 | else: 252 | continue 253 | break 254 | else: 255 | if job.dependencies and job.waitfor == "any": 256 | raise DependencyNeverSatisfiedError(str(job)) 257 | 258 | async def exec(self, job: Job) -> Any: 259 | dump = pickle.dumps(job.run) 260 | call = partial(self.remote, runpickle, dump) 261 | 262 | try: 263 | if job.array is None: 264 | return await call() 265 | else: 266 | results = await asyncio.gather(*map(call, job.array), return_exceptions=True) 267 | 268 | for result in results: 269 | if isinstance(result, Exception): 270 | raise result 271 | 272 | return dict(zip(job.array, results)) 273 | except Exception as e: 274 | raise JobFailedError(str(job)) from e 275 | 276 | async def remote(self, f: Callable, /, *args) -> Any: 277 | return await asyncio.get_running_loop().run_in_executor(self.executor, f, *args) 278 | 279 | 280 | class DummyScheduler(AsyncScheduler): 281 | r"""Dummy asynchronous scheduler. 282 | 283 | Jobs are scheduled asynchronously, but instead of executing them, their name is 284 | printed before and after a short (random) sleep time. Useful for debugging. 285 | """ 286 | 287 | backend: str = "dummy" 288 | 289 | async def exec(self, job: Job): 290 | print(f"START {job}") 291 | await asyncio.sleep(random()) 292 | print(f"END {job}") 293 | 294 | return None if job.array is None else {} 295 | 296 | 297 | class SlurmScheduler(Scheduler): 298 | r"""Slurm scheduler. 299 | 300 | Jobs are submitted to the Slurm queue. Resources are allocated by the Slurm manager 301 | according to the job and scheduler settings. Job settings have precendence over 302 | scheduler settings. 303 | 304 | Most settings (e.g. `account`, `export`, `partition`) are passed directly to 305 | `sbatch`. A few settings (e.g. `cpus`, `gpus`, `ram`) are translated into their 306 | `sbatch` equivalents. 307 | """ 308 | 309 | backend: str = "slurm" 310 | translate: Dict[str, str] = { 311 | "cpus": "cpus-per-task", 312 | "gpus": "gpus-per-task", 313 | "ram": "mem", 314 | "memory": "mem", 315 | "timelimit": "time", 316 | } 317 | 318 | def __init__( 319 | self, 320 | name: str = None, 321 | shell: str = os.environ.get("SHELL", "/bin/sh"), 322 | interpreter: str = "python", 323 | env: Sequence[str] = [], # noqa: B006 324 | **kwargs, 325 | ): 326 | r""" 327 | Arguments: 328 | name: The name of the workflow. 329 | shell: The scripting shell. 330 | interpreter: The Python interpreter. 331 | env: A sequence of commands to execute before each job is launched. 332 | kwargs: Keyword arguments passed to :class:`Scheduler`. 333 | """ 334 | 335 | super().__init__(name=name, **kwargs) 336 | 337 | assert shutil.which("sbatch") is not None, "sbatch executable not found" 338 | 339 | # Environment 340 | self.shell = shell 341 | self.interpreter = interpreter 342 | self.env = env 343 | 344 | @lru_cache(None) # noqa: B019 345 | def sacct(self, jobid: str) -> Dict[str, str]: 346 | text = subprocess.run( 347 | ["sacct", "-j", jobid, "-o", "JobID,State", "-n", "-P", "-X"], 348 | capture_output=True, 349 | check=True, 350 | text=True, 351 | ).stdout 352 | 353 | return dict(line.split("|") for line in text.splitlines()) 354 | 355 | def state(self, job: Job, i: int = None) -> str: 356 | if job in self.traces: 357 | return "CANCELLED" 358 | 359 | jobid = self.results[job] 360 | table = self.sacct(jobid) 361 | 362 | if job.array is None: 363 | return table.get(jobid, None) 364 | elif i is None: 365 | if table: 366 | return ",".join(sorted(set(table.values()))) 367 | else: 368 | return None 369 | elif i in job.array: 370 | return table.get(f"{jobid}_{i}", None) 371 | else: 372 | return None 373 | 374 | def output(self, job: Job, i: int = None) -> str: 375 | tag = self.tag(job) 376 | 377 | if job.array is None: 378 | logfile = self.path / f"{tag}.log" 379 | else: 380 | logfile = self.path / f"{tag}_{i}.log" 381 | 382 | if logfile.exists(): 383 | with open(logfile, newline="", errors="replace") as f: 384 | return f.read() 385 | else: 386 | return None 387 | 388 | def report(self, job: Job = None) -> str: 389 | if job is None: 390 | headers = ("Name", "ID", "State") 391 | rows = [] 392 | 393 | for job in self.order: 394 | if job in self.traces: 395 | jobid = None 396 | else: 397 | jobid = self.results[job] 398 | 399 | rows.append((str(job), jobid, self.state(job))) 400 | 401 | return tabulate(rows, headers, showindex=True) 402 | else: 403 | return super().report(job) 404 | 405 | def cancel(self, job: Job = None) -> str: 406 | if job is None: 407 | jobids = list(self.results.values()) 408 | else: 409 | jobid = self.results[job] 410 | jobids = [jobid] 411 | 412 | return subprocess.run( 413 | ["scancel", "-v", *jobids], 414 | capture_output=True, 415 | check=True, 416 | text=True, 417 | ).stderr.strip("\n") 418 | 419 | async def satisfy(self, job: Job) -> str: 420 | results = await asyncio.gather(*map(self.submit, job.dependencies)) 421 | 422 | for result in results: 423 | if isinstance(result, Exception): 424 | raise DependencyNeverSatisfiedError(str(job)) from result 425 | 426 | async def exec(self, job: Job) -> Any: 427 | # Submission script 428 | lines = [ 429 | f"#!{self.shell}", 430 | "#", 431 | f'#SBATCH --job-name="{job.name}"', 432 | ] 433 | 434 | if job.array is not None: 435 | indices = comma_separated(job.array) 436 | 437 | if job.array_throttle is None: 438 | lines.append(f"#SBATCH --array={indices}") 439 | else: 440 | lines.append(f"#SBATCH --array={indices}%{job.array_throttle}") 441 | 442 | tag = self.tag(job) 443 | 444 | if job.array is None: 445 | logfile = self.path / f"{tag}.log" 446 | else: 447 | logfile = self.path / f"{tag}_%a.log" 448 | 449 | lines.append(f"#SBATCH --output={logfile}") 450 | 451 | ## Settings 452 | settings = self.settings.copy() 453 | settings.update(job.settings) 454 | 455 | assert "clusters" not in settings, "multi-cluster jobs not supported" 456 | 457 | for key in settings: 458 | assert not key.startswith("ntasks"), "multi-task jobs not supported" 459 | 460 | nodes = settings.pop("nodes", 1) 461 | 462 | lines.append("#") 463 | lines.append("#SBATCH --nodes=" + f"{nodes}") 464 | lines.append("#SBATCH --ntasks-per-node=1") 465 | 466 | for key, value in settings.items(): 467 | key = self.translate.get(key, key) 468 | 469 | if type(value) is bool: 470 | if value: 471 | lines.append(f"#SBATCH --{key}") 472 | else: 473 | lines.append(f"#SBATCH --{key}={value}") 474 | 475 | ## Dependencies 476 | sep = "?" if job.waitfor == "any" else "," 477 | types = { 478 | "success": "afterok", 479 | "failure": "afternotok", 480 | "any": "afterany", 481 | } 482 | 483 | deps = [ 484 | f"{types[status]}:{await self.submit(dep)}" for dep, status in job.dependencies.items() 485 | ] 486 | 487 | if deps: 488 | lines.append("#") 489 | lines.append("#SBATCH --dependency=" + sep.join(deps)) 490 | 491 | lines.append("") 492 | 493 | ## Environment 494 | if self.env: 495 | lines.extend([*self.env, ""]) 496 | 497 | ## Pickle job 498 | pklfile = self.path / f"{tag}.pkl" 499 | 500 | with open(pklfile, "wb") as f: 501 | pickle.dump(job.run, f) 502 | 503 | pyfile = self.path / f"{tag}.py" 504 | 505 | with open(pyfile, "w") as f: 506 | f.write( 507 | "\n".join([ 508 | "import argparse", 509 | "import pickle", 510 | "", 511 | "parser = argparse.ArgumentParser()", 512 | "parser.add_argument('-i', '--index', type=int, default=None)", 513 | "", 514 | "args = parser.parse_args()", 515 | "", 516 | "with open('{}', 'rb') as f:".format(pklfile), 517 | " if args.index is None:", 518 | " pickle.load(f)()", 519 | " else:", 520 | " pickle.load(f)(args.index)", 521 | "", 522 | ]) 523 | ) 524 | 525 | if job.interpreter is None: 526 | interpreter = self.interpreter 527 | else: 528 | interpreter = job.interpreter 529 | 530 | if job.array is None: 531 | lines.append(f"srun {interpreter} {pyfile}") 532 | else: 533 | lines.append(f"srun {interpreter} {pyfile} -i $SLURM_ARRAY_TASK_ID") 534 | 535 | lines.append("") 536 | 537 | ## Save 538 | shfile = self.path / f"{tag}.sh" 539 | 540 | with open(shfile, "w") as f: 541 | f.write("\n".join(lines)) 542 | 543 | # Submit script 544 | try: 545 | text = subprocess.run( 546 | ["sbatch", "--parsable", str(shfile)], 547 | capture_output=True, 548 | check=True, 549 | text=True, 550 | ).stdout 551 | 552 | jobid, *_ = text.strip("\n").split(";") # ignore cluster name 553 | 554 | return jobid 555 | except Exception as e: 556 | if isinstance(e, subprocess.CalledProcessError): 557 | e = subprocess.SubprocessError(e.stderr.strip("\n")) 558 | 559 | raise JobSubmissionError(str(job)) from e 560 | 561 | 562 | class CyclicDependencyGraphError(Exception): 563 | pass 564 | 565 | 566 | class DependencyNeverSatisfiedError(Exception): 567 | pass 568 | 569 | 570 | class JobFailedError(Exception): 571 | pass 572 | 573 | 574 | class JobNotFailedError(Exception): 575 | pass 576 | 577 | 578 | class JobSubmissionError(Exception): 579 | pass 580 | --------------------------------------------------------------------------------