├── .gitignore ├── LICENSE ├── README.md ├── allocator.py ├── application.py ├── arbiter.py ├── clean.sh ├── cluster.py ├── configs ├── applications │ └── solo.yaml ├── arbiter │ └── noop.yaml ├── cluster │ ├── dgx-a100.yaml │ ├── dgx-h100.yaml │ ├── half_half.yaml │ ├── hhcap_half_half.yaml │ ├── isocost_a100.yaml │ ├── isocost_h100.yaml │ ├── isocost_hybrid.yaml │ ├── isocount_a100.yaml │ ├── isocount_hybrid.yaml │ ├── isopower_a100.yaml │ ├── isopower_hybrid.yaml │ ├── solo_a100.yaml │ ├── solo_h100.yaml │ └── solo_hybrid.yaml ├── config.yaml ├── experiment │ ├── baseline_a100_costopt.yaml │ ├── baseline_h100_costopt.yaml │ ├── isocost_cluster.yaml │ ├── isocount_cluster.yaml │ ├── isopower_cluster.yaml │ ├── splitwise_aa_costopt.yaml │ ├── splitwise_aa_isocost.yaml │ ├── splitwise_aa_isopower.yaml │ ├── splitwise_ha_costopt.yaml │ ├── splitwise_ha_isocost.yaml │ ├── splitwise_ha_isopower.yaml │ ├── splitwise_hh_costopt.yaml │ ├── splitwise_hh_isocost.yaml │ ├── splitwise_hh_isopower.yaml │ ├── splitwise_hhcap_costopt.yaml │ ├── splitwise_hhcap_isocost.yaml │ ├── splitwise_hhcap_isopower.yaml │ ├── traces.yaml │ └── traces_light.yaml ├── hardware_repo │ ├── default.yaml │ ├── interconnects │ │ └── nvlink.yaml │ ├── processors │ │ ├── a100-40gb.yaml │ │ ├── a100-80gb.yaml │ │ ├── h100-80gb-pcap.yaml │ │ └── h100-80gb.yaml │ └── skus │ │ ├── dgx-a100.yaml │ │ ├── dgx-h100-pcap.yaml │ │ └── dgx-h100.yaml ├── model_repo │ ├── architectures │ │ ├── bloom-176b.yaml │ │ ├── gpt3-175b.yaml │ │ ├── llama-13b.yaml │ │ ├── llama-33b.yaml │ │ ├── llama2-70b.yaml │ │ ├── opt-30b.yaml │ │ └── opt-66b.yaml │ ├── default.yaml │ └── sizes │ │ ├── bloom-176b-fp16.yaml │ │ └── llama2-70b-fp16.yaml ├── orchestrator_repo │ ├── allocators │ │ └── noop.yaml │ ├── default.yaml │ └── schedulers │ │ ├── jsq.yaml │ │ ├── kv_jsq.yaml │ │ ├── kv_round_robin.yaml │ │ ├── kv_round_robin_ethernet.yaml │ │ ├── kv_token_jsq.yaml │ │ ├── mixed_pool.yml │ │ ├── overlap_kv_jsq.yaml │ │ ├── overlap_kv_token_jsq.yaml │ │ ├── random.yaml │ │ ├── round_robin.yaml │ │ └── token_jsq.yaml ├── performance_model │ ├── constant.yaml │ └── db.yaml ├── power_model │ └── constant.yaml ├── router │ ├── noop.yaml │ └── overheads │ │ └── zero.yaml ├── start_state │ ├── baseline.yaml │ ├── orca.yaml │ ├── splitwise.yaml │ ├── splitwise_hhcap.yaml │ └── unallocated.yaml └── trace │ └── test_trace.yaml ├── data └── perf_model.csv ├── executor.py ├── flow.py ├── generate_trace.py ├── hardware_repo.py ├── initialize.py ├── instance.py ├── interconnect.py ├── metrics.py ├── model.py ├── model_repo.py ├── node.py ├── notebooks ├── example.ipynb ├── perf_model.py ├── plots.ipynb └── utils.py ├── orchestrator_repo.py ├── performance_model.py ├── power_model.py ├── processor.py ├── request.py ├── requirements.txt ├── router.py ├── run.py ├── scheduler.py ├── scripts ├── run_baseline_a.sh ├── run_baseline_h.sh ├── run_baseline_h_example.sh ├── run_costopt.sh ├── run_isocost.sh ├── run_isopower.sh ├── run_splitwise_aa.sh ├── run_splitwise_ha.sh ├── run_splitwise_ha_example.sh ├── run_splitwise_hh.sh ├── run_splitwise_hhcap.sh ├── run_throughput.sh ├── run_throughput_isocost.sh ├── run_throughput_isopower.sh └── run_traces.sh ├── server.py ├── simulator.py ├── start_state.py ├── sync_scripts ├── sync_configs.sh ├── sync_repos.sh ├── sync_results.sh └── sync_traces.sh ├── task.py ├── trace.py ├── traces └── test_trace.csv └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | multirun.yaml 2 | __pycache__ 3 | results/ 4 | .hydra 5 | *.pdf 6 | traces/ 7 | code_distributions.csv 8 | conv_distributions.csv 9 | 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Pratyush Patel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SplitwiseSim: LLM Serving Cluster Simulator 2 | 3 | SplitwiseSim is a discrete event simulator that helps evaluate model serving in LLM inference clusters. It was built to evaluate [Splitwise](#reference), a generative LLM inference serving technique that splits LLM inference phases across different machines. SplitwiseSim can easily be extended to other applications and use cases. 4 | 5 | ## Setup 6 | 7 | You can set up SplitwiseSim by installing its Python dependencies. We recommend starting with a fresh Python environment. 8 | 9 | ```python 10 | # Create and activate new Python environment 11 | conda create -n splitwise-sim python=3.11 12 | conda activate splitwise-sim 13 | 14 | # Install dependencies 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | **NOTE**: SplitwiseSim has only been tested with Python 3.11. However, it will likely also work with other Python versions. 19 | 20 | ## Inputs and Outputs 21 | 22 | SplitwiseSim takes in a hierarchical set of YAML configuration files as input, and it produces several CSV files as output. It uses [Hydra](https://hydra.cc/) for configuration management. You can learn more about configuration management from the [Hydra docs](https://hydra.cc/docs/intro/). 23 | 24 | The top-level configuration file for SplitwiseSim is [`config.yaml`](configs/config.yaml), which points to lower-level configurations specified by other files in the `configs/` directory. Specifically, `config.yaml` captures the following key components: 25 | 26 | - [cluster](configs/cluster/): the provisioned server SKUs in the cluster, along with their respective counts. 27 | - [trace](#request-traces): request trace that specifies the set of requests that arrive into the cluster. 28 | - [router](configs/router/): the cluster-level router that routes incoming requests to application-level schedulers; currently a no-op. 29 | - [arbiter](configs/arbiter/): the cluster-level arbiter that manages compute resources between applications to support autoscaling; currently a no-op. 30 | - [application](configs/applications/): the logical endpoint that the requests target, which specifies the model and the set of instances on which the request runs; currently, we support only one application. 31 | - [model_repo](configs/model_repo/): the set of models (LLMs) available to run in the cluster; used for dynamic model instantiation. 32 | - [orchestrator_repo](configs/orchestrator_repo/): the set of application resource orchestrators (i.e., schedulers and allocators) in the cluster; used for dynamic application management. 33 | - [hardware_repo](configs/hardware_repo/): the set of available SKUs that can be provisioned in the cluster; used for dynamic server instantiation. 34 | - [performance_model](#performance-model): an analytical model that helps estimate request runtimes with different batch, model, and hardware configurations. 35 | - [start_state](configs/start_state/): starting state for the cluster, which helps simplify evaluation. 36 | 37 | Several other aspects can be configured; please see [`config.yaml`](configs/config.yaml) for details. 38 | 39 | SplitwiseSim generates the following key outputs: 40 | 41 | - Summary of application-level metrics (`summary.csv`) 42 | - Per-request metrics for each completed request for each application (`detailed/{application_id}.csv`) 43 | - Request node-level metrics (`request_nodes.csv`) 44 | - Instance-level execution metrics (in `instances/`, with `debug` enabled) 45 | 46 | We provide various [utility functions](notebooks/utils.py) to process outputs, as shown in [`notebooks/example.ipynb`](notebooks/example.ipynb) and [`notebooks/plots.ipynb`](notebooks/plots.ipynb). 47 | 48 | ## Example Run 49 | 50 | The simplest way to run SplitwiseSim is to execute [`run.py`](run.py), which runs with the default configuration parameters specified in [`config.yaml`](configs/config.yaml). The default configurations can be overridden by specifying appropriate command line parameters using Hydra. Below is an example script, [`scripts/run_baseline_h_example.sh`](scripts/run_baseline_h_example.sh), which overrides the default configuration to execute a simple `Baseline-H100` configuration with a single DGX-H100 server. 51 | 52 | ```bash 53 | # scripts/run_baseline_h_example.sh 54 | 55 | SCHEDULER=token_jsq 56 | NUM_DGX_A100=0 57 | NUM_DGX_H100=1 58 | START_STATE=baseline 59 | TRACE=test_trace 60 | 61 | python run.py \ 62 | applications.0.scheduler=$SCHEDULER \ 63 | cluster=half_half \ 64 | cluster.servers.0.count=$NUM_DGX_A100 \ 65 | cluster.servers.1.count=$NUM_DGX_H100 \ 66 | start_state=$START_STATE \ 67 | performance_model=db \ 68 | trace.filename=$TRACE \ 69 | debug=True \ 70 | seed=0 71 | ``` 72 | 73 | Specifically, each configuration override changes a corresponding default from `config.yaml` as follows: 74 | 75 | - `cluster=half_half` overrides the cluster default from [`dgx-a100`](configs/cluster/dgx-a100.yaml) to [`half_half`](configs/cluster/half_half.yaml), which has 1 DGX-A100 and 1 DGX-H100 server SKU by default. 76 | - `cluster.servers.*` replace the number of DGX-A100 and DGX-H100 servers within the [`half_half`](configs/cluster/half_half.yaml) cluster to 0 and 1, respectively. 77 | - `applications.0.scheduler=token_jsq` switches the default [`round_robin`](configs/orchestrator_repo/schedulers/round_robin.yaml) scheduler, as specified in [`configs/applications/solo.yaml`](configs/applications/solo.yaml), to the [`token_jsq`](configs/orchestrator_repo/schedulers/token_jsq.yaml) scheduler. 78 | - `start_state=baseline` overrides the starting state from [`orca`](configs/start_state/orca.yaml) to [`baseline`](configs/start_state/baseline.yaml). 79 | - `performance_model=db` overrides the performance model to [`db`](configs/performance_model/db.yaml) instead of the default [`constant`](configs/performance_model/constant.yaml). 80 | - `trace.filename=test_trace` changes the trace file name (same as default, so no effect). 81 | - `debug=True` enables the debug flag (changed from `False`) 82 | - `seed=0` sets the seed to `0` (same as default, so no effect). 83 | 84 | Several of the above overrides configure objects of classes specified by the `_target_` field in the corresponding configuration files. 85 | 86 | To simulate this simple Baseline-H100 configuration with a single DGX-H100 on [`test_trace.csv`](traces/test_trace.csv), we can simply run the bash script: 87 | 88 | ```bash 89 | # run simple Baseline-H100 example 90 | ./scripts/run_baseline_h_example.sh 91 | ``` 92 | 93 | Similarly, we could run a simple Splitwise-HA configuration, which simulates KV-cache transfers from a DGX-H100 machine to DGX-A100 machine (see [paper](#reference) for more details): 94 | 95 | ```bash 96 | 97 | # run simple Splitwise-HA example 98 | ./scripts/run_splitwise_ha_example.sh 99 | ``` 100 | 101 | **NOTE**: Scripts must be run from the top-level directory. 102 | 103 | Results will be generated in the `results/` directory according to the output path template specified by the `output_dir` field in [`config.yaml`](configs/config.yaml). Open [`notebooks/example.ipynb`](notebooks/example.ipynb) using Jupyter Notebook to see an example of how to easily extract the associated outputs. 104 | 105 | ## Request Traces 106 | 107 | SplitwiseSim expects request traces in a CSV file that contains the following fields for each request: 108 | 109 | - `request_id`: ID of the request, typically a monotonically increasing number. 110 | - `request_type`: Type of the request (e.g., DL inference, LLM inference, etc.). Use `2` for generative LLM inference, which is the only supported type at present. 111 | - `application_id`: ID of the application / endpoint that the request targets. Default to `0` for a single application. 112 | - `arrival_timestamp`: Timestamp at which the request arrives into the cluster. 113 | - `batch_size`: If the request is already batched when it arrives, that can be specified here (currently not used). 114 | - `prompt_size`: Number of tokens in the input prompt of the request. 115 | - `token_size`: Number of tokens to be generated as output by the request. 116 | 117 | Many of these fields have limited configurability at present. A typical new trace would change the `request_id`, `arrival_timestamp`, `prompt_size`, and `token_size`. An example trace can be found in [`traces/test_trace.csv`](traces/test_trace.csv). 118 | 119 | ### Production Traces and Trace Generation 120 | 121 | [Splitwise](#reference) was evaluated with request traces that were based off [production traces](https://github.com/Azure/AzurePublicDataset/blob/master/AzureLLMInferenceDataset2023.md) from LLM inference services at Microsoft Azure. The [`generate_trace.py`](generate_trace.py) script can automatically download the production traces and use the corresponding prompt/token size distributions to generate request traces with different request rates. It can also help generate custom traces with different kinds of distributions. Modify and run `generate_trace.py` with desired request rates and other parameters. By default, all generated traces are expected to reside in the `traces/` directory. 122 | 123 | ## Request Processing 124 | 125 | SplitwiseSim processes request traces as follows: 126 | 127 | - All requests first arrive at a [Cluster](cluster.py)-level [Router](router.py), which forwards them to their target [Application](application.py). The Cluster also has an [Arbiter](arbiter.py) which helps reallocate [Servers](server.py) or [Processors](processor.py) between Applications. Currently, the Router and Arbiter act as no-ops, but they could be modified in the future to include smarter routing and autoscaling strategies with overheads. 128 | - Each [Request](request.py) targets a specific [Application](application.py), which may have one or more [Instances](instance.py) that run [Models](model.py). [Applications](application.py) use [Allocators](allocator.py) to spin-up/spin-down Instances on [Processors](processor.py), and they use [Schedulers](scheduler.py) to load balance Requests across Instances. Currently, we do not support dynamic Instance spin-up/spin-down, but rather use [start states](start_state.py) for specifying the initial set of Cluster Instances. 129 | - [Requests](request.py) are specified as a Directed Acyclic Graph (DAG) of [Nodes](node.py) for flexibility. Request nodes may either be [Tasks](task.py) and [Flows](flow.py). Requests are processed on [Instances](instance.py), which run on [Servers](server.py); specifically, Tasks are run on [Processors](processor.py) and Flows are run on [Links](interconnect.py). 130 | 131 | Note that all simulation times are assumed to be specified in seconds. 132 | 133 | ## Performance Model 134 | 135 | The [performance_model](performance_model.py) helps SplitwiseSim estimate how long requests run on diverse input, output, hardware, batch, etc. configurations. `performance_model.PerformanceModel` is an interface class which exposes the following two estimation functions to the simulator: 136 | 137 | 1. `get_duration()`: used to estimate the runtime of prompt and token tasks. 138 | 2. `get_iteration_duration()`: used to estimate the runtime of each batching iteration (e.g., from continuous batching). 139 | 140 | Since modern LLM serving typically uses [iteration-level scheduling](https://www.usenix.org/conference/osdi22/presentation/yu), we primarily rely on `get_iteration_duration` in the [Instance](instance.py) implementation (e.g., ORCAInstance and SplitwiseInstance). 141 | 142 | Currently, SplitwiseSim provides two concrete performance models: 143 | 144 | 1. `performance_model=constant`: This model assumes that all prompt and token tasks take a constant duration. While unrealistic, it is helpful for testing / debugging purposes. 145 | 2. `performance_model=db`: This model uses extensive profiling data from the DGX-A100 and DGX-H100 machines and is the preferable model to use for realistic simulations. The associated raw data can be found in [`data/perf-model.csv`](data/perf-model.csv). The `performance_model.DatabasePerformanceModel` class reads this raw data to build a simple linear predictor, which serves as the performance model. To extend SplitwiseSim to different LLMs/platforms, please add your profiling data to the data file and potentially update the performance model predictor. 146 | 147 | ## Experiments Workflow 148 | 149 | This section describes how to run larger-scale simulations spanning a variety of configurations. 150 | 151 | ### Parallel Simulations 152 | 153 | SplitwiseSim can be run on multiple cores (on one or more machines) to evaluate many different configurations in parallel. Each simulation configuration is run in a single process on a single core. SplitwiseSim uses [Ray](https://github.com/ray-project/ray) via the [Hydra Ray plugin](https://hydra.cc/docs/plugins/ray_launcher/) for parallelization. 154 | 155 | To start a Ray cluster, run: 156 | 157 | - `ray start --head` on the head machine. 158 | - `ray start --address=xxx` on each of the worker machines. 159 | 160 | See [Ray docs](https://docs.ray.io/en/latest/cluster/vms/user-guides/launching-clusters/on-premises.html) for more details. 161 | 162 | If you do not want to use Ray, you may alternatively use the Hydra [joblib](https://hydra.cc/docs/plugins/joblib_launcher/) launcher, which only supports multicore parallelization on a single machine. 163 | 164 | Running a Hydra configuration in parallel requires the `--multirun` flag. For example, to sweep over multiple seed values in parallel, use `python --multirun run.py seed=0,1,2,3,4,5,6,7,8,9` after starting the Ray cluster. 165 | 166 | Output from multi-machine runs is stored on different machines corresponding to where each simulation configuration runs. Subsequently, you may need to manually collect results back into the same machine using sync scripts. Example sync scripts can be found in the `sync_scripts` folder. 167 | 168 | ### Experiment Runs 169 | 170 | The `scripts/` directory provides several scripts to run larger experiments, including parallel sweeps over different cluster configurations: 171 | 172 | - To run a baseline configuration, run `./scripts/run_baseline_a.sh` (Baseline-A100) or `./scripts/run_baseline_h.sh` (Baseline-H100). 173 | - To run a Splitwise configuration, run the appropriate Splitwise-XX file under the scripts directory. For example, to run Splitwise-HA, run `./scripts/run_splitwise_ha.sh`. 174 | - Various experiment configurations used in the [Splitwise paper](#reference) are specified in the `configs/experiment/` folder. For example, to run a sweep of iso-cost clusters, you can run `./scripts/run_isocost.sh` which corresponds to `configs/experiment/*_isocost.yaml` with the appropriate sweep parameters (warning: running this may spin up many configurations in parallel and take a long time; try smaller configurations to begin with). 175 | 176 | ### Experiment Plots and Gantt Charts 177 | 178 | Outputs from experiment sweeps can be visualized by using the plotting scripts provided in `notebooks/plots.ipynb`. These scripts were used to plot some of the graphs in the [Splitwise paper](#reference). 179 | 180 | If the `debug` flag is enabled, SplitwiseSim additionally outputs iteration-level metadata per instance (including start/end timestamps), which can be visualized as Gantt charts for analysis and debugging. Check out `notebooks/example.ipynb` for a simple example. Custom markers can be added by modifying the simulator. 181 | 182 | ## Reference 183 | 184 | If you use SplitwiseSim in your work, please cite the accompanying [paper](https://www.microsoft.com/en-us/research/publication/splitwise-efficient-generative-llm-inference-using-phase-splitting/): 185 | 186 | > Pratyush Patel, Esha Choukse, Chaojie Zhang, Aashaka Shah, Íñigo Goiri, Saeed Maleki, Ricardo Bianchini. "Splitwise: Efficient Generative LLM Inference Using Phase Splitting", in Proceedings of the International Symposium on Computer Architecture (ISCA 2024). ACM, Buenos Aires, Argentina, 2024. 187 | -------------------------------------------------------------------------------- /allocator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from abc import ABC 4 | from itertools import count 5 | 6 | import model_repo 7 | 8 | from instance import Instance 9 | from simulator import clock, schedule_event, cancel_event, reschedule_event 10 | 11 | 12 | class Allocator(ABC): 13 | """ 14 | Allocator autoscales Application Instances onto Servers. 15 | It receives/releases Servers from/to Arbiters. 16 | """ 17 | def __init__(self, 18 | application, 19 | arbiter, 20 | overheads, 21 | instance_overheads, 22 | debug=False): 23 | self.application = application 24 | self.arbiter = arbiter 25 | self.overheads = overheads 26 | self.instance_overheads = instance_overheads 27 | self.total_instances = count(0) 28 | self.debug = debug 29 | 30 | @property 31 | def application(self): 32 | return self._application 33 | 34 | @application.setter 35 | def application(self, application): 36 | self._application = application 37 | 38 | def start_spin_up_instance(self, 39 | instance_cfg, 40 | processors, 41 | parallelism, 42 | pre_start=False, 43 | tag=None): 44 | """ 45 | Spin up a new instance of the application on specified processors. 46 | Assigns a metadata tag to the instance for orchestration. 47 | # TODO: better way to pass in processors / parallelism 48 | """ 49 | model_architecture = self.application.model_architecture 50 | model_size = self.application.model_size 51 | model = model_repo.get_model(model_architecture=model_architecture, 52 | model_size=model_size, 53 | model_parallelism=parallelism) 54 | instance = Instance.from_config(instance_cfg=instance_cfg, 55 | instance_id=next(self.total_instances), 56 | application=self.application, 57 | name=processors[0].name, 58 | tag=tag, 59 | model=model, 60 | processors=processors, 61 | overheads=self.instance_overheads, 62 | debug=self.debug) 63 | 64 | def finish_spin_up(): 65 | self.finish_spin_up_instance(instance) 66 | if pre_start is True: 67 | finish_spin_up() 68 | else: 69 | schedule_event(self.overheads.spin_up, finish_spin_up) 70 | 71 | def finish_spin_up_instance(self, instance): 72 | """ 73 | Finish spinning up an instance after the spin up delay. 74 | """ 75 | self.application.add_instance(instance) 76 | instance.metrics.spin_up_timestamp = clock() 77 | return instance 78 | 79 | def start_spin_down_instance(self, instance): 80 | """ 81 | Spin down an instance of the application. 82 | """ 83 | pass 84 | 85 | def finish_spin_down_instance(self, instance, processors): 86 | """ 87 | Finish spinning down an instance after the spin down delay. 88 | """ 89 | instance.memory = 0 90 | pass 91 | 92 | def run(self): 93 | """ 94 | Run the allocator. Useful for periodic calls. 95 | """ 96 | pass 97 | 98 | def get_results(self): 99 | results = {} 100 | 101 | instance_names = [] 102 | utilizations = [] 103 | for instance in self.application.instances: 104 | #assert len(instance.pending_requests) == 0, instance.instance_id 105 | #assert len(instance.pending_queue) == 0, instance.instance_id 106 | #assert instance.memory == instance.model.size.total_size, instance.instance_id 107 | instance.metrics.spin_down_timestamp = clock() 108 | instance.metrics.interval_time = clock() - instance.metrics.spin_up_timestamp 109 | utilization = instance.metrics.busy_time / instance.metrics.interval_time 110 | instance_names.append(instance.processors[0].name) 111 | utilizations.append(utilization) 112 | 113 | results['instance_names'] = instance_names 114 | results['utilizations'] = utilizations 115 | return results 116 | 117 | 118 | class NoOpAllocator(Allocator): 119 | """ 120 | No-op Allocator. 121 | 122 | Assumes that instances are already allocated to servers using start states. 123 | """ 124 | pass 125 | -------------------------------------------------------------------------------- /application.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | import model_repo 6 | import orchestrator_repo 7 | 8 | from metrics import ApplicationMetrics, ApplicationSLO 9 | from simulator import clock, schedule_event, cancel_event, reschedule_event 10 | 11 | 12 | class Application(): 13 | """ 14 | An Application is the endpoint that a Request targets. 15 | Applications can have any number of Instances which all serve the same model. 16 | Requests are scheduled to Instances by the Scheduler. 17 | Application Instances can be autoscaled by the Allocator. 18 | """ 19 | def __init__(self, 20 | application_id, 21 | model_architecture, 22 | model_size, 23 | cluster, 24 | router, 25 | arbiter, 26 | overheads, 27 | scheduler=None, 28 | allocator=None, 29 | instances=None): 30 | self.application_id = application_id 31 | 32 | # hardware 33 | self.processors = [] 34 | 35 | # model 36 | self.model_architecture = model_architecture 37 | self.model_size = model_size 38 | 39 | # orchestration 40 | if instances is None: 41 | self.instances = [] 42 | self.cluster = cluster 43 | self.scheduler = scheduler 44 | self.allocator = allocator 45 | self.router = router 46 | self.arbiter = arbiter 47 | 48 | # overheads 49 | self.overheads = overheads 50 | 51 | # metrics 52 | self.metrics = ApplicationMetrics() 53 | self.slo = ApplicationSLO() 54 | 55 | def add_instance(self, instance): 56 | """ 57 | Application-specific method to add an instance to the application. 58 | """ 59 | self.instances.append(instance) 60 | self.scheduler.add_instance(instance) 61 | 62 | def get_results(self): 63 | allocator_results = self.allocator.get_results() 64 | scheduler_results = self.scheduler.get_results() 65 | self.scheduler.save_all_request_metrics() 66 | return allocator_results, scheduler_results 67 | 68 | @classmethod 69 | def from_config(cls, *args, cluster, arbiter, router, **kwargs): 70 | # parse args 71 | application_cfg = args[0] 72 | 73 | # get model 74 | model_architecture_name = application_cfg.model_architecture 75 | model_size_name = application_cfg.model_size 76 | model_architecture = model_repo.get_model_architecture(model_architecture_name) 77 | model_size = model_repo.get_model_size(model_size_name) 78 | 79 | # get orchestrators 80 | allocator_name = application_cfg.allocator 81 | scheduler_name = application_cfg.scheduler 82 | application = cls(application_id=application_cfg.application_id, 83 | model_architecture=model_architecture, 84 | model_size=model_size, 85 | cluster=cluster, 86 | router=router, 87 | arbiter=arbiter, 88 | overheads=application_cfg.overheads) 89 | allocator = orchestrator_repo.get_allocator(allocator_name, 90 | arbiter=arbiter, 91 | application=application, 92 | debug=application_cfg.debug) 93 | scheduler = orchestrator_repo.get_scheduler(scheduler_name, 94 | router=router, 95 | application=application, 96 | debug=application_cfg.debug) 97 | application.scheduler = scheduler 98 | application.allocator = allocator 99 | return application 100 | -------------------------------------------------------------------------------- /arbiter.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from abc import ABC 4 | 5 | from simulator import clock, schedule_event, cancel_event, reschedule_event 6 | 7 | 8 | class Arbiter(ABC): 9 | """ 10 | Arbiter allocates Processors to Application Allocators. 11 | It can be used to support application autoscaling. 12 | """ 13 | def __init__(self, 14 | cluster, 15 | overheads): 16 | self.cluster = cluster 17 | self.overheads = overheads 18 | self.servers = cluster.servers 19 | self.applications = [] 20 | self.allocators = {} 21 | 22 | def add_application(self, application): 23 | self.applications.append(application) 24 | self.allocators[application.application_id] = application.allocator 25 | 26 | def run(self): 27 | pass 28 | 29 | def allocate(self, processors, application): 30 | """ 31 | Allocates processors to the application. 32 | """ 33 | pass 34 | 35 | def deallocate(self, processors, application): 36 | """ 37 | Deallocates processors from the application. 38 | """ 39 | pass 40 | 41 | 42 | class NoOpArbiter(Arbiter): 43 | """ 44 | No-op Arbiter. 45 | """ 46 | pass 47 | -------------------------------------------------------------------------------- /clean.sh: -------------------------------------------------------------------------------- 1 | rm -rf results/* 2 | -------------------------------------------------------------------------------- /cluster.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from collections import defaultdict 4 | from itertools import count 5 | 6 | from hydra.utils import instantiate 7 | 8 | import hardware_repo 9 | 10 | from simulator import clock, schedule_event, cancel_event, reschedule_event 11 | from server import Server 12 | 13 | 14 | class Cluster: 15 | """ 16 | Cluster is a collection of Servers and interconnected Links. 17 | """ 18 | def __init__(self, 19 | servers, 20 | interconnects, 21 | power_budget): 22 | self.servers = servers 23 | self.interconnects = interconnects 24 | self.power_budget = power_budget 25 | self.total_power = 0 26 | for sku_name in self.servers: 27 | for server in self.servers[sku_name]: 28 | server.cluster = self 29 | self.total_power += server.power 30 | self.inflight_commands = [] 31 | 32 | # logger for simulated power usage (NOTE: currently unsupported) 33 | #self.power_logger = utils.file_logger("power") 34 | #self.power_logger.info("time,server,power") 35 | 36 | def __str__(self): 37 | return "Cluster:" + str(self.servers) 38 | 39 | def add_server(self, server): 40 | self.servers.append(server) 41 | 42 | def remove_server(self, server): 43 | self.servers.remove(server) 44 | 45 | def models(self): 46 | models = [] 47 | for server in self.servers: 48 | models.extend(server.models) 49 | return models 50 | 51 | @property 52 | def power(self, cached=True, servers=None): 53 | """ 54 | Returns the total power usage of the cluster. 55 | Can return the cached value for efficiency. 56 | TODO: unsupported 57 | """ 58 | if cached and servers is None: 59 | return self.total_power 60 | if servers is None: 61 | servers = self.servers 62 | return sum(server.power() for server in servers) 63 | 64 | def update_power(self, power_diff): 65 | """ 66 | Updates the total power usage of the cluster. 67 | TODO: unsupported 68 | """ 69 | self.total_power += power_diff 70 | 71 | def power_telemetry(self, power): 72 | """ 73 | Logs the power usage of the cluster. 74 | TODO: currently unsupported; make configurable 75 | 76 | Args: 77 | power (float): The power usage. 78 | """ 79 | time_interval = 60 80 | schedule_event(time_interval, 81 | lambda self=self, power=self.total_power: \ 82 | self.power_telemetry(0)) 83 | 84 | def run(self): 85 | """ 86 | Runs servers in the cluster. 87 | """ 88 | # NOTE: power usage updates not supported 89 | power = 0 90 | for sku in self.servers: 91 | for server in self.servers[sku]: 92 | server.run() 93 | power += server.power 94 | 95 | @classmethod 96 | def from_config(cls, *args, **kwargs): 97 | # args processing 98 | cluster_cfg = args[0] 99 | servers_cfg = cluster_cfg.servers 100 | interconnects_cfg = cluster_cfg.interconnects 101 | 102 | # instantiate servers 103 | server_id = count() 104 | servers = defaultdict(list) 105 | for server_cfg in servers_cfg: 106 | for n in range(server_cfg.count): 107 | sku_cfg = hardware_repo.get_sku_config(server_cfg.sku) 108 | server = Server.from_config(sku_cfg, server_id=next(server_id)) 109 | servers[server_cfg.sku].append(server) 110 | 111 | # instantiate interconnects 112 | # TODO: add better network topology / configuration support 113 | interconnects = [] 114 | for interconnect_cfg in interconnects_cfg: 115 | if interconnect_cfg.topology == "p2p": 116 | continue 117 | interconnect = instantiate(interconnect_cfg) 118 | interconnects.append(interconnect) 119 | 120 | return cls(servers=servers, 121 | interconnects=interconnects, 122 | power_budget=cluster_cfg.power_budget) 123 | 124 | 125 | if __name__ == "__main__": 126 | pass 127 | -------------------------------------------------------------------------------- /configs/applications/solo.yaml: -------------------------------------------------------------------------------- 1 | # list of application_ids and associated model_architectures 2 | 3 | - application_id: 0 4 | model_architecture: bloom-176b 5 | model_size: bloom-176b-fp16 6 | #model_architecture: llama2-70b 7 | #model_size: llama2-70b-fp16 8 | allocator: noop 9 | scheduler: round_robin 10 | overheads: {} 11 | debug: ${debug} 12 | _target_: application.Application 13 | -------------------------------------------------------------------------------- /configs/arbiter/noop.yaml: -------------------------------------------------------------------------------- 1 | _target_: arbiter.NoOpArbiter 2 | overheads: {} 3 | -------------------------------------------------------------------------------- /configs/cluster/dgx-a100.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-a100 5 | count: 40 6 | - sku: dgx-h100 7 | count: 0 8 | 9 | interconnects: 10 | - link: infiniband 11 | topology: p2p 12 | -------------------------------------------------------------------------------- /configs/cluster/dgx-h100.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-a100 5 | count: 0 6 | - sku: dgx-h100 7 | #count: 800 8 | count: 40 9 | 10 | interconnects: 11 | - link: infiniband 12 | topology: p2p 13 | -------------------------------------------------------------------------------- /configs/cluster/half_half.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-a100 5 | count: 1 6 | - sku: dgx-h100 7 | count: 1 8 | 9 | interconnects: 10 | - link: infiniband 11 | topology: p2p 12 | -------------------------------------------------------------------------------- /configs/cluster/hhcap_half_half.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-h100-pcap 5 | count: 1 6 | - sku: dgx-h100 7 | count: 1 8 | 9 | interconnects: 10 | - link: infiniband 11 | topology: p2p 12 | -------------------------------------------------------------------------------- /configs/cluster/isocost_a100.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-a100 5 | count: 86 6 | - sku: dgx-h100 7 | count: 0 8 | 9 | interconnects: 10 | - link: infiniband 11 | topology: p2p 12 | -------------------------------------------------------------------------------- /configs/cluster/isocost_h100.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-a100 5 | count: 0 6 | - sku: dgx-h100 7 | count: 40 8 | 9 | interconnects: 10 | - link: infiniband 11 | topology: p2p 12 | -------------------------------------------------------------------------------- /configs/cluster/isocost_hybrid.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-a100 5 | #count: 81 6 | #count: 79 7 | #count: 64 8 | #count: 62 9 | #count: 60 10 | #count: 58 11 | #count: 55 12 | #count: 53 13 | count: 51 14 | #count: 49 15 | - sku: dgx-h100 16 | #count: 2 17 | #count: 3 18 | #count: 10 19 | #count: 11 20 | #count: 12 21 | #count: 13 22 | #count: 14 23 | #count: 15 24 | count: 16 25 | #count: 17 26 | 27 | interconnects: 28 | - link: infiniband 29 | topology: p2p 30 | -------------------------------------------------------------------------------- /configs/cluster/isocount_a100.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-a100 5 | count: 40 6 | - sku: dgx-h100 7 | count: 0 8 | 9 | interconnects: 10 | - link: infiniband 11 | topology: p2p 12 | -------------------------------------------------------------------------------- /configs/cluster/isocount_hybrid.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-a100 5 | count: 38 6 | - sku: dgx-h100 7 | count: 2 8 | 9 | interconnects: 10 | - link: infiniband 11 | topology: p2p 12 | -------------------------------------------------------------------------------- /configs/cluster/isopower_a100.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-a100 5 | #count: 800 6 | #count: 40 7 | count: 70 8 | #count: 63 9 | - sku: dgx-h100 10 | count: 0 11 | 12 | interconnects: 13 | - link: infiniband 14 | topology: p2p 15 | -------------------------------------------------------------------------------- /configs/cluster/isopower_hybrid.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-a100 5 | #count: 1 6 | #count: 38 7 | count: 37 8 | #count: 750 9 | - sku: dgx-h100 10 | #count: 1 11 | count: 16 12 | #count: 50 13 | 14 | interconnects: 15 | - link: infiniband 16 | topology: p2p 17 | -------------------------------------------------------------------------------- /configs/cluster/solo_a100.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-a100 5 | count: 1 6 | - sku: dgx-h100 7 | count: 0 8 | 9 | interconnects: 10 | - link: infiniband 11 | topology: p2p 12 | -------------------------------------------------------------------------------- /configs/cluster/solo_h100.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-a100 5 | count: 0 6 | - sku: dgx-h100 7 | count: 1 8 | 9 | interconnects: 10 | - link: infiniband 11 | topology: p2p 12 | -------------------------------------------------------------------------------- /configs/cluster/solo_hybrid.yaml: -------------------------------------------------------------------------------- 1 | power_budget: 232000 2 | 3 | servers: 4 | - sku: dgx-a100 5 | count: 1 6 | - sku: dgx-h100 7 | count: 1 8 | 9 | interconnects: 10 | - link: infiniband 11 | topology: p2p 12 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - cluster: dgx-a100 3 | - router: noop 4 | - arbiter: noop 5 | - hardware_repo: default 6 | - model_repo: default 7 | - orchestrator_repo: default 8 | - performance_model: constant 9 | - power_model: constant 10 | - applications: solo 11 | - trace: test_trace 12 | - start_state: orca 13 | - override hydra/launcher: ray 14 | - _self_ 15 | 16 | end_time: 86400 17 | debug: False 18 | seed: 0 19 | 20 | choices: ${hydra:runtime.choices} 21 | output_dir: results/${seed}/${start_state.state_type}/${trace.filename}/${cluster.servers.0.count}_${cluster.servers.1.count}/${applications.0.model_architecture}/${applications.0.scheduler} 22 | 23 | hydra: 24 | # changes the cwd to the output directory 25 | run: 26 | dir: ${output_dir} 27 | sweep: 28 | dir: "" 29 | subdir: ${output_dir} 30 | job: 31 | chdir: True 32 | launcher: 33 | ray: 34 | init: 35 | address: "10.0.0.9:6379" 36 | -------------------------------------------------------------------------------- /configs/experiment/baseline_a100_costopt.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: baseline 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-a100 12 | count: ${sweep} 13 | - sku: dgx-h100 14 | count: 0 15 | 16 | trace: 17 | filename: ${fname}_${trace_sweep}_2min 18 | 19 | seed: 0 20 | 21 | hydra: 22 | mode: MULTIRUN 23 | sweeper: 24 | params: 25 | +sweep: range(1, 140, 1) 26 | +trace_sweep: 70 27 | +fname: rr_code,rr_conv 28 | #+fname: rr_code 29 | applications.0.scheduler: token_jsq 30 | -------------------------------------------------------------------------------- /configs/experiment/baseline_h100_costopt.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: baseline 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-a100 12 | count: 0 13 | - sku: dgx-h100 14 | count: ${sweep} 15 | 16 | trace: 17 | filename: ${fname}_${trace_sweep}_2min 18 | 19 | seed: 0 20 | 21 | hydra: 22 | mode: MULTIRUN 23 | sweeper: 24 | params: 25 | +sweep: range(1, 80, 1) 26 | +trace_sweep: 70 27 | +fname: rr_code,rr_conv 28 | #+fname: rr_code 29 | applications.0.scheduler: token_jsq 30 | -------------------------------------------------------------------------------- /configs/experiment/isocost_cluster.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - _self_ 5 | 6 | cluster: 7 | servers: 8 | - sku: dgx-a100 9 | count: ${eval:'int(4.76 * (40 - ${sweep}) // 2.21)'} 10 | - sku: dgx-h100 11 | count: ${sweep} 12 | 13 | seed: 0 14 | 15 | hydra: 16 | sweeper: 17 | params: 18 | +sweep: range(1, 40, 1) 19 | -------------------------------------------------------------------------------- /configs/experiment/isocount_cluster.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - _self_ 5 | 6 | cluster: 7 | servers: 8 | - sku: dgx-a100 9 | count: ${eval:'40 - ${sweep}'} 10 | - sku: dgx-h100 11 | count: ${sweep} 12 | 13 | seed: 0 14 | 15 | hydra: 16 | sweeper: 17 | params: 18 | +sweep: range(1, 40, 1) 19 | -------------------------------------------------------------------------------- /configs/experiment/isopower_cluster.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - _self_ 5 | 6 | cluster: 7 | servers: 8 | - sku: dgx-a100 9 | count: ${eval:'int(44 * (40 - ${sweep}) // 24.8)'} 10 | - sku: dgx-h100 11 | count: ${sweep} 12 | 13 | seed: 0 14 | 15 | hydra: 16 | sweeper: 17 | params: 18 | +sweep: range(1, 40, 1) 19 | -------------------------------------------------------------------------------- /configs/experiment/splitwise_aa_costopt.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: splitwise 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-a100 12 | #count: ${eval:'(100 + ${sweep})'} 13 | count: ${eval:'(${prompt_sweep} + ${token_sweep})'} 14 | #count: ${eval:'(62 + ${sweep})'} 15 | #count: ${eval:'(33 + ${sweep})'} 16 | - sku: dgx-h100 17 | count: 0 18 | 19 | start_state: 20 | split_type: homogeneous 21 | prompt: 22 | #num_instances: ${sweep} 23 | num_instances: ${prompt_sweep} 24 | #num_instances: 62 25 | #num_instances: 33 26 | token: 27 | #num_instances: ${sweep} 28 | #num_instances: 100 29 | num_instances: ${token_sweep} 30 | 31 | trace: 32 | filename: ${fname}_${trace_sweep}_2min 33 | 34 | seed: 0 35 | 36 | hydra: 37 | mode: MULTIRUN 38 | sweeper: 39 | max_batch_size: 288 40 | params: 41 | # code 42 | #+sweep: range(30, 70, 1) 43 | #+sweep: range(60, 70, 1) 44 | #+sweep: range(1, 10, 1) 45 | #+trace_sweep: 80 46 | #+fname: rr_code 47 | 48 | # conv 49 | #+sweep: range(30, 70, 1) 50 | #+sweep: range(10, 30, 1) 51 | #+trace_sweep: 80 52 | #+fname: rr_conv 53 | 54 | # both 55 | #+sweep: range(20, 70, 1) 56 | +prompt_sweep: range(1, 70, 1) 57 | +token_sweep: range(1, 70, 1) 58 | +trace_sweep: 70 59 | +fname: rr_code,rr_conv 60 | -------------------------------------------------------------------------------- /configs/experiment/splitwise_aa_isocost.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: splitwise 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-a100 12 | count: 86 13 | - sku: dgx-h100 14 | count: 0 15 | 16 | start_state: 17 | split_type: homogeneous 18 | prompt: 19 | num_instances: ${eval:'(86 - ${sweep})'} 20 | token: 21 | num_instances: ${sweep} 22 | 23 | trace: 24 | filename: ${fname}_${trace_sweep} 25 | 26 | seed: 0 27 | 28 | hydra: 29 | mode: MULTIRUN 30 | sweeper: 31 | params: 32 | +sweep: range(5, 86, 5) 33 | +trace_sweep: range(50, 181, 10) 34 | #+fname: rr_conv 35 | +fname: rr_code,rr_conv 36 | -------------------------------------------------------------------------------- /configs/experiment/splitwise_aa_isopower.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: splitwise 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-a100 12 | count: 70 13 | - sku: dgx-h100 14 | count: 0 15 | 16 | start_state: 17 | split_type: homogeneous 18 | prompt: 19 | num_instances: ${eval:'(70 - ${sweep})'} 20 | token: 21 | num_instances: ${sweep} 22 | 23 | trace: 24 | filename: ${fname}_${trace_sweep} 25 | 26 | seed: 0 27 | 28 | hydra: 29 | mode: MULTIRUN 30 | sweeper: 31 | params: 32 | +sweep: range(5, 70, 5) 33 | #+sweep: 25,30 34 | +trace_sweep: range(50, 251, 10) 35 | +fname: rr_code,rr_conv 36 | #+fname: rr_conv 37 | #+fname: rr_code 38 | -------------------------------------------------------------------------------- /configs/experiment/splitwise_ha_costopt.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: splitwise 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-a100 12 | count: ${token_sweep} 13 | - sku: dgx-h100 14 | count: ${prompt_sweep} 15 | 16 | start_state: 17 | split_type: heterogeneous 18 | prompt: 19 | instance_names: ["dgx-h100"] 20 | token: 21 | instance_names: ["dgx-a100"] 22 | 23 | trace: 24 | filename: ${fname}_${trace_sweep}_2min 25 | 26 | seed: 0 27 | 28 | hydra: 29 | mode: MULTIRUN 30 | sweeper: 31 | max_batch_size: 288 32 | params: 33 | +prompt_sweep: range(1, 40, 1) 34 | +token_sweep: range(1, 70, 1) 35 | +trace_sweep: 70 36 | +fname: rr_code,rr_conv 37 | -------------------------------------------------------------------------------- /configs/experiment/splitwise_ha_isocost.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: splitwise 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-a100 12 | count: ${eval:'int(4.76 * (40 - ${sweep}) // 2.21)'} 13 | - sku: dgx-h100 14 | count: ${sweep} 15 | 16 | start_state: 17 | split_type: heterogeneous 18 | prompt: 19 | instance_names: ["dgx-h100"] 20 | token: 21 | instance_names: ["dgx-a100"] 22 | 23 | trace: 24 | filename: ${fname}_${trace_sweep} 25 | 26 | seed: 0 27 | 28 | hydra: 29 | mode: MULTIRUN 30 | sweeper: 31 | params: 32 | +sweep: range(5, 40, 5) 33 | +trace_sweep: range(50, 181, 10) 34 | #+fname: rr_conv 35 | +fname: rr_code,rr_conv 36 | -------------------------------------------------------------------------------- /configs/experiment/splitwise_ha_isopower.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: splitwise 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-a100 12 | #count: ${eval:'int(10200 * (40 - ${sweep}) // 6500)'} 13 | count: ${eval:'int(44 * (40 - ${sweep}) // 24.8)'} 14 | - sku: dgx-h100 15 | count: ${sweep} 16 | 17 | start_state: 18 | split_type: heterogeneous 19 | prompt: 20 | instance_names: ["dgx-h100"] 21 | token: 22 | instance_names: ["dgx-a100"] 23 | 24 | trace: 25 | #filename: ${fname}_${trace_sweep}_2min 26 | filename: ${fname}_${trace_sweep} 27 | 28 | seed: 0 29 | 30 | hydra: 31 | mode: MULTIRUN 32 | sweeper: 33 | params: 34 | +sweep: range(5, 40, 5) 35 | +trace_sweep: range(50, 251, 10) 36 | #+fname: rr_code 37 | +fname: rr_conv 38 | -------------------------------------------------------------------------------- /configs/experiment/splitwise_hh_costopt.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: splitwise 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-a100 12 | count: 0 13 | - sku: dgx-h100 14 | #count: ${eval:'(100 + ${sweep})'} 15 | count: ${eval:'(${prompt_sweep} + ${token_sweep})'} 16 | 17 | start_state: 18 | split_type: homogeneous 19 | prompt: 20 | num_instances: ${prompt_sweep} 21 | token: 22 | num_instances: ${token_sweep} 23 | 24 | trace: 25 | filename: ${fname}_${trace_sweep}_2min 26 | 27 | seed: 0 28 | 29 | hydra: 30 | mode: MULTIRUN 31 | sweeper: 32 | # this is not a simulator parameter, but a hydra parameter 33 | max_batch_size: 288 34 | params: 35 | +prompt_sweep: range(1, 40, 1) 36 | +token_sweep: range(1, 40, 1) 37 | +trace_sweep: 70 38 | +fname: rr_code,rr_conv 39 | -------------------------------------------------------------------------------- /configs/experiment/splitwise_hh_isocost.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: splitwise 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-a100 12 | count: 0 13 | - sku: dgx-h100 14 | count: 40 15 | 16 | start_state: 17 | split_type: homogeneous 18 | prompt: 19 | num_instances: ${eval:'(40 - ${sweep})'} 20 | token: 21 | num_instances: ${sweep} 22 | 23 | trace: 24 | filename: ${fname}_${trace_sweep} 25 | 26 | seed: 0 27 | 28 | hydra: 29 | mode: MULTIRUN 30 | sweeper: 31 | params: 32 | +sweep: range(5, 40, 5) 33 | +trace_sweep: range(50, 181, 10) 34 | #+fname: rr_conv 35 | +fname: rr_code,rr_conv 36 | -------------------------------------------------------------------------------- /configs/experiment/splitwise_hh_isopower.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: splitwise 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-a100 12 | count: 0 13 | - sku: dgx-h100 14 | count: 40 15 | 16 | start_state: 17 | split_type: homogeneous 18 | prompt: 19 | num_instances: ${eval:'(40 - ${sweep})'} 20 | token: 21 | num_instances: ${sweep} 22 | 23 | trace: 24 | filename: ${fname}_${trace_sweep} 25 | 26 | seed: 0 27 | 28 | hydra: 29 | mode: MULTIRUN 30 | sweeper: 31 | params: 32 | +sweep: range(5, 40, 5) 33 | +trace_sweep: range(50, 251, 10) 34 | #+sweep: 10,11,13,14,15,16,17,18,19,20,21,22,23,24,25 35 | #+fname: rr_code 36 | #+fname: rr_conv 37 | +fname: rr_code,rr_conv 38 | -------------------------------------------------------------------------------- /configs/experiment/splitwise_hhcap_costopt.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: splitwise_hhcap 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-h100-pcap 12 | count: ${token_sweep} 13 | - sku: dgx-h100 14 | count: ${prompt_sweep} 15 | 16 | start_state: 17 | split_type: heterogeneous 18 | prompt: 19 | instance_names: ["dgx-h100"] 20 | token: 21 | instance_names: ["dgx-h100-pcap"] 22 | 23 | trace: 24 | #filename: rr_mix_${trace_sweep} 25 | #filename: rr_constant_512p_512t_${trace_sweep} 26 | filename: ${fname}_${trace_sweep}_2min 27 | 28 | seed: 0 29 | 30 | hydra: 31 | mode: MULTIRUN 32 | sweeper: 33 | max_batch_size: 288 34 | params: 35 | # code 36 | #+sweep: range(20, 40, 1) 37 | #+sweep: range(1, 10, 1) 38 | #+trace_sweep: 80 39 | #+fname: rr_code 40 | # conv 41 | #+sweep: range(10, 30, 1) 42 | +prompt_sweep: range(1, 40, 1) 43 | +token_sweep: range(1, 40, 1) 44 | +trace_sweep: 70 45 | +fname: rr_code,rr_conv 46 | -------------------------------------------------------------------------------- /configs/experiment/splitwise_hhcap_isocost.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: splitwise_hhcap 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-h100-pcap 12 | count: ${eval:'(40 - ${sweep})'} 13 | - sku: dgx-h100 14 | count: ${sweep} 15 | 16 | start_state: 17 | split_type: heterogeneous 18 | prompt: 19 | instance_names: ["dgx-h100"] 20 | token: 21 | instance_names: ["dgx-h100-pcap"] 22 | 23 | trace: 24 | #filename: rr_constant_512p_512t_${trace_sweep} 25 | filename: ${fname}_${trace_sweep} 26 | 27 | seed: 0 28 | 29 | hydra: 30 | mode: MULTIRUN 31 | sweeper: 32 | params: 33 | +sweep: range(5, 40, 5) 34 | +trace_sweep: range(50, 181, 10) 35 | #+fname: rr_conv 36 | +fname: rr_code,rr_conv 37 | -------------------------------------------------------------------------------- /configs/experiment/splitwise_hhcap_isopower.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /cluster: half_half 4 | - override /performance_model: db 5 | - override /trace: test_trace 6 | - override /start_state: splitwise_hhcap 7 | - _self_ 8 | 9 | cluster: 10 | servers: 11 | - sku: dgx-h100-pcap 12 | count: ${eval:'int(44 * (40 - ${sweep}) // 30.8)'} 13 | - sku: dgx-h100 14 | count: ${sweep} 15 | 16 | start_state: 17 | split_type: heterogeneous 18 | prompt: 19 | instance_names: ["dgx-h100"] 20 | token: 21 | instance_names: ["dgx-h100-pcap"] 22 | 23 | trace: 24 | filename: ${fname}_${trace_sweep} 25 | 26 | seed: 0 27 | 28 | hydra: 29 | mode: MULTIRUN 30 | sweeper: 31 | params: 32 | +sweep: range(5, 40, 5) 33 | +trace_sweep: range(50, 251, 10) 34 | #+fname: rr_code 35 | #+fname: rr_conv 36 | +fname: rr_code,rr_conv 37 | -------------------------------------------------------------------------------- /configs/experiment/traces.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /trace: test_trace 4 | - _self_ 5 | 6 | trace: 7 | #filename: rr_mix_${sweep} 8 | #filename: rr_constant_512p_512t_${sweep} 9 | #filename: rr_conv_${sweep} 10 | #filename: rr_code_${sweep} 11 | filename: ${fname}_${sweep} 12 | 13 | hydra: 14 | sweeper: 15 | params: 16 | +sweep: range(50, 251, 10) 17 | #+fname: rr_code,rr_conv 18 | +fname: rr_conv 19 | #+fname: rr_code 20 | -------------------------------------------------------------------------------- /configs/experiment/traces_light.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /trace: test_trace 4 | - _self_ 5 | 6 | rounded_sweep: ${eval:'round(${sweep},1)'} 7 | trace: 8 | filename: ${fname}_${rounded_sweep} 9 | 10 | hydra: 11 | mode: MULTIRUN 12 | sweeper: 13 | params: 14 | #+sweep: range(0.2, 0.3, 0.2) 15 | +sweep: range(0.2, 5, 0.2) 16 | +fname: rr_code 17 | -------------------------------------------------------------------------------- /configs/hardware_repo/default.yaml: -------------------------------------------------------------------------------- 1 | processors: configs/hardware_repo/processors 2 | skus: configs/hardware_repo/skus 3 | interconnects: configs/hardware_repo/interconnects 4 | -------------------------------------------------------------------------------- /configs/hardware_repo/interconnects/nvlink.yaml: -------------------------------------------------------------------------------- 1 | # TODO: better topology / config support for interconnects 2 | _target_: interconnect.NVLink 3 | bandwidth: 1 4 | -------------------------------------------------------------------------------- /configs/hardware_repo/processors/a100-40gb.yaml: -------------------------------------------------------------------------------- 1 | _target_: processor.GPU 2 | name: a100-40gb 3 | memory_size: 42949672960 4 | -------------------------------------------------------------------------------- /configs/hardware_repo/processors/a100-80gb.yaml: -------------------------------------------------------------------------------- 1 | _target_: processor.GPU 2 | name: a100-80gb 3 | memory_size: 85899345920 4 | -------------------------------------------------------------------------------- /configs/hardware_repo/processors/h100-80gb-pcap.yaml: -------------------------------------------------------------------------------- 1 | _target_: processor.GPU 2 | name: h100-80gb-pcap 3 | memory_size: 85899345920 4 | -------------------------------------------------------------------------------- /configs/hardware_repo/processors/h100-80gb.yaml: -------------------------------------------------------------------------------- 1 | _target_: processor.GPU 2 | name: h100-80gb 3 | memory_size: 85899345920 4 | -------------------------------------------------------------------------------- /configs/hardware_repo/skus/dgx-a100.yaml: -------------------------------------------------------------------------------- 1 | _target_: server.Server 2 | name: dgx-a100 3 | tdp: 6500 4 | processors: 5 | - name: a100-80gb 6 | count: 8 7 | interconnects: {} 8 | -------------------------------------------------------------------------------- /configs/hardware_repo/skus/dgx-h100-pcap.yaml: -------------------------------------------------------------------------------- 1 | _target_: server.Server 2 | name: dgx-h100-pcap 3 | tdp: 7140 4 | processors: 5 | - name: h100-80gb-pcap 6 | count: 8 7 | interconnects: {} 8 | -------------------------------------------------------------------------------- /configs/hardware_repo/skus/dgx-h100.yaml: -------------------------------------------------------------------------------- 1 | _target_: server.Server 2 | name: dgx-h100 3 | tdp: 10200 4 | processors: 5 | - name: h100-80gb 6 | count: 8 7 | interconnects: {} 8 | -------------------------------------------------------------------------------- /configs/model_repo/architectures/bloom-176b.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.LLMArchitecture 2 | name: bloom-176b 3 | num_layers: 70 4 | hidden_size: 14336 5 | num_heads: 112 6 | -------------------------------------------------------------------------------- /configs/model_repo/architectures/gpt3-175b.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.LLMArchitecture 2 | name: gpt3-175b 3 | num_layers: 96 4 | hidden_size: 12288 5 | num_heads: 96 6 | -------------------------------------------------------------------------------- /configs/model_repo/architectures/llama-13b.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.LLMArchitecture 2 | name: llama-13b 3 | num_layers: 40 4 | hidden_size: 5120 5 | num_heads: 40 6 | -------------------------------------------------------------------------------- /configs/model_repo/architectures/llama-33b.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.LLMArchitecture 2 | name: llama-33b 3 | num_layers: 60 4 | hidden_size: 6656 5 | num_heads: 52 6 | -------------------------------------------------------------------------------- /configs/model_repo/architectures/llama2-70b.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.LLMArchitecture 2 | name: llama2-70b 3 | num_layers: 80 4 | hidden_size: 8192 5 | num_heads: 32 6 | -------------------------------------------------------------------------------- /configs/model_repo/architectures/opt-30b.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.LLMArchitecture 2 | name: opt-30b 3 | num_layers: 48 4 | hidden_size: 7168 5 | num_heads: 56 6 | -------------------------------------------------------------------------------- /configs/model_repo/architectures/opt-66b.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.LLMArchitecture 2 | name: opt-66b 3 | num_layers: 64 4 | hidden_size: 9216 5 | num_heads: 72 6 | -------------------------------------------------------------------------------- /configs/model_repo/default.yaml: -------------------------------------------------------------------------------- 1 | architectures: configs/model_repo/architectures 2 | sizes: configs/model_repo/sizes 3 | -------------------------------------------------------------------------------- /configs/model_repo/sizes/bloom-176b-fp16.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.ModelSize 2 | weights: 329000000000 3 | dtype_size: 2 4 | -------------------------------------------------------------------------------- /configs/model_repo/sizes/llama2-70b-fp16.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.ModelSize 2 | weights: 135000000000 3 | dtype_size: 2 4 | -------------------------------------------------------------------------------- /configs/orchestrator_repo/allocators/noop.yaml: -------------------------------------------------------------------------------- 1 | _target_: allocator.NoOpAllocator 2 | overheads: 3 | spin_up: 0 4 | spin_down: 0 5 | instance_overheads: 6 | run: 0 7 | preempt: 0 8 | -------------------------------------------------------------------------------- /configs/orchestrator_repo/default.yaml: -------------------------------------------------------------------------------- 1 | allocators: configs/orchestrator_repo/allocators 2 | schedulers: configs/orchestrator_repo/schedulers 3 | -------------------------------------------------------------------------------- /configs/orchestrator_repo/schedulers/jsq.yaml: -------------------------------------------------------------------------------- 1 | _target_: scheduler.JSQScheduler 2 | overheads: {} 3 | executor_overheads: 4 | submit_task: 0 5 | submit_flow: 0 6 | finish_request: 0 7 | -------------------------------------------------------------------------------- /configs/orchestrator_repo/schedulers/kv_jsq.yaml: -------------------------------------------------------------------------------- 1 | _target_: scheduler.KVJSQScheduler 2 | overheads: {} 3 | executor_overheads: 4 | submit_task: 0 5 | submit_flow: 0 6 | finish_request: 0 7 | prompt_processors: ["h100-80gb"] 8 | token_processors: ["a100-80gb"] 9 | transfer_bandwidth: 200 10 | -------------------------------------------------------------------------------- /configs/orchestrator_repo/schedulers/kv_round_robin.yaml: -------------------------------------------------------------------------------- 1 | _target_: scheduler.KVRoundRobinScheduler 2 | overheads: {} 3 | executor_overheads: 4 | submit_task: 0 5 | submit_flow: 0 6 | finish_request: 0 7 | prompt_processors: ["h100-80gb"] 8 | token_processors: ["a100-80gb"] 9 | transfer_bandwidth: 200 10 | -------------------------------------------------------------------------------- /configs/orchestrator_repo/schedulers/kv_round_robin_ethernet.yaml: -------------------------------------------------------------------------------- 1 | _target_: scheduler.KVRoundRobinScheduler 2 | overheads: {} 3 | executor_overheads: 4 | submit_task: 0 5 | submit_flow: 0 6 | finish_request: 0 7 | prompt_processors: ["h100-80gb"] 8 | token_processors: ["a100-80gb"] 9 | transfer_bandwidth: 12.5 10 | -------------------------------------------------------------------------------- /configs/orchestrator_repo/schedulers/kv_token_jsq.yaml: -------------------------------------------------------------------------------- 1 | _target_: scheduler.OverlapKVTokenJSQScheduler 2 | overheads: {} 3 | executor_overheads: 4 | submit_task: 0 5 | submit_flow: 0 6 | finish_request: 0 7 | prompt_processors: ["h100-80gb"] 8 | token_processors: ["a100-80gb"] 9 | transfer_bandwidth: 200 10 | -------------------------------------------------------------------------------- /configs/orchestrator_repo/schedulers/mixed_pool.yml: -------------------------------------------------------------------------------- 1 | _target_: scheduler.MixedPoolScheduler 2 | overheads: {} 3 | executor_overheads: 4 | submit_task: 0 5 | submit_flow: 0 6 | finish_request: 0 7 | prompt_processors: ["h100-80gb"] 8 | token_processors: ["a100-80gb"] 9 | prompt_max_pending_batch_tokens: 8192 10 | token_max_pending_batch_tokens: 2048 11 | transfer_bandwidth: 200 12 | -------------------------------------------------------------------------------- /configs/orchestrator_repo/schedulers/overlap_kv_jsq.yaml: -------------------------------------------------------------------------------- 1 | _target_: scheduler.OverlapKVJSQScheduler 2 | overheads: {} 3 | executor_overheads: 4 | submit_task: 0 5 | submit_flow: 0 6 | finish_request: 0 7 | prompt_processors: ["h100-80gb"] 8 | token_processors: ["a100-80gb"] 9 | transfer_bandwidth: 200 10 | -------------------------------------------------------------------------------- /configs/orchestrator_repo/schedulers/overlap_kv_token_jsq.yaml: -------------------------------------------------------------------------------- 1 | _target_: scheduler.OverlapKVTokenJSQScheduler 2 | overheads: {} 3 | executor_overheads: 4 | submit_task: 0 5 | submit_flow: 0 6 | finish_request: 0 7 | prompt_processors: ["h100-80gb"] 8 | token_processors: ["a100-80gb"] 9 | transfer_bandwidth: 200 10 | -------------------------------------------------------------------------------- /configs/orchestrator_repo/schedulers/random.yaml: -------------------------------------------------------------------------------- 1 | _target_: scheduler.RandomScheduler 2 | overheads: {} 3 | executor_overheads: 4 | submit_task: 0 5 | submit_flow: 0 6 | finish_request: 0 7 | -------------------------------------------------------------------------------- /configs/orchestrator_repo/schedulers/round_robin.yaml: -------------------------------------------------------------------------------- 1 | _target_: scheduler.RoundRobinScheduler 2 | overheads: {} 3 | executor_overheads: 4 | submit_task: 0 5 | submit_flow: 0 6 | finish_request: 0 7 | -------------------------------------------------------------------------------- /configs/orchestrator_repo/schedulers/token_jsq.yaml: -------------------------------------------------------------------------------- 1 | _target_: scheduler.TokenJSQScheduler 2 | overheads: {} 3 | executor_overheads: 4 | submit_task: 0 5 | submit_flow: 0 6 | finish_request: 0 7 | -------------------------------------------------------------------------------- /configs/performance_model/constant.yaml: -------------------------------------------------------------------------------- 1 | _target_: performance_model.ConstantPerformanceModel 2 | prompt_time: 1 3 | token_time: 10 4 | -------------------------------------------------------------------------------- /configs/performance_model/db.yaml: -------------------------------------------------------------------------------- 1 | _target_: performance_model.DatabasePerformanceModel 2 | 3 | db_path: data/perf_model.csv 4 | #db_path: data/vllm-bloom.csv 5 | #db_path: data/vllm-llama.csv 6 | -------------------------------------------------------------------------------- /configs/power_model/constant.yaml: -------------------------------------------------------------------------------- 1 | _target_: power_model.ConstantPowerModel 2 | idle_power: 3 | a100-80gb: 63 4 | h100-80gb: 75 5 | prompt_power: 6 | a100-80gb: 400 7 | h100-80gb: 700 8 | token_power: 9 | a100-80gb: 250 10 | h100-80gb: 380 11 | -------------------------------------------------------------------------------- /configs/router/noop.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - overheads: zero 3 | - _self_ 4 | 5 | _target_: router.NoOpRouter 6 | -------------------------------------------------------------------------------- /configs/router/overheads/zero.yaml: -------------------------------------------------------------------------------- 1 | routing_delay: 0 2 | -------------------------------------------------------------------------------- /configs/start_state/baseline.yaml: -------------------------------------------------------------------------------- 1 | # single application_id is allocated to all servers 2 | 3 | state_type: baseline 4 | application_id: 0 5 | instance: 6 | instance_type: Splitwise 7 | max_batch_size: 512 8 | max_batch_tokens: 2048 9 | max_preemptions: 4 10 | pipeline_parallelism: 1 11 | tensor_parallelism: 8 12 | -------------------------------------------------------------------------------- /configs/start_state/orca.yaml: -------------------------------------------------------------------------------- 1 | # single application_id is allocated to all servers 2 | 3 | state_type: orca 4 | application_id: 0 5 | instance: 6 | instance_type: ORCA 7 | max_batch_size: 512 8 | pipeline_parallelism: 1 9 | tensor_parallelism: 8 10 | -------------------------------------------------------------------------------- /configs/start_state/splitwise.yaml: -------------------------------------------------------------------------------- 1 | # single application_id is allocated to all servers 2 | 3 | state_type: splitwise_${start_state.prompt.num_instances}_${start_state.token.num_instances} 4 | application_id: 0 5 | split_type: homogeneous 6 | prompt: 7 | instance_type: Splitwise 8 | max_batch_size: 512 9 | max_batch_tokens: 2048 10 | max_preemptions: 4 11 | pipeline_parallelism: 1 12 | tensor_parallelism: 8 13 | num_instances: 1 14 | instance_names: ["dgx-h100"] 15 | token: 16 | instance_type: Splitwise 17 | max_batch_size: 512 18 | max_batch_tokens: 2048 19 | max_preemptions: 4 20 | pipeline_parallelism: 1 21 | tensor_parallelism: 8 22 | num_instances: 1 23 | instance_names: ["dgx-a100"] 24 | -------------------------------------------------------------------------------- /configs/start_state/splitwise_hhcap.yaml: -------------------------------------------------------------------------------- 1 | # single application_id is allocated to all servers 2 | 3 | state_type: splitwisehhcap_${start_state.prompt.num_instances}_${start_state.token.num_instances} 4 | application_id: 0 5 | split_type: homogeneous 6 | prompt: 7 | instance_type: Splitwise 8 | max_batch_size: 512 9 | max_batch_tokens: 2048 10 | max_preemptions: 4 11 | pipeline_parallelism: 1 12 | tensor_parallelism: 8 13 | num_instances: 1 14 | instance_names: ["dgx-h100"] 15 | token: 16 | instance_type: Splitwise 17 | max_batch_size: 512 18 | max_batch_tokens: 2048 19 | max_preemptions: 4 20 | pipeline_parallelism: 1 21 | tensor_parallelism: 8 22 | num_instances: 1 23 | instance_names: ["dgx-h100-pcap"] 24 | -------------------------------------------------------------------------------- /configs/start_state/unallocated.yaml: -------------------------------------------------------------------------------- 1 | # applications start without any instances 2 | 3 | state_type: unallocated 4 | -------------------------------------------------------------------------------- /configs/trace/test_trace.yaml: -------------------------------------------------------------------------------- 1 | dir: traces/ 2 | filename: test_trace 3 | path: ${trace.dir}/${trace.filename}.csv 4 | -------------------------------------------------------------------------------- /executor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from enum import IntEnum 4 | 5 | from flow import Flow 6 | from node import NodeState 7 | from simulator import clock, schedule_event, cancel_event, reschedule_event 8 | from task import Task 9 | 10 | 11 | class ExecutorType(IntEnum): 12 | LocalExecutor = 0 13 | CentralExecutor = 1 14 | 15 | 16 | class Executor(): 17 | """ 18 | Executor orchestrates Request execution on Instances and Interconnects. 19 | Executors can themselves run anywhere, e.g., on the Scheduler, Instance, etc., 20 | with different amounts of overheads. 21 | They could execute multiple Tasks/Flows of the Request in parallel. 22 | 23 | NOTE: We don't ensure predeccessors of node are completed before submit. 24 | Implicitly, we assume that the Request is a tree instead of a DAG. 25 | Could be changed by waiting on Node predecessors. 26 | """ 27 | def __init__(self, 28 | request, 29 | scheduler, 30 | overheads): 31 | self.request = request 32 | self.scheduler = scheduler 33 | self.overheads = overheads 34 | self.submitted = [] 35 | # to cancel any events 36 | self.completion_events = {} 37 | 38 | def successors(self, node): 39 | """ 40 | Returns the successors of the specified node. 41 | """ 42 | nodes = self.request.successors(node) 43 | return nodes 44 | 45 | def check_predecessors(self, node): 46 | """ 47 | Checks if all predecessors of the specified node are completed. 48 | """ 49 | for predecessor in self.request.predecessors(node): 50 | if predecessor.state != NodeState.COMPLETED: 51 | return False 52 | return True 53 | 54 | def submit(self, node=None): 55 | """ 56 | Submits the specified node for execution. 57 | """ 58 | if isinstance(node, Task): 59 | self.submit_task(node) 60 | elif isinstance(node, Flow): 61 | self.submit_flow(node) 62 | else: 63 | raise ValueError(f"Unknown node type: {type(node)}") 64 | 65 | def submit_chain(self, chain): 66 | """ 67 | Submits the specified chain of Nodes for execution. 68 | """ 69 | for node in chain: 70 | self.submit(node) 71 | 72 | def submit_task(self, task, instance=None): 73 | """ 74 | Submits the specified task for execution. 75 | If instance is not specified, uses the task's instance. 76 | """ 77 | if instance is None: 78 | instance = task.instance 79 | task.executor = self 80 | self.submitted.append(task) 81 | schedule_event(self.overheads.submit_task, 82 | lambda instance=instance,task=task: \ 83 | instance.task_arrival(task)) 84 | # if this is the first task in the chain, submit the chain 85 | self.submit_chain(task.chain) 86 | 87 | def finish_task(self, task, instance): 88 | """ 89 | Finishes the specified task. 90 | """ 91 | self.submitted.remove(task) 92 | successor_nodes = list(self.successors(task)) 93 | # NOTE: assumes a single leaf node 94 | if len(successor_nodes) == 0: 95 | self.finish_request() 96 | return 97 | # submit nodes for whom all predecessors have completed 98 | # and are not already submitted 99 | for node in successor_nodes: 100 | if node.state == NodeState.NONE and self.check_predecessors(node): 101 | self.submit(node) 102 | 103 | def submit_flow(self, flow, link=None): 104 | """ 105 | Submits the specified flow for execution. 106 | If link is not specified, uses the flow's link. 107 | """ 108 | if link is None: 109 | link = flow.link 110 | flow.executor = self 111 | self.submitted.append(flow) 112 | schedule_event(self.overheads.submit_flow, 113 | lambda link=link,flow=flow: link.flow_arrival(flow)) 114 | # if this is the first flow in the chain, submit the chain 115 | self.submit_chain(flow.chain) 116 | 117 | def finish_flow(self, flow, link): 118 | """ 119 | Finishes the specified flow. 120 | """ 121 | self.submitted.remove(flow) 122 | successor_nodes = list(self.successors(flow)) 123 | # NOTE: assumes a single leaf node 124 | if len(successor_nodes) == 0: 125 | self.finish_request() 126 | return 127 | # submit nodes for whom all predecessors have completed 128 | # and are not already submitted 129 | for node in successor_nodes: 130 | if node.state == NodeState.NONE and self.check_predecessors(node): 131 | self.submit(node) 132 | 133 | def finish_request(self): 134 | """ 135 | Finishes executing the entire Request. 136 | """ 137 | def fin_req(): 138 | self.scheduler.request_completion(self.request) 139 | schedule_event(self.overheads.finish_request, fin_req) 140 | 141 | def run(self): 142 | """ 143 | Runs the Request by submitting the root node. 144 | """ 145 | self.submit(self.request.root_node) 146 | 147 | @classmethod 148 | def create(cls, executor_type, request, scheduler, overheads): 149 | """ 150 | Creates an Executor instance based on the specified type. 151 | """ 152 | if executor_type == ExecutorType.CentralExecutor: 153 | return CentralExecutor(request, scheduler, overheads) 154 | if executor_type == ExecutorType.LocalExecutor: 155 | return LocalExecutor(request, scheduler, overheads) 156 | raise ValueError(f"Unsupported executor type: {executor_type}") 157 | 158 | 159 | class CentralExecutor(Executor): 160 | """ 161 | CentralExecutor coordinates with Scheduler for each Task. 162 | Logically, it runs within Scheduler itself. 163 | TODO: appropriate overheads 164 | """ 165 | pass 166 | 167 | 168 | class LocalExecutor(Executor): 169 | """ 170 | LocalExecutor logically runs on Servers, alongside Instances. 171 | TODO: appropriate overheads 172 | """ 173 | pass 174 | -------------------------------------------------------------------------------- /flow.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from dataclasses import dataclass, field 4 | from enum import IntEnum 5 | 6 | from instance import Instance 7 | from metrics import FlowMetrics, FlowSLO 8 | from model import Model, ModelArchitecture 9 | from node import Node 10 | from simulator import clock, schedule_event, cancel_event, reschedule_event 11 | 12 | 13 | class FlowType(IntEnum): 14 | DEFAULT = 0 15 | KVCacheTransfer = 1 16 | 17 | 18 | @dataclass(kw_only=True) 19 | class Flow(Node): 20 | """ 21 | Flows are communication nodes in the Request DAG that execute on Links. 22 | Flows are the networking counterparts of Tasks. 23 | """ 24 | flow_type: FlowType 25 | src: Instance 26 | dest: Instance 27 | batch_size: int = 1 28 | size: float = 0. 29 | duration: float = 0. 30 | notify: bool = False 31 | metrics: FlowMetrics = field(default_factory=FlowMetrics) 32 | slo: FlowSLO = field(default_factory=FlowSLO) 33 | executor: 'Executor' = None 34 | links = [] 35 | _link = None 36 | 37 | def __hash__(self): 38 | return hash(self.node_id) 39 | 40 | @property 41 | def link(self): 42 | return self._link 43 | 44 | @link.setter 45 | def link(self, link): 46 | if link is self._link: 47 | return 48 | self._link = link 49 | if link is not None: 50 | self.links.append(link) 51 | 52 | @property 53 | def duration(self): 54 | return self._duration 55 | 56 | @duration.setter 57 | def duration(self, duration): 58 | self._duration = duration 59 | 60 | @property 61 | def memory(self): 62 | return 0 63 | 64 | def run(self): 65 | super().run() 66 | 67 | # manage memory 68 | self.dest.alloc_memory(self.request, self.request.memory) 69 | 70 | def complete(self): 71 | super().complete() 72 | 73 | # manage memory 74 | self.src.free_memory(self.request, self.request.memory) 75 | 76 | @classmethod 77 | def from_type(cls, flow_type, **kwargs): 78 | if flow_type == FlowType.DEFAULT: 79 | return Flow(**kwargs) 80 | elif flow_type == FlowType.KVCacheTransfer: 81 | return KVCacheTransferFlow(**kwargs) 82 | else: 83 | raise ValueError(f"Invalid FlowType {flow_type}") 84 | 85 | 86 | @dataclass(kw_only=True) 87 | class KVCacheTransferFlow(Flow): 88 | """ 89 | Flow for transferring KV cache between instances. 90 | """ 91 | flow_type: FlowType = FlowType.KVCacheTransfer 92 | 93 | def __hash__(self): 94 | return hash(self.node_id) 95 | -------------------------------------------------------------------------------- /generate_trace.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from collections import namedtuple 4 | 5 | import requests 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from scipy import stats 11 | 12 | 13 | Distributions = namedtuple('Distributions', ['application_id', 14 | 'request_type', 15 | 'arrival_process', 16 | 'batch_size', 17 | 'prompt_size', 18 | 'token_size']) 19 | Distribution = namedtuple('Distribution', ['name', 'params']) 20 | 21 | 22 | def generate_samples(distribution, params, size): 23 | """ 24 | Generate random samples from the given distribution. 25 | """ 26 | if distribution == "constant": 27 | return np.ones(size) * params["value"] 28 | elif distribution == "normal": 29 | return stats.norm(**params).rvs(size=size) 30 | elif distribution == "truncnorm": 31 | return stats.truncnorm(**params).rvs(size=size) 32 | elif distribution == "randint": 33 | return stats.uniform(**params).rvs(size=size) 34 | elif distribution == "uniform": 35 | return stats.uniform(**params).rvs(size=size) 36 | elif distribution == "exponential": 37 | return stats.expon(**params).rvs(size=size) 38 | elif distribution == "poisson": 39 | return stats.poisson(**params).rvs(size=size) 40 | elif distribution == "trace": 41 | df = pd.read_csv(params["filename"]) 42 | return df[params["column"]].sample(size, replace=True).values 43 | else: 44 | raise ValueError(f"Invalid distribution: {distribution}") 45 | 46 | 47 | def generate_trace(max_requests, distributions, end_time=None): 48 | """ 49 | Generate a trace of requests based on the given distributions. 50 | """ 51 | # Generate request IDs 52 | request_ids = np.arange(max_requests) 53 | 54 | # Generate the distributions 55 | arrival_timestamps = generate_samples(distributions.arrival_process.name, 56 | distributions.arrival_process.params, 57 | max_requests) 58 | arrival_timestamps = np.cumsum(arrival_timestamps) 59 | application_ids = generate_samples(distributions.application_id.name, 60 | distributions.application_id.params, 61 | max_requests) 62 | application_ids = map(int, application_ids) 63 | batch_sizes = generate_samples(distributions.batch_size.name, 64 | distributions.batch_size.params, 65 | max_requests) 66 | batch_sizes = map(int, batch_sizes) 67 | prompt_sizes = generate_samples(distributions.prompt_size.name, 68 | distributions.prompt_size.params, 69 | max_requests) 70 | prompt_sizes = map(int, prompt_sizes) 71 | token_sizes = generate_samples(distributions.token_size.name, 72 | distributions.token_size.params, 73 | max_requests) 74 | token_sizes = map(int, token_sizes) 75 | request_type_ids = generate_samples(distributions.request_type.name, 76 | distributions.request_type.params, 77 | max_requests) 78 | request_type_ids = map(int, request_type_ids) 79 | 80 | # Combine the arrays into a DataFrame 81 | trace_df = pd.DataFrame({ 82 | "request_id": request_ids, 83 | "request_type": request_type_ids, 84 | "application_id": application_ids, 85 | "arrival_timestamp": arrival_timestamps, 86 | "batch_size": batch_sizes, 87 | "prompt_size": prompt_sizes, 88 | "token_size": token_sizes, 89 | }) 90 | 91 | if end_time is not None: 92 | trace_df = trace_df[trace_df["arrival_timestamp"] < end_time] 93 | 94 | return trace_df 95 | 96 | 97 | def get_exponential_scale(num_servers, utilization, request_duration): 98 | """ 99 | assumes that request_duration is in seconds 100 | """ 101 | interarrival_time = request_duration / (1.0 * utilization) 102 | exponential_scale = interarrival_time / num_servers 103 | return exponential_scale 104 | 105 | 106 | def generate_trace_from_utilization( 107 | max_requests, 108 | end_time, 109 | num_servers, 110 | utilization, 111 | request_duration, 112 | pt_distributions_file): 113 | """ 114 | Generate request traces for the simulator using prompt and token 115 | size distributions. 116 | """ 117 | exponential_scale = get_exponential_scale(num_servers, utilization, request_duration) 118 | distributions = Distributions( 119 | application_id=Distribution("constant", {"value": 0}), 120 | request_type=Distribution("constant", {"value": 2}), # 2 is for LLM inference 121 | arrival_process=Distribution("exponential", {"scale": exponential_scale}), 122 | prompt_size=Distribution("trace", {"filename": pt_distributions_file, 123 | "column": "ContextTokens"}), 124 | token_size=Distribution("trace", {"filename": pt_distributions_file, 125 | "column": "GeneratedTokens"}), 126 | batch_size=Distribution("constant", {"value": 1}), 127 | ) 128 | 129 | trace_df = generate_trace(max_requests, 130 | distributions, 131 | end_time=end_time) 132 | return trace_df 133 | 134 | 135 | def generate_trace_from_prompt_token_size_distributions( 136 | max_requests, 137 | end_time, 138 | request_rate, 139 | pt_distributions_filename): 140 | """ 141 | Generate request traces for the simulator using prompt and token 142 | size distributions. 143 | """ 144 | distributions = Distributions( 145 | application_id=Distribution("constant", {"value": 0}), 146 | request_type=Distribution("constant", {"value": 2}), # 2 is for LLM inference 147 | arrival_process=Distribution("exponential", {"scale": 1.0 / request_rate}), 148 | prompt_size=Distribution("trace", {"filename": pt_distributions_filename, 149 | "column": "ContextTokens"}), 150 | #prompt_size=Distribution("truncnorm", {"a": (prompt_min-prompt_mean)/prompt_std, 151 | # "b": (prompt_max-prompt_mean)/prompt_std, 152 | # "loc": prompt_mean, 153 | # "scale": prompt_std}), 154 | token_size=Distribution("trace", {"filename": pt_distributions_filename, 155 | "column": "GeneratedTokens"}), 156 | #token_size=Distribution("truncnorm", {"a": (token_min-token_mean)/token_std, 157 | # "b": (token_max-token_mean)/token_std, 158 | # "loc": token_mean, 159 | # "scale": token_std}), 160 | batch_size=Distribution("constant", {"value": 1}), 161 | ) 162 | trace_df = generate_trace(max_requests, 163 | distributions, 164 | end_time=end_time) 165 | return trace_df 166 | 167 | 168 | def generate_traces(max_requests, 169 | end_time, 170 | request_rates, 171 | pt_distributions_file, 172 | trace_filename_template): 173 | """ 174 | Generate traces with prompt/token size distributions. 175 | """ 176 | for request_rate in request_rates: 177 | trace_df = generate_trace_from_prompt_token_size_distributions( 178 | max_requests, 179 | end_time, 180 | request_rate, 181 | pt_distributions_file) 182 | trace_filename = trace_filename_template.format(request_rate) 183 | trace_df.to_csv(trace_filename, index=False) 184 | 185 | 186 | def generate_code_traces( 187 | max_requests, 188 | end_time, 189 | request_rates, 190 | code_distributions_file, 191 | trace_filename_template="traces/rr_code_{}.csv"): 192 | """ 193 | code traces distribution 194 | prompt_mean = 2048, prompt_std = 1973, prompt_min = 3, prompt_max = 7437 195 | token_mean = 28, token_std = 60, token_min = 6, token_max = 1899 196 | """ 197 | if not os.path.exists(trace_filename_template[:trace_filename_template.rfind("/")]): 198 | os.makedirs(trace_filename_template[:trace_filename_template.rfind("/")]) 199 | 200 | generate_traces(max_requests, 201 | end_time, 202 | request_rates, 203 | code_distributions_file, 204 | trace_filename_template) 205 | 206 | 207 | def generate_conv_traces( 208 | max_requests, 209 | end_time, 210 | request_rates, 211 | conv_distributions_file, 212 | trace_filename_template="traces/rr_conv_{}.csv"): 213 | """ 214 | conv traces distribution 215 | prompt_mean = 1155, prompt_std = 1109, prompt_min = 2, prompt_max = 14050 216 | token_mean = 211, token_std = 163, token_min = 7, token_max = 1000 217 | """ 218 | if not os.path.exists(trace_filename_template[:trace_filename_template.rfind("/")]): 219 | os.makedirs(trace_filename_template[:trace_filename_template.rfind("/")]) 220 | 221 | generate_traces(max_requests, 222 | end_time, 223 | request_rates, 224 | conv_distributions_file, 225 | trace_filename_template) 226 | 227 | 228 | def download_file(url, filename): 229 | """ 230 | Download a file from the given URL. 231 | """ 232 | response = requests.get(url) 233 | with open(filename, "wb") as f: 234 | f.write(response.content) 235 | 236 | 237 | def download_azure_llm_traces(): 238 | """ 239 | Download traces from the given URL. 240 | """ 241 | if not os.path.exists("data"): 242 | os.makedirs("data") 243 | 244 | url_base = "https://raw.githubusercontent.com/Azure/AzurePublicDataset/master/data/" 245 | 246 | if not os.path.exists("data/code_distributions.csv"): 247 | url = url_base + "AzureLLMInferenceTrace_code.csv" 248 | download_file(url, "data/code_distributions.csv") 249 | print("Downloaded code traces") 250 | 251 | if not os.path.exists("data/conv_distributions.csv"): 252 | url = url_base + "AzureLLMInferenceTrace_conv.csv" 253 | download_file(url, "data/conv_distributions.csv") 254 | print("Downloaded conv traces") 255 | 256 | 257 | if __name__ == "__main__": 258 | # download prompt and token size distributions 259 | download_azure_llm_traces() 260 | 261 | # generate request traces 262 | generate_code_traces( 263 | max_requests=1000000, 264 | end_time=600, 265 | request_rates=list(range(30, 251, 10)), 266 | code_distributions_file="data/code_distributions.csv") 267 | print("Generated code traces") 268 | 269 | generate_conv_traces( 270 | max_requests=1000000, 271 | end_time=600, 272 | request_rates=list(range(30, 251, 10)), 273 | conv_distributions_file="data/conv_distributions.csv") 274 | print("Generated conv traces") 275 | 276 | # generate request traces for 2 min 277 | generate_code_traces( 278 | max_requests=1000000, 279 | end_time=120, 280 | request_rates=list(range(30, 101, 10)), 281 | code_distributions_file="data/code_distributions.csv", 282 | trace_filename_template="traces/rr_code_{}_2min.csv") 283 | print("Generated code 2min traces") 284 | 285 | generate_conv_traces( 286 | max_requests=1000000, 287 | end_time=120, 288 | request_rates=list(range(30, 101, 10)), 289 | conv_distributions_file="data/conv_distributions.csv", 290 | trace_filename_template="traces/rr_conv_{}_2min.csv") 291 | print("Generated conv 2min traces") 292 | -------------------------------------------------------------------------------- /hardware_repo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from hydra.utils import instantiate 5 | 6 | import utils 7 | 8 | # needed by hydra instantiate 9 | import processor 10 | import interconnect 11 | 12 | 13 | hardware_repo = None 14 | 15 | 16 | class HardwareRepo(): 17 | """ 18 | Repository of all hardware configs for dynamic instantiation. 19 | """ 20 | def __init__(self, 21 | processors_path, 22 | interconnects_path, 23 | skus_path): 24 | global hardware_repo 25 | hardware_repo = self 26 | self.processor_configs = self.get_processor_configs(processors_path) 27 | self.sku_configs = self.get_sku_configs(skus_path) 28 | self.interconnect_configs = self.get_interconnect_configs( 29 | interconnects_path) 30 | 31 | def get_sku_configs(self, skus_path): 32 | return utils.read_all_yaml_cfgs(skus_path) 33 | 34 | def get_processor_configs(self, processors_path): 35 | return utils.read_all_yaml_cfgs(processors_path) 36 | 37 | def get_interconnect_configs(self, interconnects_path): 38 | return utils.read_all_yaml_cfgs(interconnects_path) 39 | 40 | def get_processor(self, processor_name): 41 | cfg = self.processor_configs[processor_name] 42 | return instantiate(cfg) 43 | 44 | def get_interconnect(self, interconnect_name): 45 | cfg = self.interconnect_configs[interconnect_name] 46 | return instantiate(cfg) 47 | 48 | def get_sku_config(self, sku_name): 49 | return self.sku_configs[sku_name] 50 | 51 | 52 | get_processor = lambda *args,**kwargs: \ 53 | hardware_repo.get_processor(*args, **kwargs) 54 | get_interconnect = lambda *args,**kwargs: \ 55 | hardware_repo.get_interconnect(*args, **kwargs) 56 | get_sku_config = lambda *args,**kwargs: \ 57 | hardware_repo.get_sku_config(*args, **kwargs) 58 | -------------------------------------------------------------------------------- /initialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for initializing the simulation environment. 3 | """ 4 | 5 | import logging 6 | import os 7 | 8 | from hydra.utils import instantiate 9 | from hydra.utils import get_original_cwd 10 | 11 | from application import Application 12 | from cluster import Cluster 13 | from hardware_repo import HardwareRepo 14 | from model_repo import ModelRepo 15 | from orchestrator_repo import OrchestratorRepo 16 | from start_state import load_start_state 17 | from trace import Trace 18 | 19 | 20 | def init_trace(cfg): 21 | trace_path = os.path.join(get_original_cwd(), cfg.trace.path) 22 | trace = Trace.from_csv(trace_path) 23 | return trace 24 | 25 | 26 | def init_hardware_repo(cfg): 27 | processors_path = os.path.join(get_original_cwd(), 28 | cfg.hardware_repo.processors) 29 | interconnects_path = os.path.join(get_original_cwd(), 30 | cfg.hardware_repo.interconnects) 31 | skus_path = os.path.join(get_original_cwd(), 32 | cfg.hardware_repo.skus) 33 | hardware_repo = HardwareRepo(processors_path, 34 | interconnects_path, 35 | skus_path) 36 | return hardware_repo 37 | 38 | 39 | def init_model_repo(cfg): 40 | model_architectures_path = os.path.join(get_original_cwd(), 41 | cfg.model_repo.architectures) 42 | model_sizes_path = os.path.join(get_original_cwd(), 43 | cfg.model_repo.sizes) 44 | model_repo = ModelRepo(model_architectures_path, model_sizes_path) 45 | return model_repo 46 | 47 | 48 | def init_orchestrator_repo(cfg): 49 | allocators_path = os.path.join(get_original_cwd(), 50 | cfg.orchestrator_repo.allocators) 51 | schedulers_path = os.path.join(get_original_cwd(), 52 | cfg.orchestrator_repo.schedulers) 53 | orchestrator_repo = OrchestratorRepo(allocators_path, schedulers_path) 54 | return orchestrator_repo 55 | 56 | 57 | def init_performance_model(cfg): 58 | performance_model = instantiate(cfg.performance_model) 59 | return performance_model 60 | 61 | 62 | def init_power_model(cfg): 63 | power_model = instantiate(cfg.power_model) 64 | return power_model 65 | 66 | 67 | def init_cluster(cfg): 68 | cluster = Cluster.from_config(cfg.cluster) 69 | return cluster 70 | 71 | 72 | def init_router(cfg, cluster): 73 | router = instantiate(cfg.router, cluster=cluster) 74 | return router 75 | 76 | 77 | def init_arbiter(cfg, cluster): 78 | arbiter = instantiate(cfg.arbiter, cluster=cluster) 79 | return arbiter 80 | 81 | 82 | def init_applications(cfg, cluster, router, arbiter): 83 | applications = {} 84 | for application_cfg in cfg.applications: 85 | application = Application.from_config(application_cfg, 86 | cluster=cluster, 87 | router=router, 88 | arbiter=arbiter) 89 | applications[application_cfg.application_id] = application 90 | return applications 91 | 92 | 93 | def init_start_state(cfg, **kwargs): 94 | load_start_state(cfg.start_state, **kwargs) 95 | 96 | 97 | if __name__ == "__main__": 98 | pass 99 | -------------------------------------------------------------------------------- /interconnect.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from dataclasses import dataclass, field 4 | from enum import IntEnum 5 | 6 | from flow import Flow 7 | from processor import CPU, GPU 8 | from simulator import clock, schedule_event, cancel_event, reschedule_event 9 | from server import Server 10 | 11 | 12 | class LinkType(IntEnum): 13 | DEFAULT = 0 14 | PCIeLink = 1 15 | EthernetLink = 2 16 | IBLink = 3 17 | NVLink = 4 18 | RDMADirectLink = 5 19 | DummyLink = 6 20 | 21 | 22 | @dataclass(kw_only=True) 23 | class Link(): 24 | """ 25 | Links are unidirectional edges in the cluster interconnect topology graph. 26 | They are the lowest-level networking equivalent of Processors. 27 | Instead of Tasks, Links can run (potentially multiple) Flows. 28 | Links have a maximum bandwidth they can support, after which point they become congested. 29 | 30 | TODO: replace with a higher-fidelity network model (e.g., ns-3). 31 | 32 | Attributes: 33 | link_type (LinkType): Type of the Link (e.g., NVLink, IB, etc). 34 | src (object): Source endpoint 35 | dest (object): Destination endpoint 36 | bandwidth (float): The maximum bandwidth supported by the Link. 37 | bandwidth_used (float): The bandwidth used by the Link. 38 | server (Server): The Server that the Processor belongs to. 39 | flows (List[Flow]): Flows running on this Link. 40 | max_flows (int): Maximum number of flows that can run in parallel on the link. 41 | """ 42 | link_type: LinkType = LinkType.DEFAULT 43 | name: str 44 | src: object 45 | dest: object 46 | bandwidth: float 47 | bandwidth_used: float 48 | _bandwidth_used: float = 0 49 | max_flows: int 50 | retry: bool = True 51 | retry_delay: float = 1. 52 | overheads: dict = field(default_factory=dict) 53 | 54 | # queues 55 | pending_queue: list[Flow] = field(default_factory=list) 56 | executing_queue: list[Flow] = field(default_factory=list) 57 | completed_queue: list[Flow] = field(default_factory=list) 58 | 59 | @property 60 | def bandwidth_used(self): 61 | return self._bandwidth_used 62 | 63 | @bandwidth_used.setter 64 | def bandwidth_used(self, bandwidth_used): 65 | if type(bandwidth_used) is property: 66 | bandwidth_used = 0 67 | if bandwidth_used < 0: 68 | raise ValueError("Bandwidth used cannot be negative") 69 | elif bandwidth_used > self.bandwidth: 70 | raise ValueError("Cannot exceed link bandwidth") 71 | self._bandwidth_used = bandwidth_used 72 | 73 | @property 74 | def bandwidth_free(self): 75 | return self.bandwidth - self.bandwidth_used 76 | 77 | @property 78 | def peers(self): 79 | pass 80 | 81 | def flow_arrival(self, flow): 82 | """ 83 | Flow arrives at the Link. 84 | """ 85 | flow.instance = self 86 | flow.arrive() 87 | self.pending_queue.append(flow) 88 | if len(self.pending_queue) > 0 and len(self.executing_queue) < self.max_flows: 89 | if flow.dest.memory + flow.request.memory <= flow.dest.max_memory: 90 | self.run_flow(flow) 91 | elif self.retry: 92 | schedule_event(self.retry_delay, lambda link=self,flow=flow: link.retry_flow(flow)) 93 | else: 94 | # will lead to OOM 95 | self.run_flow(flow) 96 | 97 | def flow_completion(self, flow): 98 | """ 99 | Flow completes on this Link. 100 | """ 101 | flow.complete() 102 | self.executing_queue.remove(flow) 103 | self.completed_queue.append(flow) 104 | flow.executor.finish_flow(flow, self) 105 | if flow.notify: 106 | flow.src.notify_flow_completion(flow) 107 | self.bandwidth_used -= (self.bandwidth - self.bandwidth_used) 108 | if len(self.pending_queue) > 0 and len(self.executing_queue) < self.max_flows: 109 | next_flow = self.pending_queue[0] 110 | if next_flow.dest.memory + next_flow.request.memory <= next_flow.dest.max_memory: 111 | self.run_flow(next_flow) 112 | elif self.retry: 113 | schedule_event(self.retry_delay, lambda link=self,flow=flow: link.retry_flow(flow)) 114 | else: 115 | # will lead to OOM 116 | self.run_flow(next_flow) 117 | 118 | def retry_flow(self, flow): 119 | """ 120 | Flow is retried on this Link. 121 | """ 122 | if flow not in self.pending_queue: 123 | return 124 | if (len(self.executing_queue) < self.max_flows) and (flow.dest.memory + flow.request.memory <= flow.dest.max_memory): 125 | self.run_flow(flow) 126 | elif self.retry: 127 | schedule_event(self.retry_delay, lambda link=self,flow=flow: link.retry_flow(flow)) 128 | else: 129 | # will lead to OOM 130 | self.run_flow(flow) 131 | 132 | def get_duration(self, flow): 133 | """ 134 | FIXME: this can be shorter than prompt duration 135 | """ 136 | return flow.size / (self.bandwidth - self.bandwidth_used) 137 | 138 | def run_flow(self, flow): 139 | """ 140 | Run a Flow on this Link. 141 | """ 142 | flow.run() 143 | self.pending_queue.remove(flow) 144 | self.executing_queue.append(flow) 145 | flow.duration = self.get_duration(flow) 146 | # TODO: policy on how to allocate bandwidth to multiple flows 147 | self.bandwidth_used += (self.bandwidth - self.bandwidth_used) 148 | schedule_event(flow.duration, 149 | lambda link=self,flow=flow: link.flow_completion(flow)) 150 | 151 | def preempt_flow(self, flow): 152 | """ 153 | Preempt a flow on this Link. 154 | """ 155 | flow.preempt() 156 | raise NotImplementedError 157 | 158 | 159 | @dataclass(kw_only=True) 160 | class PCIeLink(Link): 161 | """ 162 | PCIeLink is a specific type of Link between CPUs and GPUs. 163 | """ 164 | link_type: LinkType = LinkType.PCIeLink 165 | src: CPU 166 | dest: GPU 167 | 168 | 169 | @dataclass(kw_only=True) 170 | class EthernetLink(Link): 171 | """ 172 | EthernetLink is standard Ethernet between Servers. 173 | """ 174 | link_type: LinkType = LinkType.EthernetLink 175 | src: Server 176 | dest: Server 177 | 178 | 179 | @dataclass(kw_only=True) 180 | class IBLink(Link): 181 | """ 182 | IBLink is the Infiniband Link between Servers. 183 | """ 184 | link_type: LinkType = LinkType.IBLink 185 | src: Server 186 | dest: Server 187 | 188 | 189 | @dataclass(kw_only=True) 190 | class NVLink(Link): 191 | """ 192 | NVLink is a specific type of Link between GPUs. 193 | """ 194 | link_type: LinkType = LinkType.NVLink 195 | src: GPU 196 | dest: GPU 197 | 198 | 199 | @dataclass(kw_only=True) 200 | class RDMADirectLink(Link): 201 | """ 202 | RDMADirect is the Infiniband link between GPUs across/within Servers. 203 | """ 204 | link_type: LinkType = LinkType.RDMADirectLink 205 | src: GPU 206 | dest: GPU 207 | 208 | 209 | @dataclass(kw_only=True) 210 | class DummyLink(Link): 211 | """ 212 | A Link whose bandwidth is never actually used and can hold infinite flows. 213 | Used to simulate delay. 214 | """ 215 | link_type: LinkType = LinkType.DummyLink 216 | src: object = None 217 | dest: object = None 218 | max_flows: float = float("inf") 219 | 220 | @property 221 | def bandwidth_used(self): 222 | return self._bandwidth_used 223 | 224 | @bandwidth_used.setter 225 | def bandwidth_used(self, bandwidth_used): 226 | return 227 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from dataclasses import dataclass, field 4 | 5 | 6 | @dataclass(kw_only=True) 7 | class NodeMetrics(): 8 | arrival_timestamp: float = 0. 9 | start_timestamp: float = 0. 10 | completion_timestamp: float = 0. 11 | run_timestamp: float = 0. 12 | preempt_timestamp: float = 0. 13 | queue_time: float = 0. 14 | blocked_time: float = 0. 15 | service_time: float = 0. 16 | response_time: float = 0. 17 | 18 | 19 | @dataclass(kw_only=True) 20 | class FlowMetrics(NodeMetrics): 21 | pass 22 | 23 | 24 | @dataclass(kw_only=True) 25 | class TaskMetrics(NodeMetrics): 26 | pass 27 | 28 | 29 | @dataclass(kw_only=True) 30 | class RequestMetrics(): 31 | request_id: str = '' 32 | router_arrival_timestamp: float = 0. 33 | scheduler_arrival_timestamp: float = 0. 34 | executor_start_timestamp: float = 0. 35 | scheduler_completion_timestamp: float = 0. 36 | router_completion_timestamp: float = 0. 37 | router_queue_time: float = 0. 38 | scheduler_queue_time: float = 0. 39 | queue_time: float = 0. 40 | service_time: float = 0. 41 | scheduler_response_time: float = 0. 42 | router_response_time: float = 0. 43 | 44 | 45 | @dataclass(kw_only=True) 46 | class GenerativeLLMRequestMetrics(RequestMetrics): 47 | prompt_start_timestamp: float = 0. 48 | prompt_end_timestamp: float = 0. 49 | token_start_timestamp: float = 0. 50 | token_end_timestamp: float = 0. 51 | TTFT: float = 0. 52 | 53 | 54 | @dataclass(kw_only=True) 55 | class InstanceMetrics(): 56 | spin_up_timestamp: float = 0. 57 | run_timestamp: float = 0. 58 | spin_down_timestamp: float = 0. 59 | busy_time: float = 0. 60 | interval_time: float = 0. 61 | 62 | 63 | @dataclass(kw_only=True) 64 | class ApplicationMetrics(): 65 | num_requests: int = 0 66 | num_tasks: int = 0 67 | service_times: list[float] = field(default_factory=list) 68 | response_times: list[float] = field(default_factory=list) 69 | 70 | 71 | @dataclass(kw_only=True) 72 | class RouterMetrics(): 73 | pass 74 | 75 | 76 | @dataclass(kw_only=True) 77 | class ArbiterMetrics(): 78 | pass 79 | 80 | 81 | @dataclass(kw_only=True) 82 | class ServerMetrics(): 83 | pass 84 | 85 | 86 | @dataclass(kw_only=True) 87 | class NodeSLO(): 88 | latency: float = 0. 89 | 90 | 91 | @dataclass(kw_only=True) 92 | class TaskSLO(NodeSLO): 93 | """ 94 | TaskSLOs capture any SLOs that are specific to a task. 95 | """ 96 | pass 97 | 98 | 99 | @dataclass(kw_only=True) 100 | class FlowSLO(NodeSLO): 101 | """ 102 | FlowSLOs capture any SLOs that are specific to a task. 103 | """ 104 | pass 105 | 106 | 107 | @dataclass(kw_only=True) 108 | class RequestSLO(): 109 | """ 110 | RequestSLO captures any SLOs that are specific to a single request. 111 | """ 112 | TTFT: float = float('inf') 113 | e2e_latency: float = float('inf') 114 | 115 | 116 | @dataclass(kw_only=True) 117 | class ApplicationSLO(): 118 | """ 119 | ApplicationSLO captures any SLOs that apply across all application requests. 120 | """ 121 | TTFT: float = float('inf') 122 | per_token_latency: float = float('inf') 123 | throughput: float = 0. 124 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from dataclasses import dataclass 4 | 5 | 6 | @dataclass(kw_only=True) 7 | class ModelArchitecture(): 8 | name: str 9 | num_layers: int 10 | 11 | 12 | @dataclass(kw_only=True) 13 | class LLMArchitecture(ModelArchitecture): 14 | hidden_size: int 15 | num_heads: int 16 | 17 | 18 | @dataclass(kw_only=True) 19 | class ModelParallelism(): 20 | """ 21 | Captures the different parallelisms of a Model. 22 | """ 23 | pipeline_parallelism: int 24 | tensor_parallelism: int 25 | 26 | @property 27 | def num_processors(self): 28 | """ 29 | The number of GPUs required is the product of the parallelisms. 30 | """ 31 | return self.pipeline_parallelism * self.tensor_parallelism 32 | 33 | 34 | @dataclass(kw_only=True) 35 | class ModelSize(): 36 | """ 37 | Captures the various sizes of a Model. 38 | """ 39 | weights: int 40 | dtype_size: int 41 | 42 | @property 43 | def total_size(self): 44 | return self.weights 45 | 46 | 47 | @dataclass(kw_only=True) 48 | class Model(): 49 | name: str 50 | architecture: ModelArchitecture 51 | parallelism: ModelParallelism 52 | size: ModelSize 53 | 54 | @property 55 | def size_per_processor(self): 56 | return self.size.total_size / self.parallelism.num_processors 57 | 58 | 59 | @dataclass(kw_only=True) 60 | class GenerativeLLM(Model): 61 | """ 62 | Generative Large Language Model. 63 | NOTE: We currently don't capture embeddings, variable context lengths, etc. 64 | """ 65 | context_size: int = 0 66 | -------------------------------------------------------------------------------- /model_repo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import utils 4 | 5 | from hydra.utils import instantiate 6 | 7 | from model import Model, GenerativeLLM 8 | 9 | # needed by hydra instantiate 10 | import model 11 | 12 | model_repo = None 13 | 14 | 15 | class ModelRepo(): 16 | """ 17 | Repository of model configurations for dynamic instantiation. 18 | """ 19 | def __init__(self, 20 | model_architectures_path, 21 | model_sizes_path): 22 | global model_repo 23 | model_repo = self 24 | self.model_architecture_configs = self.get_model_architecture_configs( 25 | model_architectures_path) 26 | self.model_size_configs = self.get_model_size_configs(model_sizes_path) 27 | 28 | def get_model_architecture_configs(self, model_architectures_path): 29 | return utils.read_all_yaml_cfgs(model_architectures_path) 30 | 31 | def get_model_size_configs(self, model_sizes_path): 32 | return utils.read_all_yaml_cfgs(model_sizes_path) 33 | 34 | def get_model_architecture(self, model_architecture_name): 35 | cfg = self.model_architecture_configs[model_architecture_name] 36 | return instantiate(cfg) 37 | 38 | def get_model_size(self, model_size_name): 39 | cfg = self.model_size_configs[model_size_name] 40 | return instantiate(cfg) 41 | 42 | def get_model(self, model_architecture, model_size, model_parallelism): 43 | return GenerativeLLM(name=model_architecture.name, 44 | architecture=model_architecture, 45 | size=model_size, 46 | parallelism=model_parallelism) 47 | 48 | 49 | get_model_architecture = lambda *args,**kwargs: \ 50 | model_repo.get_model_architecture(*args, **kwargs) 51 | get_model_size = lambda *args,**kwargs: \ 52 | model_repo.get_model_size(*args, **kwargs) 53 | get_model = lambda *args,**kwargs: \ 54 | model_repo.get_model(*args, **kwargs) 55 | -------------------------------------------------------------------------------- /node.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from dataclasses import dataclass, field 4 | from enum import IntEnum 5 | 6 | from metrics import NodeMetrics 7 | from simulator import clock, schedule_event, cancel_event, reschedule_event 8 | 9 | 10 | class NodeState(IntEnum): 11 | NONE = 0 12 | QUEUED = 1 13 | RUNNING = 2 14 | BLOCKED = 3 15 | COMPLETED = 4 16 | ABORTED = 5 17 | 18 | 19 | @dataclass(kw_only=True) 20 | class Node(): 21 | """ 22 | Base class for Tasks and Nodes in a Request 23 | Simplest element of the Request DAG 24 | """ 25 | node_id: int 26 | num_preemptions: int = 0 27 | request: 'Request' = None 28 | state: NodeState = NodeState.NONE 29 | metrics: NodeMetrics = field(default_factory=NodeMetrics) 30 | # chain of nodes that must be executed back-to-back 31 | # only stored in the first node of the chain 32 | chain: list = field(default_factory=list) 33 | 34 | def __hash__(self): 35 | """ 36 | NOTE: hash functions get overridden to None in child classes 37 | """ 38 | return hash(self.node_id) 39 | 40 | def __eq__(self, other): 41 | return self.node_id == other.node_id 42 | 43 | def arrive(self): 44 | assert self.state == NodeState.NONE 45 | self.metrics.arrival_timestamp = clock() 46 | self.state = NodeState.QUEUED 47 | 48 | def run(self): 49 | assert self.state == NodeState.QUEUED 50 | self.metrics.run_timestamp = clock() 51 | self.metrics.start_timestamp = clock() 52 | self.metrics.queue_time += clock() - self.metrics.arrival_timestamp 53 | if self.request.root_node is self: 54 | self.request.metrics.prompt_start_timestamp = clock() 55 | self.request.metrics.queue_time = clock() - \ 56 | self.request.metrics.router_arrival_timestamp 57 | self.state = NodeState.RUNNING 58 | 59 | def run_after_preempt(self): 60 | assert self.state == NodeState.BLOCKED 61 | self.metrics.run_timestamp = clock() 62 | self.metrics.blocked_time += clock() - self.metrics.preempt_timestamp 63 | self.state = NodeState.RUNNING 64 | 65 | def complete(self): 66 | assert self.state == NodeState.RUNNING 67 | self.metrics.completion_timestamp = clock() 68 | self.metrics.service_time += clock() - self.metrics.run_timestamp 69 | self.metrics.response_time = clock() - self.metrics.arrival_timestamp 70 | self.state = NodeState.COMPLETED 71 | 72 | def preempt(self): 73 | assert self.state == NodeState.RUNNING 74 | self.metrics.preempt_timestamp = clock() 75 | self.metrics.service_time += clock() - self.metrics.run_timestamp 76 | self.state = NodeState.BLOCKED 77 | 78 | def abort(self): 79 | if self.state == NodeState.QUEUED: 80 | self.metrics.queue_time += clock() - self.metrics.arrival_timestamp 81 | if self.request.root_node is self: 82 | self.request.metrics.queue_time = clock() - \ 83 | self.request.metrics.router_arrival_timestamp 84 | elif self.state == NodeState.RUNNING: 85 | self.metrics.service_time += clock() - self.metrics.run_timestamp 86 | elif self.state == NodeState.BLOCKED: 87 | self.metrics.blocked_time += clock() - self.metrics.preempt_timestamp 88 | self.state = NodeState.ABORTED 89 | -------------------------------------------------------------------------------- /notebooks/perf_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pandas as pd 4 | 5 | from scipy.interpolate import interp1d 6 | from sklearn.metrics import mean_absolute_percentage_error 7 | 8 | 9 | class PerfModel: 10 | """ 11 | Performance model independent of the simulator. 12 | TODO: reuse code from simulator. 13 | """ 14 | def __init__(self, db_path, init=False): 15 | self.db = pd.read_csv(db_path, 16 | dtype={"model": "category", "hardware": "category"}) 17 | 18 | # ensure the database has the correct columns 19 | # and remove extraneous columns 20 | self.db = self.db[["model", 21 | "hardware", 22 | "tensor_parallel", 23 | "prompt_size", 24 | "batch_size", 25 | "token_size", 26 | "prompt_time", 27 | "token_time"]] 28 | 29 | # convert to seconds 30 | self.db["prompt_time"] = self.db["prompt_time"] / 1000 31 | self.db["token_time"] = self.db["token_time"] / 1000 32 | 33 | if init: 34 | self.init_predictor_numtokens() 35 | 36 | def init_predictor_numtokens(self): 37 | """ 38 | Predict using number of tokens in the batch. 39 | """ 40 | self.prompt_time_predictors = {} 41 | self.token_time_predictors = {} 42 | self.prompt_time_cache = {} 43 | self.token_time_cache = {} 44 | 45 | for model in self.db["model"].unique(): 46 | for hardware in self.db["hardware"].unique(): 47 | for tensor_parallel in self.db["tensor_parallel"].unique(): 48 | mask = (self.db["model"] == model) & (self.db["hardware"] == hardware) & (self.db["tensor_parallel"] == tensor_parallel) 49 | db_subset = self.db[mask].copy() 50 | if len(db_subset) == 0: 51 | continue 52 | db_subset["batch_tokens"] = db_subset["prompt_size"] * db_subset["batch_size"] 53 | x = db_subset[["batch_tokens", "prompt_time"]].groupby("batch_tokens").median().index 54 | y = db_subset[["batch_tokens", "prompt_time"]].groupby("batch_tokens").median()["prompt_time"] 55 | self.prompt_time_predictors[(model, hardware, tensor_parallel)] = interp1d( 56 | x, y, fill_value="extrapolate") 57 | x = db_subset[["batch_tokens", "token_time"]].groupby("batch_tokens").median().index 58 | y = db_subset[["batch_tokens", "token_time"]].groupby("batch_tokens").median()["token_time"] 59 | self.token_time_predictors[(model, hardware, tensor_parallel)] = interp1d( 60 | x, y, fill_value="extrapolate") 61 | 62 | def get_prompt_time(self, model, hardware, tensor_parallel, batch_tokens): 63 | prompt_time = self.prompt_time_cache.get((model, hardware, tensor_parallel, batch_tokens), None) 64 | if prompt_time is None: 65 | prompt_time = float(self.prompt_time_predictors[(model, hardware, tensor_parallel)](batch_tokens)) 66 | self.prompt_time_cache[(model, hardware, tensor_parallel, batch_tokens)] = float(prompt_time) 67 | return prompt_time 68 | 69 | def get_token_time(self, model, hardware, tensor_parallel, batch_tokens): 70 | token_time = self.token_time_cache.get((model, hardware, tensor_parallel, batch_tokens), None) 71 | if token_time is None: 72 | token_time = float(self.token_time_predictors[(model, hardware, tensor_parallel)](batch_tokens)) 73 | self.token_time_cache[(model, hardware, tensor_parallel, batch_tokens)] = float(token_time) 74 | return token_time 75 | 76 | def add_baseline_perf(self, 77 | request_df, 78 | model="bloom-176b", 79 | hardware="a100-80gb", 80 | tensor_parallel=8): 81 | """ 82 | Normalize request_df ttft and tbt wrt the model, hardware, and tensor_parallel. 83 | Applies the get_prompt_time and get_token_time functions. 84 | """ 85 | request_df["baseline_ttft"] = request_df.apply(lambda row: 86 | self.get_prompt_time(model, hardware, tensor_parallel, row["prompt_sizes"]), axis=1) 87 | request_df["baseline_tbt"] = request_df.apply(lambda row: 88 | self.get_token_time(model, hardware, tensor_parallel, row["prompt_sizes"]), axis=1) 89 | return request_df 90 | 91 | @staticmethod 92 | def validate_model(db_path, train_test_split=0.8): 93 | """ 94 | Validate the perf model. 95 | """ 96 | perf_model = PerfModel(db_path, init=False) 97 | db = perf_model.db 98 | 99 | # split the data 100 | train_size = int(train_test_split * len(db)) 101 | 102 | # randomize the data 103 | db = db.sample(frac=1) 104 | train_db = db.iloc[:train_size] 105 | test_db = db.iloc[train_size:] 106 | 107 | # initialize the model 108 | perf_model.db = train_db 109 | perf_model.init_predictor_numtokens() 110 | 111 | # validate the model 112 | mape = [] 113 | for i, row in test_db.iterrows(): 114 | prompt_time = perf_model.get_prompt_time(row["model"], 115 | row["hardware"], 116 | row["tensor_parallel"], 117 | row["prompt_size"] * row["batch_size"]) 118 | token_time = perf_model.get_token_time(row["model"], 119 | row["hardware"], 120 | row["tensor_parallel"], 121 | row["prompt_size"] * row["batch_size"]) 122 | mape.append(mean_absolute_percentage_error([row["prompt_time"], row["token_time"]], 123 | [prompt_time, token_time])) 124 | return mape 125 | -------------------------------------------------------------------------------- /notebooks/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for the notebooks. 3 | """ 4 | import os 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | 10 | def baseline_a100_config(num_a100, 11 | start_state="baseline", 12 | scheduler="token_jsq", 13 | h100_cost=4.76, 14 | h100_power=44, 15 | a100_cost=2.21, 16 | a100_power=24.8): 17 | config = { 18 | "name": f"Baseline-A100 ({num_a100}P/T)", 19 | "system": "Baseline-A100", 20 | "scheduler": f"{scheduler}", 21 | "start_state": start_state, 22 | "cluster": f"{num_a100}_0", 23 | "num_servers": num_a100, 24 | "num_a100": num_a100, 25 | "num_h100": 0, 26 | "num_prompts": num_a100, 27 | "num_tokens": num_a100, 28 | "cost": num_a100 * a100_cost, 29 | "power": num_a100 * a100_power, 30 | } 31 | return config 32 | 33 | def baseline_h100_config(num_h100, 34 | start_state="baseline", 35 | scheduler="token_jsq", 36 | h100_cost=4.76, 37 | h100_power=44, 38 | a100_cost=2.21, 39 | a100_power=24.8): 40 | config = { 41 | "name": f"Baseline-H100 ({num_h100}P/T)", 42 | "system": "Baseline-H100", 43 | "scheduler": f"{scheduler}", 44 | "start_state": start_state, 45 | "cluster": f"0_{num_h100}", 46 | "num_servers": num_h100, 47 | "num_a100": 0, 48 | "num_h100": num_h100, 49 | "num_prompts": num_h100, 50 | "num_tokens": num_h100, 51 | "cost": num_h100 * h100_cost, 52 | "power": num_h100 * h100_power, 53 | } 54 | return config 55 | 56 | def splitwise_ha_config(num_prompt, 57 | num_token, 58 | start_state="splitwise", 59 | scheduler="mixed_pool", 60 | h100_cost=4.76, 61 | h100_power=44, 62 | a100_cost=2.21, 63 | a100_power=24.8): 64 | num_h100 = num_prompt 65 | num_a100 = num_token 66 | config = { 67 | "name": f"Splitwise-HA ({num_prompt}P, {num_token}T)", 68 | "system": "Splitwise-HA", 69 | "scheduler": f"{scheduler}", 70 | "start_state": f"{start_state}_1_1", 71 | "cluster": f"{num_token}_{num_prompt}", 72 | "num_servers": num_token + num_prompt, 73 | "num_a100": num_token, 74 | "num_h100": num_prompt, 75 | "num_prompts": num_prompt, 76 | "num_tokens": num_token, 77 | "cost": num_h100 * h100_cost + num_a100 * a100_cost, 78 | "power": num_h100 * h100_power + num_a100 * a100_power, 79 | } 80 | return config 81 | 82 | def splitwise_aa_config(num_prompt, 83 | num_token, 84 | start_state="splitwise", 85 | scheduler="mixed_pool", 86 | h100_cost=4.76, 87 | h100_power=44, 88 | a100_cost=2.21, 89 | a100_power=24.8): 90 | num_a100 = num_prompt + num_token 91 | config = { 92 | "name": f"Splitwise-AA ({num_prompt}P, {num_token}T)", 93 | "system": "Splitwise-AA", 94 | "scheduler": f"{scheduler}", 95 | "start_state": f"{start_state}_{num_prompt}_{num_token}", 96 | "cluster": f"{num_a100}_0", 97 | "num_servers": num_a100, 98 | "num_a100": num_a100, 99 | "num_h100": 0, 100 | "num_prompts": num_prompt, 101 | "num_tokens": num_token, 102 | "cost": num_a100 * a100_cost, 103 | "power": num_a100 * a100_power, 104 | } 105 | return config 106 | 107 | def splitwise_hh_config(num_prompt, 108 | num_token, 109 | start_state="splitwise", 110 | scheduler="mixed_pool", 111 | h100_cost=4.76, 112 | h100_power=44, 113 | a100_cost=2.21, 114 | a100_power=24.8): 115 | num_h100 = num_prompt + num_token 116 | config = { 117 | "name": f"Splitwise-HH ({num_prompt}P, {num_token}T)", 118 | "system": "Splitwise-HH", 119 | "scheduler": f"{scheduler}", 120 | "start_state": f"{start_state}_{num_prompt}_{num_token}", 121 | "cluster": f"0_{num_h100}", 122 | "num_servers": num_h100, 123 | "num_a100": 0, 124 | "num_h100": num_h100, 125 | "num_prompts": num_prompt, 126 | "num_tokens": num_token, 127 | "cost": num_h100 * h100_cost, 128 | "power": num_h100 * h100_power, 129 | } 130 | return config 131 | 132 | def splitwise_hhcap_config(num_prompt, 133 | num_token, 134 | start_state="splitwisehhcap", 135 | scheduler="mixed_pool", 136 | h100_cost=4.76, 137 | h100_power=44, 138 | a100_cost=2.21, 139 | a100_power=24.8, 140 | power_cap_scaler=0.7): 141 | num_h100 = num_prompt + num_token 142 | config = { 143 | "name": f"Splitwise-HHcap ({num_prompt}P, {num_token}T)", 144 | "system": "Splitwise-HHcap", 145 | "scheduler": f"{scheduler}", 146 | "start_state": f"{start_state}_1_1", 147 | "cluster": f"{num_token}_{num_prompt}", 148 | "num_servers": num_h100, 149 | "num_a100": 0, 150 | "num_h100": num_h100, 151 | "num_prompts": num_prompt, 152 | "num_tokens": num_token, 153 | "cost": num_h100 * h100_cost, 154 | "power": num_prompt * h100_power + num_token * h100_power * power_cap_scaler, 155 | } 156 | return config 157 | 158 | def get_summary_data(results_dir, scheduler, start_state, cluster, trace, seed, model=""): 159 | try: 160 | summary_df = pd.read_csv(f"{results_dir}/{seed}/{start_state}/{trace}/{cluster}/{model}/{scheduler}/summary.csv") 161 | except Exception as e: 162 | print(e) 163 | print(f"Failed to read {results_dir}/{seed}/{start_state}/{trace}/{cluster}/{model}/{scheduler}/summary.csv") 164 | return None 165 | return summary_df 166 | 167 | def get_request_data(results_dir, scheduler, start_state, cluster, trace, seed, model=""): 168 | try: 169 | request_df = pd.read_csv(f"{results_dir}/{seed}/{start_state}/{trace}/{cluster}/{model}/{scheduler}/detailed/0.csv") 170 | except: 171 | print(f"Failed to read {results_dir}/{seed}/{start_state}/{trace}/{cluster}/{model}/{scheduler}/detailed/0.csv") 172 | return None 173 | return request_df 174 | 175 | def get_request_nodes(results_dir, scheduler, start_state, cluster, trace, seed, model=""): 176 | try: 177 | request_nodes_df = pd.read_csv(f"{results_dir}/{seed}/{start_state}/{trace}/{cluster}/{model}/{scheduler}/request_nodes.csv") 178 | request_nodes_df["start_timestamp_dt"] = pd.to_datetime(request_nodes_df["start_timestamp"], unit="s") 179 | request_nodes_df["completion_timestamp_dt"] = pd.to_datetime(request_nodes_df["completion_timestamp"], unit="s") 180 | except: 181 | print(f"Failed to read {results_dir}/{seed}/{start_state}/{trace}/{cluster}/{model}/{scheduler}/request_nodes.csv") 182 | return None 183 | return request_nodes_df 184 | 185 | def get_instances_data(results_dir, scheduler, start_state, cluster, num_servers, trace, seed, model=""): 186 | try: 187 | instance_dfs = [] 188 | application_id = 0 189 | for idx in range(num_servers): 190 | filename = f"{results_dir}/{seed}/{start_state}/{trace}/{cluster}/{model}/{scheduler}/instances/{application_id}/{idx}.csv" 191 | filepath = os.path.join(results_dir, filename) 192 | df = pd.read_csv(filepath) 193 | df["iteration"] = range(len(df)) 194 | instance_dfs.append(df) 195 | instances_df = pd.concat(instance_dfs) 196 | instances_df["iteration_start_dt"] = pd.to_datetime(instances_df["iteration_start"], unit="s") 197 | instances_df["iteration_end_dt"] = pd.to_datetime(instances_df["iteration_end"], unit="s") 198 | instances_df["duration"] = (instances_df["iteration_end"] - instances_df["iteration_start"]) 199 | instances_df["memory"] /= 1024 * 1024 * 1024 200 | return instances_df 201 | except: 202 | print(f"Failed to read {results_dir}/{seed}/{start_state}/{trace}/{cluster}/{model}/{scheduler}/instances/0/*.csv") 203 | return None 204 | 205 | def get_num_batch_tokens_baseline(instances_df): 206 | num_batch_tokens = [] 207 | for row in instances_df.iterrows(): 208 | num_batch_tokens.extend(int(row[1]["num_contiguous_iterations"]) * [row[1]["batch_tokens"]]) 209 | return num_batch_tokens 210 | 211 | def get_num_batch_tokens_splitwise(instances_df): 212 | num_prompt_batch_tokens = [] 213 | num_token_batch_tokens = [] 214 | for row in instances_df.iterrows(): 215 | if row[1]["tag"] == "prompt": 216 | num_prompt_batch_tokens.extend(int(row[1]["num_contiguous_iterations"]) * [row[1]["batch_tokens"]]) 217 | else: 218 | num_token_batch_tokens.extend(int(row[1]["num_contiguous_iterations"]) * [row[1]["batch_tokens"]]) 219 | return num_prompt_batch_tokens, num_token_batch_tokens 220 | 221 | def get_time_duration_batch_tokens(instances_df): 222 | instances_df = instances_df.copy() 223 | return instances_df.groupby("batch_tokens").sum()["duration"] 224 | 225 | def count_token_on_prompt_servers(instances_df, request_nodes_df): 226 | prompt_nodes = instances_df[instances_df["tag"] == "prompt"]["name"].unique() 227 | count = len(request_nodes_df[(request_nodes_df["node_type"] == "TOKEN") & 228 | (request_nodes_df["runner"].isin(prompt_nodes))]) 229 | num_requests = request_nodes_df["request_id"].nunique() 230 | return count, num_requests, len(prompt_nodes) 231 | 232 | def get_summary_data_with_config(results_dir, config, trace, seed, model=""): 233 | scheduler = config["scheduler"] 234 | start_state = config["start_state"] 235 | cluster = config["cluster"] 236 | return get_summary_data(results_dir, scheduler, start_state, cluster, trace, seed, model) 237 | 238 | def get_request_data_with_config(results_dir, config, trace, seed, model=""): 239 | scheduler = config["scheduler"] 240 | start_state = config["start_state"] 241 | cluster = config["cluster"] 242 | return get_request_data(results_dir, scheduler, start_state, cluster, trace, seed, model) 243 | 244 | def get_request_nodes_with_config(results_dir, config, trace, seed, model=""): 245 | scheduler = config["scheduler"] 246 | start_state = config["start_state"] 247 | cluster = config["cluster"] 248 | return get_request_nodes(results_dir, scheduler, start_state, cluster, trace, seed, model) 249 | 250 | def get_instances_data_with_config(results_dir, config, trace, seed, model=""): 251 | scheduler = config["scheduler"] 252 | start_state = config["start_state"] 253 | cluster = config["cluster"] 254 | num_servers = config["num_servers"] 255 | return get_instances_data(results_dir, scheduler, start_state, cluster, num_servers, trace, seed, model) 256 | 257 | def find_within_slo(results_df, slos): 258 | configs_within_slo = [] 259 | for system_name in results_df["system"].unique(): 260 | system_df = results_df[results_df["system"] == system_name] 261 | for key, value in slos.items(): 262 | system_df = system_df[system_df[f"{key}"] < value] 263 | configs_within_slo.append(system_df) 264 | return pd.concat(configs_within_slo) 265 | 266 | def find_cheapest(results_df): 267 | configs = [] 268 | for system_name in results_df["system"].unique(): 269 | system_df = results_df[results_df["system"] == system_name] 270 | cheapest = system_df[system_df["cost"] == system_df["cost"].min()] 271 | configs.append(cheapest) 272 | return pd.concat(configs) 273 | 274 | def find_least_power(results_df): 275 | configs = [] 276 | for system_name in results_df["system"].unique(): 277 | system_df = results_df[results_df["system"] == system_name] 278 | least_power = system_df[system_df["power"] == system_df["power"].min()] 279 | configs.append(least_power) 280 | return pd.concat(configs) 281 | 282 | def find_least_count(results_df): 283 | configs = [] 284 | for system_name in results_df["system"].unique(): 285 | system_df = results_df[results_df["system"] == system_name] 286 | least_count = system_df[system_df["num_servers"] == system_df["num_servers"].min()] 287 | configs.append(least_count) 288 | return pd.concat(configs) 289 | 290 | def find_max_throughput(results_df): 291 | if "throughput" not in results_df.columns: 292 | # add a throughput column using the trace field 293 | results_df["throughput"] = results_df["trace"].apply(lambda x: int(x.split("_")[2])) 294 | configs = [] 295 | for system_name in results_df["system"].unique(): 296 | system_df = results_df[results_df["system"] == system_name] 297 | max_throughput = system_df[system_df["throughput"] == system_df["throughput"].max()] 298 | configs.append(max_throughput) 299 | return pd.concat(configs) 300 | -------------------------------------------------------------------------------- /orchestrator_repo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from hydra.utils import instantiate 5 | 6 | import utils 7 | 8 | # needed by hydra instantiate 9 | import allocator 10 | import scheduler 11 | 12 | 13 | orchestrator_repo = None 14 | 15 | 16 | class OrchestratorRepo(): 17 | """ 18 | Repository of all orchestrator configs (Schedulers and Allocators). 19 | """ 20 | def __init__(self, 21 | allocators_path, 22 | schedulers_path): 23 | global orchestrator_repo 24 | orchestrator_repo = self 25 | self.allocator_configs = self.get_allocator_configs(allocators_path) 26 | self.scheduler_configs = self.get_scheduler_configs(schedulers_path) 27 | 28 | def get_allocator_configs(self, allocators_path): 29 | return utils.read_all_yaml_cfgs(allocators_path) 30 | 31 | def get_scheduler_configs(self, schedulers_path): 32 | return utils.read_all_yaml_cfgs(schedulers_path) 33 | 34 | def get_allocator(self, allocator_name, application, arbiter, debug, **kwargs): 35 | cfg = self.allocator_configs[allocator_name] 36 | return instantiate(cfg, 37 | application=application, 38 | arbiter=arbiter, 39 | debug=debug) 40 | 41 | def get_scheduler(self, scheduler_name, application, router, debug, **kwargs): 42 | cfg = self.scheduler_configs[scheduler_name] 43 | return instantiate(cfg, 44 | application=application, 45 | router=router, 46 | debug=debug) 47 | 48 | 49 | get_allocator = lambda *args,**kwargs: \ 50 | orchestrator_repo.get_allocator(*args, **kwargs) 51 | get_scheduler = lambda *args,**kwargs: \ 52 | orchestrator_repo.get_scheduler(*args, **kwargs) 53 | -------------------------------------------------------------------------------- /performance_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | from abc import ABC, abstractmethod 5 | 6 | import pandas as pd 7 | 8 | from hydra.utils import get_original_cwd 9 | from scipy.interpolate import interp1d 10 | 11 | from task import TaskType, PromptTask, TokenTask 12 | 13 | 14 | performance_model = None 15 | 16 | 17 | class PerformanceModel(ABC): 18 | """ 19 | PerformanceModel helps estimate the duration of tasks or iterations, 20 | under given hardware, model, and parallelism configurations. 21 | Abstract class that must be subclassed. 22 | """ 23 | def __init__(self): 24 | global performance_model 25 | performance_model = self 26 | 27 | @abstractmethod 28 | def get_duration(self, task, batch, instance, *args, **kwargs): 29 | """ 30 | Returns the execution time of the task. 31 | """ 32 | raise NotImplementedError 33 | 34 | @abstractmethod 35 | def get_iteration_duration(self, batch, instance, *args, **kwargs): 36 | """ 37 | Returns the execution time of a contiguous iteration. 38 | """ 39 | raise NotImplementedError 40 | 41 | 42 | class ConstantPerformanceModel(PerformanceModel): 43 | """ 44 | PerformanceModel that returns a constant value regardless of other parameters. 45 | Used for testing purposes. 46 | """ 47 | def __init__(self, prompt_time, token_time): 48 | super().__init__() 49 | self.prompt_time = prompt_time 50 | self.token_time = token_time 51 | 52 | def get_duration(self, task, batch, instance, *args, **kwargs): 53 | if task.task_type == TaskType.PROMPT: 54 | return self.prompt_time 55 | elif task.task_type == TaskType.TOKEN: 56 | return self.token_time 57 | else: 58 | raise NotImplementedError 59 | 60 | def get_iteration_duration(self, batch, instance, *args, **kwargs): 61 | raise NotImplementedError 62 | 63 | 64 | class DatabasePerformanceModel(PerformanceModel): 65 | """ 66 | PerformanceModel based on a CSV database of characterization runs. 67 | Interpolates between data points and updates the database correspondingly. 68 | The underlying predictor could be changed for different interpolation strategies. 69 | """ 70 | def __init__(self, db_path): 71 | super().__init__() 72 | self.db = pd.read_csv(os.path.join(get_original_cwd(), db_path), 73 | dtype={"model": "category", "hardware": "category"}) 74 | 75 | # ensure the database has the correct columns 76 | # and remove extraneous columns 77 | self.db = self.db[["model", 78 | "hardware", 79 | "tensor_parallel", 80 | "prompt_size", 81 | "batch_size", 82 | "token_size", 83 | "prompt_time", 84 | "token_time"]] 85 | 86 | # convert to seconds 87 | self.db["prompt_time"] = self.db["prompt_time"] / 1000 88 | self.db["token_time"] = self.db["token_time"] / 1000 89 | 90 | self.init_predictor() 91 | 92 | def init_predictor(self): 93 | """ 94 | Predict using number of tokens in the batch. 95 | """ 96 | self.prompt_time_predictors = {} 97 | self.token_time_predictors = {} 98 | self.prompt_time_cache = {} 99 | self.token_time_cache = {} 100 | 101 | for model in self.db["model"].unique(): 102 | for hardware in self.db["hardware"].unique(): 103 | for tensor_parallel in self.db["tensor_parallel"].unique(): 104 | mask = (self.db["model"] == model) & \ 105 | (self.db["hardware"] == hardware) & \ 106 | (self.db["tensor_parallel"] == tensor_parallel) 107 | db_subset = self.db[mask].copy() 108 | if len(db_subset) == 0: 109 | continue 110 | db_subset["batch_tokens"] = db_subset["prompt_size"] * db_subset["batch_size"] 111 | x = db_subset[["batch_tokens", "prompt_time"]].groupby("batch_tokens").median().index 112 | y = db_subset[["batch_tokens", "prompt_time"]].groupby("batch_tokens").median()["prompt_time"] 113 | self.prompt_time_predictors[(model, hardware, tensor_parallel)] = interp1d( 114 | x, y, fill_value="extrapolate") 115 | x = db_subset[["batch_tokens", "token_time"]].groupby("batch_tokens").median().index 116 | y = db_subset[["batch_tokens", "token_time"]].groupby("batch_tokens").median()["token_time"] 117 | self.token_time_predictors[(model, hardware, tensor_parallel)] = interp1d( 118 | x, y, fill_value="extrapolate") 119 | 120 | def _match(self, **kwargs): 121 | """ 122 | Returns a boolean mask for the database from kwargs. 123 | """ 124 | mask = True 125 | for k, v in kwargs.items(): 126 | mask &= (self.db[k] == v) 127 | return mask 128 | 129 | def predict_new_row(self, **kwargs): 130 | """ 131 | Predicts the prompt and token time for a new row. 132 | Inserts the new row into the database. 133 | """ 134 | model = kwargs["model"] 135 | hardware = kwargs["hardware"] 136 | tensor_parallel = kwargs["tensor_parallel"] 137 | batch_tokens = kwargs["batch_tokens"] 138 | new_row = pd.DataFrame(kwargs, index=[0]) 139 | 140 | prompt_time = self.prompt_time_predictors[(model, hardware, tensor_parallel)](batch_tokens) 141 | token_time = self.token_time_predictors[(model, hardware, tensor_parallel)](batch_tokens) 142 | 143 | new_row["prompt_time"] = prompt_time 144 | new_row["token_time"] = token_time 145 | self.db = pd.concat([self.db, new_row], ignore_index=True) 146 | return new_row 147 | 148 | def get_prompt_time(self, **kwargs): 149 | """ 150 | Returns the prompt time from the database. 151 | """ 152 | prompt_time = self.db[self._match(**kwargs)]["prompt_time"].median() 153 | # if not found, predict 154 | if math.isnan(prompt_time): 155 | new_row = self.predict_new_row(**kwargs) 156 | prompt_time = new_row["prompt_time"][0] 157 | return prompt_time 158 | 159 | def get_token_time(self, **kwargs): 160 | """ 161 | Returns the prompt time from the database. 162 | """ 163 | token_time = self.db[self._match(**kwargs)]["token_time"].median() 164 | # if not found, predict 165 | if math.isnan(token_time): 166 | new_row = self.predict_new_row(**kwargs) 167 | token_time = new_row["token_time"][0] 168 | return token_time 169 | 170 | def get_duration(self, 171 | task, 172 | batch, 173 | instance, 174 | *args, 175 | **kwargs): 176 | model = instance.model.name 177 | hardware = instance.processors[0].name 178 | pipeline_parallel = instance.model.parallelism.pipeline_parallelism 179 | tensor_parallel = instance.model.parallelism.tensor_parallelism 180 | if task.task_type == TaskType.PROMPT: 181 | prompt_size = task.request.prompt_size 182 | token_size = task.request.token_size 183 | batch_size = len(batch) 184 | prompt_time = self.get_prompt_time(model=model, 185 | hardware=hardware, 186 | tensor_parallel=tensor_parallel, 187 | prompt_size=prompt_size, 188 | batch_size=batch_size, 189 | token_size=token_size, 190 | batch=batch) 191 | return prompt_time 192 | elif task.task_type == TaskType.TOKEN: 193 | prompt_size = task.request.prompt_size 194 | token_size = task.request.token_size 195 | batch_size = len(batch) 196 | token_time = self.get_token_time(model=model, 197 | hardware=hardware, 198 | tensor_parallel=tensor_parallel, 199 | prompt_size=prompt_size, 200 | batch_size=batch_size, 201 | token_size=token_size, 202 | batch=batch) 203 | return token_time * task.token_size 204 | else: 205 | raise NotImplementedError 206 | 207 | def get_iteration_duration(self, 208 | batch, 209 | instance, 210 | *args, 211 | **kwargs): 212 | """ 213 | Note: assumes that prompts are always processed fully. 214 | i.e., we currently do not support prompt chunking. 215 | """ 216 | model = instance.model.name 217 | hardware = instance.processors[0].name 218 | pipeline_parallel = instance.model.parallelism.pipeline_parallelism 219 | tensor_parallel = instance.model.parallelism.tensor_parallelism 220 | 221 | prompt_tasks = [] 222 | token_tasks = [] 223 | batch_tokens = 0 224 | for task in batch: 225 | if isinstance(task, PromptTask): 226 | prompt_tasks.append(task) 227 | batch_tokens += task.request.prompt_size 228 | elif isinstance(task, TokenTask): 229 | token_tasks.append(task) 230 | batch_tokens += 1 231 | else: 232 | raise NotImplementedError 233 | 234 | iteration_time = None 235 | cache_key = (model, hardware, tensor_parallel, batch_tokens) 236 | predictors_key = (model, hardware, tensor_parallel) 237 | 238 | if len(prompt_tasks) == len(batch): 239 | iteration_time = self.prompt_time_cache.get(cache_key) 240 | if iteration_time is None: 241 | iteration_time = float(self.prompt_time_predictors[predictors_key](batch_tokens)) 242 | self.prompt_time_cache[cache_key] = float(iteration_time) 243 | elif len(token_tasks) == len(batch): 244 | iteration_time = self.token_time_cache.get(cache_key) 245 | if iteration_time is None: 246 | iteration_time = float(self.token_time_predictors[predictors_key](batch_tokens)) 247 | self.token_time_cache[cache_key] = float(iteration_time) 248 | else: 249 | iteration_time = self.prompt_time_cache.get(cache_key) 250 | if iteration_time is None: 251 | iteration_time = float(self.prompt_time_predictors[predictors_key](batch_tokens)) 252 | self.prompt_time_cache[cache_key] = float(iteration_time) 253 | iteration_time *= 1.1 254 | 255 | assert iteration_time > 0 256 | return iteration_time 257 | 258 | 259 | def get_duration(*args, **kwargs): 260 | """ 261 | Returns the execution time of the task. 262 | """ 263 | return performance_model.get_duration(*args, **kwargs) 264 | 265 | 266 | def get_iteration_duration(*args, **kwargs): 267 | """ 268 | Returns the execution time of a contiguous iteration. 269 | """ 270 | return performance_model.get_iteration_duration(*args, **kwargs) 271 | -------------------------------------------------------------------------------- /power_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | import pandas as pd 6 | 7 | from task import PromptTask, TokenTask 8 | 9 | 10 | power_model = None 11 | 12 | class PowerModel(ABC): 13 | """ 14 | PowerModel helps estimate power draw for Processors, Servers, and more. 15 | Abstract class that must be subclassed. 16 | 17 | TODO: unused 18 | """ 19 | def __init__(self): 20 | global power_model 21 | power_model = self 22 | 23 | @abstractmethod 24 | def get_processors_power(self, task, *args, **kwargs): 25 | """ 26 | Returns the power drawn by a single processor running the task. 27 | """ 28 | raise NotImplementedError 29 | 30 | def get_server_idle_power(self, server): 31 | """ 32 | Returns the idle server power. 33 | """ 34 | if server.name == "dgx-a100": 35 | return 1500 36 | else: 37 | return 2800 38 | 39 | 40 | class ConstantPowerModel(PowerModel): 41 | """ 42 | PowerModel that returns a constant value regardless of other parameters. 43 | """ 44 | def __init__(self, idle_power, prompt_power, token_power): 45 | super().__init__() 46 | self.idle_power = idle_power 47 | self.prompt_power = prompt_power 48 | self.token_power = token_power 49 | 50 | def get_processors_power(self, task, processors, *args, **kwargs): 51 | name = processors[0].name 52 | if task == None: 53 | return [self.idle_power[name]] * len(processors) 54 | elif isinstance(task, PromptTask): 55 | return [self.prompt_power[name]] * len(processors) 56 | elif isinstance(task, TokenTask): 57 | return [self.token_power[name]] * len(processors) 58 | else: 59 | raise NotImplementedError 60 | 61 | 62 | class DatabasePowerModel(PowerModel): 63 | """ 64 | PowerModel based on a CSV database of characterization runs. 65 | """ 66 | def __init__(self, dbfile): 67 | super().__init__() 68 | self.db = pd.read_csv(dbfile) 69 | 70 | def get_power(self, 71 | server, 72 | model, 73 | request): 74 | return self.db[server][model][request] 75 | 76 | 77 | def get_processors_power(task, *args, **kwargs): 78 | """ 79 | Returns the power drawn by a single processor running the task. 80 | """ 81 | return power_model.get_processors_power(task, *args, **kwargs) 82 | 83 | def get_server_power(*args, **kwargs): 84 | """ 85 | Returns the idle server power. 86 | """ 87 | return power_model.get_server_idle_power(*args, **kwargs) 88 | -------------------------------------------------------------------------------- /processor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from dataclasses import dataclass, field 5 | from enum import IntEnum 6 | 7 | from instance import Instance 8 | from simulator import clock, schedule_event, cancel_event, reschedule_event 9 | 10 | 11 | class ProcessorType(IntEnum): 12 | DEFAULT = 0 13 | CPU = 1 14 | GPU = 2 15 | 16 | 17 | @dataclass(kw_only=True) 18 | class Processor(): 19 | """ 20 | Processor is the lowest-level processing unit that can run computations (Tasks). 21 | Multiple Processors constitute a Server and may be linked via Interconnects. 22 | For example, CPU and GPU are both different types of Processors. 23 | 24 | Each Processor can belong to only one Server 25 | Processor could eventually run multiple Instances/Tasks. 26 | 27 | Attributes: 28 | processor_type (ProcessorType): The type of the Processor. 29 | memory_size (float): The memory size of the Processor. 30 | memory_used (float): The memory used by the Processor. 31 | server (Server): The Server that the Processor belongs to. 32 | instances (list[Instance]): Instances running on this Processor. 33 | interconnects (list[Link]): Peers that this Processor is directly connected to. 34 | """ 35 | processor_type: ProcessorType 36 | name: str 37 | server: 'Server' 38 | memory_size: int 39 | memory_used: int 40 | _memory_used: int = 0 41 | power: float = 0. 42 | _power: float = 0. 43 | instances: list[Instance] = field(default_factory=list) 44 | interconnects: list['Link'] = field(default_factory=list) 45 | 46 | @property 47 | def server(self): 48 | return self._server 49 | 50 | @server.setter 51 | def server(self, server): 52 | if type(server) is property: 53 | server = None 54 | self._server = server 55 | 56 | @property 57 | def memory_used(self): 58 | return self._memory_used 59 | 60 | @memory_used.setter 61 | def memory_used(self, memory_used): 62 | if type(memory_used) is property: 63 | memory_used = 0 64 | if memory_used < 0: 65 | raise ValueError("Memory cannot be negative") 66 | # if OOM, log instance details 67 | if memory_used > self.memory_size: 68 | if os.path.exists("oom.csv") is False: 69 | with open("oom.csv", "w", encoding="UTF-8") as f: 70 | fields = ["time", 71 | "instance_name", 72 | "instance_id", 73 | "memory_used", 74 | "processor_memory", 75 | "pending_queue_length"] 76 | f.write(",".join(fields) + "\n") 77 | with open("oom.csv", "a", encoding="UTF-8") as f: 78 | instance = self.instances[0] 79 | csv_entry = [] 80 | csv_entry.append(clock()) 81 | csv_entry.append(instance.name) 82 | csv_entry.append(instance.instance_id) 83 | csv_entry.append(memory_used) 84 | csv_entry.append(self.memory_size) 85 | csv_entry.append(len(instance.pending_queue)) 86 | f.write(",".join(map(str, csv_entry)) + "\n") 87 | # raise OOM error 88 | #raise ValueError("OOM") 89 | self._memory_used = memory_used 90 | 91 | @property 92 | def memory_free(self): 93 | return self.memory_size - self.memory_used 94 | 95 | @property 96 | def power(self): 97 | return self._power 98 | 99 | @power.setter 100 | def power(self, power): 101 | if type(power) is property: 102 | power = 0. 103 | if power < 0: 104 | raise ValueError("Power cannot be negative") 105 | self._power = power 106 | 107 | @property 108 | def peers(self): 109 | pass 110 | 111 | 112 | @dataclass(kw_only=True) 113 | class CPU(Processor): 114 | processor_type: ProcessorType = ProcessorType.CPU 115 | 116 | 117 | @dataclass(kw_only=True) 118 | class GPU(Processor): 119 | processor_type: ProcessorType = ProcessorType.GPU 120 | 121 | 122 | if __name__ == "__main__": 123 | pass 124 | -------------------------------------------------------------------------------- /request.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from dataclasses import dataclass, field 4 | from enum import IntEnum 5 | from itertools import count 6 | 7 | import networkx as nx 8 | 9 | from executor import Executor 10 | from flow import Flow 11 | from metrics import RequestMetrics, GenerativeLLMRequestMetrics, RequestSLO 12 | from node import Node 13 | from simulator import clock, schedule_event, cancel_event, reschedule_event 14 | from task import Task, TaskType 15 | 16 | 17 | class RequestState(IntEnum): 18 | """ 19 | RequestState describes the states of a Request. 20 | """ 21 | NONE = 0 22 | QUEUED_AT_ROUTER = 1 23 | QUEUED_AT_SCHEDULER = 2 24 | RUNNING_ON_EXECUTOR = 3 25 | COMPLETED_AT_SCHEDULER = 4 26 | COMPLETED_AT_ROUTER = 5 27 | ABORTED = 6 28 | 29 | 30 | class RequestType(IntEnum): 31 | COMPUTE = 0 # Not implemented 32 | DNN = 1 # Not implemented 33 | GENERATIVE_LLM = 2 34 | 35 | 36 | @dataclass(kw_only=True) 37 | class Request(): 38 | """ 39 | Request is a DAG of Tasks and Flows targeting an Application. 40 | Requests must have a single root Node. 41 | """ 42 | request_id: int 43 | node_id: count = field(default_factory=count) 44 | application_id: int 45 | request_type: RequestType 46 | batch_size: int = 1 47 | arrival_timestamp: float = 0. 48 | state: RequestState = field(default=RequestState.NONE) 49 | dag: nx.DiGraph = field(default_factory=nx.DiGraph) 50 | root_node: Node = None 51 | nodes: dict = field(default_factory=dict) 52 | metrics: RequestMetrics = field(default_factory=RequestMetrics) 53 | slo: RequestSLO = field(default_factory=RequestSLO) 54 | executor: Executor = None 55 | 56 | def __post_init__(self): 57 | pass 58 | 59 | def __hash__(self): 60 | """ 61 | NOTE: hash functions get overridden to None in child classes 62 | """ 63 | return hash(self.request_id) 64 | 65 | def __eq__(self, other): 66 | return self.request_id == other.request_id 67 | 68 | def successors(self, node): 69 | """ 70 | Returns the next Task or Flow to be executed after node. 71 | """ 72 | return self.dag.successors(node) 73 | 74 | def predecessors(self, node): 75 | """ 76 | Returns the previous Task or Flow to be executed before node. 77 | """ 78 | return self.dag.predecessors(node) 79 | 80 | def get_node(self, node_id): 81 | """ 82 | Returns the Node with node_id from the DAG. 83 | # NOTE: could alternatively store node_ids in DAG and node as attribute 84 | """ 85 | return self.nodes[node_id] 86 | 87 | def get_node_metrics(self, node_id): 88 | """ 89 | Returns the metrics of the Node with node_id. 90 | """ 91 | node = self.get_node(node_id) 92 | if isinstance(node, Task): 93 | node_type = node.task_type.name 94 | runner = f"{node.instance.name}_{node.instance.instance_id}" 95 | elif isinstance(node, Flow): 96 | node_type = node.flow_type.name 97 | runner = node.link.name 98 | else: 99 | raise ValueError("Unsupported node type") 100 | data = { 101 | "request_id": self.request_id, 102 | "request_type": self.request_type, 103 | "node_id": node_id, 104 | "node_type": node_type, 105 | "runner": runner, 106 | "start_timestamp": node.metrics.start_timestamp, 107 | "completion_timestamp": node.metrics.completion_timestamp, 108 | } 109 | return data 110 | 111 | def get_all_node_metrics(self): 112 | data = [] 113 | for node_id in self.nodes: 114 | data.append(self.get_node_metrics(node_id)) 115 | return data 116 | 117 | def arrive_at_router(self): 118 | assert self.state == RequestState.NONE 119 | self.metrics.router_arrival_timestamp = clock() 120 | self.state = RequestState.QUEUED_AT_ROUTER 121 | 122 | def arrive_at_scheduler(self): 123 | """ 124 | NOTE: we don't track routing overheads 125 | """ 126 | assert self.state == RequestState.QUEUED_AT_ROUTER 127 | self.metrics.scheduler_arrival_timestamp = clock() 128 | self.metrics.router_queue_time = clock() - \ 129 | self.metrics.router_arrival_timestamp 130 | self.state = RequestState.QUEUED_AT_SCHEDULER 131 | 132 | def run_on_executor(self): 133 | assert self.state == RequestState.QUEUED_AT_SCHEDULER 134 | self.metrics.executor_start_timestamp = clock() 135 | self.metrics.scheduler_queue_time = clock() - \ 136 | self.metrics.scheduler_arrival_timestamp 137 | self.state = RequestState.RUNNING_ON_EXECUTOR 138 | 139 | def complete_at_scheduler(self): 140 | """ 141 | NOTE: we don't track executor <--> scheduler communication overheads 142 | """ 143 | assert self.state == RequestState.RUNNING_ON_EXECUTOR 144 | self.metrics.scheduler_completion_timestamp = clock() 145 | self.metrics.service_time += clock() - \ 146 | self.metrics.executor_start_timestamp 147 | self.metrics.scheduler_response_time = clock() - \ 148 | self.metrics.scheduler_arrival_timestamp 149 | self.state = RequestState.COMPLETED_AT_SCHEDULER 150 | 151 | def complete_at_router(self): 152 | """ 153 | NOTE: we don't track scheduler <--> router communication overheads 154 | """ 155 | assert self.state == RequestState.COMPLETED_AT_SCHEDULER 156 | self.metrics.router_completion_timestamp = clock() 157 | self.metrics.router_response_time = clock() - \ 158 | self.metrics.router_arrival_timestamp 159 | self.state = RequestState.COMPLETED_AT_ROUTER 160 | 161 | def abort(self): 162 | if self.state == RequestState.QUEUED_AT_ROUTER: 163 | self.metrics.router_queue_time += clock() - \ 164 | self.metrics.router_arrival_timestamp 165 | elif self.state == RequestState.QUEUED_AT_SCHEDULER: 166 | self.metrics.scheduler_queue_time += clock() - \ 167 | self.metrics.scheduler_arrival_timestamp 168 | elif self.state == RequestState.RUNNING_ON_EXECUTOR: 169 | self.metrics.service_time += clock() - \ 170 | self.metrics.executor_start_timestamp 171 | elif self.state == RequestState.COMPLETED_AT_SCHEDULER: 172 | pass 173 | self.state = RequestState.ABORTED 174 | 175 | def get_results(self): 176 | pass 177 | 178 | def create_task(self, task_type, **kwargs): 179 | """ 180 | Creates a Task and adds it to the DAG. 181 | """ 182 | task = Task.from_type(task_type=task_type, 183 | node_id=next(self.node_id), 184 | request=self, 185 | **kwargs) 186 | self.dag.add_node(task) 187 | self.nodes[task.node_id] = task 188 | return task 189 | 190 | def create_flow(self, flow_type, **kwargs): 191 | """ 192 | Creates a Flow and adds it to the DAG. 193 | """ 194 | flow = Flow.from_type(flow_type=flow_type, 195 | node_id=next(self.node_id), 196 | request=self, 197 | **kwargs) 198 | self.dag.add_node(flow) 199 | self.nodes[flow.node_id] = flow 200 | return flow 201 | 202 | def remove_node(self, node): 203 | """ 204 | Removes a Node from the DAG. 205 | """ 206 | self.dag.remove_node(node) 207 | del self.nodes[node.node_id] 208 | 209 | @classmethod 210 | def from_dict(cls, request_dict): 211 | """ 212 | Returns a Request from a Pandas dictionary. 213 | """ 214 | if request_dict["request_type"] == RequestType.GENERATIVE_LLM: 215 | request = GenerativeLLMRequest(**request_dict) 216 | else: 217 | raise ValueError(f"Unsupported request type: {request_dict['request_type']}") 218 | return request 219 | 220 | 221 | @dataclass(kw_only=True) 222 | class GenerativeLLMRequest(Request): 223 | """ 224 | GenerativeLLMRequests are requests that generate tokens from a prompt. 225 | Prompt processing and token generation are represented as Tasks. 226 | KV-cache shipping is represented using Flows. 227 | NOTE: Assumes that KV-cache is uniformly split across all GPUs. 228 | NOTE: Multi-prompt chat conversations are not supported here. 229 | """ 230 | max_seq_len: int = 0 231 | processed_tokens: int 232 | _processed_tokens: int = 0 233 | generated_tokens: int 234 | _generated_tokens: int = 0 235 | prompt_size: int = 0 236 | token_size: int = 0 237 | kv_cache_size: int = 0 238 | flow_node: Flow = None 239 | cost: float = 0. 240 | memory: float = 0. 241 | metrics: GenerativeLLMRequestMetrics = field( 242 | default_factory=GenerativeLLMRequestMetrics) 243 | 244 | def __post_init__(self): 245 | self.max_seq_len = self.prompt_size + self.token_size 246 | # create prompt and token tasks 247 | prompt_task = self.create_task(task_type=TaskType.PROMPT, 248 | prompt_size=self.prompt_size) 249 | token_task = self.create_task(task_type=TaskType.TOKEN, 250 | token_size=self.token_size - 1) 251 | # update DAG 252 | self.dag.add_edge(prompt_task, token_task) 253 | self.root_node = prompt_task 254 | 255 | def __hash__(self): 256 | return hash(self.request_id) 257 | 258 | @property 259 | def processed_tokens(self): 260 | """ 261 | Returns the number of prompt tokens processed so far. 262 | """ 263 | return self._processed_tokens 264 | 265 | @processed_tokens.setter 266 | def processed_tokens(self, processed_tokens): 267 | """ 268 | Sets the number of prompt tokens processed so far. 269 | """ 270 | if isinstance(processed_tokens, property): 271 | processed_tokens = 0 272 | if processed_tokens > self.prompt_size + self.token_size: 273 | print(processed_tokens, self.prompt_size + self.token_size) 274 | raise ValueError("Processed tokens limit exceeded") 275 | self._processed_tokens = processed_tokens 276 | 277 | @property 278 | def generated_tokens(self): 279 | """ 280 | Returns the number of tokens generated so far. 281 | """ 282 | return self._generated_tokens 283 | 284 | @generated_tokens.setter 285 | def generated_tokens(self, generated_tokens): 286 | """ 287 | Sets the number of tokens generated so far. 288 | """ 289 | if isinstance(generated_tokens, property): 290 | generated_tokens = 0 291 | if generated_tokens > self.max_seq_len: 292 | raise ValueError("Maximum sequence length exceeded") 293 | self._generated_tokens = generated_tokens 294 | 295 | 296 | def estimate_kv_cache_size(self, num_tokens=None, model=None): 297 | """ 298 | Returns the KV-cache size after generating num_tokens 299 | Requires the Request root node to be allocated on an Instance. 300 | """ 301 | if num_tokens is None: 302 | num_tokens = self.generated_tokens 303 | if model is None: 304 | model = self.root_node.instance.model 305 | return 2 * self.batch_size * num_tokens * model.architecture.hidden_size \ 306 | * model.architecture.num_layers * model.size.dtype_size 307 | 308 | def get_nth_token_overhead(self): 309 | """ 310 | Returns the overhead of generating the nth token. 311 | """ 312 | return self.nodes[1].metrics.start_timestamp - self.nodes[0].metrics.completion_timestamp 313 | 314 | 315 | if __name__ == "__main__": 316 | pass 317 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core==1.3.2 2 | hydra-joblib-launcher==1.2.0 3 | hydra-ray-launcher==1.2.1 4 | matplotlib==3.7.2 5 | networkx==3.1 6 | numpy==1.25.2 7 | omegaconf==2.3.0 8 | pandas==2.0.3 9 | plotly==5.17.0 10 | ray==2.6.3 11 | requests==2.31.0 12 | scikit_learn==1.3.1 13 | scipy==1.11.3 14 | seaborn==0.13.0 15 | nbformat==5.9.2 16 | -------------------------------------------------------------------------------- /router.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | from simulator import clock, schedule_event, cancel_event, reschedule_event 6 | 7 | 8 | class Router(ABC): 9 | """ 10 | Router routes Requests to Application Schedulers. 11 | """ 12 | def __init__(self, 13 | cluster, 14 | overheads): 15 | self.cluster = cluster 16 | self.overheads = overheads 17 | self.applications = [] 18 | self.schedulers = {} 19 | 20 | # request queues 21 | self.pending_queue = [] 22 | self.executing_queue = [] 23 | self.completed_queue = [] 24 | 25 | def add_application(self, application): 26 | self.applications.append(application) 27 | self.schedulers[application.application_id] = application.scheduler 28 | 29 | def run(self): 30 | pass 31 | 32 | @abstractmethod 33 | def route(self, *args): 34 | """ 35 | Main routing logic 36 | """ 37 | raise NotImplementedError 38 | 39 | def request_arrival(self, request): 40 | request.arrive_at_router() 41 | self.pending_queue.append(request) 42 | self.route_request(request) 43 | 44 | def request_completion(self, request): 45 | request.complete_at_router() 46 | self.executing_queue.remove(request) 47 | self.completed_queue.append(request) 48 | 49 | def route_request(self, request): 50 | self.route(request) 51 | self.pending_queue.remove(request) 52 | self.executing_queue.append(request) 53 | 54 | def save_results(self): 55 | #results = [] 56 | #for request in self.completed_queue: 57 | # times = request.time_per_instance_type() 58 | # results.append(times) 59 | #utils.save_dict_as_csv(results, "router.csv") 60 | pass 61 | 62 | 63 | class NoOpRouter(Router): 64 | """ 65 | Forwards request to the appropriate scheduler without any overheads. 66 | """ 67 | def route(self, request): 68 | scheduler = self.schedulers[request.application_id] 69 | f = lambda scheduler=scheduler,request=request: \ 70 | scheduler.request_arrival(request) 71 | schedule_event(self.overheads.routing_delay, f) 72 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import sys 5 | 6 | import hydra 7 | 8 | from hydra.utils import instantiate 9 | from hydra.utils import get_original_cwd, to_absolute_path 10 | from omegaconf import DictConfig, OmegaConf 11 | 12 | from simulator import TraceSimulator 13 | from initialize import * 14 | 15 | 16 | # register custom hydra resolver 17 | OmegaConf.register_new_resolver("eval", eval) 18 | 19 | 20 | def run_simulation(cfg): 21 | hardware_repo = init_hardware_repo(cfg) 22 | model_repo = init_model_repo(cfg) 23 | orchestrator_repo = init_orchestrator_repo(cfg) 24 | performance_model = init_performance_model(cfg) 25 | power_model = init_power_model(cfg) 26 | cluster = init_cluster(cfg) 27 | router = init_router(cfg, cluster) 28 | arbiter = init_arbiter(cfg, cluster) 29 | applications = init_applications(cfg, cluster, router, arbiter) 30 | trace = init_trace(cfg) 31 | for application in applications.values(): 32 | router.add_application(application) 33 | arbiter.add_application(application) 34 | sim = TraceSimulator(trace=trace, 35 | cluster=cluster, 36 | applications=applications, 37 | router=router, 38 | arbiter=arbiter, 39 | end_time=cfg.end_time) 40 | init_start_state(cfg, 41 | cluster=cluster, 42 | applications=applications, 43 | router=router, 44 | arbiter=arbiter) 45 | sim.run() 46 | 47 | 48 | @hydra.main(config_path="configs", config_name="config", version_base=None) 49 | def run(cfg: DictConfig) -> None: 50 | # print config 51 | #print(OmegaConf.to_yaml(cfg, resolve=False)) 52 | #hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 53 | #print(OmegaConf.to_yaml(hydra_cfg, resolve=False)) 54 | 55 | # initialize random number generator 56 | random.seed(cfg.seed) 57 | 58 | # delete existing oom.csv if any 59 | if os.path.exists("oom.csv"): 60 | os.remove("oom.csv") 61 | 62 | run_simulation(cfg) 63 | 64 | 65 | if __name__ == "__main__": 66 | run() 67 | -------------------------------------------------------------------------------- /scripts/run_baseline_a.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | applications.0.scheduler=token_jsq \ 3 | cluster=half_half \ 4 | cluster.servers.0.count=70 \ 5 | cluster.servers.1.count=0 \ 6 | start_state=baseline \ 7 | performance_model=db \ 8 | trace.filename=rr_conv_80 \ 9 | seed=0 10 | #+experiment=traces_light \ 11 | -------------------------------------------------------------------------------- /scripts/run_baseline_h.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | applications.0.scheduler=token_jsq \ 3 | cluster=half_half \ 4 | cluster.servers.0.count=0 \ 5 | cluster.servers.1.count=40 \ 6 | start_state=baseline \ 7 | performance_model=db \ 8 | trace.filename=rr_conv_80 \ 9 | seed=0 10 | #+experiment=traces_light \ 11 | -------------------------------------------------------------------------------- /scripts/run_baseline_h_example.sh: -------------------------------------------------------------------------------- 1 | NUM_DGX_A100=0 2 | NUM_DGX_H100=1 3 | SCHEDULER=token_jsq 4 | START_STATE=baseline 5 | TRACE=test_trace 6 | 7 | python run.py \ 8 | cluster=half_half \ 9 | cluster.servers.0.count=$NUM_DGX_A100 \ 10 | cluster.servers.1.count=$NUM_DGX_H100 \ 11 | applications.0.scheduler=$SCHEDULER \ 12 | start_state=$START_STATE \ 13 | performance_model=db \ 14 | trace.filename=$TRACE \ 15 | debug=True \ 16 | seed=0 17 | -------------------------------------------------------------------------------- /scripts/run_costopt.sh: -------------------------------------------------------------------------------- 1 | python run.py --multirun seed=0 +experiment=baseline_h100_costopt & 2 | python run.py --multirun seed=0 +experiment=baseline_a100_costopt 3 | python run.py --multirun seed=0 applications.0.scheduler=mixed_pool +experiment=splitwise_hh_costopt & 4 | python run.py --multirun seed=0 applications.0.scheduler=mixed_pool +experiment=splitwise_aa_costopt 5 | python run.py --multirun seed=0 applications.0.scheduler=mixed_pool +experiment=splitwise_ha_costopt & 6 | python run.py --multirun seed=0 applications.0.scheduler=mixed_pool +experiment=splitwise_hhcap_costopt 7 | -------------------------------------------------------------------------------- /scripts/run_isocost.sh: -------------------------------------------------------------------------------- 1 | python run.py --multirun applications.0.scheduler=mixed_pool +experiment=splitwise_hh_isocost seed=0 & 2 | python run.py --multirun applications.0.scheduler=mixed_pool +experiment=splitwise_aa_isocost seed=0 & 3 | python run.py --multirun applications.0.scheduler=mixed_pool +experiment=splitwise_ha_isocost seed=0 & 4 | python run.py --multirun applications.0.scheduler=mixed_pool +experiment=splitwise_hhcap_isocost seed=0 5 | -------------------------------------------------------------------------------- /scripts/run_isopower.sh: -------------------------------------------------------------------------------- 1 | python run.py --multirun applications.0.scheduler=mixed_pool +experiment=splitwise_hh_isopower seed=0 & 2 | python run.py --multirun applications.0.scheduler=mixed_pool +experiment=splitwise_ha_isopower seed=0 & 3 | python run.py --multirun applications.0.scheduler=mixed_pool +experiment=splitwise_aa_isopower seed=0 & 4 | python run.py --multirun applications.0.scheduler=mixed_pool +experiment=splitwise_hhcap_isopower seed=0 5 | -------------------------------------------------------------------------------- /scripts/run_splitwise_aa.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | applications.0.scheduler=mixed_pool \ 3 | cluster=half_half \ 4 | cluster.servers.0.count=40 \ 5 | cluster.servers.1.count=0 \ 6 | start_state=splitwise \ 7 | start_state.prompt.num_instances=27 \ 8 | start_state.token.num_instances=13 \ 9 | performance_model=db \ 10 | trace.filename=rr_conv_80 \ 11 | seed=0 12 | #applications.0.scheduler=token_jsq \ 13 | #trace.filename=rr_code_70 \ 14 | -------------------------------------------------------------------------------- /scripts/run_splitwise_ha.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | applications.0.scheduler=mixed_pool \ 3 | cluster=half_half \ 4 | cluster.servers.0.count=26 \ 5 | cluster.servers.1.count=25 \ 6 | start_state=splitwise \ 7 | start_state.split_type=heterogeneous \ 8 | performance_model=db \ 9 | trace.filename=rr_conv_80 \ 10 | seed=0 11 | #applications.0.scheduler=token_jsq \ 12 | #trace.filename=rr_code_70 \ 13 | 14 | -------------------------------------------------------------------------------- /scripts/run_splitwise_ha_example.sh: -------------------------------------------------------------------------------- 1 | SCHEDULER=mixed_pool 2 | NUM_A100=1 3 | NUM_H100=1 4 | START_STATE=splitwise 5 | TRACE=test_trace 6 | 7 | python run.py \ 8 | applications.0.scheduler=$SCHEDULER \ 9 | cluster=half_half \ 10 | cluster.servers.0.count=$NUM_A100 \ 11 | cluster.servers.1.count=$NUM_H100 \ 12 | start_state=$START_STATE \ 13 | start_state.split_type=heterogeneous \ 14 | performance_model=db \ 15 | trace.filename=$TRACE \ 16 | debug=True \ 17 | seed=0 18 | -------------------------------------------------------------------------------- /scripts/run_splitwise_hh.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | applications.0.scheduler=mixed_pool \ 3 | applications.0.model_architecture=llama2-70b \ 4 | applications.0.model_size=llama2-70b-fp16 \ 5 | cluster=half_half \ 6 | cluster.servers.0.count=1 \ 7 | cluster.servers.1.count=1 \ 8 | start_state=splitwise \ 9 | start_state.prompt.num_instances=1 \ 10 | start_state.token.num_instances=1 \ 11 | performance_model=db \ 12 | trace.filename=rr_code_2 \ 13 | debug=True \ 14 | seed=0 15 | #applications.0.scheduler=token_jsq \ 16 | #trace.filename=rr_code_70 \ 17 | 18 | -------------------------------------------------------------------------------- /scripts/run_splitwise_hhcap.sh: -------------------------------------------------------------------------------- 1 | python run.py \ 2 | applications.0.scheduler=mixed_pool \ 3 | cluster=hhcap_half_half \ 4 | cluster.servers.0.count=5 \ 5 | cluster.servers.1.count=35 \ 6 | start_state=splitwise_hhcap \ 7 | start_state.split_type=heterogeneous \ 8 | trace.filename=rr_conv_80 \ 9 | performance_model=db \ 10 | seed=0 11 | -------------------------------------------------------------------------------- /scripts/run_throughput.sh: -------------------------------------------------------------------------------- 1 | TRACE=rr_code_130 2 | SEED=0 3 | echo "Running throughput experiments for trace $TRACE" 4 | 5 | # Baseline-A100 6 | python run.py \ 7 | applications.0.scheduler=token_jsq \ 8 | cluster=half_half \ 9 | cluster.servers.0.count=113 \ 10 | cluster.servers.1.count=0 \ 11 | start_state=baseline \ 12 | performance_model=db \ 13 | trace.filename=$TRACE \ 14 | seed=$SEED 15 | 16 | 17 | # Baseline-H100 18 | python run.py \ 19 | applications.0.scheduler=token_jsq \ 20 | cluster=half_half \ 21 | cluster.servers.0.count=0 \ 22 | cluster.servers.1.count=51 \ 23 | start_state=baseline \ 24 | performance_model=db \ 25 | trace.filename=$TRACE \ 26 | seed=$SEED 27 | 28 | 29 | # Splitwise-AA 30 | python run.py \ 31 | applications.0.scheduler=mixed_pool \ 32 | cluster=half_half \ 33 | cluster.servers.0.count=64 \ 34 | cluster.servers.1.count=0 \ 35 | start_state=splitwise \ 36 | start_state.prompt.num_instances=54 \ 37 | start_state.token.num_instances=10 \ 38 | performance_model=db \ 39 | trace.filename=$TRACE \ 40 | seed=$SEED 41 | 42 | 43 | # Splitwise-HH 44 | python run.py \ 45 | applications.0.scheduler=mixed_pool \ 46 | cluster=half_half \ 47 | cluster.servers.0.count=0 \ 48 | cluster.servers.1.count=33 \ 49 | start_state=splitwise \ 50 | start_state.prompt.num_instances=23 \ 51 | start_state.token.num_instances=10 \ 52 | performance_model=db \ 53 | trace.filename=$TRACE \ 54 | seed=$SEED 55 | 56 | 57 | # Splitwise-HA 58 | python run.py \ 59 | applications.0.scheduler=mixed_pool \ 60 | cluster=half_half \ 61 | cluster.servers.0.count=16 \ 62 | cluster.servers.1.count=23 \ 63 | start_state=splitwise \ 64 | start_state.split_type=heterogeneous \ 65 | performance_model=db \ 66 | trace.filename=$TRACE \ 67 | seed=$SEED 68 | 69 | 70 | # Splitwise-HHcap 71 | python run.py \ 72 | applications.0.scheduler=mixed_pool \ 73 | cluster=hhcap_half_half \ 74 | cluster.servers.0.count=10 \ 75 | cluster.servers.1.count=23 \ 76 | start_state=splitwise_hhcap \ 77 | start_state.split_type=heterogeneous \ 78 | performance_model=db \ 79 | trace.filename=$TRACE \ 80 | seed=$SEED 81 | -------------------------------------------------------------------------------- /scripts/run_throughput_isocost.sh: -------------------------------------------------------------------------------- 1 | TRACE=rr_conv_180 2 | SEED=0 3 | echo "Running throughput experiments for trace $TRACE" 4 | 5 | # Baseline-A100 6 | python run.py \ 7 | applications.0.scheduler=token_jsq \ 8 | cluster=half_half \ 9 | cluster.servers.0.count=86 \ 10 | cluster.servers.1.count=0 \ 11 | start_state=baseline \ 12 | performance_model=db \ 13 | trace.filename=$TRACE \ 14 | seed=$SEED 15 | 16 | 17 | ## Baseline-H100 18 | python run.py \ 19 | applications.0.scheduler=token_jsq \ 20 | cluster=half_half \ 21 | cluster.servers.0.count=0 \ 22 | cluster.servers.1.count=40 \ 23 | start_state=baseline \ 24 | performance_model=db \ 25 | trace.filename=$TRACE \ 26 | seed=$SEED 27 | 28 | 29 | # Splitwise-AA 30 | python run.py \ 31 | applications.0.scheduler=mixed_pool \ 32 | cluster=half_half \ 33 | cluster.servers.0.count=86 \ 34 | cluster.servers.1.count=0 \ 35 | start_state=splitwise \ 36 | start_state.prompt.num_instances=51 \ 37 | start_state.token.num_instances=35 \ 38 | performance_model=db \ 39 | trace.filename=$TRACE \ 40 | seed=$SEED 41 | 42 | 43 | # Splitwise-HH 44 | python run.py \ 45 | applications.0.scheduler=mixed_pool \ 46 | cluster=half_half \ 47 | cluster.servers.0.count=0 \ 48 | cluster.servers.1.count=40 \ 49 | start_state=splitwise \ 50 | start_state.prompt.num_instances=25 \ 51 | start_state.token.num_instances=15 \ 52 | performance_model=db \ 53 | trace.filename=$TRACE \ 54 | seed=$SEED 55 | 56 | 57 | # Splitwise-HA 58 | python run.py \ 59 | applications.0.scheduler=mixed_pool \ 60 | cluster=half_half \ 61 | cluster.servers.0.count=21 \ 62 | cluster.servers.1.count=30 \ 63 | start_state=splitwise \ 64 | start_state.split_type=heterogeneous \ 65 | performance_model=db \ 66 | trace.filename=$TRACE \ 67 | seed=$SEED 68 | 69 | 70 | # Splitwise-HHcap 71 | python run.py \ 72 | applications.0.scheduler=mixed_pool \ 73 | cluster=hhcap_half_half \ 74 | cluster.servers.0.count=10 \ 75 | cluster.servers.1.count=30 \ 76 | start_state=splitwise_hhcap \ 77 | start_state.split_type=heterogeneous \ 78 | performance_model=db \ 79 | trace.filename=$TRACE \ 80 | seed=$SEED 81 | -------------------------------------------------------------------------------- /scripts/run_throughput_isopower.sh: -------------------------------------------------------------------------------- 1 | TRACE=rr_conv_180 2 | SEED=0 3 | echo "Running throughput experiments for trace $TRACE" 4 | 5 | # Baseline-A100 6 | python run.py \ 7 | applications.0.scheduler=token_jsq \ 8 | cluster=half_half \ 9 | cluster.servers.0.count=70 \ 10 | cluster.servers.1.count=0 \ 11 | start_state=baseline \ 12 | performance_model=db \ 13 | trace.filename=$TRACE \ 14 | seed=$SEED 15 | 16 | 17 | # Baseline-H100 18 | python run.py \ 19 | applications.0.scheduler=token_jsq \ 20 | cluster=half_half \ 21 | cluster.servers.0.count=0 \ 22 | cluster.servers.1.count=40 \ 23 | start_state=baseline \ 24 | performance_model=db \ 25 | trace.filename=$TRACE \ 26 | seed=$SEED 27 | 28 | 29 | # Splitwise-AA 30 | python run.py \ 31 | applications.0.scheduler=mixed_pool \ 32 | cluster=half_half \ 33 | cluster.servers.0.count=70 \ 34 | cluster.servers.1.count=0 \ 35 | start_state=splitwise \ 36 | start_state.prompt.num_instances=45 \ 37 | start_state.token.num_instances=25 \ 38 | performance_model=db \ 39 | trace.filename=$TRACE \ 40 | seed=$SEED 41 | 42 | 43 | # Splitwise-HH 44 | python run.py \ 45 | applications.0.scheduler=mixed_pool \ 46 | cluster=half_half \ 47 | cluster.servers.0.count=0 \ 48 | cluster.servers.1.count=40 \ 49 | start_state=splitwise \ 50 | start_state.prompt.num_instances=25 \ 51 | start_state.token.num_instances=15 \ 52 | performance_model=db \ 53 | trace.filename=$TRACE \ 54 | seed=$SEED 55 | 56 | 57 | # Splitwise-HA 58 | python run.py \ 59 | applications.0.scheduler=mixed_pool \ 60 | cluster=half_half \ 61 | cluster.servers.0.count=26 \ 62 | cluster.servers.1.count=25 \ 63 | start_state=splitwise \ 64 | start_state.split_type=heterogeneous \ 65 | performance_model=db \ 66 | trace.filename=$TRACE \ 67 | seed=$SEED 68 | 69 | 70 | # Splitwise-HHcap 71 | python run.py \ 72 | applications.0.scheduler=mixed_pool \ 73 | cluster=hhcap_half_half \ 74 | cluster.servers.0.count=21 \ 75 | cluster.servers.1.count=25 \ 76 | start_state=splitwise_hhcap \ 77 | start_state.split_type=heterogeneous \ 78 | performance_model=db \ 79 | trace.filename=$TRACE \ 80 | seed=$SEED 81 | 82 | -------------------------------------------------------------------------------- /scripts/run_traces.sh: -------------------------------------------------------------------------------- 1 | python run.py --multirun \ 2 | applications.0.scheduler=token_jsq \ 3 | cluster=isocost_a100,isopower_a100,dgx-h100 \ 4 | start_state=baseline \ 5 | performance_model=db \ 6 | seed=0 \ 7 | +experiment=traces 8 | #cluster=dgx-h100,isocost_a100,isopower_a100,isocount_a100 \ 9 | #applications.0.scheduler=round_robin \ 10 | #trace.filename=rr_1,rr_2,rr_3,rr_4,rr_5,rr_6,rr_7,rr_8,rr_9,rr_10 \ 11 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from hydra.utils import instantiate 4 | 5 | import utils 6 | import hardware_repo 7 | 8 | from power_model import get_server_power 9 | from simulator import clock, schedule_event, cancel_event, reschedule_event 10 | 11 | # used by hydra instantiate 12 | import processor 13 | import interconnect 14 | 15 | 16 | class Server: 17 | """ 18 | Servers are a collection of Processors that may be connected by 19 | local Interconnects. Servers themselves are also interconnected 20 | by Interconnects. Servers run Instances (partially or fully). 21 | 22 | Attributes: 23 | server_id (str): The unique server_id of the server. 24 | processors (list): A list of Processors. 25 | interconnects (list[Link]): Peers that this Server is 26 | directly connected to. 27 | """ 28 | servers = {} 29 | # logger for all servers 30 | logger = None 31 | 32 | def __init__(self, 33 | server_id, 34 | name, 35 | processors, 36 | interconnects): 37 | if server_id in Server.servers: 38 | # NOTE: This is a hacky workaround for Hydra 39 | # Hydra multirun has a bug where it tries to instantiate a cluster again 40 | # with the same class, triggering this path. This likely happens because 41 | # Hydra multirun reuses the same classes across threads 42 | Server.servers = {} 43 | Server.logger = None 44 | self.server_id = server_id 45 | self.name = name 46 | self.processors = processors 47 | for proc in self.processors: 48 | proc.server = self 49 | self.interconnects = interconnects 50 | for intercon in self.interconnects: 51 | intercon.server = self 52 | self.cluster = None 53 | Server.servers[server_id] = self 54 | self.instances = [] 55 | self.power = 0 56 | self.update_power(0) 57 | #self._instances = [] 58 | 59 | # initialize server logger 60 | if Server.logger is None: 61 | self.logger = utils.file_logger("server") 62 | Server.logger = self.logger 63 | self.logger.info("time,server") 64 | else: 65 | self.logger = Server.logger 66 | 67 | def __str__(self): 68 | return f"Server:{self.server_id}" 69 | 70 | def __repr__(self): 71 | return self.__str__() 72 | 73 | @property 74 | def instances(self): 75 | return self._instances 76 | 77 | @instances.setter 78 | def instances(self, instances): 79 | self._instances = instances 80 | 81 | def update_power(self, power): 82 | old_power = self.power 83 | self.power = get_server_power(self) + \ 84 | sum(processor.power for processor in self.processors) 85 | if self.cluster: 86 | self.cluster.update_power(self.power - old_power) 87 | 88 | def run(self): 89 | pass 90 | 91 | @classmethod 92 | def load(cls): 93 | pass 94 | 95 | @classmethod 96 | def from_config(cls, *args, server_id, **kwargs): 97 | sku_cfg = args[0] 98 | processors_cfg = sku_cfg.processors 99 | interconnects_cfg = sku_cfg.interconnects 100 | 101 | processors = [] 102 | for processor_cfg in processors_cfg: 103 | for n in range(processor_cfg.count): 104 | processor = hardware_repo.get_processor(processor_cfg.name) 105 | processors.append(processor) 106 | 107 | # TODO: add better network topology / configuration support 108 | interconnects = [] 109 | for interconnect_name in interconnects_cfg: 110 | intercon = hardware_repo.get_interconnect(interconnect_name) 111 | interconnects.append(intercon) 112 | 113 | return cls(server_id=server_id, 114 | name=sku_cfg.name, 115 | processors=processors, 116 | interconnects=interconnects) 117 | 118 | 119 | if __name__ == "__main__": 120 | pass 121 | -------------------------------------------------------------------------------- /simulator.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | import logging 3 | 4 | from collections import defaultdict 5 | 6 | import utils 7 | 8 | 9 | # global simulator that drives the simulation 10 | # bad practice, but it works for now 11 | sim = None 12 | 13 | 14 | class Event: 15 | """ 16 | Events are scheduled actions in the simulator. 17 | """ 18 | def __init__(self, time, action): 19 | self.time = time 20 | self.action = action 21 | 22 | def __str__(self): 23 | return f"Event with time {self.time} and action {self.action}" 24 | 25 | def __lt__(self, other): 26 | return self.time < other.time 27 | 28 | 29 | class Simulator: 30 | """ 31 | A discrete event simulator that schedules and runs Events. 32 | """ 33 | def __init__(self, end_time): 34 | global sim 35 | sim = self 36 | self.time = 0 37 | self.end_time = end_time 38 | self.events = [] 39 | self.deleted_events = [] 40 | logging.info("Simulator initialized") 41 | 42 | # logger for simulator events 43 | self.logger = utils.file_logger("simulator") 44 | self.logger.info("time,event") 45 | 46 | def schedule(self, delay, action): 47 | """ 48 | Schedule an event by specifying delay and an action function. 49 | """ 50 | # run immediately if delay is 0 51 | if delay == 0: 52 | action() 53 | return None 54 | event = Event(self.time + delay, action) 55 | heapq.heappush(self.events, event) 56 | return event 57 | 58 | def cancel(self, event): 59 | """ 60 | Cancel an event. 61 | """ 62 | self.deleted_events.append(event) 63 | 64 | def reschedule(self, event, delay): 65 | """ 66 | Reschedule an event by cancelling and scheduling it again. 67 | """ 68 | self.cancel(event) 69 | return self.schedule(delay, event.action) 70 | 71 | def run(self): 72 | """ 73 | Run the simulation until the end time. 74 | """ 75 | while self.events and self.time < self.end_time: 76 | event = heapq.heappop(self.events) 77 | if event in self.deleted_events: 78 | self.deleted_events.remove(event) 79 | continue 80 | self.time = event.time 81 | event.action() 82 | self.logger.debug(f"{event.time},{event.action}") 83 | 84 | 85 | class TraceSimulator(Simulator): 86 | """ 87 | A discrete event simulator that processes Request arrivals from a Trace. 88 | """ 89 | def __init__(self, 90 | trace, 91 | cluster, 92 | applications, 93 | router, 94 | arbiter, 95 | end_time): 96 | super().__init__(end_time) 97 | self.trace = trace 98 | self.cluster = cluster 99 | self.applications = applications 100 | self.router = router 101 | self.arbiter = arbiter 102 | logging.info("TraceSimulator initialized") 103 | self.load_trace() 104 | 105 | def load_trace(self): 106 | """ 107 | Load requests from the trace as arrival events. 108 | """ 109 | for request in self.trace.requests: 110 | self.schedule(request.arrival_timestamp, 111 | lambda request=request: self.router.request_arrival(request)) 112 | 113 | def run(self): 114 | # start simulation by scheduling a cluster run 115 | self.schedule(0, self.cluster.run) 116 | self.schedule(0, self.router.run) 117 | self.schedule(0, self.arbiter.run) 118 | 119 | # run simulation 120 | super().run() 121 | self.logger.info(f"{self.time},end") 122 | logging.info(f"TraceSimulator completed at {self.time}") 123 | 124 | self.save_results() 125 | 126 | def save_results(self, detailed=True): 127 | """ 128 | Save results at the end of the simulation. 129 | """ 130 | self.router.save_results() 131 | 132 | sched_results = {} 133 | alloc_results = {} 134 | for application_id, application in self.applications.items(): 135 | allocator_results, scheduler_results = application.get_results() 136 | alloc_results[application_id] = allocator_results 137 | sched_results[application_id] = scheduler_results 138 | 139 | # summary sched results 140 | summary_results = defaultdict(list) 141 | for application_id, results_dict in sched_results.items(): 142 | summary_results["application_id"].append(application_id) 143 | for key, values in results_dict.items(): 144 | summary = utils.get_statistics(values) 145 | # merge summary into summary_results 146 | for metric, value in summary.items(): 147 | summary_results[f"{key}_{metric}"].append(value) 148 | 149 | # save summary results 150 | utils.save_dict_as_csv(summary_results, "summary.csv") 151 | 152 | if detailed: 153 | # create a dataframe of all requests, save as csv 154 | for application_id, result in sched_results.items(): 155 | utils.save_dict_as_csv(result, f"detailed/{application_id}.csv") 156 | for application_id, result in alloc_results.items(): 157 | utils.save_dict_as_csv(result, f"detailed/{application_id}_alloc.csv") 158 | 159 | 160 | # Convenience functions for simulator object 161 | 162 | def clock(): 163 | """ 164 | Return the current time of the simulator. 165 | """ 166 | return sim.time 167 | 168 | def schedule_event(*args): 169 | """ 170 | Schedule an event in the simulator at desired delay. 171 | """ 172 | return sim.schedule(*args) 173 | 174 | def cancel_event(*args): 175 | """ 176 | Cancel existing event in the simulator. 177 | """ 178 | return sim.cancel(*args) 179 | 180 | def reschedule_event(*args): 181 | """ 182 | Reschedule existing event in the simulator. 183 | Equivalent to cancelling and scheduling a new event. 184 | """ 185 | return sim.reschedule(*args) 186 | -------------------------------------------------------------------------------- /start_state.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions to initialize the Cluster with a starting state. 3 | """ 4 | 5 | import logging 6 | 7 | from model import ModelParallelism 8 | from simulator import clock, schedule_event, cancel_event, reschedule_event 9 | 10 | 11 | def load_start_state(start_state_cfg, **kwargs): 12 | """ 13 | Load the start state configuration and initialize the cluster. 14 | """ 15 | state_type = start_state_cfg.state_type 16 | if state_type == "unallocated": 17 | pass 18 | elif state_type == "orca": 19 | uniform(start_state_cfg, **kwargs) 20 | elif state_type == "baseline": 21 | uniform(start_state_cfg, **kwargs) 22 | elif "splitwise" in state_type: 23 | splitwise(start_state_cfg, **kwargs) 24 | else: 25 | raise ValueError(f"Unknown start state type: {state_type}") 26 | 27 | 28 | def uniform(start_state_cfg, cluster, applications, **kwargs): 29 | """ 30 | Initialize all servers with a single instance of the application. 31 | """ 32 | application = applications[start_state_cfg.application_id] 33 | allocator = application.allocator 34 | servers = cluster.servers 35 | 36 | instance_cfg = start_state_cfg.instance 37 | parallelism = ModelParallelism(pipeline_parallelism=instance_cfg.pipeline_parallelism, 38 | tensor_parallelism=instance_cfg.tensor_parallelism) 39 | 40 | for sku_name in servers: 41 | for server in servers[sku_name]: 42 | allocator.start_spin_up_instance(instance_cfg=instance_cfg, 43 | processors=server.processors, 44 | parallelism=parallelism, 45 | pre_start=True) 46 | 47 | 48 | def splitwise(start_state_cfg, cluster, applications, **kwargs): 49 | """ 50 | Initialize all servers with a single instance of the application. 51 | Separate prompt and token instances with different kinds of parallelism. 52 | TODO: use preferences and constraints within scheduler instead 53 | """ 54 | application = applications[start_state_cfg.application_id] 55 | allocator = application.allocator 56 | servers = cluster.servers 57 | 58 | prompt_cfg = start_state_cfg.prompt 59 | token_cfg = start_state_cfg.token 60 | prompt_parallelism = ModelParallelism(pipeline_parallelism=prompt_cfg.pipeline_parallelism, 61 | tensor_parallelism=prompt_cfg.tensor_parallelism) 62 | token_parallelism = ModelParallelism(pipeline_parallelism=token_cfg.pipeline_parallelism, 63 | tensor_parallelism=token_cfg.tensor_parallelism) 64 | 65 | split_type = start_state_cfg.split_type 66 | 67 | if split_type == "homogeneous": 68 | n_prompts = prompt_cfg.num_instances 69 | n_tokens = token_cfg.num_instances 70 | # allocate n_prompt instance of prompt 71 | all_servers = [server for sku_name in servers for server in servers[sku_name]] 72 | for server in all_servers[:n_prompts]: 73 | for proc_id in range(0, len(server.processors), prompt_parallelism.tensor_parallelism): 74 | allocator.start_spin_up_instance(instance_cfg=prompt_cfg, 75 | processors=server.processors[proc_id:proc_id+prompt_parallelism.tensor_parallelism], 76 | parallelism=prompt_parallelism, 77 | pre_start=True, 78 | tag="prompt") 79 | for server in all_servers[n_prompts:n_prompts+n_tokens]: 80 | for proc_id in range(0, len(server.processors), token_parallelism.tensor_parallelism): 81 | allocator.start_spin_up_instance(instance_cfg=token_cfg, 82 | processors=server.processors[proc_id:proc_id+token_parallelism.tensor_parallelism], 83 | parallelism=token_parallelism, 84 | pre_start=True, 85 | tag="token") 86 | 87 | if split_type == "heterogeneous": 88 | prompt_instances = prompt_cfg.instance_names 89 | token_instances = token_cfg.instance_names 90 | for sku_name in servers: 91 | for server in servers[sku_name]: 92 | if sku_name in prompt_instances: 93 | # allocate as many prompt instances as possible 94 | for proc_id in range(0, len(server.processors), prompt_parallelism.tensor_parallelism): 95 | allocator.start_spin_up_instance(instance_cfg=prompt_cfg, 96 | processors=server.processors[proc_id:proc_id+prompt_parallelism.tensor_parallelism], 97 | parallelism=prompt_parallelism, 98 | pre_start=True, 99 | tag="prompt") 100 | elif sku_name in token_instances: 101 | # allocate as many token instances as possible 102 | for proc_id in range(0, len(server.processors), token_parallelism.tensor_parallelism): 103 | allocator.start_spin_up_instance(instance_cfg=token_cfg, 104 | processors=server.processors[proc_id:proc_id+token_parallelism.tensor_parallelism], 105 | parallelism=token_parallelism, 106 | pre_start=True, 107 | tag="token") 108 | else: 109 | raise ValueError(f"Unsupported sku_name: {sku_name}") 110 | -------------------------------------------------------------------------------- /sync_scripts/sync_configs.sh: -------------------------------------------------------------------------------- 1 | rsync -avz --delete configs/ -e /bin/ssh sim2:/home/azureuser/splitwise-sim/configs/ 2 | rsync -avz --delete configs/ -e /bin/ssh sim3:/home/azureuser/splitwise-sim/configs/ 3 | 4 | -------------------------------------------------------------------------------- /sync_scripts/sync_repos.sh: -------------------------------------------------------------------------------- 1 | ssh sim2 "cd splitwise-sim;git pull" 2 | ssh sim3 "cd splitwise-sim;git pull" 3 | 4 | -------------------------------------------------------------------------------- /sync_scripts/sync_results.sh: -------------------------------------------------------------------------------- 1 | rsync -avz -e /bin/ssh sim2:/home/azureuser/splitwise-sim/results/ results/ 2 | ssh sim2 "rm -rf splitwise-sim/results/*" 3 | rsync -avz -e /bin/ssh sim3:/home/azureuser/splitwise-sim/results/ results/ 4 | ssh sim3 "rm -rf splitwise-sim/results/*" 5 | 6 | -------------------------------------------------------------------------------- /sync_scripts/sync_traces.sh: -------------------------------------------------------------------------------- 1 | rsync -avz --delete traces/ -e /bin/ssh sim2:/home/azureuser/splitwise-sim/traces/ 2 | rsync -avz --delete traces/ -e /bin/ssh sim3:/home/azureuser/splitwise-sim/traces/ 3 | 4 | -------------------------------------------------------------------------------- /task.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from dataclasses import dataclass, field 4 | from enum import IntEnum 5 | 6 | from metrics import TaskMetrics, TaskSLO 7 | from node import Node 8 | from simulator import clock, schedule_event, cancel_event, reschedule_event 9 | 10 | 11 | class TaskType(IntEnum): 12 | COMPUTE = 0 13 | PROMPT = 1 14 | TOKEN = 2 15 | 16 | 17 | @dataclass(kw_only=True) 18 | class Task(Node): 19 | """ 20 | Tasks are computation nodes in the Request DAG. 21 | Tasks execute on Instances. 22 | 23 | Tasks are the computational counterparts of Flows. 24 | """ 25 | task_type: TaskType 26 | batch_size: int = 1 27 | duration: float = 0. 28 | remaining_duration: float = 0. 29 | cleanup_memory: bool = True 30 | metrics: TaskMetrics = field(default_factory=TaskMetrics) 31 | slo: TaskSLO = field(default_factory=TaskSLO) 32 | executor: 'Executor' = None 33 | instances = [] 34 | _instance = None 35 | 36 | def __hash__(self): 37 | return hash(self.node_id) 38 | 39 | @property 40 | def instance(self): 41 | return self._instance 42 | 43 | @instance.setter 44 | def instance(self, instance): 45 | if instance is self._instance: 46 | return 47 | self._instance = instance 48 | if instance is not None: 49 | self.instances.append(instance) 50 | 51 | @property 52 | def memory(self): 53 | return 0 54 | 55 | @classmethod 56 | def from_type(cls, task_type, **kwargs): 57 | if task_type == TaskType.COMPUTE: 58 | return ComputeTask(**kwargs) 59 | elif task_type == TaskType.PROMPT: 60 | return PromptTask(**kwargs) 61 | elif task_type == TaskType.TOKEN: 62 | return TokenTask(**kwargs) 63 | else: 64 | raise ValueError(f"Invalid TaskType {task_type}") 65 | 66 | 67 | @dataclass(kw_only=True) 68 | class ComputeTask(Task): 69 | """ 70 | Compute tasks represent arbitrary computation. 71 | """ 72 | task_type: TaskType = TaskType.COMPUTE 73 | 74 | def __hash__(self): 75 | return hash(self.node_id) 76 | 77 | @property 78 | def memory(self): 79 | return 0 80 | 81 | 82 | @dataclass(kw_only=True) 83 | class PromptTask(Task): 84 | """ 85 | Prompt tasks are the prompt (prefill) computation in a generative LLM. 86 | They are typically the root task in a GenerativeLLMRequest. 87 | """ 88 | prompt_size: int 89 | tokens_per_iteration: int = 0 90 | processing_tokens: int = 0 91 | processed_tokens: int = 0 92 | generating_tokens: int = 0 93 | generated_tokens: int = 0 94 | task_type: TaskType = TaskType.PROMPT 95 | cleanup_memory: bool = False 96 | 97 | def __post_init__(self): 98 | self.tokens_per_iteration = self.prompt_size 99 | 100 | def __hash__(self): 101 | return hash(self.node_id) 102 | 103 | @property 104 | def memory(self): 105 | num_tokens = self.prompt_size + 1 106 | return self.request.estimate_kv_cache_size(num_tokens=num_tokens, 107 | model=self.instance.model) 108 | 109 | def max_memory(self, instance): 110 | num_tokens = self.prompt_size + 1 111 | return self.request.estimate_kv_cache_size(num_tokens=num_tokens, 112 | model=instance.model) 113 | 114 | def run(self): 115 | super().run() 116 | 117 | # manage memory 118 | self.instance.alloc_memory(self.request, self.memory) 119 | self.request.memory += self.memory 120 | 121 | def complete_iteration(self): 122 | # tokens processing 123 | # TODO: finer-grained memory management 124 | self.processed_tokens += self.processing_tokens 125 | self.request.processed_tokens += self.processing_tokens 126 | self.generated_tokens += self.generating_tokens 127 | self.request.generated_tokens += self.generating_tokens 128 | self.processing_tokens = 0 129 | self.generating_tokens = 0 130 | 131 | def is_complete(self): 132 | return self.generated_tokens == 1 133 | 134 | def complete(self): 135 | super().complete() 136 | 137 | # update scheduler bookkeeping 138 | self.instance.sched_pending_tokens -= self.prompt_size 139 | 140 | # update the TTFT 141 | self.request.metrics.prompt_end_timestamp = clock() 142 | self.request.metrics.TTFT = clock() - \ 143 | self.request.metrics.router_arrival_timestamp 144 | 145 | # ensure that we processed and generated all tokens 146 | assert self.processed_tokens == self.prompt_size 147 | assert self.request.processed_tokens == self.request.prompt_size 148 | assert self.generated_tokens == 1 149 | 150 | # manage memory 151 | if self.cleanup_memory: 152 | self.instance.free_memory(self.request, self.request.memory) 153 | self.request.memory = 0 154 | 155 | 156 | @dataclass(kw_only=True) 157 | class TokenTask(Task): 158 | """ 159 | Token tasks represent the token (decode) phase in a generative LLM. 160 | """ 161 | token_size: int 162 | tokens_per_iteration: int = 1 163 | processing_tokens: int = 0 164 | processed_tokens: int = 0 165 | generating_tokens: int = 0 166 | generated_tokens: int = 0 167 | task_type: TaskType = TaskType.TOKEN 168 | 169 | def __hash__(self): 170 | return hash(self.node_id) 171 | 172 | @property 173 | def memory(self): 174 | num_tokens = self.token_size 175 | return self.request.estimate_kv_cache_size(num_tokens=num_tokens, 176 | model=self.instance.model) 177 | 178 | def max_memory(self, instance): 179 | num_tokens = self.token_size 180 | return self.request.estimate_kv_cache_size(num_tokens=num_tokens, 181 | model=instance.model) 182 | 183 | def run(self): 184 | super().run() 185 | 186 | # manage memory 187 | self.instance.alloc_memory(self.request, self.memory) 188 | self.request.memory += self.memory 189 | 190 | def complete_iteration(self): 191 | # tokens processing 192 | self.processed_tokens += self.processing_tokens 193 | self.request.processed_tokens += self.processing_tokens 194 | self.generated_tokens += self.generating_tokens 195 | self.request.generated_tokens += self.generating_tokens 196 | self.processing_tokens = 0 197 | self.generating_tokens = 0 198 | 199 | def is_complete(self): 200 | return self.generated_tokens == self.token_size 201 | 202 | def complete(self): 203 | super().complete() 204 | 205 | # update scheduler bookkeeping 206 | self.instance.sched_pending_tokens -= 1 207 | 208 | # ensure that we generated all tokens 209 | assert self.processed_tokens == self.token_size 210 | assert self.generated_tokens == self.token_size 211 | assert self.request.generated_tokens == self.request.token_size 212 | assert self.request.processed_tokens == self.request.prompt_size + \ 213 | self.request.token_size - 1 214 | 215 | # manage memory 216 | if self.cleanup_memory: 217 | self.instance.free_memory(self.request, self.request.memory) 218 | self.request.memory = 0 219 | 220 | 221 | if __name__ == "__main__": 222 | pass 223 | -------------------------------------------------------------------------------- /trace.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from request import Request 4 | 5 | 6 | class Trace(): 7 | def __init__(self, df): 8 | self.num_requests = len(df) 9 | self.requests = [] 10 | self.populate_requests(df) 11 | 12 | def populate_requests(self, df): 13 | for idx, request_dict in df.iterrows(): 14 | request = Request.from_dict(request_dict) 15 | self.requests.append(request) 16 | 17 | @classmethod 18 | def from_csv(cls, path): 19 | df = pd.read_csv(path) 20 | return Trace(df) 21 | -------------------------------------------------------------------------------- /traces/test_trace.csv: -------------------------------------------------------------------------------- 1 | request_id,request_type,application_id,arrival_timestamp,batch_size,prompt_size,token_size 2 | 0,2,0,1,1,4096,32 3 | 1,2,0,3,1,2048,8 4 | 2,2,0,4,1,2048,16 5 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc. utility functions 3 | """ 4 | 5 | import logging 6 | import os 7 | 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import pandas as pd 12 | 13 | from hydra.utils import get_original_cwd 14 | from omegaconf import OmegaConf 15 | from scipy import stats 16 | 17 | 18 | def file_logger(name, level=logging.INFO): 19 | """ 20 | returns a custom logger that logs to a file 21 | """ 22 | logger = logging.getLogger(name) 23 | logger.setLevel(level) 24 | 25 | # don't print to console (don't propagate to root logger) 26 | logger.propagate = False 27 | 28 | # create a file handler 29 | handler = logging.FileHandler(f"{name}.csv", mode="w") 30 | handler.setLevel(level) 31 | 32 | # add the handlers to the logger 33 | logger.addHandler(handler) 34 | 35 | return logger 36 | 37 | 38 | def read_all_yaml_cfgs(yaml_cfg_dir): 39 | """ 40 | Read all yaml config files in a directory 41 | Returns a dictionary of configs keyed by the yaml filename 42 | """ 43 | yaml_cfgs = {} 44 | yaml_cfg_files = os.listdir(yaml_cfg_dir) 45 | for yaml_cfg_file in yaml_cfg_files: 46 | if not yaml_cfg_file.endswith((".yaml", ".yml")): 47 | continue 48 | yaml_cfg_path = os.path.join(yaml_cfg_dir, yaml_cfg_file) 49 | yaml_cfg = OmegaConf.load(yaml_cfg_path) 50 | yaml_cfg_name = Path(yaml_cfg_path).stem 51 | yaml_cfgs[yaml_cfg_name] = yaml_cfg 52 | return yaml_cfgs 53 | 54 | 55 | def get_statistics(values, statistics=None): 56 | """ 57 | Compute statistics for a metric 58 | """ 59 | if statistics is None: 60 | statistics = ["mean", 61 | "std", 62 | "min", 63 | "max", 64 | "median", 65 | "p50", 66 | "p90", 67 | "p95", 68 | "p99", 69 | "p999", 70 | "geomean"] 71 | results = {} 72 | if "mean" in statistics: 73 | results["mean"] = np.mean(values) 74 | if "std" in statistics: 75 | results["std"] = np.std(values) 76 | if "min" in statistics: 77 | results["min"] = np.min(values) 78 | if "max" in statistics: 79 | results["max"] = np.max(values) 80 | if "median" in statistics: 81 | results["median"] = np.median(values) 82 | if "p50" in statistics: 83 | results["p50"] = np.percentile(values, 50) 84 | if "p90" in statistics: 85 | results["p90"] = np.percentile(values, 90) 86 | if "p95" in statistics: 87 | results["p95"] = np.percentile(values, 95) 88 | if "p99" in statistics: 89 | results["p99"] = np.percentile(values, 99) 90 | if "p999" in statistics: 91 | results["p999"] = np.percentile(values, 99.9) 92 | if "geomean" in statistics: 93 | results["geomean"] = stats.gmean(values) 94 | return results 95 | 96 | 97 | def save_dict_as_csv(d, filename): 98 | dirname = os.path.dirname(filename) 99 | if dirname != "": 100 | os.makedirs(dirname, exist_ok=True) 101 | df = pd.DataFrame(d) 102 | df.to_csv(filename, index=False) 103 | --------------------------------------------------------------------------------