├── test ├── __init__.py ├── test_train.py └── test.yaml ├── models └── decima │ └── model.pt ├── schedulers ├── decima │ ├── __init__.py │ ├── env_wrapper.py │ ├── utils.py │ └── scheduler.py ├── heuristics │ ├── __init__.py │ ├── random_scheduler.py │ ├── utils.py │ └── round_robin.py ├── __init__.py └── scheduler.py ├── spark_sched_sim ├── wrappers │ ├── __init__.py │ └── stochastic_time_limit.py ├── components │ ├── __init__.py │ ├── task.py │ ├── event.py │ ├── executor.py │ ├── stage.py │ ├── job.py │ ├── renderer.py │ └── executor_tracker.py ├── __init__.py ├── data_samplers │ ├── __init__.py │ ├── data_sampler.py │ └── tpch.py ├── metrics.py ├── utils.py └── spark_sched_sim.py ├── train.py ├── .gitignore ├── trainers ├── utils │ ├── __init__.py │ ├── hidden_prints.py │ ├── profiler.py │ ├── baselines.py │ └── returns_calculator.py ├── __init__.py ├── vpg.py ├── ppo.py ├── rollout_worker.py └── trainer.py ├── cfg_loader.py ├── LICENSE ├── requirements.txt ├── .github └── workflows │ └── python-app.yml ├── config └── decima_tpch.yaml ├── README.md └── examples.py /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/decima/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArchieGertsman/spark-sched-sim/HEAD/models/decima/model.pt -------------------------------------------------------------------------------- /schedulers/decima/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["DecimaScheduler"] 2 | 3 | from .scheduler import DecimaScheduler 4 | -------------------------------------------------------------------------------- /spark_sched_sim/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["StochasticTimeLimit"] 2 | 3 | from .stochastic_time_limit import StochasticTimeLimit 4 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from cfg_loader import load 2 | from trainers import make_trainer 3 | 4 | 5 | if __name__ == "__main__": 6 | cfg = load() 7 | make_trainer(cfg).train() 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .pytest_cache/ 3 | .ipynb_checkpoints/ 4 | *.egg-info/ 5 | .DS_Store 6 | 7 | artifacts/ 8 | data/ 9 | count_params.py 10 | eval* 11 | test/old_test.py -------------------------------------------------------------------------------- /test/test_train.py: -------------------------------------------------------------------------------- 1 | from cfg_loader import load 2 | from trainers import make_trainer 3 | 4 | 5 | def test_train(): 6 | cfg = load("test/test.yaml") 7 | make_trainer(cfg).train() 8 | -------------------------------------------------------------------------------- /schedulers/heuristics/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["RandomScheduler", "RoundRobinScheduler"] 2 | 3 | from .random_scheduler import RandomScheduler 4 | from .round_robin import RoundRobinScheduler 5 | -------------------------------------------------------------------------------- /spark_sched_sim/components/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Job", "Stage", "Task", "Executor"] 2 | 3 | from .job import Job 4 | from .stage import Stage 5 | from .task import Task 6 | from .executor import Executor 7 | -------------------------------------------------------------------------------- /spark_sched_sim/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["SparkSchedSimEnv"] 2 | 3 | from gymnasium.envs.registration import register 4 | from .spark_sched_sim import SparkSchedSimEnv 5 | 6 | register(id="SparkSchedSimEnv-v0", entry_point="spark_sched_sim:SparkSchedSimEnv") 7 | -------------------------------------------------------------------------------- /trainers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["HiddenPrints", "Profiler", "ReturnsCalculator", "Baseline"] 2 | 3 | from .hidden_prints import HiddenPrints 4 | from .profiler import Profiler 5 | from .returns_calculator import ReturnsCalculator 6 | from .baselines import Baseline 7 | -------------------------------------------------------------------------------- /trainers/utils/hidden_prints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | class HiddenPrints: 6 | def __enter__(self): 7 | self._original_stdout = sys.stdout 8 | sys.stdout = open(os.devnull, "w") 9 | 10 | def __exit__(self, exc_type, exc_val, exc_tb): 11 | sys.stdout.close() 12 | sys.stdout = self._original_stdout 13 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["VPG", "PPO", "make_trainer"] 2 | 3 | from .vpg import VPG 4 | from .ppo import PPO 5 | 6 | 7 | def make_trainer(cfg): 8 | glob = globals() 9 | trainer_cls = cfg["trainer"]["trainer_cls"] 10 | assert trainer_cls in glob, f"'{trainer_cls}' is not a valid trainer." 11 | return glob[trainer_cls]( 12 | agent_cfg=cfg["agent"], env_cfg=cfg["env"], train_cfg=cfg["trainer"] 13 | ) 14 | -------------------------------------------------------------------------------- /spark_sched_sim/data_samplers/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["DataSampler", "TPCHDataSampler", "make_data_sampler"] 2 | 3 | from copy import deepcopy 4 | 5 | from .data_sampler import DataSampler 6 | from .tpch import TPCHDataSampler 7 | 8 | 9 | def make_data_sampler(data_sampler_cfg): 10 | glob = globals() 11 | data_sampler_cls = data_sampler_cfg["data_sampler_cls"] 12 | assert ( 13 | data_sampler_cls in glob 14 | ), f"'{data_sampler_cls}' is not a valid data sampler." 15 | return glob[data_sampler_cls](**deepcopy(data_sampler_cfg)) 16 | -------------------------------------------------------------------------------- /schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "Scheduler", 3 | "TrainableScheduler", 4 | "DecimaScheduler", 5 | "RandomScheduler", 6 | "RoundRobinScheduler", 7 | "make_scheduler", 8 | ] 9 | 10 | from copy import deepcopy 11 | 12 | from .scheduler import Scheduler, TrainableScheduler 13 | from .decima import DecimaScheduler 14 | from .heuristics import RandomScheduler, RoundRobinScheduler 15 | 16 | 17 | def make_scheduler(agent_cfg): 18 | glob = globals() 19 | agent_cls = agent_cfg["agent_cls"] 20 | assert agent_cls in glob, f"'{agent_cls}' is not a valid scheduler." 21 | return glob[agent_cls](**deepcopy(agent_cfg)) 22 | -------------------------------------------------------------------------------- /spark_sched_sim/data_samplers/data_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections.abc import Iterable 3 | 4 | import numpy as np 5 | 6 | from spark_sched_sim.components import Job, Stage, Task, Executor 7 | 8 | 9 | class DataSampler(ABC): 10 | np_random: np.random.Generator | None 11 | 12 | def reset(self, np_random: np.random.Generator): 13 | self.np_random = np_random 14 | 15 | @abstractmethod 16 | def job_sequence(self, max_time: float) -> Iterable[tuple[float, Job]]: 17 | pass 18 | 19 | @abstractmethod 20 | def task_duration( 21 | self, job: Job, stage: Stage, task: Task, executor: Executor 22 | ) -> float: 23 | pass 24 | -------------------------------------------------------------------------------- /spark_sched_sim/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def job_durations(env): 5 | durations = [] 6 | for job_id in env.unwrapped.active_job_ids + list(env.unwrapped.completed_job_ids): 7 | job = env.unwrapped.jobs[job_id] 8 | t_end = min(job.t_completed, env.unwrapped.wall_time) 9 | durations += [t_end - job.t_arrival] 10 | return durations 11 | 12 | 13 | def avg_job_duration(env): 14 | return np.mean(job_durations(env)) 15 | 16 | 17 | def avg_num_jobs(env): 18 | return sum(job_durations(env)) / env.unwrapped.wall_time 19 | 20 | 21 | def job_duration_percentiles(env): 22 | jd = job_durations(env) 23 | return np.percentile(jd, [25, 50, 75, 100]) 24 | -------------------------------------------------------------------------------- /spark_sched_sim/components/task.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import numpy as np 4 | 5 | 6 | @dataclass 7 | class Task: 8 | id_: int 9 | stage_id: int 10 | job_id: int 11 | executor_id: int | None = None 12 | t_accepted: float = np.inf 13 | t_completed: float = np.inf 14 | 15 | @property 16 | def __unique_id(self) -> tuple[int, int, int]: 17 | return (self.job_id, self.stage_id, self.id_) 18 | 19 | def __hash__(self) -> int: 20 | return hash(self.__unique_id) 21 | 22 | def __eq__(self, other) -> bool: 23 | if type(other) is type(self): 24 | return self.__unique_id == other.__unique_id 25 | else: 26 | return False 27 | -------------------------------------------------------------------------------- /cfg_loader.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 3 | 4 | 5 | def load(filename=None): 6 | if not filename: 7 | args = make_parser().parse_args() 8 | filename = args.filename 9 | 10 | with open(filename, "r") as stream: 11 | cfg = yaml.safe_load(stream) 12 | 13 | return cfg 14 | 15 | 16 | def make_parser(): 17 | parser = ArgumentParser( 18 | description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter 19 | ) 20 | 21 | parser.add_argument( 22 | "-f", 23 | "--file", 24 | dest="filename", 25 | help="experiment definition file", 26 | metavar="FILE", 27 | required=True, 28 | ) 29 | 30 | return parser 31 | -------------------------------------------------------------------------------- /spark_sched_sim/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import ndarray 3 | 4 | 5 | def subgraph(edge_links: ndarray, node_mask: ndarray): 6 | """ 7 | Minimal numpy version of PyG's subgraph utility function 8 | Args: 9 | edge_links: array of edges of shape (num_edges, 2), 10 | following the convention of gymnasium Graph space 11 | node_mask: indicates which nodes should be used for 12 | inducing the subgraph 13 | """ 14 | edge_mask = node_mask[edge_links[:, 0]] & node_mask[edge_links[:, 1]] 15 | edge_links = edge_links[edge_mask] 16 | 17 | # relabel the nodes 18 | node_idx = np.zeros(node_mask.size, dtype=int) 19 | node_idx[node_mask] = np.arange(node_mask.sum()) 20 | edge_links = node_idx[edge_links] 21 | 22 | return edge_links 23 | -------------------------------------------------------------------------------- /trainers/utils/profiler.py: -------------------------------------------------------------------------------- 1 | import cProfile 2 | import pstats 3 | import io 4 | from pstats import SortKey 5 | 6 | 7 | class Profiler: 8 | """context manager which profiles a block of code, 9 | then prints out the function calls sorted by cumulative 10 | execution time 11 | """ 12 | 13 | def __init__(self, amount=20): 14 | self.pr = cProfile.Profile() 15 | self.amount = amount 16 | 17 | def __enter__( 18 | self, 19 | ): 20 | self.pr.enable() 21 | return self 22 | 23 | def __exit__(self, exc_type, exc_val, exc_tb): 24 | self.pr.disable() 25 | 26 | # log stats 27 | s = io.StringIO() 28 | ps = pstats.Stats(self.pr, stream=s).sort_stats(SortKey.CUMULATIVE) 29 | ps.print_stats(self.amount) 30 | print(s.getvalue(), flush=True) 31 | -------------------------------------------------------------------------------- /test/test.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | trainer_cls: 'PPO' 3 | num_iterations: 1 4 | num_sequences: 1 5 | num_rollouts: 2 6 | seed: 42 7 | artifacts_dir: 'test/artifacts' 8 | checkpointing_freq: 50 9 | use_tensorboard: False 10 | num_epochs: 3 11 | num_batches: 10 12 | clip_range: .2 13 | target_kl: .01 14 | entropy_coeff: .04 15 | beta_discount: 5.e-3 16 | opt_cls: 'Adam' 17 | opt_kwargs: 18 | lr: 3.e-4 19 | max_grad_norm: .5 20 | 21 | agent: 22 | agent_cls: 'DecimaScheduler' 23 | embed_dim: 16 24 | gnn_mlp_kwargs: 25 | hid_dims: [32, 16] 26 | act_cls: 'LeakyReLU' 27 | act_kwargs: 28 | inplace: True 29 | negative_slope: .2 30 | policy_mlp_kwargs: 31 | hid_dims: [64, 64] 32 | act_cls: 'Tanh' 33 | 34 | env: 35 | num_executors: 50 36 | job_arrival_cap: 10 37 | moving_delay: 2000. 38 | mean_time_limit: 2.e+7 39 | job_arrival_rate: 4.e-5 40 | warmup_delay: 1000. 41 | data_sampler_cls: 'TPCHDataSampler' -------------------------------------------------------------------------------- /schedulers/heuristics/random_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ..scheduler import Scheduler 4 | from .utils import preprocess_obs, find_stage 5 | 6 | 7 | class RandomScheduler(Scheduler): 8 | def __init__(self, seed=42): 9 | self.name = "Random" 10 | self.env_wrapper_cls = None 11 | self.set_seed(seed) 12 | 13 | def set_seed(self, seed): 14 | self.np_random = np.random.RandomState(seed) 15 | 16 | def schedule(self, obs: dict) -> tuple[dict, dict]: 17 | preprocess_obs(obs) 18 | num_active_jobs = len(obs["exec_supplies"]) 19 | 20 | job_idxs = list(range(num_active_jobs)) 21 | stage_idx = -1 22 | while len(job_idxs) > 0: 23 | j = self.np_random.choice(job_idxs) 24 | stage_idx = find_stage(obs, j) 25 | if stage_idx != -1: 26 | break 27 | else: 28 | job_idxs.remove(j) 29 | 30 | num_exec = self.np_random.randint(1, obs["num_committable_execs"] + 1) 31 | 32 | return {"stage_idx": stage_idx, "num_exec": num_exec}, {} 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Arkadiy Gertsman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | cachetools==5.3.1 3 | certifi==2023.7.22 4 | charset-normalizer==3.2.0 5 | cloudpickle==2.2.1 6 | Farama-Notifications==0.0.4 7 | filelock==3.12.4 8 | google-auth==2.23.0 9 | google-auth-oauthlib==1.0.0 10 | grpcio==1.58.0 11 | gymnasium==0.29.1 12 | idna==3.4 13 | importlib-metadata==6.8.0 14 | install==1.3.5 15 | Jinja2==3.1.2 16 | joblib==1.3.2 17 | Markdown==3.4.4 18 | MarkupSafe==2.1.3 19 | mpmath==1.3.0 20 | networkx==3.1 21 | numpy==1.26.0 22 | oauthlib==3.2.2 23 | Pillow==10.0.1 24 | protobuf==4.24.3 25 | psutil==5.9.5 26 | pyasn1==0.5.0 27 | pyasn1-modules==0.3.0 28 | pygame==2.5.2 29 | pyparsing==3.1.1 30 | PyYAML==6.0.1 31 | requests==2.31.0 32 | requests-oauthlib==1.3.1 33 | rsa==4.9 34 | scikit-learn==1.3.0 35 | scipy==1.11.2 36 | sympy==1.12 37 | tensorboard==2.14.0 38 | tensorboard-data-server==0.7.1 39 | threadpoolctl==3.2.0 40 | torch==2.0.1 41 | torch-geometric==2.3.1 42 | -f https://download.pytorch.org/whl/cpu/torch_stable.html 43 | -f https://data.pyg.org/whl/torch-2.0.1+cpu.html 44 | torch-scatter==2.1.1 45 | torch-sparse==0.6.17 46 | torchaudio==2.0.2 47 | torchvision==0.15.2 48 | tqdm==4.66.1 49 | typing_extensions==4.8.0 50 | urllib3==1.26.16 51 | Werkzeug==2.3.7 52 | zipp==3.17.0 53 | -------------------------------------------------------------------------------- /schedulers/heuristics/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import numpy as np 3 | 4 | 5 | def preprocess_obs(obs: dict[str, Any]) -> None: 6 | frontier_mask = np.ones(obs["dag_batch"].nodes.shape[0], dtype=bool) 7 | dst_nodes = obs["dag_batch"].edge_links[:, 1] 8 | frontier_mask[dst_nodes] = False 9 | stage_mask = obs["dag_batch"].nodes[:, 2].astype(bool) 10 | 11 | obs["frontier_stages"] = set(frontier_mask.nonzero()[0]) 12 | obs["schedulable_stages"] = dict( 13 | zip(stage_mask.nonzero()[0], np.arange(stage_mask.sum())) 14 | ) 15 | 16 | 17 | def find_stage(obs: dict[str, Any], job_idx: int) -> int: 18 | """searches for a schedulable stage in a given job, prioritizing 19 | frontier stages 20 | """ 21 | stage_idx_start = obs["dag_ptr"][job_idx] 22 | stage_idx_end = obs["dag_ptr"][job_idx + 1] 23 | 24 | selected_stage_idx = -1 25 | for node in range(stage_idx_start, stage_idx_end): 26 | if node in obs["schedulable_stages"]: 27 | i = obs["schedulable_stages"][node] 28 | else: 29 | continue 30 | 31 | if node in obs["frontier_stages"]: 32 | return i 33 | 34 | if selected_stage_idx == -1: 35 | selected_stage_idx = i 36 | 37 | return selected_stage_idx 38 | -------------------------------------------------------------------------------- /spark_sched_sim/components/event.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | import itertools 3 | from enum import Enum, auto 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class Event: 9 | class Type(Enum): 10 | JOB_ARRIVAL = auto() 11 | TASK_FINISHED = auto() 12 | EXECUTOR_READY = auto() 13 | 14 | type: Type 15 | 16 | data: dict 17 | 18 | 19 | class EventQueue: 20 | def __init__(self) -> None: 21 | # priority queue 22 | self._pq: list[tuple[float, int, Event]] = [] 23 | 24 | # tie breaker 25 | self._counter = itertools.count() 26 | 27 | def reset(self) -> None: 28 | self._pq.clear() 29 | self._counter = itertools.count() 30 | 31 | def __bool__(self) -> bool: 32 | return bool(self._pq) 33 | 34 | def push(self, t: float, event: Event) -> None: 35 | heapq.heappush(self._pq, (t, next(self._counter), event)) 36 | 37 | def top(self) -> tuple[float, Event] | None: 38 | if not self._pq: 39 | return None 40 | 41 | t, _, event = self._pq[0] 42 | return t, event 43 | 44 | def pop(self) -> tuple[float, Event] | None: 45 | if not self._pq: 46 | return None 47 | 48 | t, _, event = heapq.heappop(self._pq) 49 | return t, event 50 | -------------------------------------------------------------------------------- /spark_sched_sim/wrappers/stochastic_time_limit.py: -------------------------------------------------------------------------------- 1 | from gymnasium import Wrapper 2 | import numpy as np 3 | 4 | 5 | class StochasticTimeLimit(Wrapper): 6 | """Samples each episode's time limit from an exponential distribution""" 7 | 8 | def __init__(self, env, mean_time_limit, seed=42): 9 | super().__init__(env) 10 | self.mean_time_limit = mean_time_limit 11 | self.np_random = np.random.RandomState(seed) 12 | 13 | def reset(self, seed=None, options=None): 14 | """samples a new time limit prior to resetting""" 15 | if seed: 16 | self.np_random = np.random.RandomState(seed) 17 | self.time_limit = self.np_random.exponential(self.mean_time_limit) 18 | print( 19 | f"resetting. seed={seed}, timelim={int(self.time_limit*1e-3)}s", flush=True 20 | ) 21 | if not options: 22 | options = {} 23 | options["time_limit"] = self.time_limit 24 | return self.env.reset(seed=seed, options=options) 25 | 26 | def step(self, act): 27 | """modifies `truncated` signal when time limit is reached""" 28 | obs, rew, term, trunc, info = self.env.step(act) 29 | if info["wall_time"] >= self.time_limit: 30 | trunc = True 31 | return obs, rew, term, trunc, info 32 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python 3.10 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: "3.10" 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install flake8 pytest 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Lint with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --ignore=E203,W503 37 | - name: Test with pytest 38 | run: | 39 | pytest 40 | -------------------------------------------------------------------------------- /schedulers/scheduler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections.abc import Iterable 3 | from gymnasium import Wrapper 4 | from torch import Tensor 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class Scheduler(ABC): 11 | """Interface for all schedulers""" 12 | 13 | name: str 14 | env_wrapper_cls: type[Wrapper] | None 15 | 16 | @abstractmethod 17 | def schedule(self, obs: dict) -> tuple[dict, dict]: 18 | pass 19 | 20 | 21 | class TrainableScheduler(Scheduler, nn.Module): 22 | """Interface for all trainable schedulers""" 23 | 24 | optim: torch.optim.Optimizer | None 25 | max_grad_norm: float | None 26 | 27 | @abstractmethod 28 | def evaluate_actions( 29 | self, obsns: Iterable[dict], actions: Iterable[tuple] 30 | ) -> dict[str, Tensor]: 31 | pass 32 | 33 | @property 34 | def device(self) -> torch.device: 35 | return next(self.parameters()).device 36 | 37 | def update_parameters(self, loss: Tensor | None = None) -> None: 38 | assert self.optim 39 | 40 | if loss: 41 | # accumulate gradients 42 | loss.backward() 43 | 44 | if self.max_grad_norm: 45 | # clip grads 46 | torch.nn.utils.clip_grad_norm_( 47 | self.parameters(), self.max_grad_norm, error_if_nonfinite=True 48 | ) 49 | 50 | # update model parameters 51 | self.optim.step() 52 | 53 | # clear accumulated gradients 54 | self.optim.zero_grad() 55 | -------------------------------------------------------------------------------- /spark_sched_sim/components/executor.py: -------------------------------------------------------------------------------- 1 | from .task import Task 2 | 3 | 4 | class Executor: 5 | def __init__(self, id_: int) -> None: 6 | # index of this operation within its operation 7 | self.id_ = id_ 8 | 9 | # task that this executor is or just finished executing, 10 | # or `None` if the executor is idle 11 | self.task: Task | None = None 12 | 13 | # id of current job that this executor is local to, if any 14 | self.job_id: int | None = None 15 | 16 | # whether or not this executing is executing a task. 17 | # NOTE: can be `False` while `self.task is not None`, 18 | # if the executor just finished executing 19 | self.is_executing = False 20 | 21 | # list of pairs [t, job_id], where `t` is the wall time that this executor 22 | # was released from job with id `job_id`, or `None` if it has not been released 23 | # yet. `job_id` is -1 if the executor is at the general pool. 24 | # NOTE: only used for rendering 25 | self.history: list[list] = [[None, -1]] 26 | 27 | @property 28 | def is_idle(self) -> bool: 29 | return self.task is None 30 | 31 | def is_at_job(self, job_id: int) -> bool: 32 | return self.job_id == job_id 33 | 34 | def add_history(self, wall_time: float, job_id: int) -> None: 35 | """should be called whenever this executor is released from a job""" 36 | if self.history is None: 37 | self.history = [] 38 | 39 | if len(self.history) > 0: 40 | # add release time to most recent history 41 | self.history[-1][0] = wall_time 42 | 43 | # add new history 44 | self.history += [[None, job_id]] 45 | -------------------------------------------------------------------------------- /trainers/vpg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.profiler 4 | 5 | from .trainer import Trainer 6 | 7 | 8 | EPS = 1e-8 9 | 10 | 11 | class VPG(Trainer): 12 | """Vanilla Policy Gradient""" 13 | 14 | def __init__(self, agent_cfg, env_cfg, train_cfg): 15 | super().__init__(agent_cfg, env_cfg, train_cfg) 16 | 17 | self.entropy_coeff = train_cfg.get("entropy_coeff", 0.0) 18 | 19 | def train_on_rollouts(self, rollout_buffers): 20 | data = self._preprocess_rollouts(rollout_buffers) 21 | 22 | policy_losses = [] 23 | entropy_losses = [] 24 | 25 | for obsns, actions, returns, baselines, old_lgprobs in zip(data.values()): 26 | eval_res = self.scheduler.evaluate_actions(obsns, actions) 27 | 28 | # re-computed log-probs don't exactly match the original ones, 29 | # but it doesn't seem to affect training 30 | # with torch.no_grad(): 31 | # diff = (lgprobs - torch.tensor(old_lgprobs)).abs() 32 | # assert lgprobs.allclose(torch.tensor(old_lgprobs)) 33 | 34 | adv = torch.from_numpy(returns - baselines).float() 35 | adv = (adv - adv.mean()) / (adv.std() + EPS) 36 | policy_loss = -(eval_res["lgprobs"] * adv).mean() 37 | policy_losses += [policy_loss.item()] 38 | 39 | entropy_loss = -eval_res["entropies"].mean() 40 | entropy_losses += [entropy_loss.item()] 41 | 42 | loss = policy_loss + self.entropy_coeff * entropy_loss 43 | loss.backward() 44 | 45 | self.scheduler.update_parameters() 46 | 47 | return { 48 | "policy loss": np.mean(policy_losses), 49 | "entropy": np.mean(entropy_losses), 50 | } 51 | -------------------------------------------------------------------------------- /trainers/utils/baselines.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Baseline: 5 | def __init__(self, num_sequences, num_rollouts): 6 | self.num_sequences = num_sequences 7 | self.num_rollouts = num_rollouts 8 | 9 | def __call__(self, ts_list, ys_list): 10 | return self.average(ts_list, ys_list) 11 | 12 | def average(self, ts_list, ys_list): 13 | baseline_list = [] 14 | for j in range(self.num_sequences): 15 | start = j * self.num_rollouts 16 | end = start + self.num_rollouts 17 | baseline_list += self._average(ts_list[start:end], ys_list[start:end]) 18 | return baseline_list 19 | 20 | def _average(self, ts_list, ys_list): 21 | ts_unique = np.unique(np.hstack(ts_list)) 22 | 23 | # shape: (num envs, len(ts_unique)) 24 | # y_hats[i, t] is the linear interpolation of (ts_list[i], ys_list[i]) 25 | # at time t 26 | y_hats = np.vstack( 27 | [np.interp(ts_unique, ts, ys) for ts, ys in zip(ts_list, ys_list)] 28 | ) 29 | 30 | # find baseline at each unique time point 31 | baseline = {} 32 | for t, y_hat in zip(ts_unique, y_hats.T): 33 | baseline[t] = y_hat.mean() 34 | 35 | baseline_list = [np.array([baseline[t] for t in ts]) for ts in ts_list] 36 | 37 | return baseline_list 38 | 39 | 40 | # def pairwise_average(ts_list, ys_list): 41 | # num_envs = len(ts_list) 42 | 43 | # baseline_list = [None] * num_envs 44 | 45 | # for i, j in zip(np.arange(num_envs//2), 46 | # np.arange(num_envs//2, num_envs)): 47 | # baseline_list[i] = \ 48 | # .5 * (ys_list[i] + np.interp(ts_list[i], ts_list[j], ys_list[j])) 49 | 50 | # baseline_list[j] = \ 51 | # .5 * (ys_list[j] + np.interp(ts_list[j], ts_list[i], ys_list[i])) 52 | 53 | # return baseline_list 54 | -------------------------------------------------------------------------------- /schedulers/heuristics/round_robin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ..scheduler import Scheduler 4 | from .utils import preprocess_obs, find_stage 5 | 6 | 7 | class RoundRobinScheduler(Scheduler): 8 | def __init__(self, num_executors, dynamic_partition=True): 9 | self.name = "Fair" if dynamic_partition else "FIFO" 10 | self.num_executors = num_executors 11 | self.dynamic_partition = dynamic_partition 12 | self.env_wrapper_cls = None 13 | 14 | def schedule(self, obs: dict) -> tuple[dict, dict]: 15 | preprocess_obs(obs) 16 | num_active_jobs = len(obs["exec_supplies"]) 17 | 18 | if self.dynamic_partition: 19 | executor_cap = self.num_executors / max(1, num_active_jobs) 20 | executor_cap = int(np.ceil(executor_cap)) 21 | else: 22 | executor_cap = self.num_executors 23 | 24 | # first, try to find a stage in the same job that is releasing executers 25 | if obs["source_job_idx"] < num_active_jobs: 26 | selected_stage_idx = find_stage(obs, obs["source_job_idx"]) 27 | 28 | if selected_stage_idx != -1: 29 | return { 30 | "stage_idx": selected_stage_idx, 31 | "num_exec": obs["num_committable_execs"], 32 | }, {} 33 | 34 | # search through jobs by order of arrival. 35 | for j in range(num_active_jobs): 36 | if obs["exec_supplies"][j] >= executor_cap or j == obs["source_job_idx"]: 37 | continue 38 | 39 | selected_stage_idx = find_stage(obs, j) 40 | if selected_stage_idx == -1: 41 | continue 42 | 43 | num_exec = min( 44 | obs["num_committable_execs"], executor_cap - obs["exec_supplies"][j] 45 | ) 46 | return {"stage_idx": selected_stage_idx, "num_exec": num_exec}, {} 47 | 48 | # didn't find any stages to schedule 49 | return {"stage_idx": -1, "num_exec": obs["num_committable_execs"]}, {} 50 | -------------------------------------------------------------------------------- /spark_sched_sim/components/stage.py: -------------------------------------------------------------------------------- 1 | from .task import Task 2 | 3 | 4 | class Stage: 5 | def __init__( 6 | self, id: int, job_id: int, num_tasks: int, rough_task_duration: float 7 | ) -> None: 8 | self.id_ = id 9 | self.job_id = job_id 10 | self.most_recent_duration = rough_task_duration 11 | self.num_tasks = num_tasks 12 | self.remaining_tasks = [ 13 | Task(id_=i, stage_id=self.id_, job_id=self.job_id) for i in range(num_tasks) 14 | ] 15 | self.num_remaining_tasks = num_tasks 16 | self.num_executing_tasks = 0 17 | self.num_completed_tasks = 0 18 | self.is_schedulable = False 19 | 20 | def __hash__(self) -> int: 21 | return hash(self.pool_key) 22 | 23 | def __eq__(self, other) -> bool: 24 | if type(other) is type(self): 25 | return self.pool_key == other.pool_key 26 | else: 27 | return False 28 | 29 | @property 30 | def pool_key(self) -> tuple[int, int]: 31 | return (self.job_id, self.id_) 32 | 33 | @property 34 | def job_pool_key(self) -> tuple[int, None]: 35 | return (self.job_id, None) 36 | 37 | @property 38 | def completed(self) -> bool: 39 | return self.num_completed_tasks == self.num_tasks 40 | 41 | @property 42 | def num_saturated_tasks(self) -> int: 43 | return self.num_executing_tasks + self.num_completed_tasks 44 | 45 | @property 46 | def next_task_id(self) -> int: 47 | return self.num_saturated_tasks 48 | 49 | @property 50 | def approx_remaining_work(self) -> float: 51 | return self.most_recent_duration * self.num_remaining_tasks 52 | 53 | def launch_next_task(self) -> Task: 54 | assert self.num_saturated_tasks < self.num_tasks 55 | task = self.remaining_tasks.pop() 56 | self.num_remaining_tasks -= 1 57 | self.num_executing_tasks += 1 58 | return task 59 | 60 | def record_task_completion(self) -> None: 61 | self.num_executing_tasks -= 1 62 | self.num_completed_tasks += 1 63 | -------------------------------------------------------------------------------- /config/decima_tpch.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | # where the training happens (cpu, cuda, cuda:0, ...) 3 | # note: rollouts are always collected using only the CPU 4 | device: 'cuda' 5 | 6 | # name of the trainer's class 7 | trainer_cls: 'PPO' 8 | 9 | # number of training iterations 10 | num_iterations: 500 11 | 12 | # number of unique job sequences sampled per training iteration 13 | num_sequences: 4 14 | 15 | # number of rollouts experienced per unique job sequence 16 | # `num_sequences` x `num_rollouts` 17 | # = total number of rollouts per training iteration 18 | # = number of rollout workers running in parallel 19 | num_rollouts: 4 20 | 21 | # base random seed; each worker gets its own seed which is offset from this. 22 | seed: 42 23 | 24 | # name of directory where all training artifacts are saved (e.g. tensorboard) 25 | artifacts_dir: 'artifacts' 26 | 27 | # if checkpointing_freq = n, then every n iterations, the best model from the 28 | # past m iterations is saved 29 | checkpointing_freq: 50 30 | 31 | # if true, then records training metrics to a tensorboard file 32 | use_tensorboard: False 33 | 34 | # PPO: number of times to train through all of the data from the most recent 35 | # iteration 36 | num_epochs: 3 37 | 38 | # PPO: number of batches to split the last iteration's training data into 39 | num_batches: 10 40 | 41 | # PPO: hyperparameter for clamping the importance sampling ratio 42 | clip_range: .2 43 | 44 | # PPO: end training cycle if approximate KL divergence exceeds `target_kl` 45 | target_kl: .01 46 | 47 | # PPO: coefficient of entropy bonus term (if 0 then no entropy bonus) 48 | entropy_coeff: .04 49 | 50 | # discount factor for (continuously) discounted returns 51 | beta_discount: 5.e-3 52 | 53 | # max reward window size for differential returns 54 | # reward_buff_cap: 200000 55 | 56 | # note: only one of `beta_discount` and `reward_buff_cap` must be specified, 57 | # indicating whether to use discounted or differential returns 58 | 59 | # optimizer settings 60 | opt_cls: 'Adam' 61 | opt_kwargs: 62 | lr: 3.e-4 63 | max_grad_norm: .5 64 | 65 | 66 | agent: 67 | agent_cls: 'DecimaScheduler' 68 | embed_dim: 16 69 | gnn_mlp_kwargs: 70 | hid_dims: [32, 16] 71 | act_cls: 'LeakyReLU' 72 | act_kwargs: 73 | inplace: True 74 | negative_slope: .2 75 | policy_mlp_kwargs: 76 | hid_dims: [64, 64] 77 | act_cls: 'Tanh' 78 | 79 | 80 | env: 81 | num_executors: 50 82 | job_arrival_cap: 200 83 | job_arrival_rate: 4.e-5 84 | moving_delay: 2000. 85 | warmup_delay: 1000. 86 | dataset: 'tpch' 87 | mean_time_limit: 2.e+7 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spark-sched-sim 2 | 3 | An Apache Spark job scheduling simulator, implemented as a [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) environment. 4 | 5 | | ![](https://i.imgur.com/6BpPWxI.png)| 6 | |:--:| 7 | | *Two Gantt charts comparing the behavior of different job scheduling algorithms. In these experiments, 50 jobs are identified by unique colors and processed in parallel by 10 identical executors (stacked vertically). Decima achieves better resource packing and lower average job completion time than Spark's fair scheduler.* | 8 | 9 | _What is job scheduling in Spark?_ 10 | - A Spark _application_ is a long-running program within the cluster that submits _jobs_ to be processed by its share of the cluster's resources. Each job encodes a directed acyclic graph (DAG) of _stages_ that depend on each other, where a dependency $A\to B$ means that stage $A$ must finish executing before stage $B$ can begin. Each stage consists of many identical _tasks_ which are units of work that operate over different shards of data. Tasks are processed by _executors_, which are JVM's running on the cluster's _worker_ nodes. 11 | - Scheduling jobs means designating which tasks runs on which executors at each time. 12 | - For more backround on Spark, see [this article](https://spark.apache.org/docs/latest/job-scheduling.html). 13 | 14 | _Why this simulator?_ 15 | - Job scheduling is important, because a smarter scheduling algorithm can result in faster job turnaround time. 16 | - This simulator allows researchers to test scheduling heuristics and train neural schedulers using reinforcement learning. 17 | 18 | --- 19 | 20 | This repository is a PyTorch Geometric implementaion of the [Decima codebase](https://github.com/hongzimao/decima-sim), adhering to the Gymnasium interface. It also includes enhancements to the reinforcement learning algorithm and model design, along with a basic PyGame renderer that generates the above charts in real time. 21 | 22 | Enhancements include: 23 | - Continuously discounted returns, improving training speed 24 | - Proximal Polixy Optimization (PPO), improving training speed and stability 25 | - A restricted action space, encouraging a fairer policy to be learned 26 | - Multiple different job sequences experienced per training iteration, reducing variance in the policy gradient (PG) estimate 27 | - No learning curriculum, improving training speed 28 | 29 | --- 30 | 31 | After cloning this repo, please run `pip install -r requirements.txt` to install the project's dependencies. 32 | 33 | To start out, try running examples via `examples.py --sched [fair|decima]`. To train Decima from scratch, modify the provided config file `config/decima_tpch.yaml` as needed, then provide the config to `train.py -f CFG_FILE`. -------------------------------------------------------------------------------- /examples.py: -------------------------------------------------------------------------------- 1 | """Examples of how to run job scheduling simulations with different schedulers 2 | """ 3 | import os.path as osp 4 | from pprint import pprint 5 | 6 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 7 | import gymnasium as gym 8 | import pathlib 9 | 10 | from cfg_loader import load 11 | from schedulers import RoundRobinScheduler, make_scheduler 12 | from spark_sched_sim import metrics 13 | 14 | 15 | ENV_CFG = { 16 | "num_executors": 10, 17 | "job_arrival_cap": 50, 18 | "job_arrival_rate": 4.0e-5, 19 | "moving_delay": 2000.0, 20 | "warmup_delay": 1000.0, 21 | "data_sampler_cls": "TPCHDataSampler", 22 | "render_mode": "human", 23 | } 24 | 25 | 26 | def main(): 27 | # save final rendering to artifacts dir 28 | pathlib.Path("artifacts").mkdir(parents=True, exist_ok=True) 29 | 30 | parser = ArgumentParser( 31 | description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter 32 | ) 33 | 34 | parser.add_argument( 35 | "--sched", 36 | choices=["fair", "decima"], 37 | dest="sched", 38 | help="which scheduler to run", 39 | required=True, 40 | ) 41 | 42 | args = parser.parse_args() 43 | 44 | sched_map = {"fair": fair_example, "decima": decima_example} 45 | 46 | sched_map[args.sched]() 47 | 48 | 49 | def fair_example(): 50 | # Fair scheduler 51 | scheduler = RoundRobinScheduler(ENV_CFG["num_executors"], dynamic_partition=True) 52 | 53 | print("Example: Fair Scheduler") 54 | print("Env settings:") 55 | pprint(ENV_CFG) 56 | 57 | print("Running episode...") 58 | avg_job_duration = run_episode(ENV_CFG, scheduler) 59 | 60 | print(f"Done! Average job duration: {avg_job_duration:.1f}s", flush=True) 61 | print() 62 | 63 | 64 | def decima_example(): 65 | cfg = load(filename=osp.join("config", "decima_tpch.yaml")) 66 | 67 | agent_cfg = cfg["agent"] | { 68 | "num_executors": ENV_CFG["num_executors"], 69 | "state_dict_path": osp.join("models", "decima", "model.pt"), 70 | } 71 | 72 | scheduler = make_scheduler(agent_cfg) 73 | 74 | print("Example: Decima") 75 | print("Env settings:") 76 | pprint(ENV_CFG) 77 | 78 | print("Running episode...") 79 | avg_job_duration = run_episode(ENV_CFG, scheduler) 80 | 81 | print(f"Done! Average job duration: {avg_job_duration:.1f}s", flush=True) 82 | 83 | 84 | def run_episode(env_cfg, scheduler, seed=1234): 85 | env = gym.make("spark_sched_sim:SparkSchedSimEnv-v0", env_cfg=env_cfg) 86 | 87 | if scheduler.env_wrapper_cls: 88 | env = scheduler.env_wrapper_cls(env) 89 | 90 | obs, _ = env.reset(seed=seed, options=None) 91 | terminated = truncated = False 92 | 93 | while not (terminated or truncated): 94 | action, _ = scheduler.schedule(obs) 95 | obs, _, terminated, truncated, _ = env.step(action) 96 | 97 | avg_job_duration = metrics.avg_job_duration(env) * 1e-3 98 | 99 | # cleanup rendering 100 | env.close() 101 | 102 | return avg_job_duration 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /trainers/utils/returns_calculator.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | import numpy as np 4 | 5 | 6 | class CircularArray: 7 | def __init__(self, cap, num_cols): 8 | self.cap = cap 9 | self.data = np.zeros((cap, num_cols)) 10 | 11 | def extend(self, new_data): 12 | num_new = new_data.shape[0] 13 | if num_new > self.cap: 14 | new_data = new_data[-self.cap :] 15 | num_new = self.cap 16 | 17 | num_keep = self.cap - num_new 18 | if num_keep > 0: 19 | self.data[:num_keep] = self.data[-num_keep:] 20 | 21 | self.data[num_keep:] = new_data 22 | 23 | 24 | class ReturnsCalculator: 25 | def __init__(self, buff_cap=None, beta=None): 26 | assert bool(buff_cap) ^ bool( 27 | beta 28 | ), "exactly one of `buff_cap` and `beta` must be specified" 29 | 30 | self.buff_cap = buff_cap 31 | self.beta = beta 32 | 33 | # estimate of the long-run average number of concurrent jobs under 34 | # the current policy 35 | self.avg_num_jobs = None 36 | 37 | if buff_cap: 38 | # circular buffer used for computing the moving average. each row 39 | # corresponds to a time-step; the first column is the duration of 40 | # that step in ms, and the second column is the reward from that 41 | # step 42 | self.buff = CircularArray(buff_cap, num_cols=2) 43 | 44 | def __call__(self, rewards_list, times_list, resets_list): 45 | dt_list = [np.array(ts[1:]) - np.array(ts[:-1]) for ts in times_list] 46 | 47 | if self.beta: 48 | return self._calc_discounted_returns(dt_list, rewards_list) 49 | else: 50 | return self._calc_differential_returns(dt_list, rewards_list) 51 | 52 | def _calc_differential_returns(self, dt_list, rewards_list): 53 | self._update_avg_num_jobs(dt_list, rewards_list) 54 | 55 | diff_returns_list = [] 56 | for dts, rs in zip(dt_list, rewards_list): 57 | diff_returns = np.zeros(len(rs)) 58 | R = 0 59 | for k, (dt, r) in reversed(list(enumerate(zip(dts, rs)))): 60 | job_time = -r 61 | expected_job_time = dt * self.avg_num_jobs 62 | R = -(job_time - expected_job_time) + R 63 | diff_returns[k] = R 64 | diff_returns_list += [diff_returns] 65 | return diff_returns_list 66 | 67 | def _calc_discounted_returns(self, dt_list, rewards_list): 68 | disc_returns_list = [] 69 | for dts, rs in zip(dt_list, rewards_list): 70 | disc_returns = np.zeros(len(rs)) 71 | R = 0 72 | for k, (dt, r) in reversed(list(enumerate(zip(dts, rs)))): 73 | R = r + np.exp(-self.beta * 1e-3 * dt) * R 74 | disc_returns[k] = R 75 | disc_returns_list += [disc_returns] 76 | return disc_returns_list 77 | 78 | def _update_avg_num_jobs(self, deltas_list, rewards_list): 79 | new_data = np.array(list(zip(chain(*deltas_list), chain(*rewards_list)))) 80 | 81 | # filter out timesteps that have a duration of 0ms 82 | new_data = new_data[new_data[:, 0] > 0] 83 | 84 | # add new data to circular buffer, discarding some of the old data 85 | self.buff.extend(new_data) 86 | 87 | total_time, rew_sum = self.buff.data.sum(0) 88 | total_job_time = -rew_sum 89 | self.avg_num_jobs = total_job_time / total_time 90 | -------------------------------------------------------------------------------- /spark_sched_sim/components/job.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Generator 2 | import numpy as np 3 | import networkx as nx 4 | 5 | from .stage import Stage 6 | from .executor import Executor 7 | 8 | 9 | class Job: 10 | """An object representing a job in the system, containing a set of stages with dependencies stored in a dag.""" 11 | 12 | def __init__( 13 | self, id_: int, stages: list[Stage], dag: nx.DiGraph, t_arrival: float 14 | ) -> None: 15 | # unique identifier of this job 16 | self.id_ = id_ 17 | 18 | # list of objects of all the stages that belong to this job 19 | self.stages = stages 20 | 21 | # all incomplete stages 22 | self.active_stages = stages.copy() 23 | 24 | # incomplete stages whose parents have completed 25 | self.frontier_stages: set[Stage] = set() 26 | 27 | # networkx dag storing the stage dependencies 28 | self.dag = dag 29 | 30 | # time that this job arrived into the system 31 | self.t_arrival = t_arrival 32 | 33 | # time that this job completed, i.e. when the last 34 | # stage completed 35 | self.t_completed = np.inf 36 | 37 | # set of executors that are local to this job 38 | self.local_executors: set[int] = set() 39 | 40 | # count of stages who have no remaining tasks 41 | self.saturated_stage_count = 0 42 | 43 | self._init_frontier() 44 | 45 | @property 46 | def pool_key(self) -> tuple[int, None]: 47 | return (self.id_, None) 48 | 49 | @property 50 | def completed(self) -> bool: 51 | return not self.num_active_stages 52 | 53 | @property 54 | def saturated(self) -> bool: 55 | return self.saturated_stage_count == len(self.stages) 56 | 57 | @property 58 | def num_stages(self) -> int: 59 | return len(self.stages) 60 | 61 | @property 62 | def num_active_stages(self) -> int: 63 | return len(self.active_stages) 64 | 65 | def record_stage_completion(self, stage: Stage) -> bool: 66 | """increments the count of completed stages""" 67 | self.active_stages.remove(stage) 68 | self.frontier_stages.remove(stage) 69 | 70 | new_stages = self._find_new_frontier_stages(stage) 71 | self.frontier_stages |= new_stages 72 | 73 | return bool(new_stages) 74 | 75 | def get_children_stages(self, stage: Stage) -> Generator[Stage, None, None]: 76 | return (self.stages[stage_id] for stage_id in self.dag.successors(stage.id_)) 77 | 78 | def get_parent_stages(self, stage: Stage) -> Generator[Stage, None, None]: 79 | return (self.stages[stage_id] for stage_id in self.dag.predecessors(stage.id_)) 80 | 81 | def attach_executor(self, executor: Executor) -> None: 82 | assert executor.task is None 83 | self.local_executors.add(executor.id_) 84 | executor.job_id = self.id_ 85 | 86 | def detach_executor(self, executor: Executor) -> None: 87 | self.local_executors.remove(executor.id_) 88 | executor.job_id = None 89 | executor.task = None 90 | 91 | # internal methods 92 | 93 | def _init_frontier(self) -> None: 94 | """returns a set containing all the stages which are 95 | source nodes in the dag, i.e. which have no dependencies 96 | """ 97 | assert not self.frontier_stages 98 | self.frontier_stages |= self._get_source_stages() 99 | 100 | def _check_dependencies(self, stage_id: int) -> bool: 101 | """searches to see if all the dependencies of stage with id `stage_id` are satisfied.""" 102 | for dep_id in self.dag.predecessors(stage_id): 103 | if not self.stages[dep_id].completed: 104 | return False 105 | 106 | return True 107 | 108 | def _get_source_stages(self) -> set[Stage]: 109 | return set( 110 | self.stages[node] for node, in_deg in self.dag.in_degree() if in_deg == 0 111 | ) 112 | 113 | def _find_new_frontier_stages(self, stage: Stage) -> set[Stage]: 114 | """if ` stage` is completed, returns all of its successors whose other dependencies are also 115 | completed, if any exist. 116 | """ 117 | if not stage.completed: 118 | return set() 119 | 120 | new_stages = set() 121 | # search through stage's children 122 | for suc_stage_id in self.dag.successors(stage.id_): 123 | # if all dependencies are satisfied, then add this child to the frontier 124 | new_stage = self.stages[suc_stage_id] 125 | if not new_stage.completed and self._check_dependencies(suc_stage_id): 126 | new_stages.add(new_stage) 127 | 128 | return new_stages 129 | -------------------------------------------------------------------------------- /trainers/ppo.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | from itertools import chain 3 | from typing import SupportsFloat 4 | from torch import Tensor 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | 10 | from .trainer import Trainer 11 | 12 | 13 | EPS = 1e-8 14 | 15 | 16 | class RolloutDataset(Dataset): 17 | def __init__(self, obsns, acts, advgs, lgprobs): 18 | self.obsns = obsns 19 | self.acts = acts 20 | self.advgs = advgs 21 | self.lgprobs = lgprobs 22 | 23 | def __len__(self): 24 | return len(self.obsns) 25 | 26 | def __getitem__(self, idx): 27 | return self.obsns[idx], self.acts[idx], self.advgs[idx], self.lgprobs[idx] 28 | 29 | 30 | # def collate_fn(batch): 31 | # obsns, acts, advgs, lgprobs = zip(*batch) 32 | # obsns = collate_obsns(obsns) 33 | # acts = torch.stack(acts) 34 | # advgs = torch.stack(advgs) 35 | # lgprobs = torch.stack(lgprobs) 36 | # return obsns, acts, advgs, lgprobs 37 | 38 | 39 | class PPO(Trainer): 40 | """Proximal Policy Optimization""" 41 | 42 | def __init__(self, agent_cfg, env_cfg, train_cfg): 43 | super().__init__(agent_cfg, env_cfg, train_cfg) 44 | 45 | self.entropy_coeff = train_cfg.get("entropy_coeff", 0.0) 46 | self.clip_range = train_cfg.get("clip_range", 0.2) 47 | self.target_kl = train_cfg.get("target_kl", 0.01) 48 | self.num_epochs = train_cfg.get("num_epochs", 10) 49 | self.num_batches = train_cfg.get("num_batches", 3) 50 | 51 | def train_on_rollouts(self, rollout_buffers): 52 | data = self._preprocess_rollouts(rollout_buffers) 53 | 54 | returns = np.array(list(chain(*data["returns_list"]))) 55 | baselines = np.concatenate(data["baselines_list"]) 56 | 57 | dataset = RolloutDataset( 58 | obsns=list(chain(*data["obsns_list"])), 59 | acts=list(chain(*data["actions_list"])), 60 | advgs=returns - baselines, 61 | lgprobs=list(chain(*data["lgprobs_list"])), 62 | ) 63 | 64 | dataloader = DataLoader( 65 | dataset, 66 | batch_size=len(dataset) // self.num_batches + 1, 67 | shuffle=True, 68 | collate_fn=lambda batch: zip(*batch), 69 | ) 70 | 71 | return self._train(dataloader) 72 | 73 | def _train(self, dataloader): 74 | policy_losses = [] 75 | entropy_losses = [] 76 | approx_kl_divs = [] 77 | continue_training = True 78 | 79 | for _ in range(self.num_epochs): 80 | if not continue_training: 81 | break 82 | 83 | for obsns, actions, advgs, old_lgprobs in dataloader: 84 | loss, info = self._compute_loss(obsns, actions, advgs, old_lgprobs) 85 | 86 | kl = info["approx_kl_div"] 87 | 88 | policy_losses += [info["policy_loss"]] 89 | entropy_losses += [info["entropy_loss"]] 90 | approx_kl_divs.append(kl) 91 | 92 | if self.target_kl is not None and kl > 1.5 * self.target_kl: 93 | print(f"Early stopping due to reaching max kl: " f"{kl:.3f}") 94 | continue_training = False 95 | break 96 | 97 | self.scheduler.update_parameters(loss) 98 | 99 | return { 100 | "policy loss": np.abs(np.mean(policy_losses)), 101 | "entropy": np.abs(np.mean(entropy_losses)), 102 | "approx kl div": np.abs(np.mean(approx_kl_divs)), 103 | } 104 | 105 | def _compute_loss( 106 | self, 107 | obsns: Iterable[dict], 108 | acts: Iterable[tuple], 109 | advantages: Iterable[SupportsFloat], 110 | old_lgprobs: Iterable[SupportsFloat], 111 | ) -> tuple[Tensor, dict[str, SupportsFloat]]: 112 | """CLIP loss""" 113 | eval_res = self.scheduler.evaluate_actions(obsns, acts) 114 | 115 | advgs = torch.tensor(advantages).float() 116 | advgs = (advgs - advgs.mean()) / (advgs.std() + EPS) 117 | 118 | log_ratio = eval_res["lgprobs"] - torch.tensor(old_lgprobs) 119 | ratio = log_ratio.exp() 120 | 121 | policy_loss1 = advgs * ratio 122 | policy_loss2 = advgs * torch.clamp( 123 | ratio, 1 - self.clip_range, 1 + self.clip_range 124 | ) 125 | policy_loss = -torch.min(policy_loss1, policy_loss2).mean() 126 | 127 | entropy_loss = -eval_res["entropies"].mean() 128 | 129 | loss = policy_loss + self.entropy_coeff * entropy_loss 130 | 131 | with torch.no_grad(): 132 | approx_kl_div = ((ratio - 1) - log_ratio).mean().item() 133 | 134 | return loss, { 135 | "policy_loss": policy_loss.item(), 136 | "entropy_loss": entropy_loss.item(), 137 | "approx_kl_div": approx_kl_div, 138 | } 139 | -------------------------------------------------------------------------------- /spark_sched_sim/components/renderer.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import numpy as np 3 | 4 | 5 | class Renderer: 6 | """renders frames that visualize the job scheduling simulation in 7 | real time. A gantt chart is displayed, with the traces of all the 8 | workers are stacked vertically. Job completions are indicated by 9 | red markers, and info about the simulation is displayed in text. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | num_workers: int, 15 | num_total_jobs: int | None = None, 16 | window_width: int = 400, 17 | window_height: int = 300, 18 | font_name: str = "couriernew", 19 | font_size: int = 16, 20 | render_fps: int = 30, 21 | ): 22 | self.num_workers = num_workers 23 | self.num_total_jobs = num_total_jobs 24 | self.window_width = window_width 25 | self.window_height = window_height 26 | self.font_name = font_name 27 | self.font_size = font_size 28 | self.render_fps = render_fps 29 | 30 | self.WORKER_RECT_H = np.ceil(self.window_height / self.num_workers) 31 | self.clock = pygame.time.Clock() 32 | 33 | @property 34 | def window(self) -> pygame.Surface: 35 | """lazy-init the window""" 36 | if not getattr(self, "_window", None): 37 | pygame.init() 38 | pygame.display.init() 39 | self._window = pygame.display.set_mode( 40 | (self.window_width, self.window_height) 41 | ) 42 | self.font = pygame.font.SysFont(self.font_name, self.font_size) 43 | return self._window 44 | 45 | def render_frame( 46 | self, 47 | worker_histories, 48 | job_completion_times, 49 | wall_time: float, 50 | avg_job_duration: float, 51 | num_active_jobs: int, 52 | num_jobs_completed: int, 53 | ) -> None: 54 | assert self.num_total_jobs 55 | 56 | # draw canvas 57 | canvas = pygame.Surface((self.window_width, self.window_height)) 58 | canvas.fill((255, 255, 255)) 59 | self._draw_worker_histories(canvas, worker_histories, wall_time) 60 | self._draw_job_completion_markers(canvas, job_completion_times, wall_time) 61 | self.window.blit(canvas, canvas.get_rect()) 62 | 63 | # draw text 64 | text_surfaces = self._make_text_surfaces( 65 | wall_time, avg_job_duration, num_active_jobs, num_jobs_completed 66 | ) 67 | for text_surface, pos in text_surfaces: 68 | self.window.blit(text_surface, pos) 69 | 70 | pygame.event.pump() 71 | pygame.display.update() 72 | 73 | self.clock.tick(self.render_fps) 74 | 75 | def close(self) -> None: 76 | if self.window is not None: 77 | pygame.image.save(self.window, "screenshot.png") 78 | pygame.display.quit() 79 | pygame.quit() 80 | 81 | # internal methods 82 | 83 | def _draw_worker_histories(self, canvas, worker_histories, wall_time): 84 | for i, history in enumerate(worker_histories): 85 | y_rect = i * self.WORKER_RECT_H 86 | x_rect = 0 87 | for i in range(len(history)): 88 | t, job_id = history[i] 89 | if i > 0: 90 | t_prev = history[i - 1][0] 91 | assert t_prev is not None 92 | else: 93 | t_prev = 0 94 | 95 | if t is None: 96 | t = wall_time 97 | 98 | width_ratio = (t - t_prev) / wall_time 99 | w_rect = np.ceil(self.window_width * width_ratio) 100 | 101 | if job_id == -1: 102 | color = (0, 0, 0) 103 | else: 104 | color1 = np.array((0, 100, 255)) 105 | color2 = np.array((2, 247, 112)) 106 | p = (job_id + 1) / self.num_total_jobs 107 | color = color1 + p * (color2 - color1) 108 | 109 | pygame.draw.rect( 110 | canvas, 111 | color, 112 | pygame.Rect( 113 | (x_rect, y_rect), 114 | (w_rect, self.WORKER_RECT_H), 115 | ), 116 | ) 117 | 118 | x_rect += w_rect 119 | 120 | def _draw_job_completion_markers(self, canvas, job_completion_times, wall_time): 121 | for t in job_completion_times: 122 | x = self.window_width * (t / wall_time) 123 | 124 | pygame.draw.rect( 125 | canvas, (255, 0, 0), pygame.Rect((x, 0), (1, self.window_height)) 126 | ) 127 | 128 | def _make_text_surfaces( 129 | self, wall_time, avg_job_duration, num_active_jobs, num_jobs_completed, dy=20 130 | ): 131 | wall_time = int(wall_time * 1e-3) 132 | 133 | surfs = [ 134 | self.font.render(f"Wall time: {wall_time}s", False, (255,) * 3), 135 | self.font.render( 136 | f"Avg job duration: {avg_job_duration}s", False, (255,) * 3 137 | ), 138 | self.font.render(f"Num active jobs: {num_active_jobs}", False, (255,) * 3), 139 | self.font.render( 140 | f"Num jobs completed: {num_jobs_completed}", False, (255,) * 3 141 | ), 142 | ] 143 | 144 | return [(surf, (0, dy * i)) for i, surf in enumerate(surfs)] 145 | -------------------------------------------------------------------------------- /schedulers/decima/env_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from numpy import ndarray 3 | from gymnasium import Wrapper, ActionWrapper, ObservationWrapper 4 | import numpy as np 5 | import gymnasium.spaces as sp 6 | 7 | from . import utils 8 | 9 | NUM_NODE_FEATURES = 5 10 | 11 | 12 | class DecimaEnvWrapper(Wrapper): 13 | def __init__(self, env): 14 | env = DecimaActWrapper(env) 15 | env = DecimaObsWrapper(env) 16 | super().__init__(env) 17 | 18 | 19 | class DecimaActWrapper(ActionWrapper): 20 | """converts Decima's actions to the environment's format""" 21 | 22 | def __init__(self, env) -> None: 23 | super().__init__(env) 24 | 25 | self.action_space = sp.Dict( 26 | { 27 | "stage_idx": sp.Discrete(1), 28 | "job_idx": sp.Discrete(1), 29 | "num_exec": sp.Discrete(env.unwrapped.num_executors), 30 | } 31 | ) 32 | 33 | def action(self, act: dict[str, Any]) -> dict[str, Any]: 34 | return {"stage_idx": act["stage_idx"], "num_exec": 1 + act["num_exec"]} 35 | 36 | 37 | class DecimaObsWrapper(ObservationWrapper): 38 | """transforms environment observations into a format that's more suitable for Decima""" 39 | 40 | def __init__( 41 | self, env, num_tasks_scale: int = 200, work_scale: float = 1e5 42 | ) -> None: 43 | super().__init__(env) 44 | 45 | self.num_tasks_scale = num_tasks_scale 46 | self.work_scale = work_scale 47 | self.num_executors = env.unwrapped.num_executors 48 | 49 | # cache edge masks, because dag batch doesn't always change between observations 50 | self._cache: dict[str, Any] = { 51 | "num_nodes": -1, 52 | "edge_links": None, 53 | "edge_masks": None, 54 | } 55 | 56 | self.observation_space = sp.Dict( 57 | { 58 | "dag_batch": sp.Graph( 59 | node_space=sp.Box(-np.inf, np.inf, (NUM_NODE_FEATURES,)), 60 | edge_space=sp.Discrete(1), 61 | ), 62 | "dag_ptr": sp.Sequence(sp.Discrete(1)), 63 | "stage_mask": sp.Sequence(sp.Discrete(2)), 64 | "exec_mask": sp.Sequence(sp.MultiBinary(self.num_executors)), 65 | "edge_masks": sp.MultiBinary((1, 1)), 66 | } 67 | ) 68 | 69 | def observation(self, obs: dict[str, Any]) -> dict[str, Any]: 70 | dag_batch = obs["dag_batch"] 71 | 72 | exec_supplies = np.array(obs["exec_supplies"]) 73 | num_committable_execs = obs["num_committable_execs"] 74 | gap = np.maximum(self.num_executors - exec_supplies, 0) 75 | 76 | # cap on number of execs that can be committed to each job 77 | commit_caps = np.minimum(gap, num_committable_execs) 78 | 79 | j_src = obs["source_job_idx"] 80 | num_jobs = exec_supplies.size 81 | if j_src < num_jobs: 82 | commit_caps[j_src] = num_committable_execs 83 | 84 | graph_instance = sp.GraphInstance( 85 | nodes=self._build_node_features(obs, commit_caps), 86 | edges=dag_batch.edges, 87 | edge_links=dag_batch.edge_links, 88 | ) 89 | 90 | stage_mask = dag_batch.nodes[:, 2].astype(bool) 91 | 92 | exec_mask = np.zeros((num_jobs, self.num_executors), dtype=bool) 93 | for j, cap in enumerate(commit_caps): 94 | exec_mask[j, :cap] = True 95 | 96 | self._validate_cache(obs) 97 | 98 | obs = { 99 | "dag_batch": graph_instance, 100 | "dag_ptr": obs["dag_ptr"], 101 | "stage_mask": stage_mask, 102 | "exec_mask": exec_mask, 103 | "edge_masks": self._cache["edge_masks"], 104 | } 105 | 106 | self.observation_space["dag_ptr"].feature_space.n = dag_batch.nodes.shape[0] + 1 107 | self.observation_space["edge_masks"].n = obs["edge_masks"].shape 108 | return obs 109 | 110 | def _build_node_features( 111 | self, obs: dict[str, Any], commit_caps: ndarray 112 | ) -> ndarray: 113 | dag_batch = obs["dag_batch"] 114 | num_nodes = dag_batch.nodes.shape[0] 115 | ptr = np.array(obs["dag_ptr"]) 116 | node_counts = ptr[1:] - ptr[:-1] 117 | exec_supplies = obs["exec_supplies"] 118 | num_active_jobs = len(exec_supplies) 119 | source_job_idx = obs["source_job_idx"] 120 | 121 | nodes = np.zeros((num_nodes, NUM_NODE_FEATURES), dtype=np.float32) 122 | 123 | # how many exec can be added to each node 124 | nodes[:, 0] = np.repeat(commit_caps, node_counts) / self.num_executors 125 | 126 | # whether or not a node belongs to the source job 127 | nodes[:, 1] = -1 128 | if source_job_idx < num_active_jobs: 129 | i = source_job_idx 130 | nodes[ptr[i] : ptr[i + 1], 1] = 1 131 | 132 | # current supply of executors for each node's job 133 | nodes[:, 2] = np.repeat(exec_supplies, node_counts) / self.num_executors 134 | 135 | # number of remaining tasks in each node 136 | num_remaining_tasks = dag_batch.nodes[:, 0] 137 | nodes[:, 3] = num_remaining_tasks / self.num_tasks_scale 138 | 139 | # approximate remaining work in each node 140 | most_recent_duration = dag_batch.nodes[:, 1] 141 | nodes[:, 4] = num_remaining_tasks * most_recent_duration / self.work_scale 142 | 143 | return nodes 144 | 145 | def _validate_cache(self, obs: dict[str, Any]) -> None: 146 | dag_batch = obs["dag_batch"] 147 | num_nodes = dag_batch.nodes.shape[0] 148 | 149 | if ( 150 | self._cache["edge_links"] is None 151 | or num_nodes != self._cache["num_nodes"] 152 | or not np.array_equal(dag_batch.edge_links, self._cache["edge_links"]) 153 | ): 154 | # dag batch has changed, so synchronize the cache 155 | self._cache = { 156 | "num_nodes": num_nodes, 157 | "edge_links": dag_batch.edge_links, 158 | "edge_masks": utils.make_dag_layer_edge_masks( 159 | (dag_batch.edge_links, num_nodes) 160 | ), 161 | } 162 | -------------------------------------------------------------------------------- /trainers/rollout_worker.py: -------------------------------------------------------------------------------- 1 | from typing import Any, SupportsFloat 2 | from multiprocessing.synchronize import Lock 3 | from multiprocessing.connection import Connection 4 | import sys 5 | from abc import ABC, abstractmethod 6 | import os.path as osp 7 | import random 8 | 9 | import gymnasium as gym 10 | import torch 11 | 12 | from spark_sched_sim.wrappers import StochasticTimeLimit 13 | from schedulers import make_scheduler 14 | from .utils import Profiler # , HiddenPrints 15 | from spark_sched_sim.metrics import avg_num_jobs 16 | 17 | 18 | class RolloutBuffer: 19 | def __init__(self, async_rollouts: bool = False) -> None: 20 | self.obsns: list[dict] = [] 21 | self.wall_times: list[float] = [] 22 | self.actions: list[tuple] = [] 23 | self.lgprobs: list[float] = [] 24 | self.rewards: list[SupportsFloat] = [] 25 | self.resets: set[int] | None = set() if async_rollouts else None 26 | 27 | def add( 28 | self, 29 | obs: dict, 30 | wall_time: float, 31 | action: tuple, 32 | lgprob: float, 33 | reward: SupportsFloat, 34 | ) -> None: 35 | self.obsns += [obs] 36 | self.wall_times += [wall_time] 37 | self.actions += [action] 38 | self.rewards += [reward] 39 | self.lgprobs += [lgprob] 40 | 41 | def add_reset(self, step: int) -> None: 42 | assert self.resets is not None, "resets are for async rollouts only." 43 | self.resets.add(step) 44 | 45 | def __len__(self) -> int: 46 | return len(self.obsns) 47 | 48 | 49 | class RolloutWorker(ABC): 50 | def __init__(self) -> None: 51 | self.reset_count = 0 52 | 53 | def __call__( 54 | self, 55 | rank: int, 56 | conn: Connection, 57 | env_cfg: dict[str, Any], 58 | scheduler_kwargs: dict[str, Any], 59 | stdout_dir: str, 60 | base_seed: int, 61 | seed_step: int, 62 | lock: Lock, 63 | ) -> None: 64 | self.rank = rank 65 | self.conn = conn 66 | self.base_seed = base_seed 67 | self.seed_step = seed_step 68 | self.reset_count = 0 69 | 70 | # log each of the processes to separate files 71 | sys.stdout = open(osp.join(stdout_dir, f"{rank}.out"), "a") 72 | 73 | self.scheduler = make_scheduler(scheduler_kwargs) 74 | self.scheduler.eval() 75 | 76 | # might need to download dataset, and only one process should do this. 77 | # this can be achieved using a lock, such that the first process to 78 | # acquire it downloads the dataset, and any subsequent processes notices 79 | # that the dataset is already present once it acquires the lock. 80 | with lock: 81 | env = gym.make("spark_sched_sim:SparkSchedSimEnv-v0", env_cfg=env_cfg) 82 | 83 | env = StochasticTimeLimit(env, env_cfg["mean_time_limit"]) 84 | env = self.scheduler.env_wrapper_cls(env) 85 | self.env = env 86 | 87 | # IMPORTANT! Each worker needs to produce unique rollouts, which are 88 | # determined by the rng seed 89 | torch.manual_seed(rank) 90 | random.seed(rank) 91 | 92 | # torch multiprocessing is very slow without this 93 | torch.set_num_threads(1) 94 | 95 | self.run() 96 | 97 | def run(self) -> None: 98 | while data := self.conn.recv(): 99 | # load updated model parameters 100 | self.scheduler.load_state_dict(data["state_dict"]) 101 | 102 | try: 103 | with Profiler(100): # , HiddenPrints(): 104 | rollout_buffer = self.collect_rollout() 105 | 106 | self.conn.send( 107 | {"rollout_buffer": rollout_buffer, "stats": self.collect_stats()} 108 | ) 109 | 110 | except Exception as e: 111 | print(repr(e), "\nAborting rollout.", flush=True) 112 | self.conn.send(e) 113 | 114 | @abstractmethod 115 | def collect_rollout(self) -> RolloutBuffer: 116 | pass 117 | 118 | @property 119 | def seed(self) -> int: 120 | return self.base_seed + self.seed_step * self.reset_count 121 | 122 | def collect_stats(self) -> dict[str, Any]: 123 | return { 124 | "avg_job_duration": self.env.unwrapped.avg_job_duration, 125 | "avg_num_jobs": avg_num_jobs(self.env), 126 | "num_completed_jobs": self.env.unwrapped.num_completed_jobs, 127 | "num_job_arrivals": self.env.unwrapped.num_completed_jobs 128 | + self.env.unwrapped.num_active_jobs, 129 | } 130 | 131 | 132 | class RolloutWorkerSync(RolloutWorker): 133 | """model updates are synchronized with environment resets""" 134 | 135 | def collect_rollout(self) -> RolloutBuffer: 136 | rollout_buffer = RolloutBuffer() 137 | 138 | obs, _ = self.env.reset(seed=self.seed) 139 | self.reset_count += 1 140 | 141 | wall_time = 0 142 | terminated = truncated = False 143 | while not (terminated or truncated): 144 | action, info = self.scheduler.schedule(obs) 145 | lgprob = info["lgprob"] 146 | 147 | new_obs, reward, terminated, truncated, info = self.env.step(action) 148 | next_wall_time = info["wall_time"] 149 | 150 | rollout_buffer.add(obs, wall_time, tuple(action.values()), lgprob, reward) 151 | 152 | obs = new_obs 153 | wall_time = next_wall_time 154 | 155 | rollout_buffer.wall_times += [wall_time] 156 | 157 | return rollout_buffer 158 | 159 | 160 | class RolloutWorkerAsync(RolloutWorker): 161 | """model updates occur at regular intervals, regardless of when the 162 | environment resets 163 | """ 164 | 165 | def __init__(self, rollout_duration: float) -> None: 166 | super().__init__() 167 | self.rollout_duration = rollout_duration 168 | self.next_obs = None 169 | self.next_wall_time = 0.0 170 | 171 | def collect_rollout(self) -> RolloutBuffer: 172 | rollout_buffer = RolloutBuffer(async_rollouts=True) 173 | 174 | if self.reset_count == 0: 175 | self.next_obs, _ = self.env.reset(seed=self.seed) 176 | self.reset_count += 1 177 | 178 | elapsed_time = 0 179 | step = 0 180 | while elapsed_time < self.rollout_duration: 181 | obs, wall_time = self.next_obs, self.next_wall_time 182 | 183 | action, info = self.scheduler.schedule(obs) 184 | lgprob = info["lgprob"] 185 | 186 | self.next_obs, reward, terminated, truncated, info = self.env.step(action) 187 | 188 | self.next_wall_time = info["wall_time"] 189 | 190 | assert obs 191 | rollout_buffer.add(obs, elapsed_time, list(action.values()), lgprob, reward) 192 | 193 | # add the duration of the this step to the total 194 | elapsed_time += self.next_wall_time - wall_time 195 | 196 | if terminated or truncated: 197 | self.next_obs, _ = self.env.reset(seed=self.seed) 198 | self.reset_count += 1 199 | self.next_wall_time = 0 200 | rollout_buffer.add_reset(step) 201 | 202 | step += 1 203 | 204 | rollout_buffer.wall_times += [elapsed_time] 205 | 206 | return rollout_buffer 207 | -------------------------------------------------------------------------------- /schedulers/decima/utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | from typing import Any 3 | from torch import Tensor 4 | from numpy import ndarray 5 | 6 | import random 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch_geometric as pyg 12 | from torch_sparse import SparseTensor 13 | from torch_scatter import segment_csr 14 | import networkx as nx 15 | from torch.distributions.utils import clamp_probs 16 | import numpy as np 17 | 18 | 19 | def sample(logits: Tensor) -> tuple[int, float]: 20 | pi = F.softmax(logits, 0).numpy() 21 | idx = random.choices(np.arange(pi.size), pi)[0] 22 | lgprob = np.log(pi[idx]) 23 | return idx, lgprob 24 | 25 | 26 | def evaluate( 27 | scores: Tensor, counts: Tensor, selections: Tensor 28 | ) -> tuple[Tensor, Tensor]: 29 | """ 30 | scores: scores that the model assigned to each action at each step 31 | counts: count of available actions at each step 32 | selections: actions that the scheduler sampled at each step 33 | """ 34 | ptr = counts.cumsum(0) 35 | ptr = torch.cat([torch.tensor([0]), ptr], 0) 36 | selections += ptr[:-1] 37 | probs = pyg.utils.softmax(scores, ptr=ptr) 38 | probs = clamp_probs(probs) 39 | log_probs = probs.log() 40 | selection_log_probs = log_probs[selections] 41 | entropies = -segment_csr(log_probs * probs, ptr) 42 | return selection_log_probs, entropies 43 | 44 | 45 | def make_mlp( 46 | input_dim: int, 47 | hid_dims: list[int], 48 | output_dim: int, 49 | act_cls: str, 50 | act_kwargs: dict[str, Any] | None = None, 51 | ) -> nn.Module: 52 | act_clss = getattr(torch.nn.modules.activation, act_cls) 53 | 54 | mlp = nn.Sequential() 55 | prev_dim = input_dim 56 | hid_dims = hid_dims + [output_dim] 57 | for i, dim in enumerate(hid_dims): 58 | mlp.append(nn.Linear(prev_dim, dim)) 59 | if i == len(hid_dims) - 1: 60 | break 61 | act_fn = act_clss(**(act_kwargs or {})) 62 | mlp.append(act_fn) 63 | prev_dim = dim 64 | return mlp 65 | 66 | 67 | def make_adj(edge_index: Tensor, num_nodes: int) -> SparseTensor: 68 | """returns a sparse COO adjacency matrix""" 69 | return SparseTensor( 70 | row=edge_index[0], 71 | col=edge_index[1], 72 | sparse_sizes=(num_nodes, num_nodes), 73 | is_sorted=True, 74 | trust_data=True, 75 | ) 76 | 77 | 78 | def ptr_to_counts(ptr): 79 | return ptr[1:] - ptr[:-1] 80 | 81 | 82 | def counts_to_ptr(x: Tensor) -> Tensor: 83 | ptr = x.cumsum(0) 84 | ptr = torch.cat([torch.tensor([0]), ptr], 0) 85 | return ptr 86 | 87 | 88 | def obs_to_pyg(obs: dict[str, Any]) -> pyg.data.Batch: 89 | """converts an env observation into a PyG `Batch` object""" 90 | obs_dag_batch = obs["dag_batch"] 91 | ptr = np.array(obs["dag_ptr"]) 92 | num_nodes_per_dag = ptr_to_counts(ptr) 93 | num_active_jobs = len(num_nodes_per_dag) 94 | 95 | dag_batch = pyg.data.Batch( 96 | x=torch.from_numpy(obs_dag_batch.nodes), 97 | edge_index=torch.from_numpy(obs_dag_batch.edge_links.T), 98 | ptr=torch.from_numpy(ptr), 99 | batch=torch.from_numpy( 100 | np.repeat(np.arange(num_active_jobs), num_nodes_per_dag) 101 | ), 102 | _num_graphs=num_active_jobs, 103 | ) 104 | 105 | dag_batch["stage_mask"] = torch.tensor(obs["stage_mask"], dtype=torch.bool) 106 | dag_batch["exec_mask"] = torch.from_numpy(obs["exec_mask"]) 107 | dag_batch["num_nodes_per_dag"] = ptr_to_counts(dag_batch.ptr) 108 | 109 | if "edge_masks" in obs: 110 | dag_batch["edge_masks"] = torch.from_numpy(obs["edge_masks"]) 111 | 112 | if "node_depth" in obs: 113 | dag_batch["node_depth"] = torch.from_numpy(obs["node_depth"]).float() 114 | 115 | return dag_batch 116 | 117 | 118 | def collate_obsns(obsns: Iterable[dict[str, Any]]) -> pyg.data.Batch: 119 | keys = ["dag_batch", "dag_ptr", "stage_mask", "exec_mask"] 120 | dag_batches, dag_ptrs, stage_masks, exec_masks = zip( 121 | *([obs[key] for key in keys] for obs in obsns) 122 | ) 123 | 124 | dag_batch = collate_dag_batches(dag_batches, dag_ptrs) 125 | 126 | dag_batch["stage_mask"] = torch.from_numpy(np.concatenate(stage_masks).astype(bool)) 127 | 128 | dag_batch["exec_mask"] = torch.from_numpy(np.vstack(exec_masks)) 129 | 130 | # number of available stage actions at each step 131 | dag_batch["num_stage_acts"] = torch.tensor([msk.sum() for msk in stage_masks]) 132 | 133 | # number of available exec actions at each step 134 | dag_batch["num_exec_acts"] = dag_batch["exec_mask"].sum(-1) 135 | 136 | if "edge_masks" in next(iter(obsns)): 137 | edge_masks_list = [obs["edge_masks"] for obs in obsns] 138 | total_num_edges = dag_batch.edge_index.shape[1] 139 | dag_batch["edge_masks"] = collate_edge_masks(edge_masks_list, total_num_edges) 140 | 141 | if "node_depth" in next(iter(obsns)): 142 | node_depth_list = [obs["node_depth"] for obs in obsns] 143 | dag_batch["node_depth"] = torch.from_numpy( 144 | np.concatenate(node_depth_list) 145 | ).float() 146 | 147 | return dag_batch 148 | 149 | 150 | def collate_edge_masks( 151 | edge_masks_list: Iterable[ndarray], total_num_edges: int 152 | ) -> ndarray: 153 | """collates list of edge mask batches from each message passing path. Since the 154 | message passing depth varies between observations, edge mask batches are padded 155 | to the maximum depth.""" 156 | max_depth = max(edge_masks.shape[0] for edge_masks in edge_masks_list) 157 | 158 | # array that will be populated with the masks from all the observations 159 | edge_masks = np.zeros((max_depth, total_num_edges), dtype=bool) 160 | 161 | i = 0 162 | for masks in edge_masks_list: 163 | # copy the data from these masks into the output array 164 | depth, num_edges = masks.shape 165 | if depth > 0: 166 | edge_masks[:depth, i : (i + num_edges)] = masks 167 | i += num_edges 168 | 169 | return edge_masks 170 | 171 | 172 | def collate_dag_batches( 173 | dag_batches: Iterable[pyg.data.Batch], dag_ptrs: Iterable[ndarray] 174 | ) -> pyg.data.Batch: 175 | """collates the dag batches from each observation into one large dag batch""" 176 | num_dags_per_obs_tup, num_nodes_per_dag_tup = zip( 177 | *( 178 | (len(dag_ptr) - 1, ptr_to_counts(torch.tensor(dag_ptr))) 179 | for dag_ptr in dag_ptrs 180 | ) 181 | ) 182 | num_dags_per_obs = torch.tensor(num_dags_per_obs_tup) 183 | num_nodes_per_dag = torch.cat(num_nodes_per_dag_tup) 184 | obs_ptr = counts_to_ptr(num_dags_per_obs) 185 | num_nodes_per_obs = segment_csr(num_nodes_per_dag, obs_ptr) 186 | num_graphs = num_dags_per_obs.sum().item() 187 | 188 | x = torch.from_numpy(np.concatenate([dag_batch.nodes for dag_batch in dag_batches])) 189 | 190 | dag_batch = pyg.data.Batch( 191 | x=x, 192 | edge_index=collate_edges(dag_batches, num_nodes_per_obs), 193 | ptr=counts_to_ptr(num_nodes_per_dag), 194 | batch=torch.arange(num_graphs).repeat_interleave( 195 | num_nodes_per_dag, output_size=x.shape[0] 196 | ), 197 | _num_graphs=num_graphs, 198 | ) 199 | 200 | # store bookkeeping attributes 201 | dag_batch["num_dags_per_obs"] = num_dags_per_obs 202 | dag_batch["num_nodes_per_dag"] = num_nodes_per_dag 203 | dag_batch["num_nodes_per_obs"] = num_nodes_per_obs 204 | dag_batch["obs_ptr"] = obs_ptr 205 | 206 | return dag_batch 207 | 208 | 209 | def collate_edges( 210 | dag_batches: Iterable[pyg.data.Batch], num_nodes_per_obs: Tensor 211 | ) -> Tensor: 212 | edge_counts_tup, edge_links_tup = zip( 213 | *( 214 | (dag_batch.edge_links.shape[0], dag_batch.edge_links) 215 | for dag_batch in dag_batches 216 | ) 217 | ) 218 | edge_counts = torch.tensor(edge_counts_tup) 219 | edge_links = np.concatenate(edge_links_tup) 220 | 221 | edge_index = torch.from_numpy(edge_links.T) 222 | 223 | # relabel the edges 224 | node_ptr = counts_to_ptr(num_nodes_per_obs) 225 | edge_index += ( 226 | node_ptr[:-1] 227 | .repeat_interleave(edge_counts, output_size=edge_index.shape[1]) 228 | .unsqueeze(0) 229 | ) 230 | 231 | return edge_index 232 | 233 | 234 | def make_edge_mask(edge_links: ndarray, node_mask: ndarray) -> ndarray: 235 | return node_mask[edge_links[:, 0]] & node_mask[edge_links[:, 1]] 236 | 237 | 238 | def make_dag_layer_edge_masks( 239 | graph_or_data: nx.DiGraph | tuple[ndarray, int] 240 | ) -> ndarray: 241 | """returns a batch of edge masks of shape (msg passing depth, num edges), 242 | where the i'th mask indicates which edges participate in the i'th root-to-leaf 243 | message passing step. 244 | """ 245 | if isinstance(graph_or_data, nx.DiGraph): 246 | G = graph_or_data 247 | else: 248 | edge_links, num_nodes = graph_or_data 249 | G = np_to_nx(edge_links, num_nodes) 250 | 251 | node_levels = list(nx.topological_generations(G)) 252 | 253 | if len(node_levels) <= 1: 254 | # no message passing to do 255 | return np.zeros((0, edge_links.shape[0]), dtype=bool) 256 | 257 | node_mask = np.zeros(len(G), dtype=bool) 258 | 259 | edge_masks = [] 260 | for node_level in node_levels[:-1]: 261 | succ = set.union(*[set(G.successors(n)) for n in node_level]) 262 | node_mask[:] = 0 263 | node_mask[node_level + list(succ)] = True 264 | edge_mask = make_edge_mask(edge_links, node_mask) 265 | edge_masks += [edge_mask] 266 | 267 | return np.stack(edge_masks) 268 | 269 | 270 | def np_to_nx(edge_links: ndarray, num_nodes: int) -> nx.DiGraph: 271 | G = nx.DiGraph() 272 | G.add_nodes_from(range(num_nodes)) 273 | G.add_edges_from(edge_links) 274 | return G 275 | -------------------------------------------------------------------------------- /spark_sched_sim/components/executor_tracker.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | 4 | PoolKey = tuple[int | None, int | None] 5 | OptPoolKey = PoolKey | None 6 | JobPoolKey = tuple[int, None] 7 | StagePoolKey = tuple[int, int] 8 | CommonPoolKey = tuple[None, None] 9 | 10 | COMMON_POOL_KEY = (None, None) 11 | 12 | 13 | class ExecutorTracker: 14 | """Maintains all executor assignments. These include: 15 | - current location of each executor, called a 'pool' 16 | - commitments from one executor pool to another, and 17 | - the executors moving between pools. 18 | 19 | The following executor pools exist: 20 | - Placeholder pool (key `None`): not an actual pool; it never contains any executors. 21 | the executor source is set to null pool when no pool is ready to schedule its executors. 22 | - General pool (key `(None, None)`): the pool where executors reside when they are not at 23 | any job. All executors start out in the common pool. 24 | - Job pool (key `(job_id, None)`): pool of idle executors at a job 25 | - Operation pool (key `(job_id, stage_id)`): pool of executors at a stage, including idle 26 | and busy 27 | """ 28 | 29 | def __init__(self, num_executors: int) -> None: 30 | self.num_executors = num_executors 31 | 32 | def reset(self) -> None: 33 | # executor id -> key of pool where the executor currently resides 34 | self._executor_locations: dict[int, OptPoolKey] = { 35 | executor_id: COMMON_POOL_KEY for executor_id in range(self.num_executors) 36 | } 37 | 38 | # pool key -> set of id's of executors who reside at this pool 39 | self._pools: dict[OptPoolKey, set[int]] = { 40 | None: set(), 41 | COMMON_POOL_KEY: set(range(self.num_executors)), 42 | } 43 | 44 | # pool key A -> 45 | # (pool key B -> 46 | # number of commitments from 47 | # pool A to pool B) 48 | self._commitments: dict[OptPoolKey, dict[OptPoolKey, int]] = { 49 | None: {}, 50 | COMMON_POOL_KEY: {}, 51 | } 52 | 53 | # pool key -> total number of outgoing commitments from this pool 54 | self._num_commitments_from: dict[OptPoolKey, int] = { 55 | None: 0, 56 | COMMON_POOL_KEY: 0, 57 | } 58 | 59 | # stage pool key -> total number of commitments to stage 60 | self._num_commitments_to_stage: dict[PoolKey, int] = {COMMON_POOL_KEY: 0} 61 | 62 | # stage pool key -> number of executors moving to stage 63 | self._num_moving_to_stage: dict[StagePoolKey, int] = {} 64 | 65 | # job id -> size of job's pool plus the total number of external commitments 66 | # and executors moving to any of its stages 67 | self._total_executor_count: dict[int | None, int] = {None: 0} 68 | 69 | # initialize executor source 70 | self._curr_source: OptPoolKey = COMMON_POOL_KEY 71 | 72 | def add_job_pool(self, pool_key: JobPoolKey) -> None: 73 | if pool_key in self._pools: 74 | raise ValueError("job pool already exists") 75 | 76 | job_id, _ = pool_key 77 | self._pools[pool_key] = set() 78 | self._commitments[pool_key] = {} 79 | self._num_commitments_from[pool_key] = 0 80 | self._total_executor_count[job_id] = 0 81 | 82 | def add_stage_pool(self, pool_key: StagePoolKey) -> None: 83 | if pool_key in self._pools: 84 | raise ValueError("stage pool already exists") 85 | 86 | job_id, _ = pool_key 87 | if (job_id, None) not in self._pools: 88 | raise ValueError(f"job with id {job_id} does not exist") 89 | 90 | self._pools[pool_key] = set() 91 | self._commitments[pool_key] = {} 92 | self._num_commitments_from[pool_key] = 0 93 | self._num_commitments_to_stage[pool_key] = 0 94 | self._num_moving_to_stage[pool_key] = 0 95 | 96 | def get_source(self) -> OptPoolKey: 97 | return self._curr_source 98 | 99 | def source_job_id(self) -> int | None: 100 | if not self._curr_source or self._curr_source is COMMON_POOL_KEY: 101 | return None 102 | else: 103 | return self._curr_source[0] 104 | 105 | def num_committable_execs(self) -> int: 106 | num_uncommitted = ( 107 | len(self._pools[self._curr_source]) 108 | - self._num_commitments_from[self._curr_source] 109 | ) 110 | assert num_uncommitted >= 0, "[num_committable_execs]" 111 | return num_uncommitted 112 | 113 | def common_pool_has_executors(self) -> bool: 114 | return bool(self._pools[COMMON_POOL_KEY]) 115 | 116 | def num_executors_moving_to_stage(self, stage_pool_key: StagePoolKey) -> int: 117 | return self._num_moving_to_stage[stage_pool_key] 118 | 119 | def num_commitments_to_stage(self, stage_pool_key: StagePoolKey) -> int: 120 | return self._num_commitments_to_stage[stage_pool_key] 121 | 122 | def exec_supply(self, job_id: int) -> int: 123 | return self._total_executor_count[job_id] 124 | 125 | def update_executor_source(self, pool_key: PoolKey) -> None: 126 | self._curr_source = pool_key 127 | 128 | def clear_executor_source(self) -> None: 129 | self._curr_source = None 130 | 131 | def get_source_commitments(self) -> dict[OptPoolKey, int]: 132 | return self._commitments[self._curr_source].copy() 133 | 134 | def get_pool(self, pool_key: OptPoolKey) -> set[int]: 135 | return self._pools[pool_key].copy() 136 | 137 | def pool_size(self, pool_key: OptPoolKey) -> int: 138 | return len(self._pools[pool_key]) 139 | 140 | def get_source_pool(self) -> set[int]: 141 | return self.get_pool(self._curr_source) 142 | 143 | def executor_location(self, executor_id: int) -> OptPoolKey: 144 | return self._executor_locations[executor_id] 145 | 146 | def add_commitment(self, num_executors: int, dst_pool_key: PoolKey) -> None: 147 | assert self._curr_source, "[add_commitment]" 148 | src_job_id = self._curr_source[0] 149 | dst_job_id = dst_pool_key[0] 150 | 151 | self._increment_commitments(dst_pool_key, n=num_executors) 152 | 153 | if dst_job_id != src_job_id: 154 | self._total_executor_count[dst_job_id] += num_executors 155 | 156 | def remove_commitment(self, executor_id: int, dst_pool_key: PoolKey) -> PoolKey: 157 | src_pool_key = self._executor_locations[executor_id] 158 | assert src_pool_key, "[remove_commitment]" 159 | 160 | if dst_pool_key not in self._commitments[src_pool_key]: 161 | raise ValueError(f"no commitments from {src_pool_key} to {dst_pool_key}") 162 | 163 | src_job_id = src_pool_key[0] 164 | dst_job_id = dst_pool_key[0] 165 | 166 | # update commitment from source to dest stage 167 | self._decrement_commitments(src_pool_key, dst_pool_key) 168 | 169 | if dst_job_id != src_job_id: 170 | self._total_executor_count[dst_job_id] -= 1 171 | assert self._total_executor_count[dst_job_id] >= 0 172 | 173 | return src_pool_key 174 | 175 | def peek_commitment(self, pool_key: OptPoolKey) -> OptPoolKey: 176 | try: 177 | return next(iter(self._commitments[pool_key])) 178 | except (KeyError, StopIteration): 179 | # no outgoing commitments from this pool 180 | return None 181 | 182 | def record_executor_arrival(self, stage_pool_key: StagePoolKey) -> None: 183 | self._num_moving_to_stage[stage_pool_key] -= 1 184 | assert self._num_moving_to_stage[stage_pool_key] >= 0 185 | 186 | def move_executor_to_pool( 187 | self, executor_id: int, new_pool_key: OptPoolKey, send: bool = False 188 | ) -> None: 189 | if send and ( 190 | not new_pool_key or new_pool_key[0] is None or new_pool_key[1] is None 191 | ): 192 | raise ValueError("can only send executors to stages") 193 | 194 | old_pool_key = self._executor_locations[executor_id] 195 | 196 | if old_pool_key is not None: 197 | # remove executor from old pool 198 | self._pools[old_pool_key].remove(executor_id) 199 | self._executor_locations[executor_id] = None 200 | 201 | if not send: 202 | # directly move executor into new pool 203 | self._executor_locations[executor_id] = new_pool_key 204 | self._pools[new_pool_key].add(executor_id) 205 | return 206 | 207 | # send the executor to the stage 208 | 209 | new_pool_key = cast(StagePoolKey, new_pool_key) 210 | 211 | self._num_moving_to_stage[new_pool_key] += 1 212 | 213 | old_job_id = old_pool_key[0] if old_pool_key is not None else None 214 | new_job_id = new_pool_key[0] 215 | assert old_job_id != new_job_id 216 | 217 | self._total_executor_count[new_job_id] += 1 218 | if old_job_id is not None: 219 | self._total_executor_count[old_job_id] -= 1 220 | assert self._total_executor_count[old_job_id] >= 0 221 | 222 | # internal methods 223 | 224 | def _increment_commitments(self, dst_pool_key: PoolKey, n: int) -> None: 225 | try: 226 | self._commitments[self._curr_source][dst_pool_key] += n 227 | except KeyError: 228 | # key not in dict yet 229 | self._commitments[self._curr_source][dst_pool_key] = n 230 | 231 | self._num_commitments_from[self._curr_source] += n 232 | self._num_commitments_to_stage[dst_pool_key] += n 233 | 234 | supply = len(self._pools[self._curr_source]) 235 | demand = self._num_commitments_from[self._curr_source] 236 | assert supply >= demand 237 | 238 | def _decrement_commitments( 239 | self, src_pool_key: PoolKey, dst_pool_key: PoolKey 240 | ) -> None: 241 | self._commitments[src_pool_key][dst_pool_key] -= 1 242 | self._num_commitments_from[src_pool_key] -= 1 243 | self._num_commitments_to_stage[dst_pool_key] -= 1 244 | 245 | assert self._num_commitments_from[src_pool_key] >= 0 246 | assert self._num_commitments_to_stage[dst_pool_key] >= 0 247 | 248 | if self._commitments[src_pool_key][dst_pool_key] == 0: 249 | self._commitments[src_pool_key].pop(dst_pool_key) 250 | -------------------------------------------------------------------------------- /spark_sched_sim/data_samplers/tpch.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pathlib 3 | from io import BytesIO 4 | from zipfile import ZipFile 5 | from urllib.request import urlopen 6 | 7 | import numpy as np 8 | import networkx as nx 9 | 10 | from .data_sampler import DataSampler 11 | from ..components import Job, Stage 12 | 13 | TPCH_URL = "https://bit.ly/3F1Go8t" 14 | QUERY_SIZES = ["2g", "5g", "10g", "20g", "50g", "80g", "100g"] 15 | NUM_QUERIES = 22 16 | 17 | 18 | class TPCHDataSampler(DataSampler): 19 | def __init__( 20 | self, 21 | job_arrival_rate: float, 22 | job_arrival_cap: int, 23 | num_executors: int, 24 | warmup_delay: int, 25 | **kwargs, 26 | ): 27 | """ 28 | job_arrival_rate (float): non-negative number that controls how 29 | quickly new jobs arrive into the system. This is the parameter 30 | of an exponential distributions, and so its inverse is the 31 | mean job inter-arrival time in ms. 32 | job_arrival_cap: (optional int): limit on the number of jobs that 33 | arrive throughout the simulation. If set to `None`, then the 34 | episode ends when a time limit is reached. 35 | num_executors (int): number of simulated executors. More executors 36 | means a higher possible level of parallelism. 37 | warmup_delay (int): an executor is slower on its first task from 38 | a stage if it was previously idle or moving jobs, which is 39 | caputred by adding a warmup delay (ms) to the task duration 40 | """ 41 | self.job_arrival_cap = job_arrival_cap 42 | self.mean_interarrival_time = 1 / job_arrival_rate 43 | self.warmup_delay = warmup_delay 44 | 45 | self.np_random = None 46 | self._init_executor_intervals(num_executors) 47 | 48 | if not osp.isdir("data/tpch"): 49 | self._download_tpch_dataset() 50 | 51 | def reset(self, np_random: np.random.Generator): 52 | self.np_random = np_random 53 | 54 | def job_sequence(self, max_time): 55 | """generates a sequence of job arrivals over time, which follow a 56 | Poisson process parameterized by `self.job_arrival_rate` 57 | """ 58 | assert self.np_random 59 | job_sequence = [] 60 | 61 | t = 0 62 | job_idx = 0 63 | while t < max_time and ( 64 | not self.job_arrival_cap or job_idx < self.job_arrival_cap 65 | ): 66 | job = self._sample_job(job_idx, t) 67 | job_sequence.append((t, job)) 68 | 69 | # sample time in ms until next arrival 70 | t += self.np_random.exponential(self.mean_interarrival_time) 71 | job_idx += 1 72 | 73 | return job_sequence 74 | 75 | def task_duration(self, job, stage, task, executor): 76 | num_local_executors = len(job.local_executors) 77 | 78 | assert num_local_executors > 0 79 | assert self.np_random 80 | 81 | data = stage.task_duration_data 82 | 83 | # sample an executor point in the data 84 | executor_key = self._sample_executor_key(data, num_local_executors) 85 | 86 | if executor.is_idle: 87 | # the executor was just sitting idly or moving between jobs, so it needs time to warm up 88 | try: 89 | return self._sample_task_duration(data, "fresh_durations", executor_key) 90 | except (ValueError, KeyError): 91 | return self._sample_task_duration( 92 | data, "first_wave", executor_key, warmup=True 93 | ) 94 | 95 | if executor.task.stage_id == task.stage_id: 96 | # the executor is continuing work on the same stage, which is relatively fast 97 | try: 98 | return self._sample_task_duration(data, "rest_wave", executor_key) 99 | except (ValueError, KeyError): 100 | pass 101 | 102 | # the executor is new to this stage (or 'rest_wave' data was not available) 103 | try: 104 | return self._sample_task_duration(data, "first_wave", executor_key) 105 | except (ValueError, KeyError): 106 | return self._sample_task_duration(data, "fresh_durations", executor_key) 107 | 108 | @classmethod 109 | def _download_tpch_dataset(cls): 110 | print("Downloading the TPC-H dataset...", flush=True) 111 | pathlib.Path("data").mkdir(parents=True, exist_ok=True) 112 | with urlopen(TPCH_URL) as zipresp: 113 | with ZipFile(BytesIO(zipresp.read())) as zfile: 114 | zfile.extractall("data") 115 | print("Done.", flush=True) 116 | 117 | @classmethod 118 | def _load_query(cls, query_num, query_size): 119 | query_path = osp.join("data/tpch", str(query_size)) 120 | 121 | adj_matrix = np.load( 122 | osp.join(query_path, f"adj_mat_{query_num}.npy"), allow_pickle=True 123 | ) 124 | 125 | task_duration_data = np.load( 126 | osp.join(query_path, f"task_duration_{query_num}.npy"), allow_pickle=True 127 | ).item() 128 | 129 | assert adj_matrix.shape[0] == adj_matrix.shape[1] 130 | assert adj_matrix.shape[0] == len(task_duration_data) 131 | 132 | return adj_matrix, task_duration_data 133 | 134 | @classmethod 135 | def _pre_process_task_duration(cls, task_duration): 136 | # remove fresh durations from first wave 137 | clean_first_wave = {} 138 | for e in task_duration["first_wave"]: 139 | clean_first_wave[e] = [] 140 | fresh_durations = MultiSet() 141 | # O(1) access 142 | for d in task_duration["fresh_durations"][e]: 143 | fresh_durations.add(d) 144 | for d in task_duration["first_wave"][e]: 145 | if d not in fresh_durations: 146 | clean_first_wave[e].append(d) 147 | else: 148 | # prevent duplicated fresh duration blocking first wave 149 | fresh_durations.remove(d) 150 | 151 | # fill in nearest neighour first wave 152 | last_first_wave = [] 153 | for e in sorted(clean_first_wave.keys()): 154 | if len(clean_first_wave[e]) == 0: 155 | clean_first_wave[e] = last_first_wave 156 | last_first_wave = clean_first_wave[e] 157 | 158 | # swap the first wave with fresh durations removed 159 | task_duration["first_wave"] = clean_first_wave 160 | 161 | @classmethod 162 | def _rough_task_duration(cls, task_duration_data): 163 | def durations(key): 164 | durations = task_duration_data[key].values() 165 | durations = [t for ts in durations for t in ts] 166 | return durations 167 | 168 | all_durations = ( 169 | durations("fresh_durations") 170 | + durations("first_wave") 171 | + durations("rest_wave") 172 | ) 173 | 174 | return np.mean(all_durations) 175 | 176 | def _sample_job(self, job_id, t_arrival): 177 | query_num = 1 + self.np_random.integers(NUM_QUERIES) 178 | query_size = self.np_random.choice(QUERY_SIZES) 179 | adj_mat, task_duration_data = self._load_query(query_num, query_size) 180 | 181 | num_stages = adj_mat.shape[0] 182 | stages = [] 183 | for stage_id in range(num_stages): 184 | data = task_duration_data[stage_id] 185 | e = next(iter(data["first_wave"])) 186 | 187 | num_tasks = len(data["first_wave"][e]) + len(data["rest_wave"][e]) 188 | 189 | # remove fresh duration from first wave duration 190 | # drag nearest neighbor first wave duration to empty spots 191 | self._pre_process_task_duration(data) 192 | 193 | # generate a node 194 | stage = Stage(stage_id, job_id, num_tasks, self._rough_task_duration(data)) 195 | stage.task_duration_data = data 196 | stages += [stage] 197 | 198 | # generate DAG 199 | dag = nx.from_numpy_array(adj_mat, create_using=nx.DiGraph) 200 | for _, _, d in dag.edges(data=True): 201 | d.clear() 202 | 203 | job = Job(job_id, stages, dag, t_arrival) 204 | job.query_num = query_num 205 | job.query_size = query_size 206 | return job 207 | 208 | def _sample_task_duration(self, data, wave, executor_key, warmup=False): 209 | """raises an exception if `executor_key` is not found in the durations from `wave`""" 210 | durations = data[wave][executor_key] 211 | duration = self.np_random.choice(durations) 212 | if warmup: 213 | duration += self.warmup_delay 214 | return duration 215 | 216 | def _sample_executor_key(self, data, num_local_executors): 217 | left_exec, right_exec = self.executor_intervals[num_local_executors] 218 | 219 | executor_key = None 220 | 221 | if left_exec == right_exec: 222 | executor_key = left_exec 223 | else: 224 | # faster than random.randint 225 | rand_pt = 1 + int(self.np_random.random() * (right_exec - left_exec)) 226 | if rand_pt <= num_local_executors - left_exec: 227 | executor_key = left_exec 228 | else: 229 | executor_key = right_exec 230 | 231 | if executor_key not in data["first_wave"]: 232 | # more executors than number of tasks in the job 233 | executor_key = max(data["first_wave"]) 234 | 235 | return executor_key 236 | 237 | def _init_executor_intervals(self, exec_cap): 238 | exec_levels = [5, 10, 20, 40, 50, 60, 80, 100] 239 | 240 | intervals = np.zeros((exec_cap + 1, 2)) 241 | 242 | # get the left most map 243 | intervals[: exec_levels[0] + 1] = exec_levels[0] 244 | 245 | # get the center map 246 | for i in range(len(exec_levels) - 1): 247 | intervals[exec_levels[i] + 1 : exec_levels[i + 1]] = ( 248 | exec_levels[i], 249 | exec_levels[i + 1], 250 | ) 251 | 252 | if exec_levels[i + 1] > exec_cap: 253 | break 254 | 255 | # at the data point 256 | intervals[exec_levels[i + 1]] = exec_levels[i + 1] 257 | 258 | # get the residual map 259 | if exec_cap > exec_levels[-1]: 260 | intervals[exec_levels[-1] + 1 : exec_cap] = exec_levels[-1] 261 | 262 | self.executor_intervals = intervals 263 | 264 | 265 | class MultiSet: 266 | """ 267 | allow duplication in set 268 | """ 269 | 270 | def __init__(self): 271 | self.set = {} 272 | 273 | def __contains__(self, item): 274 | return item in self.set 275 | 276 | def add(self, item): 277 | if item in self.set: 278 | self.set[item] += 1 279 | else: 280 | self.set[item] = 1 281 | 282 | def clear(self): 283 | self.set.clear() 284 | 285 | def remove(self, item): 286 | self.set[item] -= 1 287 | if self.set[item] == 0: 288 | del self.set[item] 289 | -------------------------------------------------------------------------------- /trainers/trainer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections.abc import Iterable 3 | from typing import Any 4 | import shutil 5 | import os 6 | import os.path as osp 7 | import sys 8 | from copy import deepcopy 9 | import json 10 | import pathlib 11 | 12 | import numpy as np 13 | import torch 14 | import multiprocessing as mp 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | from schedulers import make_scheduler, TrainableScheduler 18 | from .rollout_worker import RolloutWorkerSync, RolloutWorkerAsync, RolloutBuffer 19 | from .utils import Baseline, ReturnsCalculator 20 | 21 | 22 | CfgType = dict[str, Any] 23 | 24 | 25 | class Trainer(ABC): 26 | """Base training algorithm class. Each algorithm must implement the 27 | abstract method `train_on_rollouts` 28 | """ 29 | 30 | def __init__( 31 | self, agent_cfg: CfgType, env_cfg: CfgType, train_cfg: CfgType 32 | ) -> None: 33 | self.seed = train_cfg["seed"] 34 | torch.manual_seed(self.seed) 35 | 36 | self.scheduler_cls = agent_cfg["agent_cls"] 37 | 38 | self.device = torch.device( 39 | train_cfg.get("device", "cuda" if torch.cuda.is_available() else "cpu") 40 | ) 41 | 42 | # number of training iterations 43 | self.num_iterations: int = train_cfg["num_iterations"] 44 | 45 | # number of unique job sequences per iteration 46 | self.num_sequences: int = train_cfg["num_sequences"] 47 | 48 | # number of rollouts per job sequence 49 | self.num_rollouts: int = int(train_cfg["num_rollouts"]) 50 | 51 | self.artifacts_dir: str = train_cfg["artifacts_dir"] 52 | pathlib.Path(self.artifacts_dir).mkdir(parents=True, exist_ok=True) 53 | 54 | self.stdout_dir = osp.join(self.artifacts_dir, "stdout") 55 | self.tb_dir = osp.join(self.artifacts_dir, "tb") 56 | self.checkpointing_dir = osp.join(self.artifacts_dir, "checkpoints") 57 | self.use_tensorboard: bool = train_cfg["use_tensorboard"] 58 | self.checkpointing_freq: int = train_cfg["checkpointing_freq"] 59 | self.env_cfg = env_cfg 60 | 61 | self.baseline = Baseline(self.num_sequences, self.num_rollouts) 62 | 63 | self.rollout_duration: float | None = train_cfg.get("rollout_duration") 64 | 65 | assert ("reward_buff_cap" in train_cfg) ^ ( 66 | "beta_discount" in train_cfg 67 | ), "must provide exactly one of `reward_buff_cap` and `beta_discount` in config" 68 | 69 | if "reward_buff_cap" in train_cfg: 70 | self.return_calc = ReturnsCalculator(buff_cap=train_cfg["reward_buff_cap"]) 71 | else: 72 | beta: float = train_cfg["beta_discount"] 73 | env_cfg |= {"beta": beta} 74 | self.return_calc = ReturnsCalculator(beta=beta) 75 | 76 | self.scheduler_cfg = ( 77 | agent_cfg 78 | | {"num_executors": env_cfg["num_executors"]} 79 | | {k: train_cfg[k] for k in ["opt_cls", "opt_kwargs", "max_grad_norm"]} 80 | ) 81 | scheduler = make_scheduler(self.scheduler_cfg) 82 | assert isinstance(scheduler, TrainableScheduler), "scheduler must be trainable." 83 | self.scheduler: TrainableScheduler = scheduler 84 | 85 | def train(self) -> None: 86 | """trains the model on different job arrival sequences. 87 | For each job sequence: 88 | - multiple rollouts are collected in parallel, asynchronously 89 | - the rollouts are gathered at the center, where model parameters are 90 | updated, and 91 | - new model parameters are scattered to the rollout workers 92 | """ 93 | self._setup() 94 | 95 | # every n'th iteration, save the best model from the past n iterations, 96 | # where `n = self.model_save_freq` 97 | best_state = None 98 | 99 | exception: Exception | None = None 100 | 101 | print("Beginning training.\n", flush=True) 102 | 103 | for i in range(self.num_iterations): 104 | state_dict = deepcopy(self.scheduler.state_dict()) 105 | 106 | # # move params to GPU for learning 107 | self.scheduler.to(self.device, non_blocking=True) 108 | 109 | # scatter 110 | for conn in self.conns: 111 | conn.send({"state_dict": state_dict}) 112 | 113 | # gather 114 | results = [] 115 | for i, conn in enumerate(self.conns): 116 | res = conn.recv() 117 | if isinstance(res, Exception): 118 | print(f"An exception occured in process {i}", flush=True) 119 | exception = res 120 | break 121 | results += [res] 122 | 123 | if exception: 124 | break 125 | 126 | rollout_buffers, rollout_stats_list = zip( 127 | *[(res["rollout_buffer"], res["stats"]) for res in results if res] 128 | ) 129 | 130 | # update parameters 131 | learning_stats = self.train_on_rollouts(rollout_buffers) 132 | 133 | # return params to CPU before scattering updated state dict to the rollout workers 134 | self.scheduler.to("cpu", non_blocking=True) 135 | 136 | avg_num_jobs = self.return_calc.avg_num_jobs or np.mean( 137 | [stats["avg_num_jobs"] for stats in rollout_stats_list] 138 | ) 139 | 140 | # check if model is the current best 141 | if not best_state or avg_num_jobs < best_state["avg_num_jobs"]: 142 | best_state = self._capture_state( 143 | i, avg_num_jobs, state_dict, rollout_stats_list 144 | ) 145 | 146 | if (i + 1) % self.checkpointing_freq == 0: 147 | self._checkpoint(i, best_state) 148 | best_state = None 149 | 150 | if self.use_tensorboard: 151 | ep_lens = [len(buff) for buff in rollout_buffers if buff] 152 | self._write_stats(i, learning_stats, rollout_stats_list, ep_lens) 153 | 154 | print( 155 | f"Iteration {i+1} complete. Avg. # jobs: " f"{avg_num_jobs:.3f}", 156 | flush=True, 157 | ) 158 | 159 | self._cleanup() 160 | 161 | if exception: 162 | raise exception 163 | 164 | @abstractmethod 165 | def train_on_rollouts( 166 | self, rollout_buffers: Iterable[RolloutBuffer] 167 | ) -> dict[str, Any]: 168 | pass 169 | 170 | # internal methods 171 | 172 | def _preprocess_rollouts( 173 | self, rollout_buffers: Iterable[RolloutBuffer] 174 | ) -> dict[str, tuple]: 175 | ( 176 | obsns_list, 177 | actions_list, 178 | wall_times_list, 179 | rewards_list, 180 | lgprobs_list, 181 | resets_list, 182 | ) = zip( 183 | *( 184 | ( 185 | buff.obsns, 186 | buff.actions, 187 | buff.wall_times, 188 | buff.rewards, 189 | buff.lgprobs, 190 | buff.resets, 191 | ) 192 | for buff in rollout_buffers 193 | if buff is not None 194 | ) 195 | ) 196 | 197 | returns_list = self.return_calc( 198 | rewards_list, 199 | wall_times_list, 200 | resets_list, 201 | ) 202 | 203 | wall_times_list = tuple([wall_times[:-1] for wall_times in wall_times_list]) 204 | baselines_list = self.baseline(wall_times_list, returns_list) 205 | 206 | return { 207 | "obsns_list": obsns_list, 208 | "actions_list": actions_list, 209 | "returns_list": returns_list, 210 | "baselines_list": baselines_list, 211 | "lgprobs_list": lgprobs_list, 212 | } 213 | 214 | def _setup(self) -> None: 215 | # logging 216 | shutil.rmtree(self.stdout_dir, ignore_errors=True) 217 | os.mkdir(self.stdout_dir) 218 | sys.stdout = open(osp.join(self.stdout_dir, "main.out"), "a") 219 | 220 | if self.use_tensorboard: 221 | self.summary_writer = SummaryWriter(self.tb_dir) 222 | 223 | # model checkpoints 224 | shutil.rmtree(self.checkpointing_dir, ignore_errors=True) 225 | os.mkdir(self.checkpointing_dir) 226 | 227 | # torch 228 | torch.multiprocessing.set_start_method("spawn") 229 | # print('cuda available:', torch.cuda.is_available()) 230 | # torch.autograd.set_detect_anomaly(True) 231 | 232 | self.scheduler.train() 233 | 234 | self._start_rollout_workers() 235 | 236 | def _cleanup(self) -> None: 237 | self._terminate_rollout_workers() 238 | 239 | if self.use_tensorboard: 240 | self.summary_writer.close() 241 | 242 | print("\nTraining complete.", flush=True) 243 | 244 | def _capture_state( 245 | self, i: int, avg_num_jobs: float, state_dict: dict, stats_list: Iterable[dict] 246 | ) -> dict[str, Any]: 247 | return { 248 | "iteration": i, 249 | "avg_num_jobs": np.round(avg_num_jobs, 3), 250 | "state_dict": state_dict, 251 | "completed_job_count": int( 252 | np.mean([stats["num_completed_jobs"] for stats in stats_list]) 253 | ), 254 | } 255 | 256 | def _checkpoint(self, i: int, best_state: dict) -> None: 257 | dir = osp.join(self.checkpointing_dir, f"{i+1}") 258 | os.mkdir(dir) 259 | best_sd = best_state.pop("state_dict") 260 | torch.save(best_sd, osp.join(dir, "model.pt")) 261 | with open(osp.join(dir, "state.json"), "w") as fp: 262 | json.dump(best_state, fp) 263 | 264 | def _start_rollout_workers(self) -> None: 265 | self.procs = [] 266 | self.conns = [] 267 | 268 | base_seeds = self.seed + np.arange(self.num_sequences) 269 | base_seeds = np.repeat(base_seeds, self.num_rollouts) 270 | seed_step = self.num_sequences 271 | lock = mp.Lock() 272 | for rank, base_seed in enumerate(base_seeds): 273 | conn_main, conn_sub = mp.Pipe() 274 | self.conns += [conn_main] 275 | 276 | proc = mp.Process( 277 | target=RolloutWorkerAsync(self.rollout_duration) 278 | if self.rollout_duration 279 | else RolloutWorkerSync(), 280 | args=( 281 | rank, 282 | conn_sub, 283 | self.env_cfg, 284 | self.scheduler_cfg, 285 | self.stdout_dir, 286 | int(base_seed), 287 | seed_step, 288 | lock, 289 | ), 290 | ) 291 | 292 | self.procs += [proc] 293 | proc.start() 294 | 295 | for proc in self.procs: 296 | proc.join(5) 297 | 298 | def _terminate_rollout_workers(self) -> None: 299 | for conn in self.conns: 300 | conn.send(None) 301 | 302 | for proc in self.procs: 303 | proc.join() 304 | 305 | def _write_stats( 306 | self, 307 | epoch: int, 308 | learning_stats: dict, 309 | stats_list: Iterable[dict], 310 | ep_lens: list[int], 311 | ) -> None: 312 | episode_stats = learning_stats | { 313 | "avg num concurrent jobs": np.mean( 314 | [stats["avg_num_jobs"] for stats in stats_list] 315 | ), 316 | "avg job duration": np.mean( 317 | [stats["avg_job_duration"] for stats in stats_list] 318 | ), 319 | "completed jobs count": np.mean( 320 | [stats["num_completed_jobs"] for stats in stats_list] 321 | ), 322 | "job arrival count": np.mean( 323 | [stats["num_job_arrivals"] for stats in stats_list] 324 | ), 325 | "episode length": np.mean(ep_lens), 326 | } 327 | 328 | for name, stat in episode_stats.items(): 329 | self.summary_writer.add_scalar(name, stat, epoch) 330 | -------------------------------------------------------------------------------- /schedulers/decima/scheduler.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | from typing import Any 3 | from torch import Tensor 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch_scatter import segment_csr 8 | import torch_geometric as pyg 9 | import torch_sparse 10 | 11 | from ..scheduler import TrainableScheduler 12 | from .env_wrapper import DecimaEnvWrapper 13 | from . import utils 14 | 15 | 16 | class DecimaScheduler(TrainableScheduler): 17 | """Original Decima architecture, which uses asynchronous message passing 18 | as in DAGNN. 19 | Paper: https://dl.acm.org/doi/abs/10.1145/3341302.3342080 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_executors: int, 25 | embed_dim: int, 26 | gnn_mlp_kwargs: dict[str, Any], 27 | policy_mlp_kwargs: dict[str, Any], 28 | state_dict_path: str | None = None, 29 | opt_cls: str | None = None, 30 | opt_kwargs: dict[str, Any] | None = None, 31 | max_grad_norm: float | None = None, 32 | num_node_features: int = 5, 33 | num_dag_features: int = 3, 34 | **kwargs, 35 | ): 36 | super().__init__() 37 | 38 | self.name = "Decima" 39 | self.env_wrapper_cls = DecimaEnvWrapper 40 | self.max_grad_norm = max_grad_norm 41 | self.num_executors = num_executors 42 | 43 | self.encoder = EncoderNetwork(num_node_features, embed_dim, gnn_mlp_kwargs) 44 | 45 | emb_dims = {"node": embed_dim, "dag": embed_dim, "glob": embed_dim} 46 | 47 | self.stage_policy_network = StagePolicyNetwork( 48 | num_node_features, emb_dims, policy_mlp_kwargs 49 | ) 50 | 51 | self.exec_policy_network = ExecPolicyNetwork( 52 | num_executors, num_dag_features, emb_dims, policy_mlp_kwargs 53 | ) 54 | 55 | self._reset_biases() 56 | 57 | if state_dict_path: 58 | self.name += f":{state_dict_path}" 59 | self.load_state_dict(torch.load(state_dict_path)) 60 | 61 | if opt_cls: 62 | self.optim = getattr(torch.optim, opt_cls)( 63 | self.parameters(), **(opt_kwargs or {}) 64 | ) 65 | 66 | def _reset_biases(self) -> None: 67 | for name, param in self.named_parameters(): 68 | if "bias" in name: 69 | param.data.zero_() 70 | 71 | @torch.no_grad() 72 | def schedule(self, obs: dict) -> tuple[dict, dict]: 73 | dag_batch = utils.obs_to_pyg(obs) 74 | stage_to_job_map = dag_batch.batch 75 | stage_mask = dag_batch["stage_mask"] 76 | 77 | dag_batch.to(self.device, non_blocking=True) 78 | 79 | # 1. compute node, dag, and global representations 80 | h_dict = self.encoder(dag_batch) 81 | 82 | # 2. select a schedulable stage 83 | stage_scores = self.stage_policy_network(dag_batch, h_dict) 84 | stage_idx, stage_lgprob = utils.sample(stage_scores) 85 | 86 | # retrieve index of selected stage's job 87 | stage_idx_glob = pyg.utils.mask_to_index(stage_mask)[stage_idx] 88 | job_idx = stage_to_job_map[stage_idx_glob].item() 89 | 90 | # 3. select the number of executors to add to that stage, conditioned 91 | # on that stage's job 92 | exec_scores = self.exec_policy_network(dag_batch, h_dict, job_idx) 93 | num_exec, exec_lgprob = utils.sample(exec_scores) 94 | 95 | action = {"stage_idx": stage_idx, "job_idx": job_idx, "num_exec": num_exec} 96 | 97 | lgprob = stage_lgprob + exec_lgprob 98 | 99 | return action, {"lgprob": lgprob} 100 | 101 | def evaluate_actions( 102 | self, obsns: Iterable[dict], actions: Iterable[tuple] 103 | ) -> dict[str, Tensor]: 104 | dag_batch = utils.collate_obsns(obsns) 105 | actions_ten = torch.tensor(actions) 106 | 107 | # split columns of `actions` into separate tensors 108 | # NOTE: columns need to be cloned to avoid in-place operation 109 | stage_selections, job_indices, exec_selections = [ 110 | col.clone() for col in actions_ten.T 111 | ] 112 | 113 | num_stage_acts = dag_batch["num_stage_acts"] 114 | num_exec_acts = dag_batch["num_exec_acts"] 115 | num_nodes_per_obs = dag_batch["num_nodes_per_obs"] 116 | obs_ptr = dag_batch["obs_ptr"] 117 | job_indices += obs_ptr[:-1] 118 | 119 | # re-feed all the observations into the model with grads enabled 120 | dag_batch.to(self.device) 121 | h_dict = self.encoder(dag_batch) 122 | stage_scores = self.stage_policy_network(dag_batch, h_dict) 123 | exec_scores = self.exec_policy_network(dag_batch, h_dict, job_indices) 124 | 125 | stage_lgprobs, stage_entropies = utils.evaluate( 126 | stage_scores.cpu(), num_stage_acts, stage_selections 127 | ) 128 | 129 | exec_lgprobs, exec_entropies = utils.evaluate( 130 | exec_scores.cpu(), num_exec_acts[job_indices], exec_selections 131 | ) 132 | 133 | # aggregate the evaluations for nodes and dags 134 | action_lgprobs = stage_lgprobs + exec_lgprobs 135 | 136 | action_entropies = stage_entropies + exec_entropies 137 | action_entropies /= (self.num_executors * num_nodes_per_obs).log() 138 | 139 | return {"lgprobs": action_lgprobs, "entropies": action_entropies} 140 | 141 | 142 | class EncoderNetwork(nn.Module): 143 | def __init__( 144 | self, num_node_features: int, embed_dim: int, mlp_kwargs: dict[str, Any] 145 | ) -> None: 146 | super().__init__() 147 | 148 | self.node_encoder = NodeEncoder(num_node_features, embed_dim, mlp_kwargs) 149 | self.dag_encoder = DagEncoder(num_node_features, embed_dim, mlp_kwargs) 150 | self.global_encoder = GlobalEncoder(embed_dim, mlp_kwargs) 151 | 152 | def forward(self, dag_batch: pyg.data.Batch) -> dict[str, Tensor]: 153 | """ 154 | Returns: 155 | a dict of representations at three different levels: 156 | node, dag, and global. 157 | """ 158 | h_node = self.node_encoder(dag_batch) 159 | 160 | h_dag = self.dag_encoder(h_node, dag_batch) 161 | 162 | if "obs_ptr" in dag_batch: 163 | # batch of obsns 164 | obs_ptr = dag_batch["obs_ptr"] 165 | h_glob = self.global_encoder(h_dag, obs_ptr) 166 | else: 167 | # single obs 168 | h_glob = self.global_encoder(h_dag) 169 | 170 | return {"node": h_node, "dag": h_dag, "glob": h_glob} 171 | 172 | 173 | class NodeEncoder(nn.Module): 174 | def __init__( 175 | self, 176 | num_node_features: int, 177 | embed_dim: int, 178 | mlp_kwargs: dict[str, Any], 179 | reverse_flow: bool = True, 180 | ) -> None: 181 | super().__init__() 182 | self.reverse_flow = reverse_flow 183 | self.j, self.i = (1, 0) if reverse_flow else (0, 1) 184 | 185 | self.mlp_prep = utils.make_mlp( 186 | num_node_features, output_dim=embed_dim, **mlp_kwargs 187 | ) 188 | self.mlp_msg = utils.make_mlp(embed_dim, output_dim=embed_dim, **mlp_kwargs) 189 | self.mlp_update = utils.make_mlp(embed_dim, output_dim=embed_dim, **mlp_kwargs) 190 | 191 | def forward(self, dag_batch: pyg.data.Batch) -> Tensor: 192 | """returns a tensor of shape [num_nodes, embed_dim]""" 193 | 194 | edge_masks = dag_batch["edge_masks"] 195 | 196 | if edge_masks.shape[0] == 0: 197 | # no message passing to do 198 | return self._forward_no_mp(dag_batch.x) 199 | 200 | # pre-process the node features into initial representations 201 | h_init = self.mlp_prep(dag_batch.x) 202 | 203 | # will store all the nodes' representations 204 | h = torch.zeros_like(h_init) 205 | 206 | num_nodes = h.shape[0] 207 | 208 | src_node_mask = ~pyg.utils.index_to_mask( 209 | dag_batch.edge_index[self.i], num_nodes 210 | ) 211 | 212 | h[src_node_mask] = self.mlp_update(h_init[src_node_mask]) 213 | 214 | edge_masks_it = ( 215 | iter(reversed(edge_masks)) if self.reverse_flow else iter(edge_masks) 216 | ) 217 | 218 | # target-to-source message passing, one level of the dags at a time 219 | for edge_mask in edge_masks_it: 220 | edge_index_masked = dag_batch.edge_index[:, edge_mask] 221 | adj = utils.make_adj(edge_index_masked, num_nodes) 222 | 223 | # nodes sending messages 224 | src_mask = pyg.utils.index_to_mask(edge_index_masked[self.j], num_nodes) 225 | 226 | # nodes receiving messages 227 | dst_mask = pyg.utils.index_to_mask(edge_index_masked[self.i], num_nodes) 228 | 229 | msg = torch.zeros_like(h) 230 | msg[src_mask] = self.mlp_msg(h[src_mask]) 231 | agg = torch_sparse.matmul(adj if self.reverse_flow else adj.t(), msg) 232 | h[dst_mask] = h_init[dst_mask] + self.mlp_update(agg[dst_mask]) 233 | 234 | return h 235 | 236 | def _forward_no_mp(self, x: Tensor) -> Tensor: 237 | """forward pass without any message passing. Needed whenever 238 | all the active jobs are almost complete and only have a single 239 | layer of nodes remaining. 240 | """ 241 | return self.mlp_prep(x) 242 | 243 | 244 | class DagEncoder(nn.Module): 245 | def __init__( 246 | self, num_node_features: int, embed_dim: int, mlp_kwargs: dict[str, Any] 247 | ) -> None: 248 | super().__init__() 249 | input_dim = num_node_features + embed_dim 250 | self.mlp = utils.make_mlp(input_dim, output_dim=embed_dim, **mlp_kwargs) 251 | 252 | def forward(self, h_node: Tensor, dag_batch: pyg.data.Batch) -> Tensor: 253 | """returns a tensor of shape [num_dags, embed_dim]""" 254 | # include skip connection from raw input 255 | h_node = torch.cat([dag_batch.x, h_node], dim=1) 256 | h_dag = segment_csr(self.mlp(h_node), dag_batch.ptr) 257 | return h_dag 258 | 259 | 260 | class GlobalEncoder(nn.Module): 261 | def __init__(self, embed_dim: int, mlp_kwargs: dict[str, Any]) -> None: 262 | super().__init__() 263 | self.mlp = utils.make_mlp(embed_dim, output_dim=embed_dim, **mlp_kwargs) 264 | 265 | def forward(self, h_dag: Tensor, obs_ptr: Tensor | None = None) -> Tensor: 266 | """returns a tensor of shape [num_observations, embed_dim]""" 267 | h_dag = self.mlp(h_dag) 268 | 269 | if obs_ptr is not None: 270 | # batch of observations 271 | h_glob = segment_csr(h_dag, obs_ptr) 272 | else: 273 | # single observation 274 | h_glob = h_dag.sum(0).unsqueeze(0) 275 | 276 | return h_glob 277 | 278 | 279 | class StagePolicyNetwork(nn.Module): 280 | def __init__( 281 | self, 282 | num_node_features: int, 283 | emb_dims: dict[str, int], 284 | mlp_kwargs: dict[str, Any], 285 | ) -> None: 286 | super().__init__() 287 | input_dim = ( 288 | num_node_features + emb_dims["node"] + emb_dims["dag"] + emb_dims["glob"] 289 | ) 290 | 291 | self.mlp_score = utils.make_mlp(input_dim, output_dim=1, **mlp_kwargs) 292 | 293 | def forward(self, dag_batch: pyg.data.Batch, h_dict: dict[str, Tensor]) -> Tensor: 294 | """returns a tensor of shape [num_nodes,]""" 295 | 296 | stage_mask = dag_batch["stage_mask"] 297 | 298 | x = dag_batch.x[stage_mask] 299 | 300 | h_node = h_dict["node"][stage_mask] 301 | 302 | batch_masked = dag_batch.batch[stage_mask] 303 | h_dag_rpt = h_dict["dag"][batch_masked] 304 | 305 | if "num_stage_acts" in dag_batch: 306 | # batch of obsns 307 | num_stage_acts = dag_batch["num_stage_acts"] 308 | else: 309 | # single obs 310 | num_stage_acts = stage_mask.sum() 311 | 312 | h_glob_rpt = h_dict["glob"].repeat_interleave( 313 | num_stage_acts, output_size=h_node.shape[0], dim=0 314 | ) 315 | 316 | # residual connections to original features 317 | node_inputs = torch.cat([x, h_node, h_dag_rpt, h_glob_rpt], dim=1) 318 | 319 | node_scores = self.mlp_score(node_inputs).squeeze(-1) 320 | return node_scores 321 | 322 | 323 | class ExecPolicyNetwork(nn.Module): 324 | def __init__( 325 | self, 326 | num_executors: int, 327 | num_dag_features: int, 328 | emb_dims: dict[str, int], 329 | mlp_kwargs: dict[str, Any], 330 | ) -> None: 331 | super().__init__() 332 | self.num_executors = num_executors 333 | self.num_dag_features = num_dag_features 334 | input_dim = num_dag_features + emb_dims["dag"] + emb_dims["glob"] + 1 335 | 336 | self.mlp_score = utils.make_mlp(input_dim, output_dim=1, **mlp_kwargs) 337 | 338 | def forward( 339 | self, dag_batch: pyg.data.Batch, h_dict: dict[str, Tensor], job_indices: Tensor 340 | ) -> Tensor: 341 | exec_mask = dag_batch["exec_mask"] 342 | 343 | dag_start_idxs = dag_batch.ptr[:-1] 344 | x_dag = dag_batch.x[dag_start_idxs, : self.num_dag_features] 345 | x_dag = x_dag[job_indices] 346 | 347 | h_dag = h_dict["dag"][job_indices] 348 | 349 | exec_mask = exec_mask[job_indices] 350 | 351 | if "num_exec_acts" in dag_batch: 352 | # batch of obsns 353 | num_exec_acts = dag_batch["num_exec_acts"][job_indices] 354 | else: 355 | # single obs 356 | num_exec_acts = exec_mask.sum() 357 | x_dag = x_dag.unsqueeze(0) 358 | h_dag = h_dag.unsqueeze(0) 359 | exec_mask = exec_mask.unsqueeze(0) 360 | 361 | exec_actions = self._get_exec_actions(exec_mask) 362 | 363 | # residual connections to original features 364 | x_h_dag = torch.cat([x_dag, h_dag], dim=1) 365 | 366 | x_h_dag_rpt = x_h_dag.repeat_interleave( 367 | num_exec_acts, output_size=exec_actions.shape[0], dim=0 368 | ) 369 | 370 | h_glob_rpt = h_dict["glob"].repeat_interleave( 371 | num_exec_acts, output_size=exec_actions.shape[0], dim=0 372 | ) 373 | 374 | dag_inputs = torch.cat([x_h_dag_rpt, h_glob_rpt, exec_actions], dim=1) 375 | 376 | dag_scores = self.mlp_score(dag_inputs).squeeze(-1) 377 | return dag_scores 378 | 379 | def _get_exec_actions(self, exec_mask: Tensor) -> Tensor: 380 | exec_actions = torch.arange(self.num_executors) / self.num_executors 381 | exec_actions = exec_actions.to(exec_mask.device) 382 | exec_actions = exec_actions.repeat(exec_mask.shape[0]) 383 | exec_actions = exec_actions[exec_mask.view(-1)] 384 | exec_actions = exec_actions.unsqueeze(1) 385 | return exec_actions 386 | -------------------------------------------------------------------------------- /spark_sched_sim/spark_sched_sim.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_left, bisect_right 2 | from collections import deque 3 | from collections.abc import Iterable, Callable 4 | from typing import Any 5 | 6 | import numpy as np 7 | from gymnasium import Env 8 | import gymnasium.spaces as sp 9 | 10 | from .components import Job, Stage, Task, Executor 11 | from .components.executor_tracker import ExecutorTracker, PoolKey, COMMON_POOL_KEY 12 | from .components.event import Event, EventQueue 13 | from .data_samplers import make_data_sampler, DataSampler 14 | from .utils import subgraph 15 | from . import metrics 16 | 17 | try: 18 | from .components.renderer import Renderer 19 | 20 | PYGAME_AVAILABLE = True 21 | except ImportError: 22 | PYGAME_AVAILABLE = False 23 | 24 | 25 | NUM_NODE_FEATURES = 3 26 | RENDER_FPS = 30 27 | 28 | 29 | class SparkSchedSimEnv(Env): 30 | """A Gymnasium environment that simulates DAG job scheduling in Spark""" 31 | 32 | metadata = {"render_modes": ["human"], "render_fps": RENDER_FPS} 33 | 34 | def __init__(self, env_cfg: dict[str, Any]) -> None: 35 | # number of simulated executors. More executors means a higher possible 36 | # level of parallelism. 37 | self.num_executors: int = env_cfg["num_executors"] 38 | 39 | # time in ms it takes for a executor to move between jobs 40 | self.moving_delay: int = env_cfg["moving_delay"] 41 | 42 | # continuous discount factor in [0,+inf). If set to 0, then rewards are 43 | # not discounted. 44 | self.beta: float = env_cfg.get("beta", 0) 45 | 46 | # limit on the number of jobs that arrive throughout the simulation. If 47 | # set to `None`, then the episode ends when a time limit is reached. 48 | self.job_arrival_cap: int | None = env_cfg.get("job_arrival_cap") 49 | 50 | # if set to 'human', then a visualization of the simulation is rendred 51 | # in real time 52 | self.render_mode: str | None = env_cfg.get("render_mode") 53 | 54 | if self.render_mode == "human" and not PYGAME_AVAILABLE: 55 | raise ValueError("pygame is unavailable") 56 | 57 | self.data_sampler: DataSampler = make_data_sampler(env_cfg) 58 | 59 | # tracks the current time from the start of the simulation in ms 60 | self.wall_time: float = 0 61 | 62 | self.event_queue = EventQueue() 63 | 64 | self.jobs: dict[int, Job] = {} 65 | 66 | self.exec_tracker = ExecutorTracker(self.num_executors) 67 | 68 | self.event_handler_switch: dict[Event.Type, Callable[..., None]] = { 69 | Event.Type.JOB_ARRIVAL: self._handle_job_arrival, 70 | Event.Type.EXECUTOR_READY: self._handle_executor_arrival, 71 | Event.Type.TASK_FINISHED: self._handle_task_completion, 72 | } 73 | 74 | self.renderer: Renderer | None = None 75 | 76 | if self.render_mode == "human": 77 | self.renderer = Renderer( 78 | self.num_executors, 79 | self.job_arrival_cap, 80 | render_fps=self.metadata["render_fps"], 81 | ) 82 | 83 | self.job_duration_buff: deque[float] = deque(maxlen=200) 84 | 85 | self.action_space = sp.Dict( 86 | { 87 | # stage selection 88 | # NOTE: upper bound of this space is dynamic, equal to 89 | # the number of active stages. Initialized to 1. 90 | "stage_idx": sp.Discrete(1, start=-1), 91 | # parallelism limit selection 92 | "num_exec": sp.Discrete(self.num_executors, start=1), 93 | } 94 | ) 95 | 96 | self.observation_space = sp.Dict( 97 | { 98 | # shape: (num active stages) x (num node features) 99 | # stage features: num remaining tasks, most recent task duration, 100 | # is stage schedulable 101 | # edge features: none 102 | "dag_batch": sp.Graph( 103 | node_space=sp.Box(0, np.inf, (NUM_NODE_FEATURES,)), 104 | edge_space=sp.Discrete(1), 105 | ), 106 | # length: num active jobs 107 | # `ptr[job_idx]` returns the index of the first stage associated 108 | # with that job. E.g., the range of stage indices for a job is 109 | # given by `ptr[job_idx], ..., (ptr[job_idx+1]-1)` 110 | # NOTE: upper bound of this space is dynamic, equal to 111 | # the number of active stages. Initialized to 1. 112 | "dag_ptr": sp.Sequence(sp.Discrete(1), stack=True), 113 | # integer that represents how many executors need to be scheduled 114 | "num_committable_execs": sp.Discrete(self.num_executors + 1), 115 | # index of job who is releasing executors, if any. 116 | # set to `self.num_total_jobs` if source is common pool. 117 | "source_job_idx": sp.Discrete(1), 118 | # length: num active jobs 119 | # count of executors associated with each active job, 120 | # including moving executors and commitments from other jobs 121 | "exec_supplies": sp.Sequence( 122 | sp.Discrete(2 * self.num_executors), stack=True 123 | ), 124 | } 125 | ) 126 | 127 | def reset( 128 | self, seed: int | None = None, options: dict[str, Any] | None = None 129 | ) -> tuple[dict, dict]: 130 | super().reset(seed=seed) 131 | 132 | if options is None: 133 | options = {} 134 | 135 | time_limit = options.get("time_limit", np.inf) 136 | 137 | if time_limit is np.inf and not self.job_arrival_cap: 138 | raise ValueError("must either have a limit on job arrivals or time.") 139 | 140 | # simulation wall time in ms 141 | self.wall_time = 0 142 | 143 | self.data_sampler.reset(self.np_random) 144 | 145 | self.event_queue.reset() 146 | 147 | self.jobs.clear() 148 | 149 | job_sequence = self.data_sampler.job_sequence(time_limit) 150 | assert next(iter(job_sequence))[0] == 0, "first job must arrive at t=0" 151 | 152 | for t, job in job_sequence: 153 | self.event_queue.push(t, Event(Event.Type.JOB_ARRIVAL, data={"job": job})) 154 | self.jobs[job.id_] = job 155 | 156 | self.job_arrival_cap = len(self.jobs.keys()) 157 | self.observation_space["source_job_idx"].n = self.job_arrival_cap + 1 158 | if self.renderer: 159 | self.renderer.num_total_jobs = self.job_arrival_cap 160 | 161 | self.executors = [Executor(i) for i in range(self.num_executors)] 162 | self.exec_tracker.reset() 163 | 164 | # a fast way of obtaining the edge links for an observation is to 165 | # start out with all of them in a big array, and then to induce a 166 | # subgraph based on the current set of active nodes 167 | self._reset_edge_links() 168 | self.num_total_stages = self.all_job_ptr[-1] 169 | 170 | # must be ordered 171 | self.active_job_ids: list[int] = [] 172 | 173 | self.completed_job_ids: set[int] = set() 174 | 175 | # maintains the stages that have already been selected during the 176 | # current scheduling round so that they don't get selected again 177 | # until the next round 178 | self.selected_stages: set[Stage] = set() 179 | 180 | # (index of an active stage) -> (stage object) 181 | # used to trace an action to its corresponding stage 182 | self.stage_selection_map: dict[int, Stage] = {} 183 | 184 | self._load_initial_jobs() 185 | 186 | return self._observe(), self.info 187 | 188 | def step(self, action: dict) -> tuple[dict, float, bool, bool, dict]: 189 | self._take_action(action) 190 | 191 | if self.exec_tracker.num_committable_execs() and self.schedulable_stages: 192 | # there are still scheduling decisions to be made, so consult the agent again 193 | return self._observe(), 0, False, False, self.info 194 | 195 | # commitment round has completed, now schedule the free executors 196 | self._commit_remaining_executors() 197 | self._fulfill_commitments_from_source() 198 | self.exec_tracker.clear_executor_source() 199 | self.selected_stages.clear() 200 | 201 | # save old state attributes for computing reward 202 | wall_time_old = self.wall_time 203 | active_job_ids_old = self.active_job_ids.copy() 204 | 205 | # step through timeline until next scheduling event 206 | self._resume_simulation() 207 | 208 | job_time = self._compute_jobtime(wall_time_old, active_job_ids_old) 209 | reward = -job_time 210 | terminated = self.all_jobs_complete 211 | 212 | if not terminated: 213 | assert ( 214 | self.exec_tracker.num_committable_execs() and self.schedulable_stages 215 | ), "[step]" 216 | 217 | if self.render_mode == "human": 218 | self._render_frame() 219 | 220 | # if the episode isn't done, then start a new scheduling round at the current executor source 221 | return self._observe(), reward, terminated, False, self.info 222 | 223 | def close(self) -> None: 224 | if self.renderer: 225 | self.renderer.close() 226 | 227 | @property 228 | def all_jobs_complete(self) -> bool: 229 | return self.num_completed_jobs == len(self.jobs.keys()) 230 | 231 | @property 232 | def num_completed_jobs(self) -> int: 233 | return len(self.completed_job_ids) 234 | 235 | @property 236 | def num_active_jobs(self) -> int: 237 | return len(self.active_job_ids) 238 | 239 | @property 240 | def info(self) -> dict: 241 | return {"wall_time": self.wall_time} 242 | 243 | @property 244 | def avg_job_duration(self) -> float: 245 | return np.mean(self.job_duration_buff).item() * 1e-3 246 | 247 | # internal methods 248 | 249 | def _reset_edge_links(self) -> None: 250 | edge_links = [] 251 | job_ptr = [0] 252 | for job in self.jobs.values(): 253 | base_stage_idx = job_ptr[-1] 254 | edges = np.vstack(job.dag.edges) 255 | edge_links += [base_stage_idx + edges] 256 | job_ptr += [base_stage_idx + job.num_stages] 257 | self.all_edge_links = np.vstack(edge_links) 258 | self.all_job_ptr = np.array(job_ptr) 259 | 260 | def _load_initial_jobs(self) -> None: 261 | while q_top := self.event_queue.top(): 262 | wall_time, event = q_top 263 | 264 | if wall_time > 0: 265 | break 266 | 267 | self.event_queue.pop() 268 | 269 | job = event.data["job"] 270 | 271 | self._handle_job_arrival(job) 272 | 273 | self.schedulable_stages = self._find_schedulable_stages() 274 | 275 | def _take_action(self, action: dict) -> None: 276 | if not self.action_space.contains(action): 277 | raise ValueError("invalid action: does not belong to the action space") 278 | 279 | if action["stage_idx"] == -1: 280 | # no stage has been selected 281 | self._commit_remaining_executors() 282 | return 283 | 284 | stage = self.stage_selection_map[action["stage_idx"]] 285 | 286 | if not stage in self.schedulable_stages: 287 | raise ValueError("invalid action: stage is not currently schedulable") 288 | 289 | num_executors = action["num_exec"] 290 | 291 | if not num_executors: 292 | raise ValueError("invalid action: must commit at least one executor") 293 | 294 | if num_executors > self.exec_tracker.num_committable_execs(): 295 | raise ValueError("invalid action: too many executors requested") 296 | 297 | # agent may have requested more executors than are actually needed 298 | # or available 299 | num_executors = self._adjust_num_executors(num_executors, stage) 300 | self.exec_tracker.add_commitment(num_executors, stage.pool_key) 301 | 302 | # mark stage as selected so that it doesn't get selected again during 303 | # this scheduling round 304 | self.selected_stages.add(stage) 305 | 306 | # find remaining schedulable stages 307 | job_ids = [stage.job_id for stage in self.schedulable_stages] 308 | i = bisect_left(job_ids, stage.job_id) 309 | hi = min(len(job_ids), i + len(self.jobs[stage.job_id].active_stages)) 310 | j = bisect_right(job_ids, stage.job_id, lo=i, hi=hi) 311 | self.schedulable_stages = ( 312 | self.schedulable_stages[:i] 313 | + self._find_schedulable_stages([stage.job_id]) 314 | + self.schedulable_stages[j:] 315 | ) 316 | 317 | def _handle_event(self, event: Event) -> None: 318 | self.event_handler_switch[event.type](**event.data) 319 | 320 | def _resume_simulation(self) -> None: 321 | """resumes the simulation until either there are new scheduling 322 | decisions to be made, or it's done. 323 | """ 324 | schedulable_stages: list[Stage] = [] 325 | 326 | while q_top := self.event_queue.pop(): 327 | self.wall_time, event = q_top 328 | 329 | self._handle_event(event) 330 | 331 | if not self.exec_tracker.num_committable_execs(): 332 | continue 333 | 334 | schedulable_stages = self._find_schedulable_stages() 335 | if schedulable_stages: 336 | # there are schedulable stages and committable executors, 337 | # so we need to enter the scheduling loop 338 | break 339 | 340 | self._move_idle_executors() 341 | self.exec_tracker.clear_executor_source() 342 | 343 | self.schedulable_stages = schedulable_stages 344 | 345 | def _observe(self) -> dict[str, Any]: 346 | self.stage_selection_map.clear() 347 | 348 | nodes: list[tuple[int, float, bool]] = [] 349 | dag_ptr: list[int] = [0] 350 | active_stage_mask = np.zeros(self.num_total_stages, dtype=bool) 351 | exec_supplies: list[int] = [] 352 | source_job_idx = len(self.active_job_ids) 353 | 354 | for i, stage in enumerate(self.schedulable_stages): 355 | self.stage_selection_map[i] = stage 356 | stage.is_schedulable = True 357 | 358 | for i, job_id in enumerate(self.active_job_ids): 359 | job = self.jobs[job_id] 360 | 361 | if job_id == self.exec_tracker.source_job_id(): 362 | source_job_idx = i 363 | 364 | exec_supplies += [self.exec_tracker.exec_supply(job_id)] 365 | 366 | for stage in job.active_stages: 367 | nodes += [ 368 | ( 369 | stage.num_remaining_tasks, 370 | stage.most_recent_duration, 371 | stage.is_schedulable, 372 | ) 373 | ] 374 | stage.is_schedulable = False 375 | 376 | active_stage_mask[self.all_job_ptr[job_id] + stage.id_] = 1 377 | 378 | dag_ptr += [len(nodes)] 379 | 380 | try: 381 | nodes_stacked = np.vstack(nodes).astype(np.float32) 382 | except ValueError: 383 | # there are no active stages 384 | nodes_stacked = np.zeros((0, NUM_NODE_FEATURES), dtype=np.float32) 385 | 386 | edge_links = subgraph(self.all_edge_links, active_stage_mask) 387 | 388 | # not using edge data, so this array is always zeros 389 | edges = np.zeros(len(edge_links), dtype=int) 390 | 391 | num_committable_execs = self.exec_tracker.num_committable_execs() 392 | 393 | obs = { 394 | "dag_batch": sp.GraphInstance(nodes_stacked, edges, edge_links), 395 | "dag_ptr": dag_ptr, 396 | "num_committable_execs": num_committable_execs, 397 | "source_job_idx": source_job_idx, 398 | "exec_supplies": exec_supplies, 399 | } 400 | 401 | # update stage action space to reflect the current number of active 402 | # stages 403 | self.observation_space["dag_ptr"].feature_space.n = len(nodes) + 1 404 | self.action_space["stage_idx"].n = len(nodes) + 1 405 | 406 | return obs 407 | 408 | def _render_frame(self) -> None: 409 | assert self.renderer, "[_render_frame]" 410 | 411 | executor_histories = (executor.history for executor in self.executors) 412 | job_completion_times = ( 413 | self.jobs[job_id].t_completed for job_id in self.completed_job_ids 414 | ) 415 | average_job_duration = int(metrics.avg_job_duration(self) * 1e-3) 416 | 417 | self.renderer.render_frame( 418 | executor_histories, 419 | job_completion_times, 420 | self.wall_time, 421 | average_job_duration, 422 | self.num_active_jobs, 423 | self.num_completed_jobs, 424 | ) 425 | 426 | # event handlers 427 | 428 | def _handle_job_arrival(self, job: Job) -> None: 429 | self.active_job_ids += [job.id_] 430 | self.exec_tracker.add_job_pool(job.pool_key) 431 | for stage in job.stages: 432 | self.exec_tracker.add_stage_pool(stage.pool_key) 433 | 434 | if self.exec_tracker.common_pool_has_executors(): 435 | # if there are any executors that don't belong to any job, then 436 | # the agent might want to schedule them to this job, so start a 437 | # new round at the common pool 438 | self.exec_tracker.update_executor_source(COMMON_POOL_KEY) 439 | 440 | def _handle_executor_arrival(self, executor: Executor, stage: Stage) -> None: 441 | """performs some bookkeeping when a executor arrives""" 442 | job = self.jobs[stage.job_id] 443 | 444 | job.attach_executor(executor) 445 | executor.add_history(self.wall_time, job.id_) 446 | 447 | self.exec_tracker.record_executor_arrival(stage.pool_key) 448 | self.exec_tracker.move_executor_to_pool(executor.id_, job.pool_key) 449 | 450 | self._move_executor_to_stage(executor, stage) 451 | 452 | def _handle_task_completion(self, stage: Stage, task: Task) -> None: 453 | """performs some bookkeeping when a task completes""" 454 | job = self.jobs[stage.job_id] 455 | 456 | assert task.executor_id is not None, "[_handle_task_completion],1" 457 | executor = self.executors[task.executor_id] 458 | 459 | assert not stage.completed, "[_handle_task_completion],2" 460 | stage.record_task_completion() 461 | task.t_completed = self.wall_time 462 | executor.is_executing = False 463 | 464 | if stage.num_remaining_tasks > 0: 465 | # reassign the executor to keep working on this stage if there is more work to do 466 | self._execute_next_task(executor, stage) 467 | return 468 | 469 | did_job_frontier_change = False 470 | 471 | if stage.completed: 472 | did_job_frontier_change = self._process_stage_completion(stage) 473 | 474 | if job.completed: 475 | self._process_job_completion(job) 476 | 477 | # executor may have somewhere to be moved 478 | had_commitment = self._handle_released_executor( 479 | executor, stage, did_job_frontier_change 480 | ) 481 | 482 | # executor source may need to be updated 483 | self._update_executor_source(stage, had_commitment, did_job_frontier_change) 484 | 485 | # Other helper functions 486 | 487 | def _commit_remaining_executors(self) -> None: 488 | """There may be executors at the current source pool that weren't 489 | committed anywhere, e.g. because there were no more stages to 490 | schedule, or because the agent chose not to schedule all of them. 491 | 492 | This function explicitly commits those remaining executors to the 493 | common pool. When those executors get released, they either move to 494 | the job pool or the common pool, depending on whether the job is 495 | saturated at that time. 496 | 497 | It is important to do this, or else the agent could go in a lostage, 498 | under-committing executors from the same source pool. 499 | """ 500 | num_uncommitted_executors = self.exec_tracker.num_committable_execs() 501 | 502 | if num_uncommitted_executors > 0: 503 | self.exec_tracker.add_commitment(num_uncommitted_executors, COMMON_POOL_KEY) 504 | 505 | def _find_schedulable_stages( 506 | self, 507 | job_ids: Iterable[int] | None = None, 508 | source_job_id: int | None = None, 509 | ) -> list[Stage]: 510 | """An stage is schedulable if it is ready (see `_is_stage_ready()`), 511 | it hasn't been selected in the current scheduling round, and its job 512 | is not saturated with executors (i.e. can accept more executors). 513 | 514 | returns a union of schedulable stages over all the jobs specified 515 | in `job_ids`. If no job ids are provided, then all active jobs are 516 | searched. 517 | """ 518 | if not job_ids: 519 | job_ids = self.active_job_ids 520 | 521 | if not source_job_id: 522 | source_job_id = self.exec_tracker.source_job_id() 523 | 524 | # filter out saturated jobs. The source job is never considered saturated, because it 525 | # is not gaining any new executors during scheduling 526 | job_ids = [ 527 | job_id 528 | for job_id in job_ids 529 | if job_id == source_job_id 530 | or self.exec_tracker.exec_supply(job_id) < self.num_executors 531 | ] 532 | 533 | schedulable_stages = [ 534 | stage 535 | for job_id in job_ids 536 | for stage in self.jobs[job_id].active_stages 537 | if stage not in self.selected_stages and self._is_stage_ready(stage) 538 | ] 539 | 540 | return schedulable_stages 541 | 542 | def _is_stage_ready(self, stage: Stage) -> bool: 543 | """a stage is ready if 544 | - it is unsaturated, and 545 | - all of its parent stages are saturated 546 | """ 547 | if self._is_stage_saturated(stage): 548 | return False 549 | 550 | job = self.jobs[stage.job_id] 551 | for parent_stage in job.get_parent_stages(stage): 552 | if not self._is_stage_saturated(parent_stage): 553 | return False 554 | 555 | return True 556 | 557 | def _adjust_num_executors(self, num_executors: int, stage: Stage) -> int: 558 | """truncates the numer of executor assigned to `stage` to the stage's 559 | demand, if it's larger 560 | """ 561 | executor_demand = self._get_executor_demand(stage) 562 | num_executors_adjusted = min(num_executors, executor_demand) 563 | assert num_executors_adjusted > 0, "[_adjust_num_executors]" 564 | return num_executors_adjusted 565 | 566 | def _get_executor_demand(self, stage: Stage) -> int: 567 | """a stage's executor demand is the number of executors that it can 568 | accept in addition to the executors currently working on, committed to, 569 | and moving to the stage. Note: demand can be negative if more 570 | resources were assigned to the stage than needed. 571 | """ 572 | num_executors_moving = self.exec_tracker.num_executors_moving_to_stage( 573 | stage.pool_key 574 | ) 575 | num_commitments = self.exec_tracker.num_commitments_to_stage(stage.pool_key) 576 | 577 | demand = stage.num_remaining_tasks - (num_executors_moving + num_commitments) 578 | return demand 579 | 580 | def _is_stage_saturated(self, stage: Stage) -> bool: 581 | """a stage is saturated if it doesn't need any more executors.""" 582 | return self._get_executor_demand(stage) <= 0 583 | 584 | def _execute_next_task(self, executor: Executor, stage: Stage): 585 | """starts work on another one of `stage`'s tasks, assuming there are 586 | still tasks remaining and the executor is local to the stage 587 | """ 588 | assert stage.num_remaining_tasks > 0, "[_execute_next_task],1" 589 | assert executor.is_at_job(stage.job_id), "[_execute_next_task],2" 590 | assert not executor.is_executing, "[_execute_next_task],3" 591 | 592 | job = self.jobs[stage.job_id] 593 | 594 | task = stage.launch_next_task() 595 | if stage.num_remaining_tasks == 0: 596 | # stage just became saturated 597 | job.saturated_stage_count += 1 598 | 599 | task_duration = self.data_sampler.task_duration( 600 | self.jobs[stage.job_id], stage, task, executor 601 | ) 602 | 603 | executor.task = task 604 | executor.is_executing = True 605 | task.executor_id = executor.id_ 606 | task.t_accepted = self.wall_time 607 | stage.most_recent_duration = task_duration 608 | 609 | self.event_queue.push( 610 | self.wall_time + task_duration, 611 | Event( 612 | type=Event.Type.TASK_FINISHED, 613 | data={"stage": stage, "task": task}, 614 | ), 615 | ) 616 | 617 | def _send_executor(self, executor: Executor, stage: Stage) -> None: 618 | """sends a `executor` to `stage`, assuming that the executor is 619 | currently at a different job 620 | """ 621 | assert stage, "[_send_executor],1" 622 | assert not executor.is_executing, "[_send_executor],2" 623 | assert not executor.is_at_job(stage.job_id), "[_send_executor],3" 624 | 625 | self.exec_tracker.move_executor_to_pool(executor.id_, stage.pool_key, send=True) 626 | 627 | if executor.job_id is not None: 628 | old_job = self.jobs[executor.job_id] 629 | old_job.detach_executor(executor) 630 | 631 | self.event_queue.push( 632 | self.wall_time + self.moving_delay, 633 | Event( 634 | type=Event.Type.EXECUTOR_READY, 635 | data={"executor": executor, "stage": stage}, 636 | ), 637 | ) 638 | 639 | def _handle_released_executor( 640 | self, executor: Executor, stage: Stage, did_job_frontier_change: bool 641 | ) -> bool: 642 | """called upon a task completion. if the executor has been commited to 643 | a next stage, then try assigning it there. Otherwise, if `stage` became 644 | saturated and unlocked new stages within its job dag, then move the 645 | executor to the job's executor pool so that it can be assigned to the 646 | new stages 647 | """ 648 | commitment_pool_key = self.exec_tracker.peek_commitment(stage.pool_key) 649 | 650 | if commitment_pool_key is not None: 651 | self._fulfill_commitment(executor.id_, commitment_pool_key) 652 | return True 653 | 654 | # executor has nowhere to go, so make it idle 655 | executor.task = None 656 | 657 | if did_job_frontier_change: 658 | self._move_idle_executors(stage.pool_key, [executor.id_]) 659 | 660 | return False 661 | 662 | def _update_executor_source( 663 | self, stage: Stage, had_commitment: bool, did_job_frontier_change: bool 664 | ) -> None: 665 | """called upon a task completion. If any new stages were unlocked 666 | within this job upon the task completion, then start a new commitment 667 | round at this job's pool so that free executors can be assigned to the 668 | new stages. Otherwise, if the executor has nowhere to go, then start a 669 | new commitment round at this stage's pool to give it somewhere to go. 670 | """ 671 | if did_job_frontier_change: 672 | self.exec_tracker.update_executor_source(stage.job_pool_key) 673 | elif not had_commitment: 674 | self.exec_tracker.update_executor_source(stage.pool_key) 675 | 676 | def _process_stage_completion(self, stage: Stage) -> bool: 677 | """performs some bookkeeping when a stage completes""" 678 | job = self.jobs[stage.job_id] 679 | frontier_changed = job.record_stage_completion(stage) 680 | return frontier_changed 681 | 682 | def _process_job_completion(self, job: Job) -> None: 683 | """performs some bookkeeping when a job completes""" 684 | assert job.id_ in self.jobs, "[_process_job_completion],1" 685 | 686 | # if there are any executors still local to this job, then remove them 687 | if self.exec_tracker.pool_size(job.pool_key) > 0: 688 | self._move_idle_executors(job.pool_key) 689 | 690 | assert ( 691 | self.exec_tracker.pool_size(job.pool_key) == 0 692 | ), "[_process_job_completion],2" 693 | 694 | self.active_job_ids.remove(job.id_) 695 | self.completed_job_ids.add(job.id_) 696 | job.t_completed = self.wall_time 697 | self.job_duration_buff.append(job.t_completed - job.t_arrival) 698 | 699 | def _fulfill_commitment(self, executor_id: int, dst_pool_key: PoolKey) -> None: 700 | src_pool_key = self.exec_tracker.remove_commitment(executor_id, dst_pool_key) 701 | 702 | if dst_pool_key == COMMON_POOL_KEY: 703 | # this executor is free and isn't commited to any actual stage 704 | self._move_idle_executors(src_pool_key, [executor_id]) 705 | return 706 | 707 | job_id, stage_id = dst_pool_key 708 | assert job_id is not None and stage_id is not None, "[_fulfill_commitment]" 709 | stage = self.jobs[job_id].stages[stage_id] 710 | executor = self.executors[executor_id] 711 | 712 | self._move_executor_to_stage(executor, stage) 713 | 714 | def _get_idle_source_executors(self, pool_key: PoolKey | None = None) -> set[int]: 715 | if not pool_key: 716 | executor_ids = self.exec_tracker.get_source_pool() 717 | else: 718 | executor_ids = self.exec_tracker.get_pool(pool_key) 719 | 720 | free_executor_ids = set( 721 | ( 722 | executor_id 723 | for executor_id in executor_ids 724 | if not self.executors[executor_id].is_executing 725 | ) 726 | ) 727 | 728 | return free_executor_ids 729 | 730 | def _fulfill_commitments_from_source(self) -> None: 731 | # only consider the idle executors 732 | idle_executor_ids = self._get_idle_source_executors() 733 | commitments = self.exec_tracker.get_source_commitments() 734 | 735 | for dst_pool_key, num_executors in commitments.items(): 736 | assert dst_pool_key, "[_fulfill_commitments_from_source],1" 737 | assert num_executors, "[_fulfill_commitments_from_source],2" 738 | while num_executors and idle_executor_ids: 739 | executor_id = idle_executor_ids.pop() 740 | self._fulfill_commitment(executor_id, dst_pool_key) 741 | num_executors -= 1 742 | 743 | assert not idle_executor_ids, "[_fulfill_commitments_from_source],2" 744 | 745 | def _move_idle_executors( 746 | self, 747 | src_pool_key: PoolKey | None = None, 748 | executor_ids: list[int] | None = None, 749 | ) -> None: 750 | """When an executor becomes idle, it may need to be moved somewhere. 751 | If it's idle at a stage, it might need to be moved to the job pool. 752 | If it's idle at a job, it might need to be moved to the common pool. 753 | """ 754 | if src_pool_key is None: 755 | src_pool_key = self.exec_tracker.get_source() 756 | assert src_pool_key is not None, "[_move_idle_executors],1" 757 | 758 | if src_pool_key == COMMON_POOL_KEY: 759 | return # no-op 760 | 761 | if executor_ids is None: 762 | executor_ids = list(self._get_idle_source_executors(src_pool_key)) 763 | assert executor_ids, "[_move_idle_executors],2" 764 | 765 | job_id, stage_id = src_pool_key 766 | assert job_id is not None, "[_move_idle_executors],3" 767 | is_job_saturated = self.jobs[job_id].saturated 768 | if stage_id is None and not is_job_saturated: 769 | # source is an unsaturated job's pool 770 | return # no-op 771 | 772 | # if the source is a saturated job's pool, then move it to the common 773 | # pool. If it's a stage's pool, then move it to the job's pool. 774 | dst_pool_key = COMMON_POOL_KEY if is_job_saturated else (job_id, None) 775 | 776 | for executor_id in executor_ids: 777 | self.exec_tracker.move_executor_to_pool(executor_id, dst_pool_key) 778 | if dst_pool_key == COMMON_POOL_KEY: 779 | executor = self.executors[executor_id] 780 | job = self.jobs[job_id] 781 | job.detach_executor(executor) 782 | executor.add_history(self.wall_time, -1) 783 | 784 | def _try_backup_schedule(self, executor: Executor) -> None: 785 | """If a executor arrives to a stage that no longer needs any executors, 786 | then greedily try to find a backup stage. 787 | """ 788 | backup_stage = self._find_backup_stage(executor) 789 | if backup_stage: 790 | # found a backup 791 | self._move_executor_to_stage(executor, backup_stage) 792 | return 793 | 794 | # no backup stage found, so move executor to job or common pool 795 | # depending on whether or not the executor's job is saturated 796 | exec_location = self.exec_tracker.executor_location(executor.id_) 797 | self._move_idle_executors(exec_location, [executor.id_]) 798 | 799 | def _move_executor_to_stage(self, executor: Executor, stage: Stage) -> None: 800 | if stage.num_remaining_tasks == 0: 801 | # stage is saturated, so this executor is not needed there anymore 802 | self._try_backup_schedule(executor) 803 | return 804 | 805 | if not executor.is_at_job(stage.job_id): 806 | self._send_executor(executor, stage) 807 | return 808 | 809 | job = self.jobs[stage.job_id] 810 | if stage not in job.frontier_stages: 811 | # stage is not ready yet; make executor idle and move it to the 812 | # job pool 813 | executor.task = None 814 | self.exec_tracker.move_executor_to_pool(executor.id_, stage.job_pool_key) 815 | return 816 | 817 | # stage's dependencies are satisfied, so start working on it. 818 | self.exec_tracker.move_executor_to_pool(executor.id_, stage.pool_key) 819 | self._execute_next_task(executor, stage) 820 | 821 | def _find_backup_stage(self, executor: Executor) -> Stage | None: 822 | # first, try searching within the same job 823 | assert executor.job_id is not None, "[_find_backup_stage]" 824 | 825 | local_stages = self._find_schedulable_stages( 826 | job_ids=[executor.job_id], source_job_id=executor.job_id 827 | ) 828 | 829 | if local_stages: 830 | return local_stages[0] 831 | 832 | # now, try searching all other jobs 833 | other_job_ids = [ 834 | job_id for job_id in self.active_job_ids if not executor.is_at_job(job_id) 835 | ] 836 | 837 | other_stages = self._find_schedulable_stages( 838 | job_ids=other_job_ids, source_job_id=executor.job_id 839 | ) 840 | 841 | if other_stages: 842 | return other_stages[0] 843 | 844 | # out of luck 845 | return None 846 | 847 | def _compute_jobtime( 848 | self, wall_time_step: float, active_job_ids_step: list[int] 849 | ) -> float: 850 | duration = self.wall_time - wall_time_step 851 | if duration == 0.0: 852 | return 0.0 853 | 854 | # include jobs that completed and arrived during the most recent simulation run 855 | all_job_ids = set(active_job_ids_step + self.active_job_ids) 856 | 857 | job_time = 0.0 858 | for job_id in all_job_ids: 859 | job = self.jobs[job_id] 860 | start = max(job.t_arrival, wall_time_step) 861 | end = min(job.t_completed, self.wall_time) 862 | 863 | if self.beta == 0.0: 864 | job_time += end - start 865 | else: 866 | # continuously discounted job-time 867 | job_time += np.exp( 868 | -self.beta * 1e-3 * (start - wall_time_step) 869 | ) - np.exp(-self.beta * 1e-3 * (end - wall_time_step)) 870 | 871 | if self.beta > 0.0: 872 | job_time /= self.beta 873 | 874 | return job_time 875 | --------------------------------------------------------------------------------