├── .gitignore ├── README.md ├── alpa_serve ├── __init__.py ├── controller.py ├── http_util.py ├── placement_policy │ ├── __init__.py │ ├── base_policy.py │ ├── model_parallelism.py │ └── selective_replication.py ├── profiling.py ├── run.py ├── simulator │ ├── cluster.py │ ├── controller.py │ ├── event_loop.py │ ├── executable.py │ ├── util.py │ └── workload.py ├── trace │ ├── README.md │ ├── __init__.py │ ├── benchmark_trace.py │ ├── distribution.py │ ├── test_trace.py │ └── trace.py └── util.py ├── benchmarks └── alpa │ ├── README.md │ ├── approximate_one_case.py │ ├── bert_model.py │ ├── compare_waiting_time.py │ ├── equal_model_case.py │ ├── gen_data_goodput_vs_slo.py │ ├── gen_data_notebook.py │ ├── gen_data_simulator_align.py │ ├── gen_data_various_metrics.py │ ├── general_model_case.py │ ├── inspect_profiling_result.py │ ├── interactive_benchmarking.ipynb │ ├── plot_goodput_vs_slo.py │ ├── plot_various_metrics.py │ ├── prepare_trace.py │ ├── run_one_case.py │ ├── simulate_one_case.py │ ├── suite_debug.py │ └── util.py ├── deprecated ├── README.md ├── alpasim │ ├── __init__.py │ ├── cluster.py │ ├── model.py │ ├── scheduler.py │ ├── simulator.py │ ├── utils.py │ └── workload.py ├── azuretrace │ ├── README.md │ └── analyse.ipynb ├── cluster_traces │ ├── README.md │ ├── test_workload_8to2_30Hz_60s_interop_trace.json │ ├── test_workload_8to2_30Hz_60s_intraop_trace.json │ ├── test_workload_8to2_50Hz_60s_interop_trace.json │ ├── test_workload_8to2_50Hz_60s_intraop_trace.json │ ├── test_workload_8to2_6.667Hz_20s_baseline_trace.json │ ├── test_workload_8to2_6.667Hz_20s_interop_trace.json │ └── test_workload_8to2_6.667Hz_20s_intraop_trace.json ├── placements │ ├── README.md │ ├── placement_125M_baseline.json │ ├── placement_125M_interop.json │ ├── placement_125M_intraop.json │ ├── placement_125M_strong_baseline.json │ ├── placement_baseline.json │ ├── placement_interop.json │ ├── placement_intraop.json │ └── placement_test.json ├── scripts │ ├── memory_saving │ │ ├── benchmark.py │ │ └── placements │ │ │ ├── placement_baseline_2GPUs.json │ │ │ ├── placement_baseline_4GPUs_memx1.json │ │ │ ├── placement_baseline_4GPUs_memx2.json │ │ │ ├── placement_baseline_4GPUs_memx2_3to1.json │ │ │ ├── placement_baseline_4GPUs_memx3_3to1.json │ │ │ ├── placement_baseline_4GPUs_memx4.json │ │ │ ├── placement_pipeline_2GPUs.json │ │ │ ├── placement_pipeline_4GPUs_memx1.json │ │ │ ├── placement_pipeline_4GPUs_memx1dot5_3to1.json │ │ │ ├── placement_pipeline_4GPUs_memx2.json │ │ │ └── placement_strong_2GPUs.json │ ├── pipeline_latency │ │ ├── 2.6B.png │ │ ├── 6.7B.png │ │ └── plot.py │ └── small_model_benchmark │ │ └── strong_baseline.py ├── simulator.py ├── test.py └── workload │ ├── README.md │ ├── test_workload_8to2_10Hz_20s │ ├── test_workload_8to2_30Hz_60s │ ├── test_workload_8to2_50Hz_60s │ └── test_workload_8to2_6.667Hz_20s ├── experiments ├── ablation │ ├── ablation_general_synthetic_bert_all │ │ └── res_general_vs_all.tsv │ ├── ablation_general_synthetic_bert_all_fix_trace_seed │ │ └── res_general_vs_all.tsv │ ├── ablation_general_synthetic_mixed_all │ │ └── res_general_vs_all.tsv │ ├── ablation_general_synthetic_mixed_all_fix_trace_seed │ │ └── res_general_vs_all.tsv │ ├── align_simulator_2022_12_12 │ │ ├── res_real.tsv │ │ └── res_sim.tsv │ ├── general_synthetic_bert │ │ └── res_general_model_cases.tsv │ └── general_synthetic_mixed │ │ └── res_general_model_cases.tsv ├── batching │ └── gen_data_goodput_vs_slo.py ├── e2e_goodput │ ├── equal_model_exp.py │ ├── equal_model_suite.py │ ├── general_model_exp.py │ ├── general_model_suite.py │ ├── plot_sec6_2.py │ ├── plot_sec6_3.py │ ├── plot_sec6_4.py │ ├── plot_sec6_5.py │ ├── plot_sec6_6.py │ ├── plot_various_metrics.py │ └── visualize.py ├── motivation │ ├── README.md │ ├── changing_pipeline_overhead.py │ ├── changing_rate_cv_slo.py │ ├── illustrative_example.py │ ├── illustrative_example_slides.py │ ├── memory_budget_vs_latency.py │ ├── model_parallel_latency_throughput.py │ ├── overhead_decomposition.py │ └── queueing_theory_plot.py └── robustness │ ├── plot_average_performance.py │ ├── robustness_exp.py │ └── robustness_suite.py ├── osdi23_artifact ├── README.md ├── cleanup.sh ├── equal_model_exp.py ├── equal_model_suite.py ├── gen_data_sec6_2_e2e.sh ├── gen_data_sec6_3_large.sh ├── gen_data_sec6_4_robust.sh ├── gen_data_sec6_5_ab.sh ├── general_model_exp.py ├── general_model_suite.py ├── plot_sec6_2_e2e.py ├── plot_sec6_3_large.py ├── plot_sec6_4_robust.py ├── plot_sec6_5_ab.py ├── robustness_exp.py ├── robustness_suite.py └── sec6_2_data │ ├── azure_v1_mixed.tsv │ └── azure_v2_mixed.tsv ├── setup.py └── tests ├── run_all.py └── serve ├── test_controller.py ├── test_placement_policy.py └── test_simulator.py /.gitignore: -------------------------------------------------------------------------------- 1 | # project-specific 2 | workload/ 3 | figures/ 4 | chrome_trace/ 5 | *.csv 6 | *.txt 7 | *.png 8 | *.pdf 9 | *.sh 10 | *.zip 11 | 12 | # Python cache 13 | __pycache__ 14 | *.pyc 15 | dist 16 | *.egg-info 17 | .cache 18 | *env 19 | 20 | # NFS temp files 21 | .nfs* 22 | 23 | # Vim 24 | *.swp 25 | 26 | # pycharm 27 | .idea 28 | 29 | # vscode 30 | *vscode* 31 | 32 | # Build files 33 | alpa/pipeline_parallel/xla_custom_call_marker/build 34 | build/lib 35 | build/bdist* 36 | build_jaxlib/build/bazel* 37 | build_jaxlib/bazel-* 38 | build_jaxlib/.jax_configure.bazelrc 39 | build_jaxlib/dist 40 | 41 | # Examples build and tmp files 42 | examples/build/ 43 | examples/llm_serving/dataset/*.so 44 | examples/llm_serving/dataset/*.c 45 | examples/llm_serving/dataset/*.cpp 46 | examples/llm_serving/weblogs 47 | examples/llm_serving/keys_file.json 48 | examples/llm_serving/benchmark/tmp* 49 | examples/llm_serving/tmp* 50 | examples/opt_finetune/output/ 51 | examples/gpt2/norwegian-gpt2/ 52 | alpa_debug_info 53 | 54 | # Analysis temp files 55 | *.nvprof 56 | *.prof 57 | *.tsv 58 | *.hlo 59 | 60 | # Tests temp files 61 | tests/tmp 62 | 63 | # Dataset 64 | benchmark/deepspeed/data 65 | 66 | # plots 67 | benchmarks/**/*.pdf 68 | benchmarks/**/*.png 69 | 70 | # Numpy cache 71 | *.npy 72 | *.pkl 73 | benchmarks/alpa/tmp* 74 | *.log 75 | 76 | # Documentation website build 77 | docs/_build 78 | docs/tutorials 79 | 80 | # Example temp files 81 | examples/imagenet/imagenet 82 | 83 | # macOS temp files 84 | .DS_Store 85 | 86 | # images 87 | alpa_serve/trace/*.png 88 | alpa_serve/trace/plots 89 | dataset/ 90 | 91 | # experiments 92 | experiments/e2e_goodput/*.png 93 | experiments/e2e_goodput/*.tsv 94 | experiments/motivation/**/*.pdf 95 | experiments/motivation/**/*.pkl 96 | 97 | # ipython 98 | .ipynb_checkpoints 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AlpaServe 2 | Repo of alpa's multi-model serving system. 3 | 4 | This is the official implementation of our OSDI'23 paper: [AlpaServe: Statistical Multiplexing with Model Parallelism for Deep Learning Serving](https://www.usenix.org/conference/osdi23/presentation/li-zhouhan). 5 | 6 | To reproduce all the main results in our paper, please check the [artifact folder](./osdi23_artifact/) and follow the instructions in it. 7 | -------------------------------------------------------------------------------- /alpa_serve/__init__.py: -------------------------------------------------------------------------------- 1 | """Alpa serving backend""" 2 | -------------------------------------------------------------------------------- /alpa_serve/placement_policy/__init__.py: -------------------------------------------------------------------------------- 1 | """The model placement policy""" 2 | 3 | from alpa_serve.placement_policy.base_policy import ModelData, ClusterEnv 4 | from alpa_serve.placement_policy.model_parallelism import ( 5 | ModelParallelismILP, ModelParallelismRR, 6 | ModelParallelismGreedy, ModelParallelismSearch, 7 | ModelParallelismEqual) 8 | from alpa_serve.placement_policy.selective_replication import ( 9 | SelectiveReplicationILP, SelectiveReplicationGreedy, 10 | SelectiveReplicationUniform, SelectiveReplicationSearch, 11 | SelectiveReplicationReplacement) 12 | -------------------------------------------------------------------------------- /alpa_serve/run.py: -------------------------------------------------------------------------------- 1 | """Run a controller.""" 2 | import argparse 3 | 4 | import ray 5 | 6 | from alpa_serve.controller import run_controller 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--host", type=str, default="localhost") 11 | parser.add_argument("--port", type=int) 12 | parser.add_argument("--root-path", type=str, default="/") 13 | args = parser.parse_args() 14 | 15 | ray.init(address="auto", namespace="alpa_serve") 16 | controller = run_controller(args.host, args.port, args.root_path) 17 | 18 | while True: 19 | pass 20 | -------------------------------------------------------------------------------- /alpa_serve/simulator/cluster.py: -------------------------------------------------------------------------------- 1 | """ 2 | The cluster and device mesh abstraction. 3 | 4 | This file simulates `alpa/device_mesh.py`. 5 | """ 6 | 7 | from itertools import count 8 | from typing import Sequence, Tuple 9 | 10 | import numpy as np 11 | 12 | 13 | class GPU: 14 | idx = count() 15 | 16 | def __init__(self): 17 | self.stream_name = next(GPU.idx) 18 | 19 | 20 | class Mesh: 21 | def __init__(self, shape: Tuple[int]): 22 | self.gpus = [] 23 | for i in range(shape[0]): 24 | for j in range(shape[1]): 25 | self.gpus.append(GPU()) 26 | 27 | 28 | class MeshGroup: 29 | def __init__(self, mesh_shapes: Sequence[Tuple[int]]): 30 | self.meshes = [] 31 | 32 | for shape in mesh_shapes: 33 | self.meshes.append(Mesh(shape)) 34 | 35 | 36 | class VirtualMesh: 37 | def __init__(self, shape): 38 | self.shape = shape 39 | 40 | self.submesh_shapes = None 41 | self.launched_mesh_group = None 42 | 43 | def launch_mesh_group(self, submesh_shapes: Sequence[Tuple[int]]): 44 | assert self.launched_mesh_group is None 45 | 46 | assert np.prod(self.shape) == sum(np.prod(x) for x in submesh_shapes) 47 | 48 | self.submesh_shapes = tuple(submesh_shapes) 49 | self.launched_mesh_group = MeshGroup(submesh_shapes) 50 | 51 | return self.launched_mesh_group 52 | -------------------------------------------------------------------------------- /alpa_serve/simulator/executable.py: -------------------------------------------------------------------------------- 1 | """A pipeline executable.""" 2 | from alpa_serve.profiling import ParallelConfig, ProfilingResult 3 | from alpa_serve.simulator.cluster import VirtualMesh 4 | from alpa_serve.simulator.event_loop import clock, timed_coroutine, wait_stream, wait_multi_stream, sleep 5 | 6 | 7 | class Executable: 8 | def __init__(self, 9 | profiling_result: ProfilingResult, 10 | parallel_config: ParallelConfig, 11 | virtual_mesh: VirtualMesh): 12 | self.profile = profiling_result 13 | self.parallel_config = parallel_config 14 | self.latency_mem = profiling_result.para_dict[parallel_config] 15 | 16 | # launch or connect to a mesh group 17 | submesh_shapes = ( 18 | (parallel_config.dp, parallel_config.op),) * parallel_config.pp 19 | if virtual_mesh.launched_mesh_group: 20 | assert submesh_shapes == virtual_mesh.submesh_shapes 21 | mesh_group = virtual_mesh.launched_mesh_group 22 | else: 23 | mesh_group = virtual_mesh.launch_mesh_group(submesh_shapes) 24 | 25 | self.mesh_group = mesh_group 26 | 27 | def get_latency_dict(self): 28 | return self.latency_mem.latency 29 | 30 | @timed_coroutine 31 | async def handle_request(self, request): 32 | request.time_stamp["d"] = clock() 33 | batch_size = 1 34 | 35 | stage_latency = self.latency_mem.latency[batch_size] 36 | for mesh, latency in zip(self.mesh_group.meshes, stage_latency): 37 | # SPMD version 38 | stream = mesh.gpus[0].stream_name 39 | await wait_stream(stream, latency) 40 | 41 | # More accurate version 42 | #streams = [g.stream_name for g in mesh.gpus] 43 | #durations = [latency] * len(streams) 44 | #await wait_multi_stream(streams, durations) 45 | request.time_stamp["e"] = clock() 46 | return True 47 | -------------------------------------------------------------------------------- /alpa_serve/simulator/util.py: -------------------------------------------------------------------------------- 1 | """Common utilities""" 2 | import asyncio 3 | from functools import partial 4 | import threading 5 | 6 | import numpy as np 7 | 8 | 9 | def install_remote_methods(x): 10 | """Map obj.func.remote to obj.func, so we can create fake Ray APIs.""" 11 | for key in dir(x): 12 | value = getattr(x, key) 13 | if callable(value) and key[0] != "_": 14 | new_value = partial(value) 15 | setattr(new_value, "remote", new_value) 16 | setattr(x, key, new_value) 17 | 18 | 19 | def async_to_sync(async_def): 20 | """Convert a coroutine function to a normal function.""" 21 | assert asyncio.iscoroutinefunction(async_def) 22 | 23 | def ret_func(*args, **kwargs): 24 | corountine = async_def(*args, **kwargs) 25 | return run_coroutine(corountine) 26 | 27 | return ret_func 28 | 29 | 30 | def run_coroutine(corountine): 31 | """Run an asynchronous corountine synchronously.""" 32 | ret = [] 33 | 34 | def target(): 35 | ret.append(asyncio.run(corountine)) 36 | 37 | # Start a new thread to allow nested asyncio loops 38 | t = threading.Thread(target=target) 39 | t.start() 40 | t.join() 41 | 42 | return ret[0] 43 | 44 | 45 | # Workload generation utils 46 | 47 | 48 | class MMPPSampler: 49 | """Sample a sequence of requests from a Markov Modulated Poisson Process.""" 50 | def __init__(self, Q, lambda_): 51 | """Initialize a MMPP sampler. 52 | 53 | Args: 54 | Q (np.ndarray): Transition matrix of the Markov chain. 55 | lambda_ (np.ndarray): Lambdas of the Poisson process of each state. 56 | """ 57 | self.Q = Q 58 | self.lambda_ = lambda_ 59 | self.m = Q.shape[0] 60 | assert Q.shape == (self.m, self.m) 61 | assert lambda_.shape == (self.m,) 62 | self.Pi = np.identity(self.m) - np.diag(1 / np.diag(self.Q)) @ self.Q 63 | 64 | def sample(self, num_requests, initial_state=0): 65 | """Generate samples using the Markov-modulated Poisson process. 66 | 67 | Args: 68 | num_requests (int): Number of requests to generate. 69 | initial_state (int): Initial state of the Markov chain. 70 | 71 | Returns: 72 | tau: Arrival times of the requests. 73 | y: The duration of each state. 74 | y: The state sequence. 75 | ys: States of the individual requests. 76 | """ 77 | assert 0 <= initial_state < self.m 78 | ys = [initial_state] 79 | x = [0] 80 | y = [initial_state] 81 | tau = [0] 82 | while True: 83 | state = y[-1] 84 | y.append(np.random.choice(self.m, p=self.Pi[state])) 85 | t = x[-1] 86 | x.append(t + np.random.exponential(-1 / self.Q[state, state])) 87 | while True: 88 | t = t + np.random.exponential(1 / self.lambda_[state]) 89 | if t > x[-1]: 90 | break 91 | tau.append(t) 92 | ys.append(state) 93 | if len(tau) == num_requests + 1: 94 | return tau, (x, y, ys) 95 | 96 | def expected_request_rate(self): 97 | """Compute the expected request rate.""" 98 | return self.lambda_ @ self.Pi 99 | 100 | @classmethod 101 | def unifrom_mmpp(cls, expected_state_durations, 102 | expected_state_request_rates): 103 | """Special case of MMPP where the transition matrix from one state to 104 | another is uniform. 105 | 106 | Args: 107 | num_requests (int): Number of requests to generate. 108 | expected_state_durations (np.ndarray): Expected durations of each 109 | state. 110 | expected_state_request_rates (np.ndarray): Expected request rates of 111 | each state. 112 | initial_state (int): Initial state of the Markov chain. 113 | """ 114 | m = len(expected_state_durations) 115 | assert len(expected_state_request_rates) == m 116 | Q = np.zeros((m, m)) 117 | for i in range(m): 118 | for j in range(m): 119 | if i == j: 120 | Q[i, j] = -1 / expected_state_durations[i] 121 | else: 122 | Q[i, j] = 1 / expected_state_durations[i] / (m - 1) 123 | lambda_ = np.array(expected_state_request_rates) 124 | return cls(Q, lambda_) 125 | -------------------------------------------------------------------------------- /alpa_serve/trace/README.md: -------------------------------------------------------------------------------- 1 | # Trace Replay 2 | 3 | 4 | ## Dataset 5 | This folder provides methods to generate a TraceReplay from a public trace. Supported public trace: 6 | - Microsoft azure_v1 trace. [[Intrduction]](https://github.com/Azure/AzurePublicDataset/blob/master/AzureFunctionsDataset2019.md) [[Download]](https://drive.google.com/file/d/1Kup6JUH523CZZ7OxlkO942nAd5opuro0/view?usp=sharing) 7 | - Microsoft azure_v2 trace. [[Introduction]](https://github.com/Azure/AzurePublicDataset/blob/master/AzureFunctionsInvocationTrace2021.md) [[Download]](https://drive.google.com/file/d/1IOVoUoodBj4aKeyggxMnEVChEPutN4t7/view?usp=sharing) 8 | 9 | 10 | ## How to use 11 | First construct a trace object, which will read one of the two traces: 12 | ```python 13 | trace_name = "azure_v2" 14 | trace_dir = "~/azure_v2.pkl" 15 | trace = Trace(trace_name, trace_dir) 16 | ``` 17 | Provide a model that you want the trace to be replayed for: 18 | ```python 19 | n_model = 5 20 | models = [f"gpt{i}" for i in range(n_model)] 21 | ``` 22 | 23 | 24 | Replay the vanilla `azure_v2` trace in day 1. `azure_v1` cannot be replayed in vanilla mode. 25 | ```python 26 | 27 | replays = trace.replay_vanilla(models, 28 | model_mapping_strategy="stripe", 29 | start_time="0.0.0", 30 | end_time="1.0.0") 31 | ``` 32 | 33 | Replay `azure_v2` trace in day 1 - 5. Estimate a Gamma arrival distribution using the data from each 3600-second window 34 | and sample the arrivals from Gamma distributions. 35 | ```python 36 | replays = trace.replay(models, 37 | model_mapping_strategy="stripe", 38 | start_time="0.0.0", 39 | end_time="5.0.0", 40 | arrival_distribution="gamma", 41 | interval_seconds=3600) 42 | ``` 43 | 44 | Replay the vanilla `azure_v2` trace in day 1 - 14. However, scale the trace as if they happened in 7 days. 45 | ```python 46 | replays = trace.replay(models, 47 | model_mapping_strategy="stripe", 48 | start_time="0.0.0", 49 | end_time="13.23.60", 50 | arrival_distribution="vanilla", 51 | time_scale_factor=2.0) 52 | ``` 53 | 54 | Replay the `azure_v2` trace in day 1 using a Gamma estimator. But scale the Gamma distributions' rate and CV by 8x: 55 | ```python 56 | replays = trace.replay(models, 57 | model_mapping_strategy="stripe", 58 | start_time="0.0.0", 59 | end_time="1.0.0", 60 | arrival_distribution="gamma", 61 | rate_scale_factor=8.0, 62 | cv_scale_factor=8.0) 63 | ``` 64 | 65 | You can visualize the replayed trace by: 66 | ```python 67 | replays[model_name].report_stats() 68 | replays[model_name].visualize() 69 | ``` 70 | 71 | You can convert a TraceReplay to be a workload: 72 | ```python 73 | replays[model_name].to_workload(slo=1.0) 74 | ``` 75 | -------------------------------------------------------------------------------- /alpa_serve/trace/__init__.py: -------------------------------------------------------------------------------- 1 | from .trace import TraceReplay, Trace, load_trace, preprocess_azure_v1_trace, preprocess_azure_v2_trace, \ 2 | report_group_stats 3 | -------------------------------------------------------------------------------- /alpa_serve/trace/benchmark_trace.py: -------------------------------------------------------------------------------- 1 | from alpa_serve.trace import Trace, TraceReplay, report_group_stats 2 | from scipy.stats import entropy 3 | import numpy as np 4 | 5 | 6 | # trace_name = "azure_v2" 7 | # trace_dir = "/mnt/e/projects/projects/dataset/mms_dataset/azure_v2.pkl" 8 | 9 | trace_name = "azure_v1" 10 | trace_dir = "/mnt/e/projects/projects/dataset/mms_dataset/azure_v1.pkl" 11 | 12 | n_model = 48 13 | models = [f"gpt{i}" for i in range(n_model)] 14 | trace = Trace(trace_name, trace_dir) 15 | 16 | 17 | def cdf(x): 18 | x = np.array(x) 19 | return np.cumsum(np.sort(x)[::-1]) / np.sum(x) 20 | 21 | 22 | entropies = {} 23 | 24 | for day in range(14): 25 | for h in range(24): 26 | start_time = str(day) + "." + str(h) + ".0" 27 | if h < 23: 28 | end_time = str(day) + "." + str(h+1) + ".0" 29 | else: 30 | end_time = str(day) + "." + "23.60" 31 | 32 | 33 | replays = trace.replay(models, 34 | model_mapping_strategy="stripe", 35 | start_time=start_time, 36 | end_time=end_time, 37 | interval_seconds=60, 38 | arrival_distribution="exponential", 39 | rate_scale_factor=1e-3) 40 | # for m in replays: 41 | # replays[m].report_stats() 42 | # replays[m].visualize(n_interval=1000) 43 | report_group_stats(list(replays.values())) 44 | x = [replays[model].arrivals.size for model in replays] 45 | # print(x) 46 | entropies[start_time] = entropy(x) 47 | print(f"Entropy for {start_time} - {end_time}: {entropy(x)}, top-5 {np.sum(cdf(x)[:5]) / np.sum(cdf(x))}") 48 | 49 | print(entropies) 50 | print(max(entropies.values())) 51 | 52 | # for day in range(13, 14): 53 | # start_time = str(day) + ".0.0" 54 | # end_time = str(day+1) + ".0.0" 55 | # 56 | # if day == 13: 57 | # end_time = "13.23.60" 58 | # # replication_factors = [1, 2, 3] 59 | # print(f"Day: {start_time} - {end_time}") 60 | # distributions = ["gamma"] 61 | # for rf in replication_factors: 62 | # replays = trace.replay_vanilla(models, 63 | # model_mapping_strategy="stripe", 64 | # start_time=start_time, 65 | # end_time=end_time) 66 | # for m in replays: 67 | # replays[m].report_stats() 68 | # # replays[m].visualize(n_interval=1000) 69 | # report_group_stats(list(replays.values())) 70 | 71 | # for distribution in distributions: 72 | # replays = trace.replay(models, 73 | # model_mapping_strategy="stripe", 74 | # start_time=start_time, 75 | # end_time=end_time, 76 | # interval_seconds=5400, 77 | # arrival_distribution=distribution) 78 | # # for m in replays: 79 | # # replays[m].report_stats() 80 | # # replays[m].visualize(n_interval=1000) 81 | # report_group_stats(list(replays.values())) 82 | # x = [replays[model].arrivals.size for model in replays] 83 | # print(x) 84 | # print(f"Entropy for {start_time} - {end_time}: {entropy(x)}, CDF: {cdf(x)}") 85 | 86 | # replays = trace.replay(models, 87 | # model_mapping_strategy="stripe", 88 | # start_time="0.0.0", 89 | # end_time="2.0.0", 90 | # interval_seconds=86400 // 2, 91 | # arrival_distribution="gamma") 92 | # for m in replays: 93 | # replays[m].report_stats() 94 | # replays[m].visualize() 95 | 96 | 97 | # interval_seconds = [600] 98 | # time_scale_factors = [2.0, 4.0, 8.0] 99 | # for interval_secs in interval_seconds: 100 | # # for distribution in ["exponential", "gamma", "vanilla"]: 101 | # for distribution in ["vanilla"]: 102 | # for time_scale_factor in time_scale_factors: 103 | # replays = trace.replay(models, 104 | # model_mapping_strategy="stripe", 105 | # start_time="0.0.0", 106 | # end_time="1.0.0", 107 | # arrival_distribution=distribution, 108 | # interval_seconds=interval_secs, 109 | # time_scale_factor=time_scale_factor) 110 | # for m in replays: 111 | # replays[m].report_stats() 112 | # replays[m].visualize() 113 | -------------------------------------------------------------------------------- /alpa_serve/trace/distribution.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | shape, mode = 3., 2. 7 | s = (np.random.pareto(shape, 1000) + 1) * mode 8 | 9 | count, bins, _ = plt.hist(s, 100, density=True) 10 | fit = shape * mode ** shape / bins ** (shape + 1) 11 | plt.plot(bins, max(count)*fit/max(fit), linewidth=2, color='r') 12 | plt.show() 13 | fig = plt.gcf() 14 | figure_size = (8, 4) 15 | fig.set_size_inches(figure_size) 16 | fig.savefig("test.png", bbox_inches='tight') -------------------------------------------------------------------------------- /alpa_serve/trace/test_trace.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | 5 | from alpa_serve.trace import Trace, load_trace 6 | 7 | azure_v1_trace_dir = "azure_v1.pkl" 8 | azure_v2_trace_dir = "azure_v2.pkl" 9 | azure_v1_trace = Trace("azure_v1", azure_v1_trace_dir) 10 | azure_v2_trace = Trace("azure_v2", azure_v2_trace_dir) 11 | models = [f"gpt-{i}" for i in range(30)] 12 | 13 | 14 | @pytest.mark.skip(reason="slow test") 15 | @pytest.mark.parametrize("trace_name", ["azure_v1", "azure_v2"]) 16 | def test_read_trace(trace_name): 17 | load_trace(trace_name) 18 | 19 | @pytest.mark.skip(reason="slow test") 20 | @pytest.mark.parametrize("start_time, end_time", 21 | [("0.0.0", "1.0.0"), 22 | ("0.0.0", "0.23.60"), 23 | ("0.0.0", "7.0.0"), 24 | ("0.0.0", "13.23.60"), 25 | ("5.5.5", "8.8.8"), 26 | ("11.0.0", "13.23.60"), 27 | ("12.23.60", "13.23.60"), 28 | ("13.0.0", "13.23.60")]) 29 | def test_slice_azure_v2_trace(start_time, end_time): 30 | trace = azure_v2_trace 31 | arrivals = trace.slice(start_time=start_time, end_time=end_time) 32 | Trace.report_stats(arrivals) 33 | start_d, start_h, start_m = trace.timestr_to_dhm(start_time) 34 | end_d, end_h, end_m = trace.timestr_to_dhm(end_time) 35 | start_timestamp_seconds = start_d * 24 * 60 * 60 + start_h * 60 * 60 + start_m * 60 36 | end_timestamp_seconds = end_d * 24 * 60 * 60 + end_h * 60 * 60 + end_m * 60 37 | for function, arrival in arrivals.items(): 38 | np.all(arrival >= start_timestamp_seconds) 39 | np.all(arrival < end_timestamp_seconds) 40 | 41 | @pytest.mark.skip(reason="slow test") 42 | @pytest.mark.parametrize("start_time, end_time", 43 | [("0.0.0", "1.0.0"), 44 | ("0.0.0", "0.23.60"), 45 | ("0.0.0", "7.0.0"), 46 | ("0.0.0", "13.23.60"), 47 | ("5.5.5", "8.8.8"), 48 | ("11.0.0", "13.23.60"), 49 | ("12.23.60", "13.23.60"), 50 | ("13.0.0", "13.23.60")]) 51 | def test_slice_azure_v1_trace(start_time, end_time): 52 | trace = azure_v1_trace 53 | histogram = trace.slice(start_time, end_time) 54 | Trace.report_stats(histogram) 55 | 56 | start_d, start_h, start_m = trace.timestr_to_dhm(start_time) 57 | end_d, end_h, end_m = trace.timestr_to_dhm(end_time) 58 | start_slot = start_d * 24 * 60 + start_h * 60 + start_m 59 | end_slot = end_d * 24 * 60 + end_h * 60 + end_m 60 | n_slot = end_slot - start_slot 61 | for function, h in histogram.items(): 62 | assert h.size == n_slot 63 | 64 | @pytest.mark.parametrize("start_time, end_time", 65 | [("0.0.0", "1.0.0"), 66 | ("0.0.0", "0.23.60"), 67 | ("0.0.0", "7.0.0"), 68 | ("0.0.0", "13.23.60"), 69 | ("5.5.5", "8.8.8"), 70 | ("11.0.0", "13.23.60"), 71 | ("12.23.60", "13.23.60"), 72 | ("13.0.0", "13.23.60")]) 73 | def test_replay_vanilla(start_time, end_time): 74 | trace_replays = azure_v2_trace.replay_vanilla(models, start_time=start_time, 75 | end_time=end_time) 76 | for model, replay in trace_replays.items(): 77 | replay.report_stats() 78 | replay.visualize() 79 | 80 | 81 | @pytest.mark.parametrize("start_time, end_time", 82 | [("0.0.0", "1.0.0"), 83 | ("0.0.0", "0.23.60"), 84 | ("0.0.0", "7.0.0"), 85 | ("0.0.0", "13.23.60"), 86 | ("5.5.5", "8.8.8"), 87 | ("11.0.0", "13.23.60"), 88 | ("12.23.60", "13.23.60"), 89 | ("13.0.0", "13.23.60")]) 90 | @pytest.mark.parametrize("model_mapping_strategy", ["round_robin", "stripe"]) 91 | @pytest.mark.parametrize("arrival_distribution", ["exponential", "gamma"]) 92 | @pytest.mark.parametrize("interval_seconds", [60, 1800, 3600, 14400]) 93 | def test_replay_poisson(start_time, end_time, model_mapping_strategy, arrival_distribution, interval_seconds): 94 | azure_v2_trace.replay(models, 95 | model_mapping_strategy=model_mapping_strategy, 96 | start_time=start_time, 97 | end_time=end_time, 98 | arrival_distribution=arrival_distribution, 99 | interval_seconds=interval_seconds) 100 | 101 | 102 | 103 | if __name__ == "__main__": 104 | import pytest 105 | import sys 106 | 107 | sys.exit(pytest.main(["-v", "-x", "-s", __file__])) 108 | -------------------------------------------------------------------------------- /alpa_serve/util.py: -------------------------------------------------------------------------------- 1 | """Common utilities.""" 2 | from collections import namedtuple 3 | import functools 4 | import logging 5 | import math 6 | from typing import Sequence, Any 7 | 8 | import ray 9 | import numpy as np 10 | 11 | # global switch for batching 12 | # enable_batching = True 13 | batchsize_config = [1, 2, 4, 8, 16] 14 | 15 | # A general serving case. 16 | # We can simulate or run such a case. 17 | ServingCase = namedtuple("ServingCase", 18 | ("register_models", "generate_workload", "placement_policy")) 19 | 20 | 21 | GB = 1 << 30 22 | eps = 1e-6 23 | inf = 1e100 24 | 25 | 26 | def build_logger(name="alpa_serve"): 27 | logger = logging.getLogger(name) 28 | logger.setLevel(logging.INFO) 29 | return logger 30 | 31 | 32 | def add_sync_method(actor, method_names): 33 | """Add a actor.sync method to wait for all calls to methods 34 | listed in method_names.""" 35 | calls = [] 36 | 37 | def sync(): 38 | ray.get(calls) 39 | calls.clear() 40 | 41 | setattr(actor, "sync", sync) 42 | 43 | for name in method_names: 44 | attr = getattr(actor, name) 45 | old_remote = attr.remote 46 | setattr(attr, "remote", functools.partial(wrapped_remote_call, old_remote, calls)) 47 | 48 | 49 | def wrapped_remote_call(old_remote, calls, *args, **kwargs): 50 | ret = old_remote(*args, *kwargs) 51 | calls.append(ret) 52 | return ret 53 | 54 | 55 | def write_tsv(heads: Sequence[str], 56 | values: Sequence[Any], 57 | filename: str, 58 | print_line: bool = True): 59 | """Write tsv data to a file.""" 60 | assert len(heads) == len(values) 61 | 62 | values = [str(x) for x in values] 63 | 64 | with open(filename, "a", encoding="utf-8") as fout: 65 | fout.write("\t".join(values) + "\n") 66 | 67 | if print_line: 68 | line = "" 69 | for i in range(len(heads)): 70 | line += heads[i] + ": " + values[i] + " " 71 | print(line) 72 | 73 | 74 | def get_factors(n: int): 75 | step = 2 if n % 2 else 1 76 | ret = list( 77 | set( 78 | functools.reduce( 79 | list.__add__, 80 | ([i, n // i] for i in range(1, int(math.sqrt(n)) + 1, step) if n % i == 0), 81 | ) 82 | ) 83 | ) 84 | ret.sort() 85 | return ret 86 | 87 | 88 | def to_str_round(x: Any, decimal: int = 6): 89 | """Print a python object but round all floating point numbers.""" 90 | if isinstance(x, str): 91 | return x 92 | if isinstance(x, (list, tuple, np.ndarray)): 93 | tmp_str = ", ".join([to_str_round(y, decimal=decimal) for y in x]) 94 | return "[" + tmp_str + "]" 95 | if isinstance(x, dict): 96 | return str({k: to_str_round(v, decimal=decimal) for k, v in x.items()}) 97 | if isinstance(x, (int, np.int32, np.int64)): 98 | return str(x) 99 | if isinstance(x, (float, np.float32, np.float64)): 100 | format_str = f"%.{decimal}f" 101 | return format_str % x 102 | if x is None: 103 | return str(x) 104 | raise ValueError("Invalid value: " + str(x)) 105 | 106 | 107 | def is_valid_size(n: int, i: int): 108 | if i <= n % 8 or (n - i) % 8 == 0 or n % 8 == 0: 109 | return True 110 | else: 111 | return False 112 | 113 | # partition n into k parts that summed to n 114 | # each part could only be 2^k 115 | def get_partitions(n: int, k: int): 116 | if k == 1: 117 | return [[n]] 118 | 119 | ret = [] 120 | for i in range(1, n): 121 | if not is_valid_size(n, i): continue 122 | pre_partitions = get_partitions(n - i, k - 1) 123 | ret += [partition + [i] for partition in pre_partitions] 124 | return ret 125 | 126 | 127 | #def get_partitions(n: int, k: int, lb: int = 1): 128 | # if k == 1: 129 | # if n >= lb: 130 | # return [[n]] 131 | # else: 132 | # return [] 133 | # 134 | # ret = [] 135 | # for i in range(lb, n): 136 | # if not is_valid_size(n, i): continue 137 | # pre_partitions = get_partitions(n - i, k - 1, i) 138 | # ret += [partition + [i] for partition in pre_partitions] 139 | # return ret 140 | 141 | 142 | def get2tok(n: int): 143 | assert n > 0 144 | ret = [1] 145 | while True: 146 | if ret[-1] * 2 <= n: 147 | ret.append(ret[-1] * 2) 148 | else: 149 | break 150 | return ret 151 | 152 | 153 | def decompose2tok(n: int): 154 | ret = [] 155 | i = 1 156 | while n > 0: 157 | if n % 2 == 1: 158 | ret.append(i) 159 | i *= 2 160 | n = n // 2 161 | return ret 162 | 163 | 164 | if __name__ == "__main__": 165 | print(get_partitions(64, 6)) 166 | print(len(get_partitions(64, 6))) 167 | print(get2tok(34)) 168 | print(decompose2tok(13)) 169 | -------------------------------------------------------------------------------- /benchmarks/alpa/README.md: -------------------------------------------------------------------------------- 1 | ## Profiling Database 2 | Download the [profiling results](https://github.com/alpa-projects/mms/issues/14) into this folder. 3 | ## Run one case 4 | ``` 5 | python3 run_one_case.py --case debug_manual_1 6 | ``` 7 | 8 | ## Simulate one case 9 | Get the profiling results file from https://github.com/alpa-projects/mms/issues/14. 10 | 11 | ``` 12 | python3 simulate_one_case.py --case debug_manual_1 13 | ``` 14 | To show the debugging timestamp: 15 | ``` 16 | python3 simulate_one_case.py --case debug_manual_1 --debug 17 | ``` 18 | 19 | -------------------------------------------------------------------------------- /benchmarks/alpa/approximate_one_case.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | from alpa_serve.simulator.controller import approximate_one_case 5 | from alpa_serve.simulator.workload import Workload 6 | 7 | from benchmarks.alpa.suite_debug import suite_debug 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--case", type=str, default="debug_replicate") 13 | parser.add_argument("--debug", action="store_true") 14 | parser.add_argument("--bench-speed", action="store_true") 15 | parser.add_argument("--fast-stats", action="store_true") 16 | args = parser.parse_args() 17 | 18 | stats, placement = approximate_one_case(suite_debug[args.case], debug=args.debug, 19 | fast_stats=args.fast_stats) 20 | Workload.print_stats(stats) 21 | 22 | if args.bench_speed: 23 | tic = time.time() 24 | stats, placement = approximate_one_case(suite_debug[args.case], debug=args.debug) 25 | print(f"time: {time.time() - tic:.4f}") 26 | -------------------------------------------------------------------------------- /benchmarks/alpa/compare_waiting_time.py: -------------------------------------------------------------------------------- 1 | from alpa_serve.simulator.workload import GammaProcess 2 | 3 | 4 | def kingman_formula(arrival_rate, arrival_CV, service_rate): 5 | p = arrival_rate / service_rate 6 | assert 0 <= p <= 1 7 | return p / (1 - p) * (arrival_CV ** 2) / 2 * (1 / service_rate) 8 | 9 | 10 | def waiting_time(workload, service_time): 11 | return kingman_formula(workload.rate, workload.cv, 1 / service_time) + service_time 12 | 13 | 14 | def pipeline_waiting_time(workload, stage_service_time): 15 | return kingman_formula(workload.rate, workload.cv, 1 / max(stage_service_time)) + sum(stage_service_time) 16 | 17 | 18 | if __name__ == "__main__": 19 | r_a = GammaProcess(3, 2).generate_workload("a", start=0, duration=1000, seed=10) 20 | r_b = GammaProcess(3, 2).generate_workload("b", start=0, duration=1000, seed=11) 21 | r_c = GammaProcess(3, 2).generate_workload("c", start=0, duration=1000, seed=12) 22 | r_d = GammaProcess(3, 2).generate_workload("d", start=0, duration=1000, seed=13) 23 | 24 | # replication 1x 25 | w1 = waiting_time(r_a + r_b, 0.1) 26 | w2 = waiting_time(r_c + r_d, 0.1) 27 | print(f"w1: {w1: .3f}, w2: {w2:.3f}") 28 | 29 | # replication 2x 30 | w1 = waiting_time(r_a[::2] + r_b[::2] + r_c[::2] + r_d[::2], 0.1) 31 | w2 = waiting_time(r_a[1::2] + r_b[1::2] + r_c[1::2] + r_d[1::2], 0.1) 32 | print(f"w1: {w1: .3f}, w2: {w2:.3f}") 33 | r = r_a[::2] + r_b[::2] + r_c[::2] + r_d[::2] 34 | print(f"rate: {r.rate: .3f}, cv: {r.cv: .3f}") 35 | 36 | # pipeline 2x 37 | w1 = pipeline_waiting_time(r_a + r_b + r_c + r_d, [0.052, 0.050]) 38 | print(f"w1: {w1: .3f}") 39 | r = r_a + r_b + r_c + r_d 40 | print(f"rate: {r.rate: .3f}, cv: {r.cv: .3f}") 41 | -------------------------------------------------------------------------------- /benchmarks/alpa/gen_data_goodput_vs_slo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from benchmarks.alpa.equal_model_case import EqualModelCase, run_equal_model_cases 4 | from benchmarks.alpa.general_model_case import GeneralModelCase, run_general_model_cases 5 | from alpa_serve.util import GB, batchsize_config 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--output", type=str, default="res_goodput_vs_slo.tsv") 11 | parser.add_argument("--parallel", action="store_true") 12 | parser.add_argument("--policy", type=str) 13 | parser.add_argument("--slo-scale", type=float) 14 | parser.add_argument("--trace", choices=["synthetic", "azure_v2"], 15 | default="synthetic") 16 | parser.add_argument("--mode", choices=["simulate", "run"], 17 | default="simulate") 18 | parser.add_argument("--unequal", action="store_true") 19 | parser.add_argument("--model-type", type=str, default="all_transformers", 20 | choices=["all_transformers", "mixed"]) 21 | parser.add_argument("--protocol", type=str, default="http", 22 | choices=["http", "ray"]) 23 | parser.add_argument("--relax-slo", action="store_true") 24 | parser.add_argument("--debug-tstamp", action="store_true") 25 | parser.add_argument("--enable-batching", action="store_true") 26 | parser.add_argument("--max-batchsize", type=int, default=2) 27 | 28 | args = parser.parse_args() 29 | 30 | # choices: {"sr-greedy", "sr-ilp", "mp-ilp", 31 | # "mp-greedy-2", "mp-greedy-4", "mp-greedy-8", 32 | # "mp-search", "mp-search-sep"} 33 | if args.policy is not None: 34 | policies = [args.policy] 35 | else: 36 | # policies = ["sr-greedy", "mp-search", "sr-replace-30"] 37 | policies = ["mp-search"] 38 | 39 | if args.enable_batching: 40 | assert args.max_batchsize == batchsize_config[-1], f"maximum batchsize is not {args.max_batchsize}, set it in alpa_serve/util.py" 41 | policies = [policy + "-batch-" + str(args.max_batchsize) for policy in policies] 42 | 43 | exp_name = "goodput_vs_slo" 44 | num_devices = 8 45 | mem_budget = 13 * GB 46 | model_type = "bert-2.6b" 47 | num_models = 8 48 | total_rate = 32 49 | if args.trace == "synthetic": 50 | # choices: {"gamma", "uniform_mmpp"} 51 | arrival_process = "gamma" 52 | # choices: {"uniform", "power_law", "triangle_decay"} 53 | rate_distribution = "power_law" 54 | arrival_process_kwargs = {"cv": 4} 55 | elif args.trace == "azure_v2": 56 | # choices: {"azure_v2"} 57 | arrival_process = "azure_v2" 58 | rate_distribution = None 59 | arrival_process_kwargs = None 60 | 61 | if args.slo_scale is not None: 62 | slo_scales = [args.slo_scale] 63 | else: 64 | slo_scales = [0.5, 1, 2, 3, 4, 5, 6, 8, 10, 12, 14] 65 | duration = 1000 66 | 67 | if args.unequal: 68 | # multi-model config 69 | if args.model_type == "mixed": 70 | model_set = ["bert-1.3b", "bert-2.6b", "bert-6.7b", "moe-1.3b", "moe-2.4b", "moe-5.3b"] 71 | else: 72 | model_set = ["bert-6.7b", "moe-1.3b"] 73 | num_devices = 64 74 | total_rate = 70 75 | fixed_num_modelset = 8 76 | model_types = model_set * fixed_num_modelset 77 | model_names = sum([[f"{model_type}-{i}" for model_type in model_set] for i in range(fixed_num_modelset)], []) 78 | 79 | cases = [] 80 | for slo_scale in slo_scales: 81 | for policy_name in policies: 82 | cases.append(GeneralModelCase( 83 | exp_name, num_devices, mem_budget, model_types, model_names, 84 | total_rate, rate_distribution, 85 | arrival_process, arrival_process_kwargs, 86 | slo_scale, duration, policy_name)) 87 | 88 | run_general_model_cases(cases, 89 | output_file=args.output, 90 | mode=args.mode, 91 | debug_tstamp=args.debug_tstamp, 92 | parallel=args.parallel, 93 | enable_batching=args.enable_batching) 94 | else: 95 | cases = [] 96 | for slo_scale in slo_scales: 97 | for policy_name in policies: 98 | cases.append(EqualModelCase( 99 | exp_name, num_devices, mem_budget, model_type, num_models, 100 | total_rate, rate_distribution, 101 | arrival_process, arrival_process_kwargs, 102 | slo_scale, duration, policy_name, 103 | None, None, None, None)) 104 | 105 | 106 | run_equal_model_cases(cases, 107 | output_file=args.output, 108 | mode=args.mode, 109 | relax_slo=args.relax_slo, 110 | protocol=args.protocol, 111 | debug_tstamp=args.debug_tstamp, 112 | parallel=args.parallel, 113 | enable_batching=args.enable_batching) 114 | -------------------------------------------------------------------------------- /benchmarks/alpa/gen_data_simulator_align.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from benchmarks.alpa.equal_model_case import EqualModelCase, run_equal_model_cases 4 | from benchmarks.alpa.general_model_case import GeneralModelCase, run_general_model_cases 5 | from alpa_serve.util import GB 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--output", type=str, default="res_goodput_vs_slo.tsv") 11 | parser.add_argument("--parallel", action="store_true") 12 | parser.add_argument("--policy", type=str) 13 | parser.add_argument("--slo-scale", type=float) 14 | parser.add_argument("--trace", choices=["synthetic", "azure_v2"], 15 | default="synthetic") 16 | parser.add_argument("--mode", choices=["simulate", "run"], 17 | default="simulate") 18 | parser.add_argument("--unequal", action="store_true") 19 | parser.add_argument("--model-type", type=str, default="all_transformers", 20 | choices=["all_transformers", "mixed"]) 21 | parser.add_argument("--protocol", type=str, default="http", 22 | choices=["http", "ray"]) 23 | parser.add_argument("--relax-slo", action="store_true") 24 | parser.add_argument("--debug-tstamp", action="store_true") 25 | 26 | args = parser.parse_args() 27 | 28 | # choices: {"sr-greedy", "sr-ilp", "mp-ilp", 29 | # "mp-greedy-2", "mp-greedy-4", "mp-greedy-8", 30 | # "mp-search", "mp-search-sep"} 31 | if args.policy is not None: 32 | policies = [args.policy] 33 | else: 34 | policies = ["sr-greedy", "mp-search"] 35 | exp_name = "goodput_vs_slo" 36 | num_devices = 16 37 | mem_budget = 13 * GB 38 | model_type = "bert-2.6b" 39 | num_models = 24 40 | total_rate = 40 41 | if args.trace == "synthetic": 42 | # choices: {"gamma", "uniform_mmpp"} 43 | arrival_process = "gamma" 44 | # choices: {"uniform", "power_law", "triangle_decay"} 45 | rate_distribution = "power_law" 46 | arrival_process_kwargs = {"cv": 4} 47 | elif args.trace == "azure_v2": 48 | # choices: {"azure_v2"} 49 | arrival_process = "azure_v2" 50 | rate_distribution = None 51 | arrival_process_kwargs = None 52 | 53 | if args.slo_scale is not None: 54 | slo_scales = [args.slo_scale] 55 | else: 56 | slo_scales = [0.5, 1, 1.5, 2, 3, 4, 5, 6, 8, 10] 57 | duration = 200 58 | 59 | if args.unequal: 60 | # multi-model config 61 | if args.model_type == "mixed": 62 | model_set = ["bert-1.3b", "bert-2.6b", "bert-6.7b", "moe-1.3b", "moe-2.4b", "moe-5.3b"] 63 | else: 64 | model_set = ["bert-6.7b", "moe-1.3b"] 65 | num_devices = 64 66 | total_rate = 70 67 | fixed_num_modelset = 8 68 | model_types = model_set * fixed_num_modelset 69 | model_names = sum([[f"{model_type}-{i}" for model_type in model_set] for i in range(fixed_num_modelset)], []) 70 | 71 | cases = [] 72 | for slo_scale in slo_scales: 73 | for policy_name in policies: 74 | cases.append(GeneralModelCase( 75 | exp_name, num_devices, mem_budget, model_types, model_names, 76 | total_rate, rate_distribution, 77 | arrival_process, arrival_process_kwargs, 78 | slo_scale, duration, policy_name)) 79 | 80 | run_general_model_cases(cases, 81 | output_file=args.output, 82 | mode=args.mode, 83 | debug_tstamp=args.debug_tstamp, 84 | parallel=args.parallel) 85 | else: 86 | cases = [] 87 | for slo_scale in slo_scales: 88 | for policy_name in policies: 89 | cases.append(EqualModelCase( 90 | exp_name, num_devices, mem_budget, model_type, num_models, 91 | total_rate, rate_distribution, 92 | arrival_process, arrival_process_kwargs, 93 | slo_scale, duration, policy_name, 94 | None, None, None, None)) 95 | 96 | 97 | run_equal_model_cases(cases, 98 | output_file=args.output, 99 | mode=args.mode, 100 | relax_slo=args.relax_slo, 101 | protocol=args.protocol, 102 | debug_tstamp=args.debug_tstamp, 103 | parallel=args.parallel) 104 | -------------------------------------------------------------------------------- /benchmarks/alpa/gen_data_various_metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from benchmarks.alpa.equal_model_case import EqualModelCase, run_equal_model_cases 4 | from alpa_serve.util import GB 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--exp-name", type=str, default="all") 10 | parser.add_argument("--output", type=str, default="res_various_metrics.tsv") 11 | parser.add_argument("--parallel", action="store_true") 12 | parser.add_argument("--mode", choices=["simulate", "run"], 13 | default="simulate") 14 | 15 | args = parser.parse_args() 16 | 17 | if args.exp_name == "all": 18 | exps = ["num_devices", "num_models", "slo", "rate", "cv"] 19 | else: 20 | exps = [args.exp_name] 21 | 22 | # choices: {"sr-greedy", "sr-ilp", "mp-ilp", "mp-greedy-2", "mp-greedy-8"} 23 | policies = ["sr-greedy", "sr-search", "mp-greedy-4", "mp-search"] 24 | mem_budget = 12 * GB 25 | model_type = "bert-2.6b" 26 | rate_distribution = "power_law" 27 | arrival_process = "gamma" 28 | duration = 200 29 | 30 | fixed_num_devices = 8 31 | fixed_num_models = 8 32 | fixed_per_model_rate = 3 33 | fixed_total_rate = fixed_num_models * fixed_per_model_rate 34 | fixed_slo_scale = 6 35 | fixed_cv = {"cv": 4} 36 | 37 | num_devices_list = [4, 8, 12, 16, 20, 24, 28, 32, 36] 38 | num_models_list = [1, 2, 4, 6, 8, 10, 12, 14, 16] 39 | total_rates = [1, 2, 4, 8, 12, 16, 20, 24, 28, 32] 40 | slo_scales = [3, 4, 5, 6, 8, 10, 12, 14, 20] 41 | cvs = [0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5] 42 | 43 | ##### goodput vs num_devices ##### 44 | if "num_devices" in exps: 45 | cases = [] 46 | for num_devices in num_devices_list: 47 | for policy_name in policies: 48 | cases.append(EqualModelCase( 49 | num_devices, mem_budget, model_type, fixed_num_models, 50 | fixed_total_rate, rate_distribution, 51 | arrival_process, fixed_cv, 52 | fixed_slo_scale, duration, policy_name)) 53 | 54 | run_equal_model_cases(cases, exp_name="goodput_vs_num_devices", 55 | output_file=args.output, 56 | mode=args.mode, parallel=args.parallel) 57 | 58 | ##### goodput vs num_models ##### 59 | if "num_models" in exps: 60 | cases = [] 61 | for num_models in num_models_list: 62 | for policy_name in policies: 63 | cases.append(EqualModelCase( 64 | fixed_num_devices, mem_budget, model_type, num_models, 65 | fixed_per_model_rate * num_models, rate_distribution, 66 | arrival_process, fixed_cv, 67 | fixed_slo_scale, duration, policy_name)) 68 | 69 | run_equal_model_cases(cases, exp_name="goodput_vs_num_models", 70 | output_file=args.output, 71 | mode=args.mode, parallel=args.parallel) 72 | 73 | ##### goodput vs slo ##### 74 | if "slo" in exps: 75 | cases = [] 76 | for slo_scale in slo_scales: 77 | for policy_name in policies: 78 | cases.append(EqualModelCase( 79 | fixed_num_devices, mem_budget, model_type, fixed_num_models, 80 | fixed_total_rate, rate_distribution, 81 | arrival_process, fixed_cv, 82 | slo_scale, duration, policy_name)) 83 | 84 | run_equal_model_cases(cases, exp_name="goodput_vs_slo", 85 | output_file=args.output, 86 | mode=args.mode, parallel=args.parallel) 87 | 88 | ##### goodput vs total_rate ##### 89 | if "rate" in exps: 90 | cases = [] 91 | for total_rate in total_rates: 92 | for policy_name in policies: 93 | cases.append(EqualModelCase( 94 | fixed_num_devices, mem_budget, model_type, fixed_num_models, 95 | total_rate, rate_distribution, 96 | arrival_process, fixed_cv, 97 | fixed_slo_scale, duration, policy_name)) 98 | 99 | run_equal_model_cases(cases, exp_name="goodput_vs_total_rate", 100 | output_file=args.output, 101 | mode=args.mode, parallel=args.parallel) 102 | 103 | ##### goodput vs cv ##### 104 | if "cv" in exps: 105 | cases = [] 106 | for cv in cvs: 107 | for policy_name in policies: 108 | cases.append(EqualModelCase( 109 | fixed_num_devices, mem_budget, model_type, fixed_num_models, 110 | fixed_total_rate, rate_distribution, 111 | arrival_process, {"cv": cv}, 112 | fixed_slo_scale, duration, policy_name)) 113 | 114 | run_equal_model_cases(cases, exp_name="goodput_vs_cv", 115 | output_file=args.output, 116 | mode=args.mode, parallel=args.parallel) 117 | -------------------------------------------------------------------------------- /benchmarks/alpa/inspect_profiling_result.py: -------------------------------------------------------------------------------- 1 | """Inspect the profiling database.""" 2 | import pickle 3 | 4 | import numpy as np 5 | 6 | import alpa_serve 7 | from alpa_serve.profiling import ParallelConfig 8 | 9 | prof = pickle.load(open("profiling_result.pkl", "rb")) 10 | 11 | parallel_configs = ParallelConfig(1, 1, 8) 12 | bs = 1 13 | 14 | #for model_name in ["bert-1.3b", "bert-2.6b", "bert-6.7b", "moe-1.3b", "moe-2.4b", "moe-7.1b"]: 15 | 16 | for model_name in ["bert-1.3b", "bert-2.6b"]: 17 | for parallel_config in [ParallelConfig(1,1,1), ParallelConfig(1,1,8)]: 18 | base_latency = sum(prof[model_name].para_dict[ParallelConfig(1, 1, 1)].latency[bs]) 19 | latency = prof[model_name].para_dict[parallel_config].latency[bs] 20 | print(f"Model: {model_name}, {parallel_config}, Latency: {sum(latency):.4f}, " 21 | f"Latency Overhead: {sum(latency) / base_latency:.2f}, " 22 | f"Throughput Overhead: {np.prod(parallel_config) * max(latency) / base_latency:.2f}") 23 | -------------------------------------------------------------------------------- /benchmarks/alpa/plot_goodput_vs_slo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | 8 | from benchmarks.alpa.equal_model_case import read_equal_model_case_tsv 9 | from benchmarks.alpa.plot_various_metrics import show_name, method2color, method2order 10 | 11 | 12 | def read_data(filename): 13 | # Dict[policy -> Dict[slo_scale -> goodput]] 14 | data = defaultdict(lambda: defaultdict(dict)) 15 | 16 | rate = cv = None 17 | 18 | for line in read_equal_model_case_tsv(filename): 19 | policy, slo_scale, goodput, total_rate, kwargs, mode = ( 20 | line["policy_name"], line["slo_scale"], line["goodput"], 21 | line["total_rate"], line["arrival_process_kwargs"], line["mode"]) 22 | 23 | if rate is None: 24 | rate = total_rate 25 | cv = kwargs["cv"] if kwargs else 1 26 | 27 | if mode == "simulate": 28 | policy = policy 29 | else: 30 | policy = policy + "-real" 31 | 32 | data[policy][slo_scale] = goodput 33 | 34 | return data, {"total_rate": rate, "per_model_cv": cv} 35 | 36 | 37 | def plot_goodput_vs_slo(data, title, output, show): 38 | fig, ax = plt.subplots() 39 | figure_size = (5, 5) 40 | 41 | methods = list(data.keys()) 42 | methods.sort(key=lambda x: method2order(x)) 43 | 44 | curves = [] 45 | legends = [] 46 | x_max = 0 47 | y_max = 0 48 | for method in methods: 49 | curve = data[method] 50 | xs, ys = zip(*curve.items()) 51 | xs, ys = np.array(xs), np.array(ys) * 100 52 | indices = np.argsort(xs) 53 | xs, ys = xs[indices], ys[indices] 54 | if "batch" in method: 55 | curve = ax.plot(xs, ys, "--*", color=method2color(method)) 56 | else: 57 | curve = ax.plot(xs, ys, "-*", color=method2color(method)) 58 | 59 | curves.append(curve[0]) 60 | legends.append(show_name(method)) 61 | 62 | x_max = max(x_max, *xs) 63 | y_max = max(y_max, *ys) 64 | 65 | ax.set_ylim(bottom=0, top=max(y_max * 1.05, 100)) 66 | ax.set_xlim(left=0.3, right=16) 67 | ax.set_ylabel("Goodput (%)") 68 | ax.set_xlabel("SLO scale (x)") 69 | ax.set_xscale("log") 70 | xticks = [0.3, 0.5, 1, 2, 4, 8, 16] 71 | ax.set_xticks(xticks) 72 | ax.set_xticklabels(xticks) 73 | ax.set_xticks([], minor=True) 74 | ax.legend(curves, legends) 75 | ax.set_title(title) 76 | 77 | #ax.axline([1, 99], [2, 99], color="gray", linestyle='--') 78 | 79 | if show: 80 | plt.show() 81 | 82 | fig.set_size_inches(figure_size) 83 | fig.savefig(output, bbox_inches='tight') 84 | print(f"Output the plot to {output}") 85 | 86 | 87 | if __name__ == "__main__": 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument("--input", type=str, default="res_goodput_vs_slo.tsv") 90 | parser.add_argument("--output", type=str, default="goodput_vs_slo.png") 91 | parser.add_argument("--show", action="store_true") 92 | args = parser.parse_args() 93 | 94 | data, params = read_data(args.input) 95 | title = ", ".join(f"{k} = {v}" for k, v in params.items()) 96 | plot_goodput_vs_slo(data, title, args.output, args.show) 97 | -------------------------------------------------------------------------------- /benchmarks/alpa/prepare_trace.py: -------------------------------------------------------------------------------- 1 | from alpa_serve.trace import preprocess_azure_v2_trace 2 | 3 | trace_dir = "/home/ubuntu/efs/mms/dataset/" 4 | preprocess_azure_v2_trace(trace_dir) -------------------------------------------------------------------------------- /benchmarks/alpa/simulate_one_case.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | from alpa_serve.simulator.controller import simulate_one_case, approximate_one_case 5 | from alpa_serve.simulator.workload import Workload 6 | 7 | from benchmarks.alpa.suite_debug import suite_debug 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--case", type=str, default="debug_replicate") 13 | parser.add_argument("--debug", action="store_true") 14 | parser.add_argument("--bench-speed", action="store_true") 15 | args = parser.parse_args() 16 | 17 | print("simulate_one_case") 18 | tic = time.time() 19 | stats, placement = simulate_one_case(suite_debug[args.case], debug=args.debug) 20 | print(f"time: {time.time() - tic:.4f}") 21 | Workload.print_stats(stats) 22 | print("") 23 | 24 | print("approximate_one_case") 25 | stats, placement = approximate_one_case(suite_debug[args.case], debug=args.debug) 26 | tic = time.time() 27 | stats, placement = approximate_one_case(suite_debug[args.case], debug=args.debug) 28 | print(f"time: {time.time() - tic:.4f}") 29 | Workload.print_stats(stats) 30 | 31 | if args.bench_speed: 32 | tic = time.time() 33 | stats, placement = approximate_one_case(suite_debug[args.case], debug=args.debug) 34 | print(f"time: {time.time() - tic:.4f}") 35 | -------------------------------------------------------------------------------- /benchmarks/alpa/util.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import logging 3 | 4 | from alpa_serve.simulator.executable import Executable 5 | 6 | from benchmarks.alpa.bert_model import BertModel, bert_specs 7 | 8 | 9 | def build_logger(): 10 | logger = logging.getLogger("benchmark") 11 | logger.setLevel(logging.INFO) 12 | logger.addHandler(logging.StreamHandler()) 13 | return logger 14 | 15 | 16 | def get_model_def(name, is_simulator, prof_database): 17 | result = prof_database.get(name) 18 | if result is None: 19 | raise ValueError(f"Invalid model name: {name}") 20 | 21 | if is_simulator: 22 | return partial(Executable, result) 23 | else: 24 | if name == "bert-1.3b": 25 | return partial(BertModel, bert_specs["1.3B"], result) 26 | elif name == "bert-2.6b": 27 | return partial(BertModel, bert_specs["2.6B"], result) 28 | elif name == "bert-6.7b": 29 | return partial(BertModel, bert_specs["6.7B"], result) 30 | elif name == "bert-103.5b": 31 | return partial(BertModel, bert_specs["103.5B"], result) 32 | else: 33 | raise ValueError(f"Invalid model name: {name}") 34 | -------------------------------------------------------------------------------- /deprecated/README.md: -------------------------------------------------------------------------------- 1 | # DLSIM 2 | 3 | ## Input 4 | 5 | Currently the simulator has no intelligence, you should prepare the [workload](./workload/README.md) and [model placement policy](./placements/README.md) first before running simulation. 6 | 7 | ## Usage 8 | 9 | ```txt 10 | usage: simulator.py [-h] --name NAME [-n NUM_NODES] [-d NUM_DEVICES_PER_NODE] [-c MEMORY_CAPACITY] -w WORKLOAD -p PLACEMENT [--chrome_trace] 11 | 12 | Cluster simulator for distributed DL inference tasks 13 | 14 | optional arguments: 15 | -h, --help show this help message and exit 16 | --name NAME simulation name 17 | -n NUM_NODES, --num_nodes NUM_NODES 18 | number of nodes in the cluster 19 | -d NUM_DEVICES_PER_NODE, --num_devices_per_node NUM_DEVICES_PER_NODE 20 | number of devices per node in the cluster 21 | -c MEMORY_CAPACITY, --memory_capacity MEMORY_CAPACITY 22 | GPU memory capacity in GB 23 | -w WORKLOAD, --workload WORKLOAD 24 | Workload Filename 25 | -p PLACEMENT, --placement PLACEMENT 26 | Placement Filename 27 | --chrome_trace Dump chrome trace 28 | ``` 29 | 30 | Example: `python3 simulator.py --name test_sim -n 1 -d 2 -c 16 -w skewed_workload -p placement_baseline.csv --chrome_trace` 31 | 32 | It will print the simulation results in the console. Also, there will be a `test_sim.json` under `chrome_trace` folder. You can open `chrome://tracing` in chrome and load this json file to see the simulation results. 33 | 34 | ## TODO 35 | 36 | - workload abstraction 37 | - [x] plot 38 | - [x] open-loop poisson generator 39 | - [x] workload save 40 | - [x] workload load 41 | - [ ] closed-loop workload generator 42 | - model abstraction 43 | - [x] executable 44 | - [x] model statistics (stage latencies) 45 | - [ ] memory statistics (params + activation) 46 | - [ ] more general to read in the profile csv directly 47 | - cluster abstraction 48 | - [x] device hierarchy 49 | - [x] submesh 50 | - [x] per-GPU task queue 51 | - [x] per-GPU clock 52 | - placement abstraction 53 | - [x] placement strategy save/load 54 | - [x] submesh <--> executable 55 | - [x] inter-op executor 56 | - [x] intra-op executor 57 | - simulation execution engine 58 | - [x] scheduler abstraction 59 | - [x] logging utilities 60 | - [x] tracing plotter 61 | - [ ] fix the ray overhead magic number 62 | - [ ] memory checker 63 | - check with real-execution trace 64 | - [x] baseline trace 65 | - [x] inter-op only trace 66 | - [x] intra-op only trace 67 | - [ ] inter-op + intra-op trace 68 | -------------------------------------------------------------------------------- /deprecated/alpasim/__init__.py: -------------------------------------------------------------------------------- 1 | from .cluster import * 2 | from .model import * 3 | from .scheduler import * 4 | from .simulator import * 5 | from .utils import * 6 | from .workload import * -------------------------------------------------------------------------------- /deprecated/alpasim/model.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | class Executable: 4 | def __init__(self, model_name: str, executable_name: str, stage_shapes: List[tuple], stage_latencies: float): 5 | """ 6 | @param executable_name: used in the model_configs to index 7 | @param stage_shapes: list of tuple which contains the shape of submesh for each stage, 8 | each tuple represents (# of nodes, # GPU per node) 9 | @param stage_latencies: latency for each stage 10 | """ 11 | assert len(stage_shapes) == len(stage_latencies) and len(stage_shapes) > 0 12 | self.model_name = model_name 13 | self.executable_name = executable_name 14 | self.stage_shapes = stage_shapes 15 | self.stage_latencies = stage_latencies 16 | self.num_stage = len(stage_shapes) 17 | 18 | def __str__(self): 19 | ret = f"{self.executable_name}:\n" 20 | for idx, (submesh, latency) in enumerate(zip(self.stage_shapes, self.stage_latencies)): 21 | ret += f"stage {idx}: {submesh} => {latency:.5f}s\n" 22 | return ret 23 | 24 | model_configs = { 25 | # Benchmarked on AWS P3 instances with Tesla V100 26 | "Bert_125M": { 27 | "Bert_125M_1_device": Executable("Bert_125M", "Bert_125M_1_device", [(1,1)], [0.01893]), 28 | "Bert_125M_2_device_intra": Executable("Bert_125M", "Bert_125M_2_device_intra", [(1,2)], [0.0139]), 29 | "Bert_125M_2_device_inter": Executable("Bert_125M", "Bert_125M_2_device_inter", [(1,1),(1,1)], [0.010, 0.010]), 30 | }, 31 | "Bert_2.6B": { 32 | "Bert_2.6B_1_device": Executable("Bert_2.6B", "Bert_2.6B_1_device", [(1,1)], [0.14520]), 33 | "Bert_2.6B_2_device_intra": Executable("Bert_2.6B", "Bert_2.6B_2_device_intra", [(1,2)], [0.0985]), 34 | "Bert_2.6B_2_device_inter": Executable("Bert_2.6B", "Bert_2.6B_2_device_inter", [(1,1),(1,1)], [0.0730, 0.0734]), 35 | "Bert_2.6B_4_device_inter": Executable("Bert_2.6B", "Bert_2.6B_4_device_inter", [(1,1),(1,1),(1,1),(1,1)], [0.037, 0.037, 0.037, 0.037]), 36 | } 37 | } 38 | 39 | 40 | if __name__ == '__main__': 41 | for model_name, execs in model_configs.items(): 42 | for exec in execs: 43 | print(exec) 44 | -------------------------------------------------------------------------------- /deprecated/alpasim/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from alpasim.workload import WorkLoad 4 | from alpasim.cluster import MeshExecutor, ScheduledTask 5 | 6 | 7 | # TODO (Zhong Yinmin): generalize the Scheduler and write an abstract class 8 | 9 | class FIFOScheduler: 10 | def __init__(self, workload: WorkLoad, meshexecutors: Dict[str, List[MeshExecutor]], model_id_to_service_name: Dict[int, str]): 11 | """ 12 | @param workload: the workload to be scheduled. 13 | @param meshexecutors: a map from service_name to a list of its corresponding mesh_executors 14 | @param model_id_to_service_name: Workload only contains model_id, so it is independent of the concrete models. 15 | model_id_to_service_name are used to connect the workload with the concrete models. 16 | """ 17 | self.workload = workload 18 | self.meshexecutors = meshexecutors 19 | self.model_id_to_service_name = model_id_to_service_name 20 | self.tasks = list(iter(self.workload)) 21 | self.scheduled_tasks = [] 22 | self.completed_tasks = [] 23 | 24 | def handle_event(self): 25 | task = self.tasks.pop(0) 26 | service_name = self.model_id_to_service_name[task.model_id] 27 | # choose the meshexecutor with shortest queue 28 | designated_meshexecutor = min(self.meshexecutors[service_name], key=lambda m: m.next_idle_time) 29 | scheduled_task = ScheduledTask(task.task_id, task.model_id, task.arrive_time, 30 | task.arrive_time, task.SLO, designated_meshexecutor) 31 | self.scheduled_tasks.append(scheduled_task) 32 | scheduled_task.start_execution() 33 | self.completed_tasks.append(scheduled_task) 34 | 35 | @property 36 | def next_event_time(self): 37 | if len(self.tasks) == 0: 38 | return float('inf') 39 | else: 40 | # FIFO scheduler schedules tasks in the order of arrival 41 | return self.tasks[0].arrive_time 42 | 43 | 44 | -------------------------------------------------------------------------------- /deprecated/alpasim/simulator.py: -------------------------------------------------------------------------------- 1 | from alpasim.cluster import Cluster 2 | 3 | class EventMonitor: 4 | """ 5 | All the entities must implement next_event_time and handle_event property in 6 | order to be monitored by the EventMonitor. 7 | """ 8 | def __init__(self, entities): 9 | for entity in entities: 10 | self.check_entity(entity) 11 | self.entities = entities 12 | 13 | def check_entity(self, entity): 14 | if getattr(entity, "next_event_time", None) is None: 15 | raise TypeError(f"{entity} does not have a next_event_time property") 16 | if getattr(entity, "handle_event", None) is None or not callable(entity.handle_event): 17 | raise TypeError(f"{entity} does not have a handle_event method") 18 | 19 | def add_entity(self, entity): 20 | self.check_entity(entity) 21 | self.entities.append(entity) 22 | 23 | def next_event_entity(self): 24 | """ 25 | Return the entity with smallest next_event_time. 26 | If all the entities have infinite next_event_time, return None. 27 | """ 28 | assert len(self.entities) > 0, "EventMonitor is empty" 29 | next_entity = self.entities[0] 30 | for entity in self.entities: 31 | if entity.next_event_time < next_entity.next_event_time: 32 | next_entity = entity 33 | if next_entity.next_event_time == float('inf'): 34 | return None 35 | return next_entity 36 | 37 | 38 | class Simulator: 39 | def __init__(self, scheduler, cluster: Cluster): 40 | self.monitor = EventMonitor(cluster.get_all_gpus()) 41 | self.monitor.add_entity(scheduler) 42 | 43 | def start(self): 44 | while True: 45 | entity = self.monitor.next_event_entity() 46 | if entity is None: 47 | break 48 | entity.handle_event() -------------------------------------------------------------------------------- /deprecated/alpasim/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import List 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | from alpasim.cluster import ScheduledTask 9 | 10 | def compute_statistics_from_cluster_trace(trace_file: str): 11 | with open(trace_file, "r") as f: 12 | traces = json.load(f) 13 | latencies = dict.fromkeys(traces.keys()) 14 | for model_id, trace in traces.items(): 15 | latencies[model_id] = [] 16 | for a, (_, e, _, _) in zip(trace["arrive"], trace["stage_exec_info"][-1]): 17 | latencies[model_id].append(e - a) 18 | latencies[model_id] = np.array(latencies[model_id]) 19 | overall_latencies = np.concatenate(list(latencies.values()), axis=None) 20 | print("-------------------------------") 21 | print(f"{trace_file} statistics:") 22 | print(f"overall mean latency: {np.mean(overall_latencies):.3f}s") 23 | print(f"overall 99% tail latency: {np.quantile(overall_latencies, 0.99):.3f}s") 24 | for model_id, l in latencies.items(): 25 | print(f"model{model_id} mean latency: {np.mean(l):.3f}s") 26 | print(f"model{model_id} 99% tail latency: {np.quantile(l, 0.99):.3f}s") 27 | print("-------------------------------") 28 | return latencies, overall_latencies 29 | 30 | def compute_statistics_from_simulation(tasks: List[ScheduledTask]): 31 | latencies = {} 32 | for t in tasks: 33 | _, e, _, _ = t.stage_execution_info[-1][0] 34 | if t.model_id not in latencies: 35 | latencies[t.model_id] = [] 36 | latencies[t.model_id].append(e - t.arrive_time) 37 | overall_latencies = np.concatenate(list(latencies.values()), axis=None) 38 | print("-------------------------------") 39 | print(f"simulation statistics:") 40 | print(f"overall mean latency: {np.mean(overall_latencies):.3f}s") 41 | print(f"overall 99% tail latency: {np.quantile(overall_latencies, 0.99):.3f}s") 42 | for model_id, l in latencies.items(): 43 | print(f"model{model_id} mean latency: {np.mean(l):.3f}s") 44 | print(f"model{model_id} 99% tail latency: {np.quantile(l, 0.99):.3f}s") 45 | return latencies, overall_latencies 46 | 47 | color_list = [ 48 | "thread_state_uninterruptible", 49 | "thread_state_iowait", 50 | "thread_state_running", 51 | "thread_state_runnable", 52 | "thread_state_unknown", 53 | "background_memory_dump", 54 | "light_memory_dump", 55 | "detailed_memory_dump", 56 | "vsync_highlight_color", 57 | "generic_work", 58 | "good", 59 | "bad", 60 | "terrible", 61 | "yellow", 62 | "olive", 63 | "rail_response", 64 | "rail_animation", 65 | "rail_idle", 66 | "rail_load", 67 | "startup", 68 | "heap_dump_stack_frame", 69 | "heap_dump_object_type", 70 | "heap_dump_child_node_arrow", 71 | "cq_build_running", 72 | "cq_build_passed", 73 | "cq_build_failed", 74 | "cq_build_attempt_runnig", 75 | "cq_build_attempt_passed", 76 | "cq_build_attempt_failed", 77 | ] 78 | 79 | def dump_chrome_tracing_from_simulation(tasks: List[ScheduledTask], filename: str): 80 | slot_list = [] 81 | def get_color(t): 82 | return color_list[t.request_id % len(color_list)] 83 | 84 | for t in tasks: 85 | for stage_num, stage_exec_info in enumerate(t.stage_execution_info): 86 | for intra_num, (start, end, node_id, gpu_id) in enumerate(stage_exec_info): 87 | slot = {"name": f"r{t.request_id}.s{stage_num}.{intra_num}", 88 | "cat": f"stage{stage_num}, intra{intra_num}", 89 | "ph": "X", 90 | "pid": node_id, 91 | "tid": gpu_id, 92 | "ts": start * 1e6, 93 | "dur": (end - start) * 1e6, 94 | "cname": get_color(t)} 95 | slot_list.append(slot) 96 | 97 | os.makedirs(os.path.dirname(filename), exist_ok=True) 98 | with open(filename, "w") as fout: 99 | fout.write(json.dumps({ 100 | "traceEvents": slot_list, 101 | "displayTimeUnit": "ms", 102 | })) 103 | 104 | 105 | def dump_chrome_tracing_from_cluster_trace(trace_file: str, dumpfile: str): 106 | def get_color(i): 107 | return color_list[i % len(color_list)] 108 | 109 | slot_list = [] 110 | with open(trace_file, "r") as f: 111 | traces = json.load(f) 112 | for model_id, trace in traces.items(): 113 | for i, stage_exec_info in enumerate(trace["stage_exec_info"]): 114 | for rq_id, (s, e, node_ids, devices) in zip(trace["rq_id"], stage_exec_info): 115 | for node_id, devices_per_node in zip(node_ids, devices): 116 | for device in devices_per_node: 117 | slot = {"name": f"r{rq_id}s{i}", 118 | "cat": f"model{model_id}", 119 | "ph": "X", 120 | "pid": node_id, 121 | "tid": device, 122 | "ts": float(s) * 1e6, 123 | "dur": float(e - s) * 1e6, 124 | "cname": get_color(rq_id)} 125 | slot_list.append(slot) 126 | 127 | os.makedirs(os.path.dirname(dumpfile), exist_ok=True) 128 | with open(dumpfile, "w") as fout: 129 | fout.write(json.dumps({ 130 | "traceEvents": slot_list, 131 | "displayTimeUnit": "ms", 132 | })) -------------------------------------------------------------------------------- /deprecated/azuretrace/README.md: -------------------------------------------------------------------------------- 1 | # Azure Functions Invocation Trace 2021 2 | 3 | ## Introduction 4 | This is a trace of function invocations in [Microsoft's Azure Functions](https://docs.microsoft.com/en-us/azure/azure-functions/functions-overview) for two weeks starting on 2021-01-31. This trace has been used in the SOSP 2021 paper [**Faster and Cheaper Serverless Computing on Harvested Resources**](https://www.microsoft.com/en-us/research/publication/faster-and-cheaper-serverless-computing-on-harvested-resources/). 5 | 6 | 7 | ## Using the Data 8 | 9 | ### License 10 | The data is made available and licensed under a [CC-BY Attribution License](https://github.com/Azure/AzurePublicDataset/blob/master/LICENSE). By downloading it or using them, you agree to the terms of this license. 11 | 12 | ### Attribution 13 | If you use this data for a publication or project, please cite the accompanying paper: 14 | 15 | > Yanqi Zhang, Íñigo Goiri, Gohar Irfan Chaudhry, Rodrigo Fonseca, Sameh Elnikety, Christina Delimitrou, Ricardo Bianchini. "[**Faster and Cheaper Serverless Computing on Harvested Resources**](https://www.microsoft.com/en-us/research/publication/faster-and-cheaper-serverless-computing-on-harvested-resources/)", in Proceedings of the ACM International Symposium on Operating Systems Principles (SOSP), October 2021. 16 | 17 | If you have any questions, comments, or concerns, or if you would like to share tools for working with the traces, please contact us at azurepublicdataset@service.microsoft.com 18 | 19 | ### Downloading 20 | You can download the dataset here: [**AzureFunctionsInvocationTraceForTwoWeeksJan2021.rar**](https://github.com/Azure/AzurePublicDataset/raw/master/data/AzureFunctionsInvocationTraceForTwoWeeksJan2021.rar). 21 | 22 | ## Schema and Description 23 | 24 | ### Schema 25 | 26 | - app: application id (encrypted) 27 | - func: function id (encrypted), and unique only within an application 28 | - end_timestamp: function invocation end timestamp (in seconds) 29 | - duration: duration of function invocation (in seconds) 30 | 31 | ### Remarks 32 | 33 | In Azure Functions, the unit of deployment is called an application, and an application has one or more functions. For example, an application could be a binary file with one or more entry points. A function invocation specifies both the app id and the func id withen the app. 34 | 35 | Invocation timestamps have been modified from those in the actual production trace. 36 | 37 | #### Sample 38 | 39 | |app|func|end_timestamp|duration| 40 | |--|--|--|--| 41 | |734272c01926d19690e5ec308bab64ef97950b75b1c7582283e0783fce1751d8|313c03f53a0d31f70aec25f62efb33e7dd779725ca4af579018452d1204beaad|5160.142570018768|0.134| 42 | |17c37a0fdd5d1932b755c0e6447137bc08fd524f455e14fdac414f584de08dc5|c9f8e30e36d1aef62c10b3cfca6e289a93848a148d876dd514753040314f4817|5161.280997037888|0.013| 43 | |7fa05b607ae861b85ec53cea12d3efaed8be0f9a92f5d6e8067244161d491e96|9bc86d6cd1ee254aaa313492f0fd88be8bd7b92d50d4237ff52d7685440c0906|5241.567729949951|42.356| 44 | |c8c43e1a911f29e5506460a2fbef61ff39723d672f3b3b67d12d4c236c6872f7|653cdbc309bc359f3289d3b4df21c4a8e478d22946b35cbfdab05377dcacd3e0|5253.883348941803|42.372| 45 | |db6be4a997f386b37c6246aaeecf81ab81562db84cf4c0d44907d9df2d0ab9fc|9040b71f8a0325ba418c85bcefa3b19c02c781bed6284af487d3f111f369534a|5219.518173933029|0.108| 46 | |f7bfe5bc8d2a37a5c15986fbfc2c477a746e866adcb9663f9df7535b61c3eb9b|34f4775366e51728635af48df1a96d332cf1565eee069a0030f12966ae760274|5220.1072909832|0.093| 47 | -------------------------------------------------------------------------------- /deprecated/cluster_traces/README.md: -------------------------------------------------------------------------------- 1 | # Cluster Trace 2 | 3 | These trace files are generated by running the experiments on physical cluster, and serve as the ground truth to check the accuracy of the simulator. 4 | 5 | ## Json File Format 6 | 7 | Example file: 8 | 9 | ```json 10 | { 11 | "0": { // model_id 12 | "rq_id": [1, 3], // request ids 13 | "e2e_latency": [0.1, 0.1], // end to end latency experienced by client (closed loop) 14 | "arrive": [0.1, 0.6], // request arrival timestamp 15 | "stage_exec_info": [ 16 | [ 17 | [start, end, node_ids, devices] 18 | ... 19 | ], // stage one 20 | [ 21 | [start, end, node_ids, devices] 22 | ... 23 | ], // stage two 24 | ] 25 | }, 26 | "1": { 27 | ... ... 28 | ... ... 29 | } 30 | } 31 | ``` 32 | -------------------------------------------------------------------------------- /deprecated/placements/README.md: -------------------------------------------------------------------------------- 1 | # Placement File 2 | 3 | Placement Files are saved/loaded in `alpasim/cluster.py`. 4 | 5 | Currently the files in this folder are crafted by hand. They defines the placement policy for the load-balance demo. 6 | 7 | Finally (not supported now) the placement files should be generated by the scheduling algorithm and fed into the simulator. 8 | 9 | ## Json File Format 10 | 11 | Each element in the List defines a meshexecutor (a wrapper around model executable and meshgroup). 12 | 13 | There are three key-value pairs in each element: 14 | 15 | - model_name: defined in model_configs in model.py 16 | - executable_name: defined in model_configs in model.py 17 | - mesh_group: a list of meshes which defines the placement of the model executable, each mesh contains: 18 | - node_ids: list of node ids 19 | - devices: list of device list, the length must equal the length of node_ids 20 | -------------------------------------------------------------------------------- /deprecated/placements/placement_125M_baseline.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_125M_0", 4 | "model_name": "Bert_125M", 5 | "executable_name": "Bert_125M_1_device", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 0 14 | ] 15 | ] 16 | } 17 | ] 18 | }, 19 | { 20 | "service_name": "Bert_125M_1", 21 | "model_name": "Bert_125M", 22 | "executable_name": "Bert_125M_1_device", 23 | "mesh_group": [ 24 | { 25 | "node_ids": [ 26 | 0 27 | ], 28 | "devices": [ 29 | [ 30 | 1 31 | ] 32 | ] 33 | } 34 | ] 35 | } 36 | ] -------------------------------------------------------------------------------- /deprecated/placements/placement_125M_interop.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_125M_0", 4 | "model_name": "Bert_125M", 5 | "executable_name": "Bert_125M_2_device_inter", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 1 14 | ] 15 | ] 16 | }, 17 | { 18 | "node_ids": [ 19 | 0 20 | ], 21 | "devices": [ 22 | [ 23 | 0 24 | ] 25 | ] 26 | } 27 | ] 28 | }, 29 | { 30 | "service_name": "Bert_125M_1", 31 | "model_name": "Bert_125M", 32 | "executable_name": "Bert_125M_2_device_inter", 33 | "mesh_group": [ 34 | { 35 | "node_ids": [ 36 | 0 37 | ], 38 | "devices": [ 39 | [ 40 | 1 41 | ] 42 | ] 43 | }, 44 | { 45 | "node_ids": [ 46 | 0 47 | ], 48 | "devices": [ 49 | [ 50 | 0 51 | ] 52 | ] 53 | } 54 | ] 55 | } 56 | ] -------------------------------------------------------------------------------- /deprecated/placements/placement_125M_intraop.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_125M_0", 4 | "model_name": "Bert_125M", 5 | "executable_name": "Bert_125M_2_device_intra", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 0, 1 14 | ] 15 | ] 16 | } 17 | ] 18 | }, 19 | { 20 | "service_name": "Bert_125M_1", 21 | "model_name": "Bert_125M", 22 | "executable_name": "Bert_125M_2_device_intra", 23 | "mesh_group": [ 24 | { 25 | "node_ids": [ 26 | 0 27 | ], 28 | "devices": [ 29 | [ 30 | 0, 1 31 | ] 32 | ] 33 | } 34 | ] 35 | } 36 | ] -------------------------------------------------------------------------------- /deprecated/placements/placement_125M_strong_baseline.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_125M_0", 4 | "model_name": "Bert_125M", 5 | "executable_name": "Bert_125M_1_device", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 0 14 | ] 15 | ] 16 | } 17 | ] 18 | }, 19 | { 20 | "service_name": "Bert_125M_0", 21 | "model_name": "Bert_125M", 22 | "executable_name": "Bert_125M_1_device", 23 | "mesh_group": [ 24 | { 25 | "node_ids": [ 26 | 0 27 | ], 28 | "devices": [ 29 | [ 30 | 1 31 | ] 32 | ] 33 | } 34 | ] 35 | }, 36 | { 37 | "service_name": "Bert_125M_1", 38 | "model_name": "Bert_125M", 39 | "executable_name": "Bert_125M_1_device", 40 | "mesh_group": [ 41 | { 42 | "node_ids": [ 43 | 0 44 | ], 45 | "devices": [ 46 | [ 47 | 0 48 | ] 49 | ] 50 | } 51 | ] 52 | }, 53 | { 54 | "service_name": "Bert_125M_1", 55 | "model_name": "Bert_125M", 56 | "executable_name": "Bert_125M_1_device", 57 | "mesh_group": [ 58 | { 59 | "node_ids": [ 60 | 0 61 | ], 62 | "devices": [ 63 | [ 64 | 1 65 | ] 66 | ] 67 | } 68 | ] 69 | } 70 | ] -------------------------------------------------------------------------------- /deprecated/placements/placement_baseline.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_2.6B_0", 4 | "model_name": "Bert_2.6B", 5 | "executable_name": "Bert_2.6B_1_device", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 0 14 | ] 15 | ] 16 | } 17 | ] 18 | }, 19 | { 20 | "service_name": "Bert_2.6B_1", 21 | "model_name": "Bert_2.6B", 22 | "executable_name": "Bert_2.6B_1_device", 23 | "mesh_group": [ 24 | { 25 | "node_ids": [ 26 | 0 27 | ], 28 | "devices": [ 29 | [ 30 | 1 31 | ] 32 | ] 33 | } 34 | ] 35 | } 36 | ] -------------------------------------------------------------------------------- /deprecated/placements/placement_interop.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_2.6B_0", 4 | "model_name": "Bert_2.6B", 5 | "executable_name": "Bert_2.6B_2_device_inter", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 1 14 | ] 15 | ] 16 | }, 17 | { 18 | "node_ids": [ 19 | 0 20 | ], 21 | "devices": [ 22 | [ 23 | 0 24 | ] 25 | ] 26 | } 27 | ] 28 | }, 29 | { 30 | "service_name": "Bert_2.6B_1", 31 | "model_name": "Bert_2.6B", 32 | "executable_name": "Bert_2.6B_2_device_inter", 33 | "mesh_group": [ 34 | { 35 | "node_ids": [ 36 | 0 37 | ], 38 | "devices": [ 39 | [ 40 | 1 41 | ] 42 | ] 43 | }, 44 | { 45 | "node_ids": [ 46 | 0 47 | ], 48 | "devices": [ 49 | [ 50 | 0 51 | ] 52 | ] 53 | } 54 | ] 55 | } 56 | ] -------------------------------------------------------------------------------- /deprecated/placements/placement_intraop.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_2.6B_0", 4 | "model_name": "Bert_2.6B", 5 | "executable_name": "Bert_2.6B_2_device_intra", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 0, 1 14 | ] 15 | ] 16 | } 17 | ] 18 | }, 19 | { 20 | "service_name": "Bert_2.6B_1", 21 | "model_name": "Bert_2.6B", 22 | "executable_name": "Bert_2.6B_2_device_intra", 23 | "mesh_group": [ 24 | { 25 | "node_ids": [ 26 | 0 27 | ], 28 | "devices": [ 29 | [ 30 | 0, 1 31 | ] 32 | ] 33 | } 34 | ] 35 | } 36 | ] -------------------------------------------------------------------------------- /deprecated/placements/placement_test.json: -------------------------------------------------------------------------------- 1 | [{"service_name": "Bert_2.6B_0", "model_name": "Bert_2.6B", "executable_name": "Bert_2.6B_1_device", "mesh_group": [{"node_ids": [0], "devices": [[0]]}]}, {"service_name": "Bert_2.6B_1", "model_name": "Bert_2.6B", "executable_name": "Bert_2.6B_1_device", "mesh_group": [{"node_ids": [0], "devices": [[1]]}]}] -------------------------------------------------------------------------------- /deprecated/scripts/memory_saving/placements/placement_baseline_2GPUs.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_2.6B_0", 4 | "model_name": "Bert_2.6B", 5 | "executable_name": "Bert_2.6B_1_device", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 0 14 | ] 15 | ] 16 | } 17 | ] 18 | }, 19 | { 20 | "service_name": "Bert_2.6B_1", 21 | "model_name": "Bert_2.6B", 22 | "executable_name": "Bert_2.6B_1_device", 23 | "mesh_group": [ 24 | { 25 | "node_ids": [ 26 | 0 27 | ], 28 | "devices": [ 29 | [ 30 | 1 31 | ] 32 | ] 33 | } 34 | ] 35 | } 36 | ] -------------------------------------------------------------------------------- /deprecated/scripts/memory_saving/placements/placement_baseline_4GPUs_memx1.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_2.6B_0", 4 | "model_name": "Bert_2.6B", 5 | "executable_name": "Bert_2.6B_1_device", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 0 14 | ] 15 | ] 16 | } 17 | ] 18 | }, 19 | { 20 | "service_name": "Bert_2.6B_1", 21 | "model_name": "Bert_2.6B", 22 | "executable_name": "Bert_2.6B_1_device", 23 | "mesh_group": [ 24 | { 25 | "node_ids": [ 26 | 0 27 | ], 28 | "devices": [ 29 | [ 30 | 1 31 | ] 32 | ] 33 | } 34 | ] 35 | }, 36 | { 37 | "service_name": "Bert_2.6B_2", 38 | "model_name": "Bert_2.6B", 39 | "executable_name": "Bert_2.6B_1_device", 40 | "mesh_group": [ 41 | { 42 | "node_ids": [ 43 | 0 44 | ], 45 | "devices": [ 46 | [ 47 | 2 48 | ] 49 | ] 50 | } 51 | ] 52 | }, 53 | { 54 | "service_name": "Bert_2.6B_3", 55 | "model_name": "Bert_2.6B", 56 | "executable_name": "Bert_2.6B_1_device", 57 | "mesh_group": [ 58 | { 59 | "node_ids": [ 60 | 0 61 | ], 62 | "devices": [ 63 | [ 64 | 3 65 | ] 66 | ] 67 | } 68 | ] 69 | } 70 | ] -------------------------------------------------------------------------------- /deprecated/scripts/memory_saving/placements/placement_baseline_4GPUs_memx2.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_2.6B_0", 4 | "model_name": "Bert_2.6B", 5 | "executable_name": "Bert_2.6B_1_device", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 0 14 | ] 15 | ] 16 | } 17 | ] 18 | }, 19 | { 20 | "service_name": "Bert_2.6B_1", 21 | "model_name": "Bert_2.6B", 22 | "executable_name": "Bert_2.6B_1_device", 23 | "mesh_group": [ 24 | { 25 | "node_ids": [ 26 | 0 27 | ], 28 | "devices": [ 29 | [ 30 | 0 31 | ] 32 | ] 33 | } 34 | ] 35 | }, 36 | { 37 | "service_name": "Bert_2.6B_0", 38 | "model_name": "Bert_2.6B", 39 | "executable_name": "Bert_2.6B_1_device", 40 | "mesh_group": [ 41 | { 42 | "node_ids": [ 43 | 0 44 | ], 45 | "devices": [ 46 | [ 47 | 1 48 | ] 49 | ] 50 | } 51 | ] 52 | }, 53 | { 54 | "service_name": "Bert_2.6B_1", 55 | "model_name": "Bert_2.6B", 56 | "executable_name": "Bert_2.6B_1_device", 57 | "mesh_group": [ 58 | { 59 | "node_ids": [ 60 | 0 61 | ], 62 | "devices": [ 63 | [ 64 | 1 65 | ] 66 | ] 67 | } 68 | ] 69 | }, 70 | { 71 | "service_name": "Bert_2.6B_2", 72 | "model_name": "Bert_2.6B", 73 | "executable_name": "Bert_2.6B_1_device", 74 | "mesh_group": [ 75 | { 76 | "node_ids": [ 77 | 0 78 | ], 79 | "devices": [ 80 | [ 81 | 2 82 | ] 83 | ] 84 | } 85 | ] 86 | }, 87 | { 88 | "service_name": "Bert_2.6B_3", 89 | "model_name": "Bert_2.6B", 90 | "executable_name": "Bert_2.6B_1_device", 91 | "mesh_group": [ 92 | { 93 | "node_ids": [ 94 | 0 95 | ], 96 | "devices": [ 97 | [ 98 | 2 99 | ] 100 | ] 101 | } 102 | ] 103 | }, 104 | { 105 | "service_name": "Bert_2.6B_2", 106 | "model_name": "Bert_2.6B", 107 | "executable_name": "Bert_2.6B_1_device", 108 | "mesh_group": [ 109 | { 110 | "node_ids": [ 111 | 0 112 | ], 113 | "devices": [ 114 | [ 115 | 3 116 | ] 117 | ] 118 | } 119 | ] 120 | }, 121 | { 122 | "service_name": "Bert_2.6B_3", 123 | "model_name": "Bert_2.6B", 124 | "executable_name": "Bert_2.6B_1_device", 125 | "mesh_group": [ 126 | { 127 | "node_ids": [ 128 | 0 129 | ], 130 | "devices": [ 131 | [ 132 | 3 133 | ] 134 | ] 135 | } 136 | ] 137 | } 138 | ] -------------------------------------------------------------------------------- /deprecated/scripts/memory_saving/placements/placement_baseline_4GPUs_memx2_3to1.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_2.6B_0", 4 | "model_name": "Bert_2.6B", 5 | "executable_name": "Bert_2.6B_1_device", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 0 14 | ] 15 | ] 16 | } 17 | ] 18 | }, 19 | { 20 | "service_name": "Bert_2.6B_1", 21 | "model_name": "Bert_2.6B", 22 | "executable_name": "Bert_2.6B_1_device", 23 | "mesh_group": [ 24 | { 25 | "node_ids": [ 26 | 0 27 | ], 28 | "devices": [ 29 | [ 30 | 0 31 | ] 32 | ] 33 | } 34 | ] 35 | }, 36 | { 37 | "service_name": "Bert_2.6B_0", 38 | "model_name": "Bert_2.6B", 39 | "executable_name": "Bert_2.6B_1_device", 40 | "mesh_group": [ 41 | { 42 | "node_ids": [ 43 | 0 44 | ], 45 | "devices": [ 46 | [ 47 | 1 48 | ] 49 | ] 50 | } 51 | ] 52 | }, 53 | { 54 | "service_name": "Bert_2.6B_1", 55 | "model_name": "Bert_2.6B", 56 | "executable_name": "Bert_2.6B_1_device", 57 | "mesh_group": [ 58 | { 59 | "node_ids": [ 60 | 0 61 | ], 62 | "devices": [ 63 | [ 64 | 1 65 | ] 66 | ] 67 | } 68 | ] 69 | }, 70 | { 71 | "service_name": "Bert_2.6B_0", 72 | "model_name": "Bert_2.6B", 73 | "executable_name": "Bert_2.6B_1_device", 74 | "mesh_group": [ 75 | { 76 | "node_ids": [ 77 | 0 78 | ], 79 | "devices": [ 80 | [ 81 | 2 82 | ] 83 | ] 84 | } 85 | ] 86 | }, 87 | { 88 | "service_name": "Bert_2.6B_2", 89 | "model_name": "Bert_2.6B", 90 | "executable_name": "Bert_2.6B_1_device", 91 | "mesh_group": [ 92 | { 93 | "node_ids": [ 94 | 0 95 | ], 96 | "devices": [ 97 | [ 98 | 2 99 | ] 100 | ] 101 | } 102 | ] 103 | }, 104 | { 105 | "service_name": "Bert_2.6B_1", 106 | "model_name": "Bert_2.6B", 107 | "executable_name": "Bert_2.6B_1_device", 108 | "mesh_group": [ 109 | { 110 | "node_ids": [ 111 | 0 112 | ], 113 | "devices": [ 114 | [ 115 | 3 116 | ] 117 | ] 118 | } 119 | ] 120 | }, 121 | { 122 | "service_name": "Bert_2.6B_3", 123 | "model_name": "Bert_2.6B", 124 | "executable_name": "Bert_2.6B_1_device", 125 | "mesh_group": [ 126 | { 127 | "node_ids": [ 128 | 0 129 | ], 130 | "devices": [ 131 | [ 132 | 3 133 | ] 134 | ] 135 | } 136 | ] 137 | } 138 | ] -------------------------------------------------------------------------------- /deprecated/scripts/memory_saving/placements/placement_baseline_4GPUs_memx3_3to1.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_2.6B_0", 4 | "model_name": "Bert_2.6B", 5 | "executable_name": "Bert_2.6B_1_device", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 0 14 | ] 15 | ] 16 | } 17 | ] 18 | }, 19 | { 20 | "service_name": "Bert_2.6B_1", 21 | "model_name": "Bert_2.6B", 22 | "executable_name": "Bert_2.6B_1_device", 23 | "mesh_group": [ 24 | { 25 | "node_ids": [ 26 | 0 27 | ], 28 | "devices": [ 29 | [ 30 | 0 31 | ] 32 | ] 33 | } 34 | ] 35 | }, 36 | { 37 | "service_name": "Bert_2.6B_2", 38 | "model_name": "Bert_2.6B", 39 | "executable_name": "Bert_2.6B_1_device", 40 | "mesh_group": [ 41 | { 42 | "node_ids": [ 43 | 0 44 | ], 45 | "devices": [ 46 | [ 47 | 0 48 | ] 49 | ] 50 | } 51 | ] 52 | }, 53 | { 54 | "service_name": "Bert_2.6B_0", 55 | "model_name": "Bert_2.6B", 56 | "executable_name": "Bert_2.6B_1_device", 57 | "mesh_group": [ 58 | { 59 | "node_ids": [ 60 | 0 61 | ], 62 | "devices": [ 63 | [ 64 | 1 65 | ] 66 | ] 67 | } 68 | ] 69 | }, 70 | { 71 | "service_name": "Bert_2.6B_1", 72 | "model_name": "Bert_2.6B", 73 | "executable_name": "Bert_2.6B_1_device", 74 | "mesh_group": [ 75 | { 76 | "node_ids": [ 77 | 0 78 | ], 79 | "devices": [ 80 | [ 81 | 1 82 | ] 83 | ] 84 | } 85 | ] 86 | }, 87 | { 88 | "service_name": "Bert_2.6B_2", 89 | "model_name": "Bert_2.6B", 90 | "executable_name": "Bert_2.6B_1_device", 91 | "mesh_group": [ 92 | { 93 | "node_ids": [ 94 | 0 95 | ], 96 | "devices": [ 97 | [ 98 | 1 99 | ] 100 | ] 101 | } 102 | ] 103 | }, 104 | { 105 | "service_name": "Bert_2.6B_0", 106 | "model_name": "Bert_2.6B", 107 | "executable_name": "Bert_2.6B_1_device", 108 | "mesh_group": [ 109 | { 110 | "node_ids": [ 111 | 0 112 | ], 113 | "devices": [ 114 | [ 115 | 2 116 | ] 117 | ] 118 | } 119 | ] 120 | }, 121 | { 122 | "service_name": "Bert_2.6B_1", 123 | "model_name": "Bert_2.6B", 124 | "executable_name": "Bert_2.6B_1_device", 125 | "mesh_group": [ 126 | { 127 | "node_ids": [ 128 | 0 129 | ], 130 | "devices": [ 131 | [ 132 | 2 133 | ] 134 | ] 135 | } 136 | ] 137 | }, 138 | { 139 | "service_name": "Bert_2.6B_3", 140 | "model_name": "Bert_2.6B", 141 | "executable_name": "Bert_2.6B_1_device", 142 | "mesh_group": [ 143 | { 144 | "node_ids": [ 145 | 0 146 | ], 147 | "devices": [ 148 | [ 149 | 2 150 | ] 151 | ] 152 | } 153 | ] 154 | }, 155 | { 156 | "service_name": "Bert_2.6B_0", 157 | "model_name": "Bert_2.6B", 158 | "executable_name": "Bert_2.6B_1_device", 159 | "mesh_group": [ 160 | { 161 | "node_ids": [ 162 | 0 163 | ], 164 | "devices": [ 165 | [ 166 | 3 167 | ] 168 | ] 169 | } 170 | ] 171 | }, 172 | { 173 | "service_name": "Bert_2.6B_1", 174 | "model_name": "Bert_2.6B", 175 | "executable_name": "Bert_2.6B_1_device", 176 | "mesh_group": [ 177 | { 178 | "node_ids": [ 179 | 0 180 | ], 181 | "devices": [ 182 | [ 183 | 3 184 | ] 185 | ] 186 | } 187 | ] 188 | }, 189 | { 190 | "service_name": "Bert_2.6B_3", 191 | "model_name": "Bert_2.6B", 192 | "executable_name": "Bert_2.6B_1_device", 193 | "mesh_group": [ 194 | { 195 | "node_ids": [ 196 | 0 197 | ], 198 | "devices": [ 199 | [ 200 | 3 201 | ] 202 | ] 203 | } 204 | ] 205 | } 206 | ] -------------------------------------------------------------------------------- /deprecated/scripts/memory_saving/placements/placement_pipeline_2GPUs.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_2.6B_0", 4 | "model_name": "Bert_2.6B", 5 | "executable_name": "Bert_2.6B_2_device_inter", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 1 14 | ] 15 | ] 16 | }, 17 | { 18 | "node_ids": [ 19 | 0 20 | ], 21 | "devices": [ 22 | [ 23 | 0 24 | ] 25 | ] 26 | } 27 | ] 28 | }, 29 | { 30 | "service_name": "Bert_2.6B_1", 31 | "model_name": "Bert_2.6B", 32 | "executable_name": "Bert_2.6B_2_device_inter", 33 | "mesh_group": [ 34 | { 35 | "node_ids": [ 36 | 0 37 | ], 38 | "devices": [ 39 | [ 40 | 1 41 | ] 42 | ] 43 | }, 44 | { 45 | "node_ids": [ 46 | 0 47 | ], 48 | "devices": [ 49 | [ 50 | 0 51 | ] 52 | ] 53 | } 54 | ] 55 | } 56 | ] -------------------------------------------------------------------------------- /deprecated/scripts/memory_saving/placements/placement_pipeline_4GPUs_memx1.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_2.6B_0", 4 | "model_name": "Bert_2.6B", 5 | "executable_name": "Bert_2.6B_4_device_inter", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 0 14 | ] 15 | ] 16 | }, 17 | { 18 | "node_ids": [ 19 | 0 20 | ], 21 | "devices": [ 22 | [ 23 | 1 24 | ] 25 | ] 26 | }, 27 | { 28 | "node_ids": [ 29 | 0 30 | ], 31 | "devices": [ 32 | [ 33 | 2 34 | ] 35 | ] 36 | }, 37 | { 38 | "node_ids": [ 39 | 0 40 | ], 41 | "devices": [ 42 | [ 43 | 3 44 | ] 45 | ] 46 | } 47 | ] 48 | }, 49 | { 50 | "service_name": "Bert_2.6B_1", 51 | "model_name": "Bert_2.6B", 52 | "executable_name": "Bert_2.6B_4_device_inter", 53 | "mesh_group": [ 54 | { 55 | "node_ids": [ 56 | 0 57 | ], 58 | "devices": [ 59 | [ 60 | 0 61 | ] 62 | ] 63 | }, 64 | { 65 | "node_ids": [ 66 | 0 67 | ], 68 | "devices": [ 69 | [ 70 | 1 71 | ] 72 | ] 73 | }, 74 | { 75 | "node_ids": [ 76 | 0 77 | ], 78 | "devices": [ 79 | [ 80 | 2 81 | ] 82 | ] 83 | }, 84 | { 85 | "node_ids": [ 86 | 0 87 | ], 88 | "devices": [ 89 | [ 90 | 3 91 | ] 92 | ] 93 | } 94 | ] 95 | }, 96 | { 97 | "service_name": "Bert_2.6B_2", 98 | "model_name": "Bert_2.6B", 99 | "executable_name": "Bert_2.6B_4_device_inter", 100 | "mesh_group": [ 101 | { 102 | "node_ids": [ 103 | 0 104 | ], 105 | "devices": [ 106 | [ 107 | 0 108 | ] 109 | ] 110 | }, 111 | { 112 | "node_ids": [ 113 | 0 114 | ], 115 | "devices": [ 116 | [ 117 | 1 118 | ] 119 | ] 120 | }, 121 | { 122 | "node_ids": [ 123 | 0 124 | ], 125 | "devices": [ 126 | [ 127 | 2 128 | ] 129 | ] 130 | }, 131 | { 132 | "node_ids": [ 133 | 0 134 | ], 135 | "devices": [ 136 | [ 137 | 3 138 | ] 139 | ] 140 | } 141 | ] 142 | }, 143 | { 144 | "service_name": "Bert_2.6B_3", 145 | "model_name": "Bert_2.6B", 146 | "executable_name": "Bert_2.6B_4_device_inter", 147 | "mesh_group": [ 148 | { 149 | "node_ids": [ 150 | 0 151 | ], 152 | "devices": [ 153 | [ 154 | 0 155 | ] 156 | ] 157 | }, 158 | { 159 | "node_ids": [ 160 | 0 161 | ], 162 | "devices": [ 163 | [ 164 | 1 165 | ] 166 | ] 167 | }, 168 | { 169 | "node_ids": [ 170 | 0 171 | ], 172 | "devices": [ 173 | [ 174 | 2 175 | ] 176 | ] 177 | }, 178 | { 179 | "node_ids": [ 180 | 0 181 | ], 182 | "devices": [ 183 | [ 184 | 3 185 | ] 186 | ] 187 | } 188 | ] 189 | } 190 | ] -------------------------------------------------------------------------------- /deprecated/scripts/memory_saving/placements/placement_pipeline_4GPUs_memx1dot5_3to1.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_2.6B_0", 4 | "model_name": "Bert_2.6B", 5 | "executable_name": "Bert_2.6B_2_device_inter", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 0 14 | ] 15 | ] 16 | }, 17 | { 18 | "node_ids": [ 19 | 0 20 | ], 21 | "devices": [ 22 | [ 23 | 1 24 | ] 25 | ] 26 | } 27 | ] 28 | }, 29 | { 30 | "service_name": "Bert_2.6B_0", 31 | "model_name": "Bert_2.6B", 32 | "executable_name": "Bert_2.6B_2_device_inter", 33 | "mesh_group": [ 34 | { 35 | "node_ids": [ 36 | 0 37 | ], 38 | "devices": [ 39 | [ 40 | 2 41 | ] 42 | ] 43 | }, 44 | { 45 | "node_ids": [ 46 | 0 47 | ], 48 | "devices": [ 49 | [ 50 | 3 51 | ] 52 | ] 53 | } 54 | ] 55 | }, 56 | { 57 | "service_name": "Bert_2.6B_1", 58 | "model_name": "Bert_2.6B", 59 | "executable_name": "Bert_2.6B_2_device_inter", 60 | "mesh_group": [ 61 | { 62 | "node_ids": [ 63 | 0 64 | ], 65 | "devices": [ 66 | [ 67 | 0 68 | ] 69 | ] 70 | }, 71 | { 72 | "node_ids": [ 73 | 0 74 | ], 75 | "devices": [ 76 | [ 77 | 1 78 | ] 79 | ] 80 | } 81 | ] 82 | }, 83 | { 84 | "service_name": "Bert_2.6B_1", 85 | "model_name": "Bert_2.6B", 86 | "executable_name": "Bert_2.6B_2_device_inter", 87 | "mesh_group": [ 88 | { 89 | "node_ids": [ 90 | 0 91 | ], 92 | "devices": [ 93 | [ 94 | 2 95 | ] 96 | ] 97 | }, 98 | { 99 | "node_ids": [ 100 | 0 101 | ], 102 | "devices": [ 103 | [ 104 | 3 105 | ] 106 | ] 107 | } 108 | ] 109 | }, 110 | { 111 | "service_name": "Bert_2.6B_2", 112 | "model_name": "Bert_2.6B", 113 | "executable_name": "Bert_2.6B_4_device_inter", 114 | "mesh_group": [ 115 | { 116 | "node_ids": [ 117 | 0 118 | ], 119 | "devices": [ 120 | [ 121 | 0 122 | ] 123 | ] 124 | }, 125 | { 126 | "node_ids": [ 127 | 0 128 | ], 129 | "devices": [ 130 | [ 131 | 1 132 | ] 133 | ] 134 | }, 135 | { 136 | "node_ids": [ 137 | 0 138 | ], 139 | "devices": [ 140 | [ 141 | 2 142 | ] 143 | ] 144 | }, 145 | { 146 | "node_ids": [ 147 | 0 148 | ], 149 | "devices": [ 150 | [ 151 | 3 152 | ] 153 | ] 154 | } 155 | ] 156 | }, 157 | { 158 | "service_name": "Bert_2.6B_3", 159 | "model_name": "Bert_2.6B", 160 | "executable_name": "Bert_2.6B_4_device_inter", 161 | "mesh_group": [ 162 | { 163 | "node_ids": [ 164 | 0 165 | ], 166 | "devices": [ 167 | [ 168 | 0 169 | ] 170 | ] 171 | }, 172 | { 173 | "node_ids": [ 174 | 0 175 | ], 176 | "devices": [ 177 | [ 178 | 1 179 | ] 180 | ] 181 | }, 182 | { 183 | "node_ids": [ 184 | 0 185 | ], 186 | "devices": [ 187 | [ 188 | 2 189 | ] 190 | ] 191 | }, 192 | { 193 | "node_ids": [ 194 | 0 195 | ], 196 | "devices": [ 197 | [ 198 | 3 199 | ] 200 | ] 201 | } 202 | ] 203 | } 204 | ] -------------------------------------------------------------------------------- /deprecated/scripts/memory_saving/placements/placement_strong_2GPUs.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "service_name": "Bert_2.6B_0", 4 | "model_name": "Bert_2.6B", 5 | "executable_name": "Bert_2.6B_1_device", 6 | "mesh_group": [ 7 | { 8 | "node_ids": [ 9 | 0 10 | ], 11 | "devices": [ 12 | [ 13 | 0 14 | ] 15 | ] 16 | } 17 | ] 18 | }, 19 | { 20 | "service_name": "Bert_2.6B_0", 21 | "model_name": "Bert_2.6B", 22 | "executable_name": "Bert_2.6B_1_device", 23 | "mesh_group": [ 24 | { 25 | "node_ids": [ 26 | 0 27 | ], 28 | "devices": [ 29 | [ 30 | 1 31 | ] 32 | ] 33 | } 34 | ] 35 | }, 36 | { 37 | "service_name": "Bert_2.6B_1", 38 | "model_name": "Bert_2.6B", 39 | "executable_name": "Bert_2.6B_1_device", 40 | "mesh_group": [ 41 | { 42 | "node_ids": [ 43 | 0 44 | ], 45 | "devices": [ 46 | [ 47 | 0 48 | ] 49 | ] 50 | } 51 | ] 52 | }, 53 | { 54 | "service_name": "Bert_2.6B_1", 55 | "model_name": "Bert_2.6B", 56 | "executable_name": "Bert_2.6B_1_device", 57 | "mesh_group": [ 58 | { 59 | "node_ids": [ 60 | 0 61 | ], 62 | "devices": [ 63 | [ 64 | 1 65 | ] 66 | ] 67 | } 68 | ] 69 | } 70 | ] -------------------------------------------------------------------------------- /deprecated/scripts/pipeline_latency/2.6B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/mms/dba47b18e95f037aadc8aff336e2e7e337010495/deprecated/scripts/pipeline_latency/2.6B.png -------------------------------------------------------------------------------- /deprecated/scripts/pipeline_latency/6.7B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/mms/dba47b18e95f037aadc8aff336e2e7e337010495/deprecated/scripts/pipeline_latency/6.7B.png -------------------------------------------------------------------------------- /deprecated/scripts/pipeline_latency/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from pprint import pprint 3 | 4 | gpu_nums = [1, 2, 4, 8, 16, 32] 5 | latencies = { 6 | "2.6B": { 7 | 1: [0.145, 0.145, 0.144, 0.145, 0.146, 0.154], 8 | 2: [0.187, 0.184, 0.182, 0.185, 0.197, 0.213], 9 | 4: [0.368, 0.363, 0.362, 0.373, 0.383, 0.426], 10 | 8: [0.736, 0.733, 0.733, 0.745, 0.78, 0.85], 11 | 16: [1.433, 1.433, 1.444, 1.447, 1.542, 1.65] 12 | }, 13 | "6.7B": { 14 | 1: [0.233, 0.233, 0.232, 0.234, 0.243, 0.253], 15 | 2: [0.349, 0.35, 0.35, 0.353, 0.369, 0.411], 16 | 4: [0.7, 0.694, 0.699, 0.705, 0.739, 0.798], 17 | 8: [1.383, 1.386, 1.39, 1.397, 1.457, 1.558], 18 | 16: [2.823, 2.823, 2.833, 2.852, 2.959, 3.18] 19 | } 20 | } 21 | 22 | # S H L head V 23 | # 2.6B 1024 2560 32 32 51200 24 | # 6.7B 1024 4096 32 32 51200 25 | 26 | CROSS_NODE_RAY_OVERHEAD = 0.010 27 | cross_node_commu_overhead = { 28 | "2.6B": 0.0015625, # 1024 * 2560 * 2 / 1024 / 1024 / (25 / 8 * 1024) # 25Gbps 29 | "6.7B": 0.0025, # 1024 * 4096 * 2 / 1024 / 1024 / (25 / 8 * 1024) # 25Gbps 30 | } 31 | 32 | in_node_commu_overhead = { 33 | # "2.6B": 0.00009766, # 1024 * 2560 * 2 / 1024 / 1024 / (50 * 1024) # 50GB/s 34 | "2.6B": 0.0003, # profiling result 35 | # "6.7B": 0.00015625, # 1024 * 4096 * 2 / 1024 / 1024 / (50 * 1024) # 50GB/s 36 | "6.7B": 0.0004, # profiling result 37 | } 38 | 39 | 40 | predicted_latencies = { 41 | "2.6B": { 42 | 1: [0.145], 2: [0.187], 4: [0.366], 8: [0.736], 16: [1.433] 43 | }, 44 | "6.7B": { 45 | 1: [0.231], 2: [0.349], 4: [0.7], 8: [1.383], 16: [2.823] 46 | } 47 | } 48 | 49 | for model_size in predicted_latencies: 50 | for bs in [1, 2, 4, 8, 16]: 51 | single = predicted_latencies[model_size][bs][0] 52 | predicted_latencies[model_size][bs].append(single + in_node_commu_overhead[model_size] * bs) 53 | predicted_latencies[model_size][bs].append(single + in_node_commu_overhead[model_size] * 3 * bs) 54 | predicted_latencies[model_size][bs].append(single + in_node_commu_overhead[model_size] * 7 * bs) 55 | predicted_latencies[model_size][bs].append(single + (in_node_commu_overhead[model_size] * 14 56 | + cross_node_commu_overhead[model_size]) * bs) 57 | predicted_latencies[model_size][bs].append(single + (in_node_commu_overhead[model_size] * 28 58 | + cross_node_commu_overhead[model_size] * 3) * bs) 59 | 60 | pprint(predicted_latencies) 61 | 62 | def plot(model_size): 63 | bs_config = [1, 2, 4, 8, 16] 64 | plt.figure() 65 | plt.title("E2E Latency v.s. #Pipeline Stages") 66 | plt.xlabel("#GPU (pipeline stages)") 67 | plt.ylabel("E2E Latency (s)") 68 | plt.xticks(gpu_nums) 69 | for bs in reversed(bs_config): 70 | plt.plot(gpu_nums, latencies[model_size][bs], "-o", markersize=5, label=f"BS={bs}, real") 71 | plt.plot(gpu_nums, predicted_latencies[model_size][bs], "--o", markersize=5, label=f"BS={bs}, calculated") 72 | plt.legend() 73 | plt.savefig(f"{model_size}.png") 74 | 75 | plot("2.6B") 76 | plot("6.7B") -------------------------------------------------------------------------------- /deprecated/scripts/small_model_benchmark/strong_baseline.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("./") 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | from alpasim.cluster import Cluster, Mesh, load_meshexecutors, save_meshexecutors 7 | from alpasim.scheduler import FIFOScheduler 8 | from alpasim.simulator import Simulator 9 | from alpasim.workload import generate_workload, PossoinWorkLoad 10 | from alpasim.utils import compute_statistics_from_cluster_trace, compute_statistics_from_simulation, \ 11 | dump_chrome_tracing_from_simulation, dump_chrome_tracing_from_cluster_trace 12 | 13 | def plot_cdf(latencies, overall_latencies, is_baseline): 14 | model0_latencies = latencies[0] 15 | model1_latencies = latencies[1] 16 | # sort data 17 | x1, x2, x = np.sort(model0_latencies), np.sort(model1_latencies), np.sort(overall_latencies) 18 | # calculate CDF values 19 | y1, y2, y = 1. * np.arange(len(model0_latencies)) / (len(model0_latencies) - 1), \ 20 | 1. * np.arange(len(model1_latencies)) / (len(model1_latencies) - 1), \ 21 | 1. * np.arange(len(overall_latencies)) / (len(overall_latencies) - 1), 22 | # plot CDF 23 | if is_baseline: 24 | plt.plot(x1, y1, ":", color="c", label="baseline model0") 25 | plt.plot(x2, y2, "-.", color="c", label="baseline model1") 26 | plt.plot(x, y, "-", color="c", label="baseline overall") 27 | else: 28 | plt.plot(x1, y1, ":", color="orange", label="parallel model0") 29 | plt.plot(x2, y2, "-.", color="orange", label="parallel model1") 30 | plt.plot(x, y, "-", color="orange", label="parallel overall") 31 | 32 | 33 | def run_strong_baseline(): 34 | workload_name = "test_workload_8to2_50Hz_60s" 35 | placement_filename = "./placements/placement_125M_baseline.json" 36 | model_id_to_service_name = {0: "Bert_125M_0", 1: "Bert_125M_1"} 37 | print("\n========================") 38 | print("Test 125M baseline trace:") 39 | workload_filename = f"./workload/{workload_name}" 40 | workload = PossoinWorkLoad.load(workload_filename) 41 | cluster = Cluster(1, 2, 16) 42 | meshexecutors = load_meshexecutors(placement_filename, cluster) 43 | scheduler = FIFOScheduler(workload, meshexecutors, model_id_to_service_name) 44 | simulator = Simulator(scheduler, cluster) 45 | simulator.start() 46 | latencies, overall_latencies = compute_statistics_from_simulation(scheduler.completed_tasks) 47 | plot_cdf(latencies, overall_latencies, True) 48 | 49 | def run_interop(): 50 | workload_name = "test_workload_8to2_50Hz_60s" 51 | placement_filename = "./placements/placement_125M_interop.json" 52 | model_id_to_service_name = {0: "Bert_125M_0", 1: "Bert_125M_1"} 53 | print("\n========================") 54 | print("Test 125M interop trace:") 55 | workload_filename = f"./workload/{workload_name}" 56 | workload = PossoinWorkLoad.load(workload_filename) 57 | cluster = Cluster(1, 2, 16) 58 | meshexecutors = load_meshexecutors(placement_filename, cluster) 59 | scheduler = FIFOScheduler(workload, meshexecutors, model_id_to_service_name) 60 | simulator = Simulator(scheduler, cluster) 61 | simulator.start() 62 | latencies, overall_latencies = compute_statistics_from_simulation(scheduler.completed_tasks) 63 | plot_cdf(latencies, overall_latencies, False) 64 | 65 | def run_intraop(): 66 | workload_name = "test_workload_8to2_50Hz_60s" 67 | placement_filename = "./placements/placement_125M_intraop.json" 68 | model_id_to_service_name = {0: "Bert_125M_0", 1: "Bert_125M_1"} 69 | print("\n========================") 70 | print("Test 125M intraop trace:") 71 | workload_filename = f"./workload/{workload_name}" 72 | workload = PossoinWorkLoad.load(workload_filename) 73 | cluster = Cluster(1, 2, 16) 74 | meshexecutors = load_meshexecutors(placement_filename, cluster) 75 | scheduler = FIFOScheduler(workload, meshexecutors, model_id_to_service_name) 76 | simulator = Simulator(scheduler, cluster) 77 | simulator.start() 78 | latencies, overall_latencies = compute_statistics_from_simulation(scheduler.completed_tasks) 79 | plot_cdf(latencies, overall_latencies, False) 80 | 81 | parallel_method = "interop" 82 | #parallel_method = "intraop" 83 | 84 | plt.figure() 85 | run_strong_baseline() 86 | if parallel_method == "interop": 87 | run_interop() 88 | else: 89 | run_intraop() 90 | 91 | plt.legend() 92 | plt.ylabel("CDF") 93 | plt.xlabel("Latency(s)") 94 | 95 | # savefig 96 | plt.savefig(f"cdf_{parallel_method}") 97 | 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /deprecated/simulator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from alpasim.cluster import Cluster, load_meshexecutors 4 | from alpasim.scheduler import FIFOScheduler 5 | from alpasim.utils import dump_chrome_tracing_from_simulation 6 | from alpasim.workload import PossoinWorkLoad 7 | from alpasim.simulator import Simulator 8 | 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(description="Cluster simulator for distributed DL inference tasks") 12 | parser.add_argument('--name', type=str, required=True, 13 | help="simulation name") 14 | parser.add_argument('-n', '--num_nodes', type=int, default=1, 15 | help="number of nodes in the cluster") 16 | parser.add_argument('-d', '--num_devices_per_node', type=int, default=2, 17 | help="number of devices per node in the cluster") 18 | parser.add_argument('-c', '--memory_capacity', type=int, default=16, 19 | help="GPU memory capacity in GB") 20 | parser.add_argument('-w', '--workload', type=str, required=True, 21 | help='Workload Filename') 22 | parser.add_argument('-p', '--placement', type=str, required=True, 23 | help='Placement Filename') 24 | parser.add_argument('--chrome_trace', action='store_true', 25 | help='Dump chrome trace') 26 | args = parser.parse_args() 27 | 28 | cluster = Cluster(args.num_nodes, args.num_devices_per_node, args.memory_capacity) 29 | workload = PossoinWorkLoad.load(f"./workload/{args.workload}") 30 | meshexecutors = load_meshexecutors(f"./placements/{args.placement}", cluster) 31 | scheduler = FIFOScheduler(workload, meshexecutors) 32 | simulator = Simulator(scheduler, cluster) 33 | simulator.start() 34 | if args.chrome_trace: 35 | dump_chrome_tracing_from_simulation(scheduler.completed_tasks, f"./chrome_trace/{args.name}.json") 36 | -------------------------------------------------------------------------------- /deprecated/workload/README.md: -------------------------------------------------------------------------------- 1 | # Workload File 2 | 3 | Workload file is generated/saved/loaded in `alpasim/workload.py`. 4 | 5 | ## PossoinWorkload File Format 6 | 7 | First row is metadata, which contains: 8 | 9 | - model_num 10 | - tot_arrival_rate 11 | - proportions 12 | - duration 13 | - SLOs 14 | - workload_name 15 | 16 | From second row to the end are requests info, each row contains: 17 | 18 | - model_id 19 | - arrive_timestamp (in second) 20 | -------------------------------------------------------------------------------- /deprecated/workload/test_workload_8to2_10Hz_20s: -------------------------------------------------------------------------------- 1 | 2,10,"[0.5, 0.5]",20,"[0.25, 0.25]",test_workload_8to2_10Hz_20s 2 | 0,0.0 3 | 0,0.016178625090231012 4 | 0,0.08767096368971053 5 | 0,0.10635962879147334 6 | 0,0.11660595731545761 7 | 0,0.14107213914883487 8 | 0,0.5515635626014919 9 | 1,0.6545373678195101 10 | 0,0.7009757663480196 11 | 0,0.8593462821428112 12 | 0,0.927250277046247 13 | 0,0.9787070232864395 14 | 1,1.0682257410858031 15 | 0,1.1947363458498523 16 | 1,1.2402710394609315 17 | 0,1.2625918514774512 18 | 0,1.311242777338951 19 | 0,1.3865060247824668 20 | 0,1.5059008635684212 21 | 0,1.5956036549774821 22 | 0,1.6586399386975912 23 | 0,1.8177813682909112 24 | 0,1.850176333828387 25 | 0,1.8535443098768978 26 | 0,2.229060523027998 27 | 0,2.2424080952703274 28 | 0,2.2996473313509673 29 | 0,2.434802620872246 30 | 0,2.5016443059218467 31 | 0,2.5403042852907225 32 | 1,2.6358581321076215 33 | 1,2.9574492956040133 34 | 0,2.959031275068509 35 | 0,3.055918785274431 36 | 0,3.0716926016789334 37 | 0,3.0812617240131277 38 | 0,3.13775134863309 39 | 0,3.363178676050936 40 | 0,3.3653233955663024 41 | 0,3.528628269606197 42 | 0,3.5590105878100236 43 | 0,3.592914796375641 44 | 0,3.698888782985365 45 | 1,3.7105422388732046 46 | 0,3.8742854844486545 47 | 1,3.9393156987267597 48 | 0,3.9667828070138182 49 | 0,4.137563941398187 50 | 0,4.183380232201867 51 | 1,4.196103456345683 52 | 1,4.36758639273928 53 | 0,4.464439145051358 54 | 0,4.513478492687897 55 | 0,4.530042261076868 56 | 0,4.530246721611402 57 | 0,4.705968978122072 58 | 1,4.848890322056627 59 | 1,4.9085620095099625 60 | 1,5.025952990142035 61 | 0,5.045221984427619 62 | 0,5.083282392545631 63 | 0,5.084460509274678 64 | 0,5.102590445820243 65 | 0,5.167316570730358 66 | 0,5.210139596158508 67 | 1,5.317152898590351 68 | 0,5.440748659970929 69 | 0,5.513624995744061 70 | 0,5.531697049076725 71 | 0,5.558113582853236 72 | 0,5.579789818762441 73 | 0,5.626446902544927 74 | 0,5.645995573598957 75 | 0,5.780768028552207 76 | 0,6.059464171808437 77 | 0,6.097935416297194 78 | 0,6.107577144000878 79 | 0,6.178232159410493 80 | 0,6.242707761790129 81 | 0,6.269691067936045 82 | 0,6.356104211847679 83 | 0,6.5889189141259585 84 | 0,6.592220137470327 85 | 1,6.763410915916538 86 | 1,6.7696988629603885 87 | 0,6.778637199671151 88 | 0,6.927166845637971 89 | 0,7.024362415341146 90 | 0,7.047641256810061 91 | 0,7.195157479165619 92 | 0,7.6003182753940175 93 | 0,7.669767344683825 94 | 0,8.345562362510261 95 | 0,8.470197860534343 96 | 1,8.501733674773874 97 | 0,8.675783079041972 98 | 0,8.786256405098147 99 | 0,8.911099633851935 100 | 0,8.930394295595928 101 | 0,8.976622168329921 102 | 0,9.013941002161872 103 | 0,9.194190737184353 104 | 0,9.236026463268779 105 | 0,9.270266257319838 106 | 0,9.279304772334196 107 | 1,9.314798957968454 108 | 1,9.465727758250573 109 | 1,9.511792075197617 110 | 1,9.552230952514659 111 | 0,9.742553325163774 112 | 1,9.864721982610872 113 | 0,9.920864667540625 114 | 1,10.224488823633635 115 | 0,10.237613703153064 116 | 0,10.362499471210118 117 | 0,10.420284521938369 118 | 0,10.525029305725162 119 | 0,10.537414383166217 120 | 0,10.64479445385737 121 | 0,11.178914167361564 122 | 0,11.290357531363714 123 | 1,11.371041818113637 124 | 1,11.499298007342874 125 | 0,11.529385829414736 126 | 0,11.641457051986176 127 | 0,11.791255849146523 128 | 0,11.91760936187514 129 | 0,11.93572240428526 130 | 0,12.051773374227047 131 | 0,12.142136333977852 132 | 0,12.153435854396111 133 | 0,12.265513229993584 134 | 0,12.288334078246889 135 | 0,12.352746557868198 136 | 0,12.374861705793288 137 | 0,12.50708776741476 138 | 0,12.50972815803597 139 | 0,12.613972087079059 140 | 0,12.663877630080718 141 | 1,12.750603110233143 142 | 0,12.752887749244179 143 | 0,12.762615498653409 144 | 1,12.999653687585715 145 | 0,13.008493974400922 146 | 0,13.083892796944156 147 | 0,13.230079233120238 148 | 0,13.244429884164372 149 | 0,13.288296767632172 150 | 1,13.370827354689101 151 | 0,13.476733571016249 152 | 0,13.577734344366066 153 | 1,13.645829642484891 154 | 0,13.650126415728705 155 | 0,13.712428512386637 156 | 1,13.720631615249763 157 | 0,14.051853238809116 158 | 0,14.174375257127423 159 | 0,14.363188329023293 160 | 1,14.392441115855702 161 | 1,14.550183566261422 162 | 0,14.58297761749802 163 | 0,14.595409578159458 164 | 0,14.802746080104896 165 | 0,14.920162027878957 166 | 1,15.100262311766446 167 | 0,15.267737037028697 168 | 0,15.306972083180133 169 | 1,15.356622584275453 170 | 0,15.423114760227804 171 | 0,15.461649126153375 172 | 0,15.51942405610844 173 | 0,15.54433918780734 174 | 0,15.778208775834662 175 | 0,15.976575201538814 176 | 0,16.146881403809186 177 | 1,16.339413517152245 178 | 0,16.349450789033973 179 | 1,16.500218929426094 180 | 1,16.598028602786556 181 | 0,16.739298767487863 182 | 0,16.780996891513276 183 | 0,16.809684611933505 184 | 0,16.82012608793275 185 | 0,16.94451216999574 186 | 1,17.07475685604132 187 | 0,17.205241536477146 188 | 0,17.206834824216816 189 | 0,17.28321693011089 190 | 0,17.300115599643274 191 | 0,17.399782204584337 192 | 0,17.608252641638263 193 | 0,17.652051878514246 194 | 0,17.69053690490229 195 | 0,17.722333783323187 196 | 1,17.761136281026463 197 | 0,17.80973422905783 198 | 0,17.864176258294332 199 | 1,17.919196183756554 200 | 1,18.02275628823749 201 | 0,18.208886576455626 202 | 0,18.312188894129868 203 | 0,18.59166208832001 204 | 1,18.68343824961135 205 | 0,18.808670702762182 206 | 0,18.916294939192085 207 | 0,18.977217830113947 208 | 0,18.994323400973112 209 | 0,19.003080571663734 210 | 0,19.05717141261364 211 | 1,19.13080111815729 212 | 0,19.16554132216886 213 | 0,19.297639579469863 214 | 0,19.422862785628112 215 | 1,19.57175976651724 216 | 0,19.67084719984937 217 | 0,19.83494648093997 218 | 0,19.87670310493472 219 | 1,19.9409740264135 220 | -------------------------------------------------------------------------------- /deprecated/workload/test_workload_8to2_6.667Hz_20s: -------------------------------------------------------------------------------- 1 | 2,6.667,"[0.8, 0.2]",20,"[0.25, 0.25]",test_workload_8to2_6.667Hz_20s 2 | 0,0.0 3 | 0,0.38785915931808984 4 | 0,0.4142083704734896 5 | 1,0.4958246725899595 6 | 0,0.6900889465892203 7 | 0,0.7807855284167089 8 | 0,0.8455546296017873 9 | 0,0.9123256326421172 10 | 0,0.9650467955591858 11 | 0,1.0791695315175611 12 | 0,1.101846805056513 13 | 1,1.271926810700596 14 | 0,1.6286856606880271 15 | 0,1.714045953684007 16 | 1,1.789523194875095 17 | 1,1.81825415086913 18 | 0,1.8281760683099917 19 | 0,1.9463154703333143 20 | 0,2.0622046337943516 21 | 0,2.2087007814023853 22 | 0,2.3087577684005716 23 | 0,2.3789113482413704 24 | 0,2.42523877195751 25 | 0,2.452131637836243 26 | 0,2.5045703849325776 27 | 0,2.7048990810797946 28 | 0,2.728294420014128 29 | 0,2.984704219401499 30 | 0,3.0634025119642927 31 | 0,3.0912293760482763 32 | 0,3.3851171117709895 33 | 0,3.4385561357239123 34 | 1,3.441707472988317 35 | 0,3.7181556043315 36 | 0,3.7187906763046508 37 | 0,3.833625080628506 38 | 1,3.876707607066355 39 | 0,3.949187287610746 40 | 0,3.970664187047496 41 | 0,4.2215835774006525 42 | 0,4.247329658722286 43 | 1,4.304735582176452 44 | 1,4.59856222371967 45 | 1,4.860087182703689 46 | 0,4.923597986962046 47 | 0,5.013555946121519 48 | 0,5.065055467538393 49 | 1,5.262991016327795 50 | 1,5.332804880560673 51 | 1,5.459157513715694 52 | 1,5.514346822333702 53 | 0,5.752667550251483 54 | 0,5.799325508850001 55 | 0,6.068519046015989 56 | 0,6.20495852365687 57 | 0,6.324480635294552 58 | 0,6.698375783242967 59 | 0,6.725493705486089 60 | 0,7.150371080161162 61 | 0,7.186992554916359 62 | 0,7.344471181247913 63 | 0,7.425151289821606 64 | 0,7.430642050790597 65 | 0,7.551758264926435 66 | 0,7.752660501515499 67 | 0,7.757931449107972 68 | 0,7.8075988509670475 69 | 0,7.950141468917985 70 | 0,8.059459135277343 71 | 0,8.4182673111153 72 | 0,8.435319629701032 73 | 1,8.444560604633121 74 | 0,8.491855999695133 75 | 0,8.908729732263163 76 | 0,9.116163842924689 77 | 0,9.227535738211746 78 | 1,9.428114621243436 79 | 0,9.467597754902156 80 | 0,9.878176314413306 81 | 0,10.134741549909181 82 | 0,10.176422489186113 83 | 0,10.223021024665087 84 | 0,10.400864651418777 85 | 0,10.754117806111326 86 | 0,10.819185163087026 87 | 0,10.965156717287709 88 | 0,11.027355393758887 89 | 0,11.155121170691984 90 | 1,11.218487788196445 91 | 0,11.310543783406436 92 | 0,11.818603631283505 93 | 0,11.850108237780905 94 | 0,11.903622167230646 95 | 0,12.358302258513909 96 | 0,12.922531712585457 97 | 0,12.932998805006093 98 | 1,13.165721354454323 99 | 0,13.70610439611292 100 | 0,13.753060207762017 101 | 1,13.76864468523469 102 | 1,14.246318860450343 103 | 0,14.260284193962066 104 | 0,14.263263401695378 105 | 1,14.346411925094714 106 | 0,14.405410814682025 107 | 0,14.604300533641007 108 | 0,14.648714358766322 109 | 0,14.701113884830038 110 | 0,14.805004636117687 111 | 1,14.900301629049615 112 | 0,15.05694827507222 113 | 0,15.26258048389176 114 | 0,15.483194251442086 115 | 0,15.935737888137295 116 | 0,16.294040373991415 117 | 1,16.382839446893076 118 | 1,16.4844638835141 119 | 0,16.678116853262026 120 | 0,16.83741567233837 121 | 0,16.85918997604224 122 | 0,16.99604752998396 123 | 0,17.25472934536894 124 | 0,17.352164574919737 125 | 1,17.504191055126096 126 | 0,17.60510186235473 127 | 0,17.747202431840275 128 | 0,17.864081512686734 129 | 0,17.938300848380564 130 | 0,18.119541950218608 131 | 0,18.14777340435394 132 | 0,18.491248873214005 133 | 0,18.66571718606002 134 | 0,18.682493725549737 135 | 0,18.869920334716017 136 | 0,18.963668570930444 137 | 1,19.028855720815493 138 | 0,19.077331439700547 139 | 0,19.090381200144837 140 | 0,19.28815861347427 141 | 1,19.29265720149321 142 | 0,19.519220711209883 143 | 0,19.824523910312536 144 | 0,19.91213568796626 145 | -------------------------------------------------------------------------------- /experiments/batching/gen_data_goodput_vs_slo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from benchmarks.alpa.equal_model_case import EqualModelCase, run_equal_model_cases 4 | from benchmarks.alpa.general_model_case import GeneralModelCase, run_general_model_cases 5 | from alpa_serve.util import GB, batchsize_config 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--output", type=str, default="res_goodput_vs_slo.tsv") 11 | parser.add_argument("--parallel", action="store_true") 12 | parser.add_argument("--policy", type=str) 13 | parser.add_argument("--slo-scale", type=float) 14 | parser.add_argument("--trace", choices=["synthetic", "azure_v2"], 15 | default="synthetic") 16 | parser.add_argument("--mode", choices=["simulate", "run"], 17 | default="simulate") 18 | parser.add_argument("--unequal", action="store_true") 19 | parser.add_argument("--model-type", type=str, default="all_transformers", 20 | choices=["all_transformers", "mixed"]) 21 | parser.add_argument("--protocol", type=str, default="http", 22 | choices=["http", "ray"]) 23 | parser.add_argument("--relax-slo", action="store_true") 24 | parser.add_argument("--debug-tstamp", action="store_true") 25 | parser.add_argument("--enable-batching", action="store_true") 26 | parser.add_argument("--max-batchsize", type=int, default=2) 27 | 28 | args = parser.parse_args() 29 | 30 | # choices: {"sr-greedy", "sr-ilp", "mp-ilp", 31 | # "mp-greedy-2", "mp-greedy-4", "mp-greedy-8", 32 | # "mp-search", "mp-search-sep"} 33 | if args.policy is not None: 34 | policies = [args.policy] 35 | else: 36 | # policies = ["sr-greedy", "mp-search", "sr-replace-30"] 37 | policies = ["mp-search"] 38 | 39 | if args.enable_batching: 40 | assert args.max_batchsize == batchsize_config[-1], f"maximum batchsize is not {args.max_batchsize}, set it in alpa_serve/util.py" 41 | policies = [policy + "-batch-" + str(args.max_batchsize) for policy in policies] 42 | 43 | exp_name = "goodput_vs_slo" 44 | num_devices = 8 45 | mem_budget = 13 * GB 46 | model_type = "bert-1.3b" 47 | num_models = 16 48 | total_rate = 64 49 | if args.trace == "synthetic": 50 | # choices: {"gamma", "uniform_mmpp"} 51 | arrival_process = "gamma" 52 | # choices: {"uniform", "power_law", "triangle_decay"} 53 | rate_distribution = "power_law" 54 | arrival_process_kwargs = {"cv": 4} 55 | elif args.trace == "azure_v2": 56 | # choices: {"azure_v2"} 57 | arrival_process = "azure_v2" 58 | rate_distribution = None 59 | arrival_process_kwargs = None 60 | 61 | if args.slo_scale is not None: 62 | slo_scales = [args.slo_scale] 63 | else: 64 | slo_scales = [0.5, 1, 2, 3, 4, 5, 6, 8, 10, 12, 14] 65 | # slo_scales = [14] 66 | duration = 1000 67 | 68 | if args.unequal: 69 | # multi-model config 70 | if args.model_type == "mixed": 71 | model_set = ["bert-1.3b", "bert-2.6b", "bert-6.7b", "moe-1.3b", "moe-2.4b", "moe-5.3b"] 72 | else: 73 | model_set = ["bert-6.7b", "moe-1.3b"] 74 | num_devices = 64 75 | total_rate = 70 76 | fixed_num_modelset = 8 77 | model_types = model_set * fixed_num_modelset 78 | model_names = sum([[f"{model_type}-{i}" for model_type in model_set] for i in range(fixed_num_modelset)], []) 79 | 80 | cases = [] 81 | for slo_scale in slo_scales: 82 | for policy_name in policies: 83 | cases.append(GeneralModelCase( 84 | exp_name, num_devices, mem_budget, model_types, model_names, 85 | total_rate, rate_distribution, 86 | arrival_process, arrival_process_kwargs, 87 | slo_scale, duration, policy_name)) 88 | 89 | run_general_model_cases(cases, 90 | output_file=args.output, 91 | mode=args.mode, 92 | debug_tstamp=args.debug_tstamp, 93 | parallel=args.parallel, 94 | enable_batching=args.enable_batching) 95 | else: 96 | cases = [] 97 | for slo_scale in slo_scales: 98 | for policy_name in policies: 99 | cases.append(EqualModelCase( 100 | exp_name, num_devices, mem_budget, model_type, num_models, 101 | total_rate, rate_distribution, 102 | arrival_process, arrival_process_kwargs, 103 | slo_scale, duration, policy_name, 104 | None, None, None, None)) 105 | 106 | 107 | run_equal_model_cases(cases, 108 | output_file=args.output, 109 | mode=args.mode, 110 | relax_slo=args.relax_slo, 111 | protocol=args.protocol, 112 | debug_tstamp=args.debug_tstamp, 113 | parallel=args.parallel, 114 | enable_batching=args.enable_batching) 115 | -------------------------------------------------------------------------------- /experiments/e2e_goodput/general_model_suite.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | BenchmarkConfig = namedtuple( 4 | "BenchmarkConfig", 5 | [ 6 | "fixed_num_devices", "fixed_num_modelset", "fixed_slo_scale", # general 7 | "fixed_rate_scale", "fixed_cv_scale", # real trace only 8 | "num_devices_list", "num_modelset_list", "slo_scales", 9 | "rate_list", "cv_list", # synthetic trace only 10 | "rate_scales", "cv_scales", # real trace only 11 | ] 12 | ) 13 | 14 | synthetic_suite = { 15 | "all_transformers": BenchmarkConfig( 16 | fixed_num_devices = 24, 17 | fixed_num_modelset = 14, 18 | fixed_slo_scale = 5, 19 | num_devices_list = [16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96], 20 | num_modelset_list = [1, 4, 6, 8, 10, 12, 14, 16], 21 | slo_scales = [1, 2, 4, 8, 12, 16], 22 | rate_list = [1, 8, 16, 24, 32, 48, 64, 80], 23 | cv_list = [0.5, 1, 2, 4, 6], 24 | fixed_rate_scale = None, 25 | fixed_cv_scale = None, 26 | rate_scales = None, 27 | cv_scales = None, 28 | ), 29 | "mixed": BenchmarkConfig( 30 | fixed_num_devices = 32, 31 | fixed_num_modelset = 10, 32 | fixed_slo_scale = 5, 33 | num_devices_list = [8, 24, 40, 56, 72, 88, 104, 120, 136, 152, 168], 34 | num_modelset_list = [1, 2, 4, 6, 8, 10, 12, 14, 16], 35 | slo_scales = [1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24], 36 | rate_list = [1, 8, 16, 24, 32, 64, 96, 128], 37 | cv_list = [0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 8], 38 | fixed_rate_scale = None, 39 | fixed_cv_scale = None, 40 | rate_scales = None, 41 | cv_scales = None, 42 | ), 43 | } 44 | 45 | azure_v1_suite = { 46 | "all_transformers": BenchmarkConfig( 47 | fixed_num_devices = 24, 48 | fixed_num_modelset = 12, 49 | fixed_slo_scale = 5, 50 | fixed_rate_scale = 2e-3, 51 | fixed_cv_scale = 4, 52 | num_devices_list = [8, 16, 24, 32, 40, 48], 53 | num_modelset_list = [8, 12, 16, 20, 24, 28, 32], 54 | slo_scales = [0.75, 1, 2, 3, 4, 5, 7.5, 10], 55 | rate_list = [], 56 | cv_list = [], 57 | rate_scales = [5e-4, 1e-3, 2e-3, 4e-3, 5e-3, 6e-3, 7e-3, 8e-3], 58 | cv_scales = [1, 2, 3, 4, 5, 6, 8], 59 | ), 60 | "mixed": BenchmarkConfig( 61 | fixed_num_devices = 48, 62 | fixed_num_modelset = 12, 63 | fixed_slo_scale = 5, 64 | fixed_rate_scale = 2e-3, 65 | fixed_cv_scale = 4, 66 | num_devices_list = [32, 40, 48, 54, 64, 72, 96], 67 | num_modelset_list = [8, 12, 16, 20, 24, 28, 32], 68 | slo_scales =[0.75, 1, 2, 3, 4, 5, 7.5, 10], 69 | rate_list = [], 70 | cv_list = [], 71 | rate_scales = [5e-4, 1e-3, 2e-3, 4e-3, 5e-3, 6e-3, 7e-3, 8e-3], 72 | cv_scales = [1, 2, 3, 4, 5, 6, 8], 73 | ), 74 | } 75 | 76 | azure_v2_suite = { 77 | "all_transformers": BenchmarkConfig( 78 | fixed_num_devices = 24, 79 | fixed_num_modelset = 12, 80 | fixed_slo_scale = 5, 81 | fixed_rate_scale = 32, 82 | fixed_cv_scale = 1, 83 | num_devices_list = [8, 16, 20, 24, 32, 40, 48], 84 | num_modelset_list = [8, 12, 16, 20, 24, 28, 32], 85 | slo_scales = [0.75, 1, 1.25, 1.5, 2, 2.5, 5, 10], 86 | rate_list = [], 87 | cv_list = [], 88 | rate_scales = [1, 4, 8, 16, 32, 64, 128], 89 | cv_scales = [1, 2, 3, 4, 5, 6, 8], 90 | ), 91 | "mixed": BenchmarkConfig( 92 | fixed_num_devices = 48, 93 | fixed_num_modelset = 12, 94 | fixed_slo_scale = 5, 95 | fixed_rate_scale = 32, 96 | fixed_cv_scale = 1, 97 | num_devices_list = [16, 24, 32, 40, 48, 54, 64], 98 | num_modelset_list = [8, 12, 16, 18, 20, 22], 99 | slo_scales = [0.75, 1, 1.25, 1.5, 2, 2.5, 5], 100 | rate_list = [], 101 | cv_list = [], 102 | rate_scales = [1, 4, 8, 16, 32, 64], 103 | cv_scales = [1, 2, 3, 4, 6, 8], 104 | ), 105 | } 106 | -------------------------------------------------------------------------------- /experiments/e2e_goodput/plot_sec6_3.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | 4 | import os 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | 11 | from benchmarks.alpa.equal_model_case import read_equal_model_case_tsv 12 | from benchmarks.alpa.general_model_case import read_general_model_case_tsv 13 | from benchmarks.alpa.plot_various_metrics import show_name, method2color, method2order 14 | 15 | linestyles = ["solid", "dashed", "dashdot", "dotted", (0, (3,5,1,5,1,5))] 16 | methodcolors = ["C2", "C1", "C0", "C3", "C4"] 17 | 18 | def plot_goodput_common(data, threshold, increasing, ax, xlabel, ybottom): 19 | methods = list(data.keys()) 20 | methods.sort(key=lambda x: method2order(x)) 21 | 22 | curves = [] 23 | legends = [] 24 | first_good = [] 25 | x_max = 0 26 | y_max = 0 27 | for i, method in enumerate(methods): 28 | curve = data[method] 29 | xs_, ys_ = zip(*curve.items()) 30 | xs = [x for x, _ in sorted(zip(xs_, ys_))] 31 | ys = [y for _, y in sorted(zip(xs_, ys_))] 32 | ys = np.array(ys) * 100 33 | curve = ax.plot(xs, ys, color=methodcolors[i], marker='.', linestyle=linestyles[i], linewidth=4, markersize=15) 34 | curves.append(curve[0]) 35 | legends.append(show_name(method)) 36 | 37 | if increasing: 38 | iterator = range(len(xs)) 39 | else: 40 | iterator = reversed(range(len(xs))) 41 | 42 | found = False 43 | for i in iterator: 44 | if ys[i] >= threshold * 100: 45 | first_good.append(xs[i]) 46 | found = True 47 | break 48 | if not found: 49 | first_good.append(0) 50 | 51 | x_max = max(x_max, *xs) 52 | y_max = max(y_max, *ys) 53 | 54 | ax.tick_params(axis='both', which='major', labelsize=20) 55 | ax.tick_params(axis='both', which='minor', labelsize=20) 56 | ax.set_ylim(bottom=ybottom, top=max(y_max * 1.02, 100)) 57 | ax.set_xlabel(xlabel, fontsize=20) 58 | ax.grid() 59 | 60 | for i in range(len(methods)): 61 | if first_good[i] == 0: 62 | continue 63 | ax.axvline(first_good[i], color=methodcolors[i], linestyle=":", linewidth=4) 64 | 65 | return curves, legends 66 | 67 | 68 | def plot_goodput(lines, threshold, folder, pdf): 69 | rate_data = defaultdict(lambda: defaultdict(dict)) 70 | cv_data = defaultdict(lambda: defaultdict(dict)) 71 | slo_data = defaultdict(lambda: defaultdict(dict)) 72 | 73 | for line in lines: 74 | if line["exp_name"] == "goodput_vs_rate": 75 | policy, x, goodput = ( 76 | line["policy_name"], line["total_rate"], line["goodput"]) 77 | rate_data[policy][x] = goodput 78 | if line["exp_name"] == "goodput_vs_cv": 79 | policy, x, goodput = ( 80 | line["policy_name"], line["arrival_process_kwargs"]["cv"], line["goodput"]) 81 | cv_data[policy][x] = goodput 82 | if line["exp_name"] == "goodput_vs_slo": 83 | policy, x, goodput = ( 84 | line["policy_name"], line["slo_scale"], line["goodput"]) 85 | slo_data[policy][x] = goodput 86 | 87 | fig, axs = plt.subplots(1, 3) 88 | 89 | datas = [rate_data, cv_data, slo_data] 90 | xlabels = ["Rate (r/s)", "CV", "SLO Scale"] 91 | ybottoms = [60,60,0] 92 | increasings = [False, False, True] 93 | for data, increasing, ax, xlabel, ybottom in zip(datas, increasings, axs, xlabels, ybottoms): 94 | curves, legends = plot_goodput_common(data, threshold, increasing, ax, xlabel, ybottom) 95 | 96 | fig.text(0.07, 0.5, "SLO Attainment (%)", va='center', rotation='vertical', fontsize=20) 97 | fig.legend(curves, legends, loc="upper center", ncol=6, bbox_to_anchor=(0.5, 1.1), fontsize=20) 98 | 99 | if pdf: 100 | output = os.path.join(folder, "large_model_exp.pdf") 101 | else: 102 | output = os.path.join(folder, "large_model_exp.png") 103 | 104 | figure_size = (18, 5) 105 | fig.set_size_inches(figure_size) 106 | fig.savefig(output, bbox_inches='tight') 107 | print(f"Output the plot to {output}") 108 | 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("--input", type=str, required=True) 114 | parser.add_argument("--output-dir", type=str, default="paper_figures") 115 | parser.add_argument("--show", action="store_true") 116 | parser.add_argument("--pdf", action="store_true") 117 | args = parser.parse_args() 118 | 119 | os.makedirs(args.output_dir, exist_ok=True) 120 | 121 | threshold = 0.99 122 | 123 | lines = read_equal_model_case_tsv(args.input) 124 | 125 | plot_goodput(lines, threshold, args.output_dir, args.pdf) 126 | -------------------------------------------------------------------------------- /experiments/e2e_goodput/plot_sec6_5.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | 4 | import os 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import pickle 11 | 12 | from benchmarks.alpa.equal_model_case import read_equal_model_case_tsv 13 | from benchmarks.alpa.plot_various_metrics import show_name, method2color, method2order 14 | 15 | linestyles = ["solid", "dashed", "dashdot", "dotted", (0, (3,5,1,5,1,5)), (0, (3,10,1,10,1,10))] 16 | methodcolors = ["C2", "C1", "C0", "C3", "C4", "C5", "C6", "C7", "C8", "C9"] 17 | 18 | def plot_goodput_common(data, threshold, increasing, ax, xlabel, ybottom): 19 | methods = list(data.keys()) 20 | methods.sort(key=lambda x: method2order(x)) 21 | 22 | curves = [] 23 | legends = [] 24 | first_good = [] 25 | x_max = 0 26 | y_max = 0 27 | for i, method in enumerate(methods): 28 | curve = data[method] 29 | xs_, ys_ = zip(*curve.items()) 30 | xs = [x for x, _ in sorted(zip(xs_, ys_))] 31 | ys = [y for _, y in sorted(zip(xs_, ys_))] 32 | ys = np.array(ys) * 100 33 | curve = ax.plot(xs, ys, color=methodcolors[i], marker='.', linestyle=linestyles[i], linewidth=4, markersize=15) 34 | curves.append(curve[0]) 35 | legends.append(show_name(method)) 36 | 37 | if increasing: 38 | iterator = range(len(xs)) 39 | else: 40 | iterator = reversed(range(len(xs))) 41 | 42 | found = False 43 | for i in iterator: 44 | if ys[i] >= threshold * 100: 45 | first_good.append(xs[i]) 46 | found = True 47 | break 48 | if not found: 49 | first_good.append(0) 50 | 51 | x_max = max(x_max, *xs) 52 | y_max = max(y_max, *ys) 53 | 54 | ax.tick_params(axis='both', which='major', labelsize=20) 55 | ax.tick_params(axis='both', which='minor', labelsize=20) 56 | ax.set_ylim(bottom=ybottom, top=max(y_max * 1.02, 100)) 57 | ax.set_xlabel(xlabel, fontsize=20) 58 | ax.grid() 59 | ax.legend(curves, legends, fontsize=20) 60 | 61 | for i in range(len(methods)): 62 | if first_good[i] == 0: 63 | continue 64 | ax.axvline(first_good[i], color=methodcolors[i], linestyle=":", linewidth=4) 65 | 66 | return curves, legends 67 | 68 | 69 | def plot_goodput(bs_lines, batching_lines, threshold, folder, pdf): 70 | bs_data = defaultdict(lambda: defaultdict(dict)) 71 | batching_data = defaultdict(lambda: defaultdict(dict)) 72 | 73 | for line in bs_lines: 74 | if line["exp_name"] == "goodput_vs_slo": 75 | policy, x, goodput = ( 76 | line["policy_name"], line["slo_scale"], line["goodput"]) 77 | bs_data[policy][x] = goodput 78 | else: 79 | continue 80 | 81 | for line in batching_lines: 82 | if line["exp_name"] == "goodput_vs_slo": 83 | policy, x, goodput = ( 84 | line["policy_name"], line["slo_scale"], line["goodput"]) 85 | batching_data[policy][x] = goodput 86 | else: 87 | continue 88 | 89 | # fig, axs = plt.subplots(1, 5) 90 | fig, axs = plt.subplots(1, 2) 91 | 92 | datas = [bs_data, batching_data] 93 | xlabels = ["SLO Scale", "SLO Scale"] 94 | ybottoms = [0,0,0,0] 95 | increasings = [True, True] 96 | for data, increasing, ax, xlabel, ybottom in zip(datas, increasings, axs, xlabels, ybottoms): 97 | curves, legends = plot_goodput_common(data, threshold, increasing, ax, xlabel, ybottom) 98 | 99 | fig.text(0.05, 0.5, "SLO Attainment (%)", va='center', rotation='vertical', fontsize=20) 100 | # fig.legend(reversed(curves), reversed(legends), loc="upper center", ncol=4, bbox_to_anchor=(0.5, 1.1), fontsize=20) 101 | 102 | if pdf: 103 | output = os.path.join(folder, "batching.pdf") 104 | else: 105 | output = os.path.join(folder, "batching.png") 106 | 107 | figure_size = (15, 5) 108 | fig.set_size_inches(figure_size) 109 | fig.savefig(output, bbox_inches='tight') 110 | print(f"Output the plot to {output}") 111 | 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument("--input", type=str, required=True) 117 | parser.add_argument("--output-dir", type=str, default="./paper_figures") 118 | parser.add_argument("--show", action="store_true") 119 | parser.add_argument("--pdf", action="store_true") 120 | args = parser.parse_args() 121 | 122 | os.makedirs(args.output_dir, exist_ok=True) 123 | 124 | threshold = 0.99 125 | 126 | bs_lines = read_equal_model_case_tsv(args.input + "/res_batchsize.tsv") 127 | batching_lines = read_equal_model_case_tsv(args.input + "/res_batching.tsv") 128 | 129 | 130 | plot_goodput(bs_lines, batching_lines, threshold, args.output_dir, args.pdf) 131 | -------------------------------------------------------------------------------- /experiments/e2e_goodput/plot_sec6_6.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | 4 | import os 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import pickle 11 | 12 | from benchmarks.alpa.equal_model_case import read_equal_model_case_tsv 13 | from benchmarks.alpa.general_model_case import read_general_model_case_tsv 14 | from benchmarks.alpa.plot_various_metrics import show_name, method2color, method2order 15 | 16 | linestyles = ["solid", "dashed", "dashdot", "dotted", (0, (3,5,1,5,1,5))] 17 | 18 | paper_name = { 19 | "mp-round-robin": "Round robin", 20 | "mp-greedy-4": "Greedy placement", 21 | "mp-search-sep": "Greedy placement + Group partitioning", 22 | } 23 | 24 | def plot_goodput_common(data, threshold, increasing, ax, xlabel, ybottom): 25 | methods = list(data.keys()) 26 | # methods.sort(key=lambda x: method2order(x)) 27 | 28 | curves = [] 29 | legends = [] 30 | first_good = [] 31 | x_max = 0 32 | y_max = 0 33 | for i, method in enumerate(methods): 34 | curve = data[method] 35 | xs_, ys_ = zip(*curve.items()) 36 | xs = [x for x, _ in sorted(zip(xs_, ys_))] 37 | ys = [y for _, y in sorted(zip(xs_, ys_))] 38 | ys = np.array(ys) * 100 39 | curve = ax.plot(xs, ys, color=method2color(method), marker='.', linestyle=linestyles[i], linewidth=4, markersize=15) 40 | curves.append(curve[0]) 41 | legends.append(paper_name.get(method)) 42 | 43 | if increasing: 44 | iterator = range(len(xs)) 45 | else: 46 | iterator = reversed(range(len(xs))) 47 | 48 | found = False 49 | for i in iterator: 50 | if ys[i] >= threshold * 100: 51 | first_good.append(xs[i]) 52 | found = True 53 | break 54 | if not found: 55 | first_good.append(0) 56 | 57 | x_max = max(x_max, *xs) 58 | y_max = max(y_max, *ys) 59 | 60 | ax.tick_params(axis='both', which='major', labelsize=20) 61 | ax.tick_params(axis='both', which='minor', labelsize=20) 62 | ax.set_ylim(bottom=ybottom, top=max(y_max * 1.02, 100)) 63 | ax.set_xlabel(xlabel, fontsize=22) 64 | ax.grid() 65 | 66 | ax.legend(curves, legends, fontsize=18.5, loc="lower left") 67 | 68 | for i in range(len(methods)): 69 | if first_good[i] == 0: 70 | continue 71 | ax.axvline(first_good[i], color=method2color(methods[i]), linestyle=":", linewidth=4) 72 | 73 | 74 | def plot_goodput(lines, threshold, folder, pdf): 75 | rate_data = defaultdict(lambda: defaultdict(dict)) 76 | cv_data = defaultdict(lambda: defaultdict(dict)) 77 | 78 | for line in lines: 79 | if line["exp_name"] == "goodput_vs_rate": 80 | policy, x, goodput = ( 81 | line["policy_name"], line["total_rate"], line["goodput"]) 82 | rate_data[policy][x] = goodput 83 | elif line["exp_name"] == "goodput_vs_cv": 84 | policy, x, goodput = ( 85 | line["policy_name"], line["arrival_process_kwargs"]["cv"], line["goodput"]) 86 | cv_data[policy][x] = goodput 87 | else: 88 | continue 89 | 90 | fig, axs = plt.subplots(1, 2) 91 | 92 | datas = [rate_data, cv_data] 93 | xlabels = ["Rate (r/s)", "CV"] 94 | ybottoms = [60,60] 95 | increasings = [False, False] 96 | for data, increasing, ax, xlabel, ybottom in zip(datas, increasings, axs, xlabels, ybottoms): 97 | plot_goodput_common(data, threshold, increasing, ax, xlabel, ybottom) 98 | 99 | fig.text(0.07, 0.5, "SLO Attainment (%)", va='center', rotation='vertical', fontsize=22) 100 | 101 | if pdf: 102 | output = os.path.join(folder, "ablation.pdf") 103 | else: 104 | output = os.path.join(folder, "ablation.png") 105 | 106 | figure_size = (18, 7) 107 | fig.set_size_inches(figure_size) 108 | fig.savefig(output, bbox_inches='tight') 109 | print(f"Output the plot to {output}") 110 | 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--input", type=str, required=True) 116 | parser.add_argument("--output-dir", type=str, default="paper_figures") 117 | parser.add_argument("--show", action="store_true") 118 | parser.add_argument("--pdf", action="store_true") 119 | args = parser.parse_args() 120 | 121 | os.makedirs(args.output_dir, exist_ok=True) 122 | 123 | threshold = 0.99 124 | 125 | lines = read_general_model_case_tsv(args.input) 126 | 127 | plot_goodput(lines, threshold, args.output_dir, args.pdf) 128 | -------------------------------------------------------------------------------- /experiments/e2e_goodput/visualize.py: -------------------------------------------------------------------------------- 1 | from alpa_serve.trace import Trace 2 | # azure_v2_trace_dir = "/home/ubuntu/efs/mms/dataset/azure_v2.pkl" 3 | azure_v2_trace_dir = "/home/ubuntu/azure_v2.pkl" 4 | azure_v2_trace = Trace("azure_v2", azure_v2_trace_dir) 5 | num_models = 30 6 | model_names = [f"m{i}" for i in range(num_models)] 7 | # train_replays = azure_v2_trace.replay(model_names, model_mapping_strategy="stripe", 8 | # arrival_distribution="vanilla", 9 | # start_time='5.0.0', end_time='6.0.0', 10 | # replication_factor=1) 11 | train_replays = azure_v2_trace.replay(model_names, 12 | model_mapping_strategy="stripe", 13 | arrival_distribution="gamma", 14 | start_time='5.0.0', 15 | end_time='6.0.0', 16 | interval_seconds=5400, 17 | rate_scale_factor=1, 18 | cv_scale_factor=4) 19 | for model_name in model_names: 20 | replay = train_replays[model_name] 21 | replay.report_stats() 22 | print(replay.rate()) -------------------------------------------------------------------------------- /experiments/motivation/README.md: -------------------------------------------------------------------------------- 1 | # Motivation Experiments (~1hr) 2 | 3 | ## Prepare profiling databases 4 | 5 | Please find the profiling databases used in motivation experiments [here](https://github.com/alpa-projects/mms/issues/14#issuecomment-1521422527). There should be three databases in total: 6 | - `profiling_result.pkl` 7 | - `profiling_result_long_sequence_manual.pkl` 8 | - `profiling_result_long_sequence_dp.pkl` 9 | 10 | Please unzip the files and put them under `experiments/motivation/`. 11 | 12 | ## Two model example (Sec 3.1, Figure 2 (a)-(d)) 13 | 14 | To generate the figures: 15 | ```bash 16 | python illustrative_example.py 17 | ``` 18 | 19 | Figures mapping: 20 | - Figure 2 (a): `illustrative_example_1.pdf`. 21 | - Figure 2 (b): `illustrative_example_2.pdf`. 22 | - Figure 2 (c): `illustrative_example_3.pdf`. 23 | - Figure 2 (d): `illustrative_example_utilization_4.pdf`. 24 | 25 | ## Changing per-GPU memory (Sec 3.2, Figure 4) 26 | 27 | To generate the figures: 28 | ```bash 29 | python memory_budget_vs_latency.py 30 | ``` 31 | 32 | Figures mapping: 33 | - Figure 4 (left): `memory_budget_vs_latency_mean_latency_2.pdf`. 34 | - Figure 4 (right): `memory_budget_vs_latency_p99_latency_2.pdf`. 35 | 36 | ## Changing arrival rates, CVs, and SLOs (Sec 3.2, Figure 5, 6, 7(a)) 37 | 38 | To generate the figures: 39 | ```bash 40 | python changing_rate_cv_slo.py 41 | ``` 42 | 43 | Figures mapping: 44 | - Figure 5 (left): `changing_rate_cv_slo_1.pdf`. 45 | - Figure 5 (right): `changing_rate_cv_slo_1.5.pdf`. 46 | - Figure 6 (left): `changing_rate_cv_slo_2.pdf`. 47 | - Figure 6 (right): `changing_rate_cv_slo_2.5.pdf`. 48 | - Figure 7 (a): `changing_rate_cv_slo_3.pdf`. 49 | 50 | ## Changing model parallel overhead (Sec 3.3, Figure 7(b)) 51 | 52 | To generate the figures: 53 | ```bash 54 | python changing_pipeline_overhead.py 55 | ``` 56 | 57 | Figures mapping: 58 | - Figure 7 (b): `changing_pipeline_overhead_1.pdf`. 59 | 60 | ## Model parallel overhead (Sec 3.3, Figure 8 & Sec 6.5, Figure 14) 61 | 62 | To generate the figures: 63 | ```bash 64 | python overhead_decomposition.py 65 | ``` 66 | 67 | Figures mapping: 68 | - Figure 8 (a): `overhead_decomposition_pp.pdf`. 69 | - Figure 8 (b): `overhead_decomposition_op.pdf`. 70 | - Figure 14 (a): `overhead_decomposition_pp_compare_bert-1.3b.pdf` 71 | - Figure 14 (b): `overhead_decomposition_pp_compare_bert-2.6b.pdf` 72 | 73 | 74 | ## Latency, throughput, and memory usage of model parallelism (Sec 3.3, Figure 9) 75 | 76 | To generate the figures: 77 | ```bash 78 | python model_parallel_latency_throughput.py 79 | ``` 80 | 81 | Figures mapping: 82 | - Figure 9 (a): `model_parallel_latency.pdf`. 83 | - Figure 9 (b): `model_parallel_throughput.pdf` 84 | - Figure 9 (c): `model_parallel_memory.pdf` 85 | -------------------------------------------------------------------------------- /experiments/motivation/changing_pipeline_overhead.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import argparse 4 | import matplotlib as mpl 5 | mpl.use('Pdf') 6 | import matplotlib.pyplot as plt 7 | from matplotlib.transforms import Bbox 8 | 9 | from benchmarks.alpa.equal_model_case import EqualModelCase, run_equal_model_cases 10 | from alpa_serve.util import GB 11 | from alpa_serve.profiling import ProfilingDatabase 12 | 13 | color_dict = { 14 | "sr-uniform": "C1", 15 | "mp-greedy-8": "C0", 16 | } 17 | policy_name_to_label = { 18 | "sr-uniform": "Replication", 19 | "mp-greedy-8": "Model Parallelism", 20 | } 21 | 22 | def run_case(case_id=1, mode="simulate", parallel=False): 23 | policies = ["mp-greedy-8", "sr-uniform"] 24 | num_devices = 8 25 | model_type = "bert-2.6b" 26 | mem_budget = 13 * GB 27 | num_models = 8 28 | arrival_process = "gamma" 29 | rate_distribution = "uniform" 30 | slo_scale = np.inf 31 | total_rate = 30 32 | arrival_process_kwargs = {"cv": 3.0} 33 | overheads = list(np.linspace(1.0, 1.5, 6)) + [None] 34 | duration = 500 35 | results = [] 36 | slo_scales = np.linspace(1.0, 20, 20) 37 | for overhead in overheads: 38 | if overhead is None: 39 | prof_database = None 40 | policy_name = "sr-uniform" 41 | else: 42 | prof_database = ProfilingDatabase("profiling_result.pkl") 43 | single_device_latency = (prof_database.results[model_type] 44 | .para_dict[(1,1,1)].latency[1][0]) 45 | (prof_database.results[model_type] 46 | .para_dict[(1, 1, num_devices)].latency[1]) = [ 47 | overhead * single_device_latency / num_devices] * num_devices 48 | policy_name = "mp-greedy-8" 49 | cases = [] 50 | for slo_scale in slo_scales: 51 | cases.append(EqualModelCase( 52 | None, 53 | num_devices, mem_budget, model_type, num_models, 54 | total_rate, rate_distribution, 55 | arrival_process, arrival_process_kwargs, 56 | slo_scale, duration, policy_name, 57 | None, None, None, None)) 58 | 59 | all_results = run_equal_model_cases(cases, 60 | output_file=None, 61 | mode=mode, 62 | parallel=parallel, 63 | prof_database=prof_database, 64 | return_stats_and_placement=True) 65 | stats = [result[0] for result in all_results] 66 | results.append((policy_name, overhead, slo_scales, stats)) 67 | 68 | 69 | with open(f"changing_pipeline_overhead_{case_id}.pkl", "wb") as f: 70 | pickle.dump(results, f) 71 | 72 | def plot_case(case_id=1): 73 | with open(f"changing_pipeline_overhead_{case_id}.pkl", "rb") as f: 74 | results = pickle.load(f) 75 | 76 | plt.figure(figsize=(3, 2)) 77 | for policy_name, overhead, slo_scales, stats in results: 78 | x = slo_scales 79 | label = policy_name_to_label[policy_name] + ("" if overhead is None else f" ($\\alpha$={overhead})") 80 | y = [] 81 | for stat in stats: 82 | y.append(stat.goodput * 100) 83 | alpha = 1 - (overhead - 1) * 1.6 if overhead is not None else 1 84 | plt.plot(x, y, '.-', label=label, alpha = alpha, color = color_dict[policy_name]) 85 | plt.xlabel("SLO Scale") 86 | plt.ylabel("SLO Attainment (%)") 87 | plt.grid() 88 | plt.legend(prop={'size': 5.5}) 89 | plt.tight_layout() 90 | plt.savefig(f"changing_pipeline_overhead_{case_id}.pdf", bbox_inches=Bbox([[0, 0], [3, 2.25]])) 91 | # plt.show() 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument("--parallel", action="store_true") 97 | parser.add_argument("--mode", choices=["simulate", "run"], 98 | default="simulate") 99 | 100 | args = parser.parse_args() 101 | run_case(case_id=1, mode=args.mode, parallel=args.parallel) 102 | plot_case(case_id=1) 103 | -------------------------------------------------------------------------------- /experiments/motivation/memory_budget_vs_latency.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import argparse 4 | import matplotlib as mpl 5 | mpl.use('Pdf') 6 | import matplotlib.pyplot as plt 7 | 8 | from benchmarks.alpa.equal_model_case import EqualModelCase, run_equal_model_cases 9 | from alpa_serve.util import GB 10 | from alpa_serve.profiling import ProfilingDatabase 11 | 12 | color_dict = { 13 | "sr-greedy": "C1", 14 | "mp-search": "C0", 15 | } 16 | policy_name_to_label = { 17 | "sr": "Replication", 18 | "mp": "Model Parallelism", 19 | } 20 | 21 | 22 | def run_case(case_id=1, mode="simulate", parallel=False): 23 | prof_database = ProfilingDatabase("profiling_result.pkl") 24 | num_devices = 8 25 | num_models = 8 26 | arrival_process = "gamma" 27 | rate_distribution = "uniform" 28 | arrival_process_kwargs = {"cv": 3.0} 29 | slo_scale = np.inf 30 | duration = 20000 31 | if case_id == 1: 32 | total_rate = 30 33 | model_type = "bert-1.3b" 34 | mp_mem_budgets = [3.5 * GB, 6 * GB, 11 * GB, 21.6 * GB] 35 | mp_policies = ["mp-greedy-8", "mp-greedy-4", "mp-greedy-2", "sr-uniform"] 36 | sr_mem_budgets = [2.7 * GB * i for i in range(1, 9)] 37 | sr_policies = ["sr-uniform"] * len(sr_mem_budgets) 38 | elif case_id == 2: 39 | total_rate = 20 40 | model_type = "bert-2.6b" 41 | mp_mem_budgets = [7 * GB, 12 * GB, 22 * GB, 44 * GB] 42 | mp_policies = ["mp-greedy-8", "mp-greedy-4", "mp-greedy-2", "sr-uniform"] 43 | sr_mem_budgets = [5.5 * GB * i for i in range(1, 9)] 44 | sr_policies = ["sr-uniform"] * len(sr_mem_budgets) 45 | 46 | cases = [] 47 | for policy_name, mem_budget in zip(mp_policies, mp_mem_budgets): 48 | cases.append(EqualModelCase( 49 | None, 50 | num_devices, mem_budget, model_type, num_models, 51 | total_rate, rate_distribution, 52 | arrival_process, arrival_process_kwargs, 53 | slo_scale, duration, policy_name, 54 | None, None, None, None)) 55 | 56 | for policy_name, mem_budget in zip(sr_policies, sr_mem_budgets): 57 | cases.append(EqualModelCase( 58 | None, 59 | num_devices, mem_budget, model_type, num_models, 60 | total_rate, rate_distribution, 61 | arrival_process, arrival_process_kwargs, 62 | slo_scale, duration, policy_name, 63 | None, None, None, None)) 64 | 65 | 66 | results = run_equal_model_cases(cases, 67 | output_file=None, 68 | mode=mode, 69 | parallel=parallel, 70 | return_stats_and_placement=True) 71 | 72 | stats = [result[0] for result in results] 73 | results = ((mp_mem_budgets, mp_policies, sr_mem_budgets, sr_policies), stats) 74 | with open(f"memory_budget_vs_latency_results_{case_id}.pkl", "wb") as f: 75 | pickle.dump(results, f) 76 | 77 | def get_latency_percentile(stat, p=99): 78 | all_latencies = [] 79 | for per_model_stat in stat.per_model_stats: 80 | all_latencies.extend(per_model_stat.latency) 81 | return np.percentile(all_latencies, p) 82 | 83 | def plot_case(case_id=1): 84 | with open(f"memory_budget_vs_latency_results_{case_id}.pkl", "rb") as f: 85 | ((mp_mem_budgets, mp_policies, sr_mem_budgets, sr_policies), stats) = pickle.load(f) 86 | mp_x = np.array(mp_mem_budgets) / GB 87 | mp_y = [] 88 | mp_y_p99 = [] 89 | mp_stats = stats[:len(mp_policies)] 90 | for stat in mp_stats: 91 | mp_y.append(stat.latency_mean) 92 | mp_y_p99.append(get_latency_percentile(stat, 99)) 93 | 94 | sr_x = np.array(sr_mem_budgets) / GB 95 | sr_y = [] 96 | sr_stats = stats[len(mp_policies):] 97 | sr_y_p99 = [] 98 | for stat in sr_stats: 99 | sr_y.append(stat.latency_mean) 100 | sr_y_p99.append(get_latency_percentile(stat, 99)) 101 | 102 | plt.figure(figsize=(3, 2)) 103 | plt.plot(mp_x, mp_y, '.-', label=policy_name_to_label["mp"]) 104 | plt.plot(sr_x, sr_y, '.-', label=policy_name_to_label["sr"]) 105 | plt.axvline(13, linestyle='--', color = "black", label = "GPU Memory Bound", linewidth=0.75) 106 | plt.xlabel("Memory Budget (GB)") 107 | plt.ylabel("Mean Latency (s)") 108 | plt.grid() 109 | plt.legend(prop={'size': 8}) 110 | plt.tight_layout() 111 | plt.savefig(f"memory_budget_vs_latency_mean_latency_{case_id}.pdf") 112 | 113 | plt.figure(figsize=(3, 2)) 114 | plt.plot(mp_x, mp_y_p99, '.-', label=policy_name_to_label["mp"]) 115 | plt.plot(sr_x, sr_y_p99, '.-', label=policy_name_to_label["sr"]) 116 | plt.axvline(13, linestyle='--', color = "black", label = "GPU Memory Bound", linewidth=0.75) 117 | plt.xlabel("Memory Budget (GB)") 118 | plt.ylabel("P99 Latency (s)") 119 | plt.grid() 120 | plt.legend(prop={'size': 8}) 121 | plt.tight_layout() 122 | plt.savefig(f"memory_budget_vs_latency_p99_latency_{case_id}.pdf") 123 | 124 | 125 | if __name__ == "__main__": 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument("--parallel", action="store_true") 128 | parser.add_argument("--mode", choices=["simulate", "run"], 129 | default="simulate") 130 | 131 | args = parser.parse_args() 132 | run_case(case_id=1, mode=args.mode, parallel=args.parallel) 133 | plot_case(case_id=1) 134 | run_case(case_id=2, mode=args.mode, parallel=args.parallel) 135 | plot_case(case_id=2) 136 | -------------------------------------------------------------------------------- /experiments/motivation/model_parallel_latency_throughput.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import argparse 4 | import matplotlib as mpl 5 | mpl.use('Pdf') 6 | import matplotlib.pyplot as plt 7 | from alpa_serve.util import GB 8 | 9 | def get_latency_and_throughput(all_latency): 10 | latency = sum(all_latency) 11 | throughput = 1 / max(all_latency) 12 | return latency, throughput 13 | 14 | def plot_database(database, model_name="bert-2.6b"): 15 | n_gpus = [1, 2, 4, 8] 16 | pp_latency = [] 17 | pp_throughput = [] 18 | pp_memory = [] 19 | op_latency = [] 20 | op_throughput = [] 21 | op_memory = [] 22 | dp_latency = [] 23 | dp_throughput = [] 24 | dp_memory = [] 25 | for n in n_gpus: 26 | pp_all_latency = database[model_name].para_dict[(1, 1, n)].latency[1] 27 | pp_latency.append(get_latency_and_throughput(pp_all_latency)[0]) 28 | pp_throughput.append(get_latency_and_throughput(pp_all_latency)[1]) 29 | pp_memory.append(sum(database[model_name].para_dict[(1, 1, n)].weight_mem) / GB) 30 | op_all_latency = database[model_name].para_dict[(1, n, 1)].latency[1] 31 | op_latency.append(get_latency_and_throughput(op_all_latency)[0]) 32 | op_memory.append(sum(database[model_name].para_dict[(1, n, 1)].weight_mem) * n / GB) 33 | op_throughput.append(get_latency_and_throughput(op_all_latency)[1]) 34 | dp_all_latency = database[model_name].para_dict[(1, 1, 1)].latency[1] 35 | dp_latency.append(dp_all_latency) 36 | dp_throughput.append(n / max(dp_all_latency)) 37 | dp_memory.append(sum(database[model_name].para_dict[(1, 1, 1)].weight_mem) * n / GB) 38 | 39 | plt.figure(figsize=(3, 2)) 40 | plt.plot(n_gpus, pp_latency, '.-', label="Inter-op Parallelism") 41 | plt.plot(n_gpus, op_latency, '.-', label="Intra-op Parallelism") 42 | plt.plot(n_gpus, dp_latency, '.-', label="Replication") 43 | # plt.axvline(8, linestyle='--', color = "black", label = "Single Node Boundary", linewidth=0.75) 44 | plt.xlabel("#GPUs") 45 | plt.ylim(bottom=0) 46 | plt.ylabel("Latency (s)") 47 | plt.grid() 48 | plt.legend(prop={'size': 7}) 49 | plt.tight_layout() 50 | plt.savefig(f"model_parallel_latency.pdf") 51 | 52 | plt.figure(figsize=(3, 2)) 53 | plt.plot(n_gpus, pp_throughput, '.-', label="Inter-op Parallelism") 54 | plt.plot(n_gpus, op_throughput, '.-', label="Intra-op Parallelism") 55 | plt.plot(n_gpus, dp_throughput, '.-', label="Replication") 56 | # plt.axvline(8, linestyle='--', color = "black", label = "Single Node Boundary", linewidth=0.75) 57 | plt.xlabel("#GPUs") 58 | plt.ylabel("Throughput (req/s)") 59 | plt.grid() 60 | plt.legend(prop={'size': 7}) 61 | plt.tight_layout() 62 | plt.savefig(f"model_parallel_throughput.pdf") 63 | 64 | plt.figure(figsize=(3, 2)) 65 | plt.plot(n_gpus, pp_memory, '.-', label="Inter-op Parallelism") 66 | plt.plot(n_gpus, op_memory, '.-', label="Intra-op Parallelism") 67 | plt.plot(n_gpus, dp_memory, '.-', label="Replication") 68 | # plt.axvline(8, linestyle='--', color = "black", label = "Single Node Boundary", linewidth=0.75) 69 | plt.xlabel("#GPUs") 70 | plt.ylabel("Memory (GB)") 71 | plt.grid() 72 | plt.legend(prop={'size': 7}) 73 | plt.tight_layout() 74 | plt.savefig(f"model_parallel_memory.pdf") 75 | 76 | 77 | if __name__ == "__main__": 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument("--database_filename", type=str, default="profiling_result_long_sequence_manual.pkl") 80 | args = parser.parse_args() 81 | with open(args.database_filename, "rb") as f: 82 | database = pickle.load(f) 83 | 84 | plot_database(database) 85 | -------------------------------------------------------------------------------- /experiments/motivation/overhead_decomposition.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import argparse 4 | import matplotlib as mpl 5 | mpl.use('Pdf') 6 | import matplotlib.pyplot as plt 7 | from alpa_serve.util import GB 8 | 9 | def get_database_data(database, model_name="bert-2.6b", op=True): 10 | n_gpus = [1, 2, 4, 8] 11 | n_gpus_str = [str(x) for x in n_gpus] 12 | base_latency = pp_all_latency = database[model_name].para_dict[(1, 1, 1)].latency[1][0] 13 | pp_comm = [] 14 | pp_uneven = [] 15 | op_comm = [] 16 | for n in n_gpus: 17 | pp_all_latency = database[model_name].para_dict[(1, 1, n)].latency[1] 18 | comm = max(sum(pp_all_latency) - base_latency, 0) 19 | uneven = max(max(pp_all_latency) * n - comm - base_latency, 0) 20 | pp_comm.append(comm) 21 | pp_uneven.append(uneven) 22 | if op: 23 | op_all_latency = database[model_name].para_dict[(1, n, 1)].latency[1] 24 | comm = op_all_latency[0] - base_latency / n 25 | op_comm.append(comm) 26 | return n_gpus, n_gpus_str, base_latency, pp_comm, pp_uneven, op_comm 27 | 28 | 29 | def plot_one_database(n_gpus, n_gpus_str, base_latency, pp_comm, pp_uneven, op_comm): 30 | plt.figure(figsize=(3, 2.5)) 31 | plt.grid(axis="y") 32 | plt.bar(n_gpus_str, [base_latency] * len(n_gpus), label="Compuation", width=0.3) 33 | plt.bar(n_gpus_str, pp_comm, label="Communication Overhead", bottom = [base_latency] * len(n_gpus), width=0.3) 34 | plt.bar(n_gpus_str, pp_uneven, label="Uneven Partition Overhead", bottom = [base_latency + x for x in pp_comm], width=0.3) 35 | plt.xlabel("Number of GPUs") 36 | plt.ylabel("Latency (s)") 37 | plt.legend(prop={'size': 7}) 38 | plt.tight_layout() 39 | plt.savefig(f"overhead_decomposition_pp.pdf") 40 | 41 | plt.figure(figsize=(3, 2.5)) 42 | plt.grid(axis="y") 43 | plt.bar(n_gpus_str, [base_latency / n for n in n_gpus], label="Compuation", width=0.3) 44 | plt.bar(n_gpus_str, op_comm, label="Communication Overhead", bottom = [base_latency / n for n in n_gpus], width=0.3) 45 | plt.xlabel("Number of GPUs") 46 | plt.ylabel("Latency (s)") 47 | plt.ylim(0, 0.25) 48 | plt.legend(prop={'size': 7}) 49 | plt.tight_layout() 50 | plt.savefig(f"overhead_decomposition_op.pdf") 51 | 52 | def plot_two_databases(manual_result, dp_result, model_name="bert-2.6b", ylim=(0.2, 0.3)): 53 | plt.figure(figsize=(3, 2.5)) 54 | n_gpus, n_gpus_str, base_latency, pp_comm, pp_uneven, _ = manual_result 55 | x = np.arange(len(n_gpus)) 56 | width = 0.3 57 | ax = plt.gca() 58 | print(model_name, "manual", pp_comm[-1] + pp_uneven[-1]) 59 | ax.bar(x - width/2, [base_latency] * len(n_gpus), width=0.3, color="C0", alpha=0.5) 60 | ax.bar(x - width/2, pp_comm, bottom = [base_latency] * len(n_gpus), width=0.3, color="C1", alpha=0.5) 61 | ax.bar(x - width/2, pp_uneven, bottom = [base_latency + x for x in pp_comm], width=0.3, color="C2", alpha=0.5) 62 | _, _, _, pp_comm, pp_uneven, _ = dp_result 63 | print(model_name, "dp", pp_comm[-1] + pp_uneven[-1]) 64 | ax.bar(x + width/2, [base_latency] * len(n_gpus), width=0.3, color="C0", label="Compuation") 65 | ax.bar(x + width/2, pp_comm, bottom = [base_latency] * len(n_gpus), width=0.3, color="C1", label="Communication Overhead") 66 | ax.bar(x + width/2, pp_uneven, bottom = [base_latency + x for x in pp_comm], width=0.3, color="C2", label="Uneven Partition Overhead") 67 | ax.set_xticks(x) 68 | ax.set_xticklabels(n_gpus_str) 69 | ax.grid(axis="y") 70 | plt.ylim(*ylim) 71 | plt.xlabel("Number of GPUs") 72 | plt.ylabel("Latency (s)") 73 | plt.legend(loc="upper left", prop={'size': 7}) 74 | plt.tight_layout() 75 | plt.savefig(f"overhead_decomposition_pp_compare_{model_name}.pdf") 76 | 77 | if __name__ == "__main__": 78 | with open("profiling_result_long_sequence_manual.pkl", "rb") as f: 79 | manual_database = pickle.load(f) 80 | with open("profiling_result_long_sequence_dp.pkl", "rb") as f: 81 | dp_database = pickle.load(f) 82 | manual_result = get_database_data(manual_database, "bert-2.6b") 83 | dp_result = get_database_data(dp_database, "bert-2.6b") 84 | plot_one_database(*manual_result) 85 | plot_two_databases(manual_result, dp_result, "bert-2.6b", ylim=(0.2, 0.3)) 86 | 87 | manual_result = get_database_data(manual_database, "bert-1.3b", op=False) 88 | dp_result = get_database_data(dp_database, "bert-1.3b", op=False) 89 | plot_two_databases(manual_result, dp_result, "bert-1.3b", ylim=(0.1, 0.25)) 90 | -------------------------------------------------------------------------------- /experiments/motivation/queueing_theory_plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def alpha(r): 5 | return ((-8 + r ** 2 + np.sqrt(64 - 96 * r + 56 * (r ** 2) - 12 * (r ** 3) + r ** 4)) 6 | /(3 * (-2 * r + r ** 2))) 7 | 8 | def beta(r): 9 | return (r - np.sqrt(8 - 4 * r + r ** 2))/(-2 + r) 10 | 11 | if __name__ == "__main__": 12 | x = np.linspace(0.0, 2.0, 1000) 13 | y_alpha = alpha(x) 14 | y_beta = beta(x) 15 | # print("y_alpha", y_alpha) 16 | # print("y_beta", y_beta) 17 | plt.figure(figsize=(3, 2)) 18 | plt.plot(x, y_alpha, label=r"$\alpha$") 19 | plt.plot(x, y_beta, label=r"$\beta$") 20 | plt.xlabel(r"$\lambda D$") 21 | plt.legend() 22 | plt.grid() 23 | plt.tight_layout() 24 | plt.savefig("queueing_theory.pdf") -------------------------------------------------------------------------------- /experiments/robustness/robustness_suite.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | BenchmarkConfig = namedtuple( 5 | "BenchmarkConfig", 6 | [ 7 | "fixed_num_devices", "fixed_num_models", "fixed_slo_scale", # general 8 | "fixed_rate_scale", "fixed_cv_scale", # real trace only 9 | "num_devices_list", "num_models_list", "slo_scales", # general 10 | "rate_list", "cv_list", # synthetic trace only 11 | "rate_scales", "cv_scales", # real trace only 12 | ] 13 | ) 14 | 15 | azure_v1_suite = { 16 | "bert-1.3b": BenchmarkConfig( 17 | fixed_num_devices = 16, 18 | fixed_num_models = 48, 19 | fixed_slo_scale = 5, 20 | fixed_rate_scale = 2e-3, 21 | fixed_cv_scale = 4, 22 | num_devices_list = [8, 16, 24, 32, 40, 48], 23 | num_models_list = [36, 48, 60, 72, 84, 96], 24 | slo_scales = [0.75, 1, 2, 3, 4, 5, 7.5, 10], 25 | rate_list = [8, 16, 32, 64, 96, 128, 160, 192], 26 | cv_list = [1, 2, 4, 6, 8], 27 | rate_scales = [5e-4, 1e-3, 2e-3, 4e-3, 5e-3, 6e-3, 7e-3, 8e-3], 28 | cv_scales = [1, 2, 3, 4, 5, 6, 8], 29 | ), 30 | "bert-2.6b": BenchmarkConfig( 31 | fixed_num_devices = 32, 32 | fixed_num_models = 48, 33 | fixed_slo_scale = 5, 34 | fixed_rate_scale = 2e-3, 35 | fixed_cv_scale = 4, 36 | num_devices_list = [16, 24, 32, 40, 48, 56, 64], 37 | num_models_list = [32, 40, 48, 56, 72, 96, 128], 38 | slo_scales = [0.75, 1, 2, 3, 4, 5, 7.5, 10], 39 | rate_list = [8, 16, 32, 64, 96, 128, 160, 192], 40 | cv_list = [1, 2, 4, 6, 8], 41 | rate_scales = [5e-4, 1e-3, 2e-3, 4e-3, 5e-3, 6e-3, 7e-3, 8e-3], 42 | cv_scales = [1, 2, 3, 4, 5, 6, 8], 43 | ), 44 | "bert-6.7b": BenchmarkConfig( 45 | fixed_num_devices = 64, 46 | fixed_num_models = 48, 47 | fixed_slo_scale = 5, 48 | fixed_rate_scale = 2e-3, 49 | fixed_cv_scale = 4, 50 | num_devices_list=[48, 56, 64, 72, 80, 96, 128], 51 | num_models_list=[40, 48, 64, 72, 84, 96, 108], 52 | slo_scales=[0.75, 1, 1.5, 2, 2.5, 3, 4, 5, 7.5, 10], 53 | rate_list=[8, 16, 32, 64, 96, 128, 160, 192], 54 | cv_list=[1, 2, 4, 6, 8], 55 | rate_scales=[5e-4, 1e-3, 2e-3, 4e-3, 5e-3, 6e-3, 7e-3, 8e-3], 56 | cv_scales=[1, 2, 3, 4, 5, 6, 8], 57 | ), 58 | } 59 | 60 | azure_v2_suite = { 61 | "bert-1.3b": BenchmarkConfig( 62 | fixed_num_devices = 16, 63 | fixed_num_models = 48, 64 | fixed_slo_scale = 5, 65 | fixed_rate_scale = 32, 66 | fixed_cv_scale = 1, 67 | num_devices_list = [4, 8, 12, 16, 20], 68 | num_models_list = [32, 40, 56, 64, 72, 80], 69 | slo_scales = [0.75, 1, 1.5, 2, 3, 4, 5, 7.5], 70 | rate_list = [8, 16, 32, 64, 96, 128, 160, 192], 71 | cv_list = [1, 2, 4, 6, 8], 72 | rate_scales = [1, 4, 16, 32, 48, 72, 96, 128, 256], 73 | cv_scales = [1, 1.5, 2, 2.5, 3, 4, 5], 74 | ), 75 | "bert-2.6b": BenchmarkConfig( 76 | fixed_num_devices = 32, 77 | fixed_num_models = 48, 78 | fixed_slo_scale = 5, 79 | fixed_rate_scale = 32, 80 | fixed_cv_scale = 1, 81 | num_devices_list = [16, 24, 32, 40, 48, 56, 64], 82 | num_models_list = [32, 40, 56, 64, 72, 80], 83 | slo_scales = [0.75, 1, 1.5, 2, 3, 4, 5, 7.5, 10, 20], 84 | rate_list = [8, 16, 32, 64, 96, 128, 160, 192], 85 | cv_list = [1, 2, 4, 6, 8], 86 | rate_scales = [1, 8, 16, 32, 48, 72, 96, 128], 87 | cv_scales = [1, 1.5, 2, 2.5, 3, 4, 5], 88 | ), 89 | "bert-6.7b": BenchmarkConfig( 90 | fixed_num_devices = 64, 91 | fixed_num_models = 48, 92 | fixed_slo_scale = 5, 93 | fixed_rate_scale = 32, 94 | fixed_cv_scale = 1, 95 | num_devices_list=[48, 56, 64, 72, 80, 96, 128], 96 | num_models_list=[40, 48, 56, 72, 84, 96, 108], 97 | slo_scales=[0.75, 1, 1.5, 2, 2.5, 3, 4, 5, 7.5, 10], 98 | rate_list=[8, 16, 32, 64, 96, 128, 160, 192], 99 | cv_list=[1, 2, 4, 6, 8], 100 | rate_scales=[1, 8, 16, 32, 64, 96, 128], 101 | cv_scales=[1, 1.5, 2, 2.5, 3, 3.5, 4, 5], 102 | ), 103 | } 104 | -------------------------------------------------------------------------------- /osdi23_artifact/README.md: -------------------------------------------------------------------------------- 1 | # AlpaServe: Statistical Multiplexing with Model Parallelism for Deep Learning Serving 2 | 3 | This is the artifact for the paper "AlpaServe: Statistical Multiplexing with Model Parallelism for Deep Learning Serving". We are going to reproduce the main results in the paper. 4 | 5 | ## Setup the environment 6 | 7 | Install alpa_serve package by running 8 | 9 | ```shell 10 | pip install -e . 11 | ``` 12 | in the root folder of `mms` project. 13 | 14 | Launch the ray runtime 15 | 16 | ```shell 17 | ray start --head 18 | ``` 19 | 20 | Now you are ready to reproduce all the main results in the paper. 21 | 22 | ### Dataset 23 | 24 | To get and use Azure Function Trace Dataset, read [this instruction](../alpa_serve/trace/README.md). 25 | 26 | ### Profiling results 27 | 28 | Our algorithm relies on the profiling results, which is provided as `profiling_result.pkl` in [this issue](https://github.com/alpa-projects/mms/issues/14). If you want reproduce the profiling results yourself, plese follow [this benchmarking script](https://github.com/alpa-projects/alpa/blob/main/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py) and [this conversion script](https://github.com/alpa-projects/alpa/blob/main/benchmark/alpa/gen_serving_database.py). 29 | 30 | ## End-to-end Results (Section 6.2, Figure. 11) 31 | 32 | Generate data under `sec6_2_data/` (1 hour) 33 | 34 | ```shell 35 | bash gen_data_sec6_2_e2e.sh 36 | ``` 37 | 38 | Currently, the script above only produces the data for the 1st, 2nd, 4th, and 5th columns in Figure. 11. The data for the 3thd and 6th columns, i.e., `S3@MAF1` and `S3@MAF2`, take a long time to run (~ 5 hours), so we provide the data in advance. If you want to reproduce the results, please uncomment the last two commands in `gen_data_sec6_2_e2e.sh`. 39 | 40 | Plot figures under `paper_figures/`. There are four figures `goodput_vs_num_devices.pdf`, `goodput_vs_rate_scale.pdf`, `goodput_vs_cv_scale.pdf`, `goodput_vs_slo_scale.pdf` in total, each represents one row in Figure 11, respectively. 41 | 42 | ``` 43 | python3 plot_sec6_2_e2e.py --pdf 44 | ``` 45 | 46 | ## Serving Very Large Models (Section 6.3, Figure. 12) 47 | 48 | This experiment was originally done on eight p3.16xlarge AWS instances with 64 GPUs and run for over 20 hours. 49 | Due to the limited budget, we provide a simulated version and the accuracy of the simulator is verified in the paper. 50 | 51 | Generate data into `sec6_3_data/large_model_exp.tsv` (5 sec) 52 | 53 | ``` 54 | bash gen_data_sec6_3_large.sh 55 | ``` 56 | 57 | Plot figure `paper_figures/large_model_exp.pdf` 58 | 59 | ``` 60 | python3 plot_sec6_3_large.py --pdf 61 | ``` 62 | 63 | ## Robustness to Changing Traffic Patterns (Section 6.4, Figure. 13) 64 | 65 | Generate data under `sec6_4_data/` (10 min) 66 | 67 | ``` 68 | bash gen_data_sec6_4_robust.sh 69 | ``` 70 | 71 | Plot figure `paper_fugures/robustness.pdf` 72 | 73 | ``` 74 | python3 plot_sec6_4_robust.py --pdf 75 | ``` 76 | 77 | 78 | ## Ablation Study (Section 6.5, Figure. 14) 79 | 80 | Generate data into `sec6_5_data/ablation.tsv` (5 min) 81 | 82 | ``` 83 | bash gen_data_sec6_5_ab.sh 84 | ``` 85 | 86 | Plot figure `paper_figures/ablation.pdf` 87 | 88 | ``` 89 | python3 plot_sec6_5_ab.py --pdf 90 | ``` 91 | 92 | ## Motivation results 93 | 94 | Please refer to [this instruction](../experiments/motivation/README.md) to reproduce the results in the motivation section. -------------------------------------------------------------------------------- /osdi23_artifact/cleanup.sh: -------------------------------------------------------------------------------- 1 | rm paper_figures/*.pdf 2 | rm sec6_2_data/azure_v1_1dot3b.tsv 3 | rm sec6_2_data/azure_v1_6dot7b.tsv 4 | rm sec6_2_data/azure_v2_1dot3b.tsv 5 | rm sec6_2_data/azure_v2_6dot7b.tsv 6 | rm sec6_3_data/*.tsv 7 | rm sec6_4_data/*.tsv 8 | rm sec6_5_data/*.tsv -------------------------------------------------------------------------------- /osdi23_artifact/gen_data_sec6_2_e2e.sh: -------------------------------------------------------------------------------- 1 | python equal_model_exp.py --trace-dir /home/ubuntu/mms/dataset/azure_v1.pkl --exp-ids all --output azure_v1_1dot3b.tsv --exp-name sec6_2_data --workload=azure_v1 --model-type=bert-1.3b --parallel 2 | python equal_model_exp.py --trace-dir /home/ubuntu/mms/dataset/azure_v1.pkl --exp-ids all --output azure_v1_6dot7b.tsv --exp-name sec6_2_data --workload=azure_v1 --model-type=bert-6.7b --parallel 3 | python equal_model_exp.py --trace-dir /home/ubuntu/mms/dataset/azure_v2.pkl --exp-ids all --output azure_v2_1dot3b.tsv --exp-name sec6_2_data --workload=azure_v2 --model-type=bert-1.3b --parallel 4 | python equal_model_exp.py --trace-dir /home/ubuntu/mms/dataset/azure_v2.pkl --exp-ids all --output azure_v2_6dot7b.tsv --exp-name sec6_2_data --workload=azure_v2 --model-type=bert-6.7b --parallel 5 | ## warning, the two following commands take a long time (about 5 hours) to run 6 | # python general_model_exp.py --trace-dir /home/ubuntu/mms/dataset/azure_v1.pkl --exp-ids all --output azure_v1_mixed.tsv --exp-name sec6_2_data --workload=azure_v1 --model-type=mixed --parallel 7 | # python general_model_exp.py --trace-dir /home/ubuntu/mms/dataset/azure_v2.pkl --exp-ids all --output azure_v2_mixed.tsv --exp-name sec6_2_data --workload=azure_v2 --model-type=mixed --parallel -------------------------------------------------------------------------------- /osdi23_artifact/gen_data_sec6_3_large.sh: -------------------------------------------------------------------------------- 1 | python equal_model_exp.py --output large_model_exp.tsv --exp-name sec6_3_data --workload=synthetic --rate 8 --duration 1200 --model-type=bert-103.5b --large-models -------------------------------------------------------------------------------- /osdi23_artifact/gen_data_sec6_4_robust.sh: -------------------------------------------------------------------------------- 1 | python3 robustness_exp.py --trace-dir /home/ubuntu/mms/dataset/azure_v1.pkl --workload azure_v1 --exp-name sec6_4_data --output robustness_exp --model-type bert-1.3b --parallel 2 | -------------------------------------------------------------------------------- /osdi23_artifact/gen_data_sec6_5_ab.sh: -------------------------------------------------------------------------------- 1 | python3 general_model_exp.py --exp-name sec6_5_data --output ablation.tsv --workload=synthetic --model-type=mixed --policy=mp-round-robin --ablation --parallel 2 | python3 general_model_exp.py --exp-name sec6_5_data --output ablation.tsv --workload=synthetic --model-type=mixed --policy=mp-greedy-4 --ablation --parallel 3 | python3 general_model_exp.py --exp-name sec6_5_data --output ablation.tsv --workload=synthetic --model-type=mixed --policy=mp-search-sep --ablation --parallel -------------------------------------------------------------------------------- /osdi23_artifact/general_model_suite.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | BenchmarkConfig = namedtuple( 4 | "BenchmarkConfig", 5 | [ 6 | "fixed_num_devices", "fixed_num_modelset", "fixed_slo_scale", # general 7 | "fixed_rate_scale", "fixed_cv_scale", # real trace only 8 | "num_devices_list", "num_modelset_list", "slo_scales", 9 | "rate_list", "cv_list", # synthetic trace only 10 | "rate_scales", "cv_scales", # real trace only 11 | ] 12 | ) 13 | 14 | synthetic_suite = { 15 | "all_transformers": BenchmarkConfig( 16 | fixed_num_devices = 24, 17 | fixed_num_modelset = 14, 18 | fixed_slo_scale = 5, 19 | num_devices_list = [16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96], 20 | num_modelset_list = [1, 4, 6, 8, 10, 12, 14, 16], 21 | slo_scales = [1, 2, 4, 8, 12, 16], 22 | rate_list = [1, 8, 16, 24, 32, 48, 64, 80], 23 | cv_list = [0.5, 1, 2, 4, 6], 24 | fixed_rate_scale = None, 25 | fixed_cv_scale = None, 26 | rate_scales = None, 27 | cv_scales = None, 28 | ), 29 | "mixed": BenchmarkConfig( 30 | fixed_num_devices = 32, 31 | fixed_num_modelset = 10, 32 | fixed_slo_scale = 5, 33 | num_devices_list = [8, 24, 40, 56, 72, 88, 104, 120, 136, 152, 168], 34 | num_modelset_list = [1, 2, 4, 6, 8, 10, 12, 14, 16], 35 | slo_scales = [1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24], 36 | rate_list = [8, 16, 24, 32, 64, 96, 128], 37 | cv_list = [0.5, 1, 1.5, 2, 3, 4, 5, 6, 7], 38 | fixed_rate_scale = None, 39 | fixed_cv_scale = None, 40 | rate_scales = None, 41 | cv_scales = None, 42 | ), 43 | } 44 | 45 | azure_v1_suite = { 46 | "all_transformers": BenchmarkConfig( 47 | fixed_num_devices = 24, 48 | fixed_num_modelset = 12, 49 | fixed_slo_scale = 5, 50 | fixed_rate_scale = 2e-3, 51 | fixed_cv_scale = 4, 52 | num_devices_list = [8, 16, 24, 32, 40, 48], 53 | num_modelset_list = [8, 12, 16, 20, 24, 28, 32], 54 | slo_scales = [0.75, 1, 2, 3, 4, 5, 7.5, 10], 55 | rate_list = [], 56 | cv_list = [], 57 | rate_scales = [5e-4, 1e-3, 2e-3, 4e-3, 5e-3, 6e-3, 7e-3, 8e-3], 58 | cv_scales = [1, 2, 3, 4, 5, 6, 8], 59 | ), 60 | "mixed": BenchmarkConfig( 61 | fixed_num_devices = 48, 62 | fixed_num_modelset = 12, 63 | fixed_slo_scale = 5, 64 | fixed_rate_scale = 2e-3, 65 | fixed_cv_scale = 4, 66 | num_devices_list = [32, 40, 48, 54, 64, 72, 96], 67 | num_modelset_list = [8, 12, 16, 20, 24, 28, 32], 68 | slo_scales =[0.75, 1, 2, 3, 4, 5, 7.5, 10], 69 | rate_list = [], 70 | cv_list = [], 71 | rate_scales = [5e-4, 1e-3, 2e-3, 4e-3, 5e-3, 6e-3, 7e-3, 8e-3], 72 | cv_scales = [1, 2, 3, 4, 5, 6, 8], 73 | ), 74 | } 75 | 76 | azure_v2_suite = { 77 | "all_transformers": BenchmarkConfig( 78 | fixed_num_devices = 24, 79 | fixed_num_modelset = 12, 80 | fixed_slo_scale = 5, 81 | fixed_rate_scale = 32, 82 | fixed_cv_scale = 1, 83 | num_devices_list = [8, 16, 20, 24, 32, 40, 48], 84 | num_modelset_list = [8, 12, 16, 20, 24, 28, 32], 85 | slo_scales = [0.75, 1, 1.25, 1.5, 2, 2.5, 5, 10], 86 | rate_list = [], 87 | cv_list = [], 88 | rate_scales = [1, 4, 8, 16, 32, 64, 128], 89 | cv_scales = [1, 2, 3, 4, 5, 6, 8], 90 | ), 91 | "mixed": BenchmarkConfig( 92 | fixed_num_devices = 48, 93 | fixed_num_modelset = 12, 94 | fixed_slo_scale = 5, 95 | fixed_rate_scale = 32, 96 | fixed_cv_scale = 1, 97 | num_devices_list = [16, 24, 32, 40, 48, 54, 64], 98 | num_modelset_list = [8, 12, 16, 18, 20, 22], 99 | slo_scales = [0.75, 1, 1.25, 1.5, 2, 2.5, 5], 100 | rate_list = [], 101 | cv_list = [], 102 | rate_scales = [1, 4, 8, 16, 32, 64], 103 | cv_scales = [1, 2, 3, 4, 6, 8], 104 | ), 105 | } 106 | -------------------------------------------------------------------------------- /osdi23_artifact/plot_sec6_3_large.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | 4 | import os 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | 11 | from benchmarks.alpa.equal_model_case import read_equal_model_case_tsv 12 | from benchmarks.alpa.general_model_case import read_general_model_case_tsv 13 | from benchmarks.alpa.plot_various_metrics import show_name, method2color, method2order 14 | 15 | linestyles = ["solid", "dashed", "dashdot", "dotted", (0, (3,5,1,5,1,5))] 16 | methodcolors = ["C2", "C1", "C0", "C3", "C4"] 17 | 18 | def plot_goodput_common(data, threshold, increasing, ax, xlabel, ybottom): 19 | methods = list(data.keys()) 20 | methods.sort(key=lambda x: method2order(x)) 21 | 22 | curves = [] 23 | legends = [] 24 | first_good = [] 25 | x_max = 0 26 | y_max = 0 27 | for i, method in enumerate(methods): 28 | curve = data[method] 29 | xs_, ys_ = zip(*curve.items()) 30 | xs = [x for x, _ in sorted(zip(xs_, ys_))] 31 | ys = [y for _, y in sorted(zip(xs_, ys_))] 32 | ys = np.array(ys) * 100 33 | curve = ax.plot(xs, ys, color=methodcolors[i], marker='.', linestyle=linestyles[i], linewidth=4, markersize=15) 34 | curves.append(curve[0]) 35 | legends.append(show_name(method)) 36 | 37 | if increasing: 38 | iterator = range(len(xs)) 39 | else: 40 | iterator = reversed(range(len(xs))) 41 | 42 | found = False 43 | for i in iterator: 44 | if ys[i] >= threshold * 100: 45 | first_good.append(xs[i]) 46 | found = True 47 | break 48 | if not found: 49 | first_good.append(0) 50 | 51 | x_max = max(x_max, *xs) 52 | y_max = max(y_max, *ys) 53 | 54 | ax.tick_params(axis='both', which='major', labelsize=20) 55 | ax.tick_params(axis='both', which='minor', labelsize=20) 56 | ax.set_ylim(bottom=ybottom, top=max(y_max * 1.02, 100)) 57 | ax.set_xlabel(xlabel, fontsize=20) 58 | ax.grid() 59 | 60 | for i in range(len(methods)): 61 | if first_good[i] == 0: 62 | continue 63 | ax.axvline(first_good[i], color=methodcolors[i], linestyle=":", linewidth=4) 64 | 65 | return curves, legends 66 | 67 | 68 | def plot_goodput(lines, threshold, folder, pdf): 69 | rate_data = defaultdict(lambda: defaultdict(dict)) 70 | cv_data = defaultdict(lambda: defaultdict(dict)) 71 | slo_data = defaultdict(lambda: defaultdict(dict)) 72 | 73 | for line in lines: 74 | if line["exp_name"] == "goodput_vs_rate": 75 | policy, x, goodput = ( 76 | line["policy_name"], line["total_rate"], line["goodput"]) 77 | rate_data[policy][x] = goodput 78 | if line["exp_name"] == "goodput_vs_cv": 79 | policy, x, goodput = ( 80 | line["policy_name"], line["arrival_process_kwargs"]["cv"], line["goodput"]) 81 | cv_data[policy][x] = goodput 82 | if line["exp_name"] == "goodput_vs_slo": 83 | policy, x, goodput = ( 84 | line["policy_name"], line["slo_scale"], line["goodput"]) 85 | slo_data[policy][x] = goodput 86 | 87 | fig, axs = plt.subplots(1, 3) 88 | 89 | datas = [rate_data, cv_data, slo_data] 90 | xlabels = ["Rate (r/s)", "CV", "SLO Scale"] 91 | ybottoms = [60,60,0] 92 | increasings = [False, False, True] 93 | for data, increasing, ax, xlabel, ybottom in zip(datas, increasings, axs, xlabels, ybottoms): 94 | curves, legends = plot_goodput_common(data, threshold, increasing, ax, xlabel, ybottom) 95 | 96 | fig.text(0.07, 0.5, "SLO Attainment (%)", va='center', rotation='vertical', fontsize=20) 97 | fig.legend(curves, legends, loc="upper center", ncol=6, bbox_to_anchor=(0.5, 1.1), fontsize=20) 98 | 99 | if pdf: 100 | output = os.path.join(folder, "large_model_exp.pdf") 101 | else: 102 | output = os.path.join(folder, "large_model_exp.png") 103 | 104 | figure_size = (18, 5) 105 | fig.set_size_inches(figure_size) 106 | fig.savefig(output, bbox_inches='tight') 107 | print(f"Output the plot to {output}") 108 | 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("--input", type=str, default="sec6_3_data/large_model_exp.tsv") 114 | parser.add_argument("--output-dir", type=str, default="paper_figures") 115 | parser.add_argument("--show", action="store_true") 116 | parser.add_argument("--pdf", action="store_true") 117 | args = parser.parse_args() 118 | 119 | os.makedirs(args.output_dir, exist_ok=True) 120 | 121 | threshold = 0.99 122 | 123 | lines = read_equal_model_case_tsv(args.input) 124 | 125 | plot_goodput(lines, threshold, args.output_dir, args.pdf) 126 | -------------------------------------------------------------------------------- /osdi23_artifact/plot_sec6_5_ab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | 4 | import os 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import pickle 11 | 12 | from benchmarks.alpa.equal_model_case import read_equal_model_case_tsv 13 | from benchmarks.alpa.general_model_case import read_general_model_case_tsv 14 | from benchmarks.alpa.plot_various_metrics import show_name, method2color, method2order 15 | 16 | linestyles = ["solid", "dashed", "dashdot", "dotted", (0, (3,5,1,5,1,5))] 17 | 18 | paper_name = { 19 | "mp-round-robin": "Round robin", 20 | "mp-greedy-4": "Greedy placement", 21 | "mp-search-sep": "Greedy placement + Group partitioning", 22 | } 23 | 24 | def plot_goodput_common(data, threshold, increasing, ax, xlabel, ybottom): 25 | methods = list(data.keys()) 26 | # methods.sort(key=lambda x: method2order(x)) 27 | 28 | curves = [] 29 | legends = [] 30 | first_good = [] 31 | x_max = 0 32 | y_max = 0 33 | for i, method in enumerate(methods): 34 | curve = data[method] 35 | xs_, ys_ = zip(*curve.items()) 36 | xs = [x for x, _ in sorted(zip(xs_, ys_))] 37 | ys = [y for _, y in sorted(zip(xs_, ys_))] 38 | ys = np.array(ys) * 100 39 | curve = ax.plot(xs, ys, color=method2color(method), marker='.', linestyle=linestyles[i], linewidth=4, markersize=15) 40 | curves.append(curve[0]) 41 | legends.append(paper_name.get(method)) 42 | 43 | if increasing: 44 | iterator = range(len(xs)) 45 | else: 46 | iterator = reversed(range(len(xs))) 47 | 48 | found = False 49 | for i in iterator: 50 | if ys[i] >= threshold * 100: 51 | first_good.append(xs[i]) 52 | found = True 53 | break 54 | if not found: 55 | first_good.append(0) 56 | 57 | x_max = max(x_max, *xs) 58 | y_max = max(y_max, *ys) 59 | 60 | ax.tick_params(axis='both', which='major', labelsize=20) 61 | ax.tick_params(axis='both', which='minor', labelsize=20) 62 | ax.set_ylim(bottom=ybottom, top=max(y_max * 1.02, 100)) 63 | ax.set_xlabel(xlabel, fontsize=22) 64 | ax.grid() 65 | 66 | ax.legend(curves, legends, fontsize=18.5, loc="lower left") 67 | 68 | for i in range(len(methods)): 69 | if first_good[i] == 0: 70 | continue 71 | ax.axvline(first_good[i], color=method2color(methods[i]), linestyle=":", linewidth=4) 72 | 73 | 74 | def plot_goodput(lines, threshold, folder, pdf): 75 | rate_data = defaultdict(lambda: defaultdict(dict)) 76 | cv_data = defaultdict(lambda: defaultdict(dict)) 77 | 78 | for line in lines: 79 | if line["exp_name"] == "goodput_vs_rate": 80 | policy, x, goodput = ( 81 | line["policy_name"], line["total_rate"], line["goodput"]) 82 | rate_data[policy][x] = goodput 83 | elif line["exp_name"] == "goodput_vs_cv": 84 | policy, x, goodput = ( 85 | line["policy_name"], line["arrival_process_kwargs"]["cv"], line["goodput"]) 86 | cv_data[policy][x] = goodput 87 | else: 88 | continue 89 | 90 | fig, axs = plt.subplots(1, 2) 91 | 92 | datas = [rate_data, cv_data] 93 | xlabels = ["Rate (r/s)", "CV"] 94 | ybottoms = [60,60] 95 | increasings = [False, False] 96 | for data, increasing, ax, xlabel, ybottom in zip(datas, increasings, axs, xlabels, ybottoms): 97 | plot_goodput_common(data, threshold, increasing, ax, xlabel, ybottom) 98 | 99 | fig.text(0.07, 0.5, "SLO Attainment (%)", va='center', rotation='vertical', fontsize=22) 100 | 101 | if pdf: 102 | output = os.path.join(folder, "ablation.pdf") 103 | else: 104 | output = os.path.join(folder, "ablation.png") 105 | 106 | figure_size = (18, 7) 107 | fig.set_size_inches(figure_size) 108 | fig.savefig(output, bbox_inches='tight') 109 | print(f"Output the plot to {output}") 110 | 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--input", type=str, default="sec6_5_data/ablation.tsv") 116 | parser.add_argument("--output-dir", type=str, default="paper_figures") 117 | parser.add_argument("--show", action="store_true") 118 | parser.add_argument("--pdf", action="store_true") 119 | args = parser.parse_args() 120 | 121 | os.makedirs(args.output_dir, exist_ok=True) 122 | 123 | threshold = 0.99 124 | 125 | lines = read_general_model_case_tsv(args.input) 126 | 127 | plot_goodput(lines, threshold, args.output_dir, args.pdf) 128 | -------------------------------------------------------------------------------- /osdi23_artifact/robustness_suite.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | BenchmarkConfig = namedtuple( 5 | "BenchmarkConfig", 6 | [ 7 | "fixed_num_devices", "fixed_num_models", "fixed_slo_scale", # general 8 | "fixed_rate_scale", "fixed_cv_scale", # real trace only 9 | "num_devices_list", "num_models_list", "slo_scales", # general 10 | "rate_list", "cv_list", # synthetic trace only 11 | "rate_scales", "cv_scales", # real trace only 12 | ] 13 | ) 14 | 15 | azure_v1_suite = { 16 | "bert-1.3b": BenchmarkConfig( 17 | fixed_num_devices = 16, 18 | fixed_num_models = 48, 19 | fixed_slo_scale = 5, 20 | fixed_rate_scale = 2e-3, 21 | fixed_cv_scale = 4, 22 | num_devices_list = [8, 16, 24, 32, 40, 48], 23 | num_models_list = [36, 48, 60, 72, 84, 96], 24 | slo_scales = [0.75, 1, 2, 3, 4, 5, 7.5, 10], 25 | rate_list = [8, 16, 32, 64, 96, 128, 160, 192], 26 | cv_list = [1, 2, 4, 6, 8], 27 | rate_scales = [5e-4, 1e-3, 2e-3, 4e-3, 5e-3, 6e-3, 7e-3, 8e-3], 28 | cv_scales = [1, 2, 3, 4, 5, 6, 8], 29 | ), 30 | "bert-2.6b": BenchmarkConfig( 31 | fixed_num_devices = 32, 32 | fixed_num_models = 48, 33 | fixed_slo_scale = 5, 34 | fixed_rate_scale = 2e-3, 35 | fixed_cv_scale = 4, 36 | num_devices_list = [16, 24, 32, 40, 48, 56, 64], 37 | num_models_list = [32, 40, 48, 56, 72, 96, 128], 38 | slo_scales = [0.75, 1, 2, 3, 4, 5, 7.5, 10], 39 | rate_list = [8, 16, 32, 64, 96, 128, 160, 192], 40 | cv_list = [1, 2, 4, 6, 8], 41 | rate_scales = [5e-4, 1e-3, 2e-3, 4e-3, 5e-3, 6e-3, 7e-3, 8e-3], 42 | cv_scales = [1, 2, 3, 4, 5, 6, 8], 43 | ), 44 | "bert-6.7b": BenchmarkConfig( 45 | fixed_num_devices = 64, 46 | fixed_num_models = 48, 47 | fixed_slo_scale = 5, 48 | fixed_rate_scale = 2e-3, 49 | fixed_cv_scale = 4, 50 | num_devices_list=[48, 56, 64, 72, 80, 96, 128], 51 | num_models_list=[40, 48, 64, 72, 84, 96, 108], 52 | slo_scales=[0.75, 1, 1.5, 2, 2.5, 3, 4, 5, 7.5, 10], 53 | rate_list=[8, 16, 32, 64, 96, 128, 160, 192], 54 | cv_list=[1, 2, 4, 6, 8], 55 | rate_scales=[5e-4, 1e-3, 2e-3, 4e-3, 5e-3, 6e-3, 7e-3, 8e-3], 56 | cv_scales=[1, 2, 3, 4, 5, 6, 8], 57 | ), 58 | } 59 | 60 | azure_v2_suite = { 61 | "bert-1.3b": BenchmarkConfig( 62 | fixed_num_devices = 16, 63 | fixed_num_models = 48, 64 | fixed_slo_scale = 5, 65 | fixed_rate_scale = 32, 66 | fixed_cv_scale = 1, 67 | num_devices_list = [4, 8, 12, 16, 20], 68 | num_models_list = [32, 40, 56, 64, 72, 80], 69 | slo_scales = [0.75, 1, 1.5, 2, 3, 4, 5, 7.5], 70 | rate_list = [8, 16, 32, 64, 96, 128, 160, 192], 71 | cv_list = [1, 2, 4, 6, 8], 72 | rate_scales = [1, 4, 16, 32, 48, 72, 96, 128, 256], 73 | cv_scales = [1, 1.5, 2, 2.5, 3, 4, 5], 74 | ), 75 | "bert-2.6b": BenchmarkConfig( 76 | fixed_num_devices = 32, 77 | fixed_num_models = 48, 78 | fixed_slo_scale = 5, 79 | fixed_rate_scale = 32, 80 | fixed_cv_scale = 1, 81 | num_devices_list = [16, 24, 32, 40, 48, 56, 64], 82 | num_models_list = [32, 40, 56, 64, 72, 80], 83 | slo_scales = [0.75, 1, 1.5, 2, 3, 4, 5, 7.5, 10, 20], 84 | rate_list = [8, 16, 32, 64, 96, 128, 160, 192], 85 | cv_list = [1, 2, 4, 6, 8], 86 | rate_scales = [1, 8, 16, 32, 48, 72, 96, 128], 87 | cv_scales = [1, 1.5, 2, 2.5, 3, 4, 5], 88 | ), 89 | "bert-6.7b": BenchmarkConfig( 90 | fixed_num_devices = 64, 91 | fixed_num_models = 48, 92 | fixed_slo_scale = 5, 93 | fixed_rate_scale = 32, 94 | fixed_cv_scale = 1, 95 | num_devices_list=[48, 56, 64, 72, 80, 96, 128], 96 | num_models_list=[40, 48, 56, 72, 84, 96, 108], 97 | slo_scales=[0.75, 1, 1.5, 2, 2.5, 3, 4, 5, 7.5, 10], 98 | rate_list=[8, 16, 32, 64, 96, 128, 160, 192], 99 | cv_list=[1, 2, 4, 6, 8], 100 | rate_scales=[1, 8, 16, 32, 64, 96, 128], 101 | cv_scales=[1, 1.5, 2, 2.5, 3, 3.5, 4, 5], 102 | ), 103 | } 104 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | install_requires = [ 4 | "numba", 5 | "scipy", 6 | "pulp", 7 | "matplotlib", 8 | "starlette", 9 | "uvicorn", 10 | "fastapi", 11 | ] 12 | 13 | setup(name="alpa_serve", 14 | install_requires=install_requires, 15 | packages=find_packages(exclude=["simulator"])) 16 | -------------------------------------------------------------------------------- /tests/run_all.py: -------------------------------------------------------------------------------- 1 | """Run all test cases. 2 | Run each file in a separate process to avoid GPU memory conflicts. 3 | 4 | Usages: 5 | # Run all files 6 | python3 run_all.py 7 | 8 | # Run files whose names contain "pipeline" 9 | python3 run_all.py --run-pattern pipeline 10 | 11 | # Run files whose names contain "shard_parallel" 12 | python3 run_all.py --run-pattern shard_parallel 13 | 14 | # Run files whose names do not contain "torch" 15 | python3 run_all.py --skip-pattern torch 16 | """ 17 | 18 | import argparse 19 | import glob 20 | import multiprocessing 21 | import os 22 | import numpy as np 23 | import time 24 | from typing import Sequence 25 | import unittest 26 | 27 | slow_testcases = set([ 28 | "pipeline_parallel/test_stage_construction_slow.py", 29 | "torch_frontend/test_zhen.py", 30 | ]) 31 | 32 | 33 | def run_unittest_files(files, args): 34 | """Run unit test files one by one in separates processes.""" 35 | os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str( 36 | args.xla_client_mem_fraction) 37 | # Must import alpa after setting the global env 38 | from alpa.util import run_with_timeout 39 | 40 | for filename in files: 41 | if args.run_pattern is not None and args.run_pattern not in filename: 42 | continue 43 | if args.skip_pattern is not None and args.skip_pattern in filename: 44 | continue 45 | if not args.enable_slow_tests and filename in slow_testcases: 46 | continue 47 | 48 | def func(): 49 | ret = unittest.main(module=None, argv=["", "-vb"] + [filename]) 50 | 51 | p = multiprocessing.Process(target=func) 52 | 53 | def run_one_file(): 54 | p.start() 55 | p.join() 56 | 57 | try: 58 | run_with_timeout(run_one_file, timeout=args.time_limit_per_file) 59 | if p.exitcode != 0: 60 | return False 61 | except TimeoutError: 62 | p.terminate() 63 | time.sleep(5) 64 | print(f"\nTimeout after {args.time_limit_per_file} seconds " 65 | f"when running {filename}") 66 | return False 67 | 68 | return True 69 | 70 | 71 | if __name__ == "__main__": 72 | arg_parser = argparse.ArgumentParser() 73 | arg_parser.add_argument( 74 | "--run-pattern", 75 | type=str, 76 | default=None, 77 | help="Run files whose names contain the provided string") 78 | arg_parser.add_argument( 79 | "--skip-pattern", 80 | type=str, 81 | default=None, 82 | help="Do not run files whose names contain the provided string") 83 | arg_parser.add_argument( 84 | "--enable-slow-tests", 85 | action="store_true", 86 | help="Run test cases including profiling, which takes a long time") 87 | arg_parser.add_argument( 88 | "--xla-client-mem-fraction", 89 | type=float, 90 | default=0.25, 91 | help="The fraction of GPU memory used to run unit tests") 92 | arg_parser.add_argument( 93 | "--time-limit-per-file", 94 | type=int, 95 | default=1000, 96 | help="The time limit for running one file in seconds.") 97 | arg_parser.add_argument("--order", 98 | type=str, 99 | default="sorted", 100 | choices=["sorted", "random", "reverse_sorted"]) 101 | args = arg_parser.parse_args() 102 | 103 | files = glob.glob("**/test_*.py", recursive=True) 104 | if args.order == "sorted": 105 | files.sort() 106 | elif args.order == "random": 107 | files = [files[i] for i in np.random.permutation(len(files))] 108 | elif args.order == "reverse_sorted": 109 | files.sort() 110 | files = reversed(files) 111 | 112 | tic = time.time() 113 | success = run_unittest_files(files, args) 114 | 115 | if success: 116 | print(f"Success. Time elapsed: {time.time() - tic:.2f}s") 117 | else: 118 | print(f"Fail. Time elapsed: {time.time() - tic:.2f}s") 119 | 120 | exit(0 if success else -1) 121 | -------------------------------------------------------------------------------- /tests/serve/test_controller.py: -------------------------------------------------------------------------------- 1 | """Test alpa.serve controller.""" 2 | import unittest 3 | 4 | import numpy as np 5 | import ray 6 | import requests 7 | from tokenizers import Tokenizer 8 | 9 | from alpa.api import parallelize 10 | from alpa_serve.controller import run_controller 11 | 12 | 13 | class EchoModel: 14 | 15 | async def handle_request(self, request): 16 | obj = await request.json() 17 | return obj 18 | 19 | 20 | class AddOneModel: 21 | 22 | def __init__(self): 23 | 24 | def func(x): 25 | return x + 1 26 | 27 | self.add_one = parallelize(func) 28 | 29 | async def handle_request(self, request): 30 | obj = await request.json() 31 | x = np.array(obj["x"]) 32 | y = self.add_one(x) 33 | return await y.to_np_async() 34 | 35 | 36 | class TokenizerModel: 37 | 38 | def __init__(self): 39 | self.tokenizer = Tokenizer.from_pretrained("bert-base-uncased") 40 | 41 | async def handle_request(self, request): 42 | obj = await request.json() 43 | x = obj["input"] 44 | y = self.tokenizer.encode(x) 45 | return y.ids 46 | 47 | 48 | class ControllerTest(unittest.TestCase): 49 | 50 | def setUp(self): 51 | ray.init(address="auto") 52 | 53 | def tearDown(self): 54 | ray.shutdown() 55 | 56 | def test_query(self): 57 | controller = run_controller("localhost") 58 | 59 | info = ray.get(controller.get_info.remote()) 60 | host, port, root_path = info["host"], info["port"], info["root_path"] 61 | 62 | controller.register_model.remote("echo", EchoModel) 63 | controller.register_model.remote("add_one", AddOneModel) 64 | controller.register_model.remote("tokenizer", TokenizerModel) 65 | group_id = 0 66 | controller.create_mesh_group_manager.remote(group_id, [1, 4]) 67 | controller.create_replica.remote("echo", group_id) 68 | controller.create_replica.remote("add_one", group_id) 69 | controller.create_replica.remote("tokenizer", group_id) 70 | 71 | controller.sync() 72 | 73 | url = f"http://{host}:{port}{root_path}" 74 | 75 | json = { 76 | "model": "echo", 77 | "task": "completions", 78 | "input": "Paris is the capital city of", 79 | } 80 | resp = requests.post(url=url, json=json) 81 | assert resp.json() == json, f"{resp.json()}" 82 | 83 | resp = requests.post(url=url, 84 | json={ 85 | "model": "add_one", 86 | "x": list(range(16)), 87 | }) 88 | assert resp.text == str(list(range(1, 17))) 89 | 90 | src = "Paris is the capital city of" 91 | resp = requests.post(url=url, json={"model": "tokenizer", "input": src}) 92 | tokenizer = Tokenizer.from_pretrained("bert-base-uncased") 93 | assert resp.text == str(tokenizer.encode(src).ids) 94 | 95 | 96 | def suite(): 97 | suite = unittest.TestSuite() 98 | suite.addTest(ControllerTest("test_query")) 99 | return suite 100 | 101 | 102 | if __name__ == "__main__": 103 | runner = unittest.TextTestRunner() 104 | runner.run(suite()) 105 | -------------------------------------------------------------------------------- /tests/serve/test_placement_policy.py: -------------------------------------------------------------------------------- 1 | """Test placement policy""" 2 | import unittest 3 | 4 | import numpy as np 5 | 6 | from alpa_serve.simulator.controller import Controller 7 | from alpa_serve.placement_policy import (ModelData, ClusterEnv, 8 | SelectiveReplicationGreedy, SelectiveReplicationSearch, 9 | ModelParallelismGreedy, ModelParallelismSearch) 10 | from alpa_serve.profiling import ParallelConfig, load_test_prof_result 11 | from alpa.util import GB 12 | 13 | 14 | class EchoModel: 15 | def __init__(self, parallel_config, virtual_mesh): 16 | pass 17 | 18 | async def handle_request(self, request): 19 | return request 20 | 21 | 22 | class PlacementPolicyTest(unittest.TestCase): 23 | 24 | def test_selective_replication(self): 25 | cluster_env = ClusterEnv(num_devices=4, mem_budget=4.5*GB) 26 | model_datas = [ 27 | ModelData("m0", 1, 5, 1, load_test_prof_result("test-2GB-100ms")), 28 | ModelData("m1", 1, 5, 1, load_test_prof_result("test-2GB-100ms")), 29 | ModelData("m2", 1, 5, 1, load_test_prof_result("test-2GB-100ms")), 30 | ModelData("m3", 1, 5, 1, load_test_prof_result("test-2GB-100ms")), 31 | ] 32 | 33 | for policy in [SelectiveReplicationGreedy(), 34 | SelectiveReplicationSearch(verbose=1)]: 35 | placement, _ = policy.solve_placement( 36 | model_datas, cluster_env) 37 | 38 | # Check result 39 | assert all(g == ParallelConfig(1, 1, 1) for g in placement.group_configs) 40 | for i in range(4): 41 | assert sum(x.count(i) for x in placement.group_models) == 2 42 | 43 | def test_model_parallelism(self): 44 | cluster_env = ClusterEnv(num_devices=4, mem_budget=4.5*GB) 45 | model_datas = [ 46 | ModelData("m0", 1, 5, 1, load_test_prof_result("test-2GB-100ms")), 47 | ModelData("m1", 1, 5, 1, load_test_prof_result("test-2GB-100ms")), 48 | ModelData("m2", 1, 5, 1, load_test_prof_result("test-2GB-100ms")), 49 | ModelData("m3", 1, 5, 1, load_test_prof_result("test-2GB-100ms")), 50 | ] 51 | 52 | for policy in [ModelParallelismGreedy(group_size=2)]: 53 | placement, _ = policy.solve_placement( 54 | model_datas, cluster_env) 55 | 56 | assert len(placement.group_configs) == 2 57 | assert placement.group_configs[0].pp == 2 58 | assert placement.group_configs[1].pp == 2 59 | assert placement.group_models[0] == [0, 1, 2, 3] 60 | assert placement.group_models[1] == [0, 1, 2, 3] 61 | 62 | def test_model_parallelism_search(self): 63 | cluster_env = ClusterEnv(num_devices=4, mem_budget=2.5*GB) 64 | model_datas = [ 65 | ModelData("m0", 0.4, 4, 8, load_test_prof_result("test-2GB-100ms")), 66 | ModelData("m1", 0.4, 4, 8, load_test_prof_result("test-2GB-100ms")), 67 | ModelData("m2", 0.4, 4, 8, load_test_prof_result("test-2GB-100ms")), 68 | ModelData("m3", 0.4, 4, 8, load_test_prof_result("test-2GB-100ms")), 69 | ] 70 | 71 | for policy in [ModelParallelismSearch(verbose=2)]: 72 | placement, _ = policy.solve_placement( 73 | model_datas, cluster_env) 74 | 75 | assert len(placement.group_configs) == 1 76 | assert placement.group_configs[0].pp == 4 77 | assert list(placement.group_models[0]) == [0, 1, 2, 3] 78 | 79 | def test_placement_api(self): 80 | for policy in [SelectiveReplicationGreedy(), ModelParallelismGreedy()]: 81 | controller = Controller() 82 | controller.register_model.remote("m0", EchoModel) 83 | controller.register_model.remote("m1", EchoModel) 84 | controller.register_model.remote("m2", EchoModel) 85 | controller.register_model.remote("m3", EchoModel) 86 | 87 | cluster_env = ClusterEnv(num_devices=4, mem_budget=4.5*GB) 88 | model_datas = [ 89 | ModelData("m0", 1, 5, 1, load_test_prof_result("test-2GB-100ms")), 90 | ModelData("m1", 1, 5, 1, load_test_prof_result("test-2GB-100ms")), 91 | ModelData("m2", 1, 5, 1, load_test_prof_result("test-2GB-100ms")), 92 | ModelData("m3", 1, 5, 1, load_test_prof_result("test-2GB-100ms")), 93 | ] 94 | policy.place_models(controller, cluster_env, model_datas) 95 | 96 | 97 | def suite(): 98 | suite = unittest.TestSuite() 99 | suite.addTest(PlacementPolicyTest("test_selective_replication")) 100 | suite.addTest(PlacementPolicyTest("test_model_parallelism")) 101 | suite.addTest(PlacementPolicyTest("test_model_parallelism_search")) 102 | suite.addTest(PlacementPolicyTest("test_placement_api")) 103 | return suite 104 | 105 | 106 | if __name__ == "__main__": 107 | runner = unittest.TextTestRunner() 108 | runner.run(suite()) 109 | -------------------------------------------------------------------------------- /tests/serve/test_simulator.py: -------------------------------------------------------------------------------- 1 | """Test alpa.serve controller.""" 2 | import asyncio 3 | from functools import partial 4 | import unittest 5 | 6 | import ray 7 | 8 | from alpa_serve.profiling import ParallelConfig, load_test_prof_result 9 | from alpa_serve.controller import run_controller 10 | from alpa_serve.simulator.controller import Controller, Client 11 | from alpa_serve.simulator.event_loop import run_event_loop 12 | from alpa_serve.simulator.executable import Executable 13 | from alpa_serve.simulator.workload import Workload, Request, PoissonProcess 14 | 15 | 16 | class EchoModel: 17 | def __init__(self, virtual_mesh=None): 18 | pass 19 | 20 | async def handle_request(self, request, delay=None): 21 | return request 22 | 23 | 24 | class SimulatorTest(unittest.TestCase): 25 | 26 | async def main_test_query(self, controller): 27 | controller.register_model.remote("echo", EchoModel) 28 | 29 | group_id = 0 30 | controller.create_mesh_group_manager.remote(group_id, [1, 4]) 31 | controller.create_replica.remote("echo", group_id) 32 | 33 | controller.sync() 34 | 35 | request = Request("echo", None, None, 0, {}) 36 | ret = controller.handle_request.remote(request) 37 | assert request == await ret 38 | 39 | def test_query(self): 40 | # Test the simulator 41 | controller = Controller() 42 | run_event_loop(self.main_test_query(controller)) 43 | 44 | # Test the real system 45 | ray.init(address="auto") 46 | controller = run_controller("localhost") 47 | asyncio.run(self.main_test_query(controller)) 48 | 49 | async def main_test_client(self): 50 | controller = Controller() 51 | controller.register_model.remote( 52 | "a", partial(Executable, load_test_prof_result("test-2GB-100ms"))) 53 | 54 | group_id = 0 55 | controller.create_mesh_group_manager.remote(group_id, [1, 2]) 56 | controller.create_replica.remote("a", group_id, 57 | [ParallelConfig(1, 1, 2)]) 58 | 59 | w = PoissonProcess(10).generate_workload("a", 0, 60, slo=0.15) 60 | client = Client(controller) 61 | client.submit_workload(w) 62 | 63 | return client, w 64 | 65 | def test_client(self): 66 | client, w = run_event_loop(self.main_test_client()) 67 | stats = client.compute_stats(w, warmup=10) 68 | Workload.print_stats(stats) 69 | 70 | 71 | def suite(): 72 | suite = unittest.TestSuite() 73 | suite.addTest(SimulatorTest("test_query")) 74 | suite.addTest(SimulatorTest("test_client")) 75 | return suite 76 | 77 | 78 | if __name__ == "__main__": 79 | runner = unittest.TextTestRunner() 80 | runner.run(suite()) 81 | --------------------------------------------------------------------------------