├── .gitignore ├── Dockerfile ├── License ├── README.md ├── aicb.py ├── core ├── __init__.py └── grouped_gemm_util.py ├── download └── AICB_v1.0.deb ├── images ├── detail_log.png ├── readme_01.png ├── result_log.png ├── simai_dingtalk.jpg ├── simai_wechat.jpg ├── time_log.png ├── tutorial_1.png ├── tutorial_2.png ├── tutorial_3.png ├── tutorial_4.png ├── tutorial_5.png ├── tutorial_6.png └── tutorial_7.png ├── log_analyzer ├── __init__.py ├── analyze_res_csv.py ├── ds_comm_log_analyzer.py ├── log.py ├── plot.py └── utils.py ├── results └── visual_output │ └── A100_example.html ├── run_suites.py ├── scripts ├── coll_comm_check.sh ├── deepspeed_llama.sh ├── megatron_gpt.sh ├── megatron_workload_with_aiob.sh └── run_in_cluster.py ├── training └── tutorial.md ├── utils ├── benchmark_logger.py ├── timer.py └── utils.py ├── visualize ├── __init__.py ├── example.html ├── generate.py └── inputs │ └── A100_example.csv ├── workload ├── Workload_spec_v1.1.csv ├── aiob_inputs │ └── Example.txt ├── physical │ ├── micro_test │ │ ├── all_gather_workload.csv │ │ ├── all_reduce_workload.csv │ │ ├── all_to_all_workload.csv │ │ ├── multi_all_reduce_workload.csv │ │ └── reduce_scatter_workload.csv │ └── model_workload │ │ ├── G13B-M1-C01_GPT13B_megatron_tp8_pp1_mbs1.csv │ │ ├── G13B-M1-C02_GPT13B_megatron_tp8_pp1_mbs1_sp.csv │ │ ├── G175B-M1-C03_GPT175B_megatron_tp8_pp16_mbs1.csv │ │ ├── L13B-D1-C03_Llama13B_zero2_mbs1.csv │ │ ├── L13B-D1-C04_Llama13B_zero3_mbs1.csv │ │ ├── L30B-D1-C05_Llama30B_zero2_mbs1.csv │ │ ├── L30B-D1-C06_Llama30B_zero3_mbs1.csv │ │ ├── L65B-D1-C07_Llama65B_zero2_mbs1.csv │ │ ├── L65B-D1-C07_Llama65B_zero3_mbs1.csv │ │ ├── L65B-M1-C05_Llama65B_megatron_tp2_pp8_mbs1.csv │ │ ├── L7B-D1-C01_Llama7B_zero2_mbs1.csv │ │ ├── L7B-D1-C02_Llama7B_zero3_mbs1.csv │ │ └── L7B-M1-C04_Llama7B_megatron_tp2_pp1_mbs1.csv └── simAI │ ├── micro_test │ ├── all_gather.txt │ ├── all_reduce.txt │ ├── all_to_all.txt │ └── muti_all_reduce.txt │ └── model_workload │ ├── G13B-M1-C01_GPT13B_megatron_tp8_pp1_mbs1_A100.txt │ ├── G13B-M1-C02_GPT13B_megatron_tp8_pp1_mbs1_sp_A100.txt │ ├── G175B-M1-C03_GPT175B_megatron_tp8_pp1_mbs1_A100.txt │ ├── L65B-M1-C05_Llama65B_megatron_tp8_pp1_mbs1_A100.txt │ ├── L65B_D1_C08_Llama65B_deepspeed_zero3_A100.txt │ ├── L7B-D1-C02_Llama7B_deepspeed_zero3_A100.txt │ └── L7B-M1-C04_Llama7B_megatron_tp2_pp1_mbs1_A100.txt ├── workload_applyer.py └── workload_generator ├── AIOB_simAI_workload_generator.py ├── __init__.py ├── analysis_pytorch_trace.py ├── generate_collective_test.py ├── generate_deepspeed_stage1_2_workload.py ├── generate_deepspeed_stage3_workload.py ├── generate_ds_trace_replay_workload.py ├── generate_megatron_workload.py ├── mocked_model ├── AiobMegatron.py ├── MockedDeepspeed.py ├── MockedMegatron.py ├── MockedModel.py └── __init__.py └── workload_generator.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *local* 3 | .ipynb_checkpoints 4 | .pytest_cache 5 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.08-py3 2 | 3 | WORKDIR /workspace/AICB 4 | 5 | Copy . /workspace/AICB 6 | 7 | RUN mv ./workload_generator /usr/local/lib/python3.10/dist-packages &&\ 8 | mv ./utils /usr/local/lib/python3.10/dist-packages &&\ 9 | mv ./log_analyzer /usr/local/lib/python3.10/dist-packages &&\ 10 | pip install git+https://github.com/fanshiqing/grouped_gemm@v1.0 \ 11 | pip3 install einops 12 | 13 | 14 | -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | /* 2 | *Copyright (c) 2021, Alibaba Group; 3 | *Licensed under the Apache License, Version 2.0 (the "License"); 4 | *you may not use this file except in compliance with the License. 5 | *You may obtain a copy of the License at 6 | 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | *Unless required by applicable law or agreed to in writing, software 10 | *distributed under the License is distributed on an "AS IS" BASIS, 11 | *WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | *See the License for the specific language governing permissions and 13 | *limitations under the License. 14 | */ -------------------------------------------------------------------------------- /aicb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | import torch 14 | from utils.utils import get_args, get_comp_out, extract_averages, Comp_with_aiob 15 | from utils.benchmark_logger import bench_logger 16 | from workload_generator.mocked_model.MockedDeepspeed import DeepspeedForCausalLM 17 | from workload_generator.mocked_model.MockedMegatron import MegatronModel 18 | from workload_generator.generate_deepspeed_stage1_2_workload import ( 19 | DeepSpeedStage1, 20 | DeepSpeedStage2, 21 | ) 22 | from workload_generator.generate_deepspeed_stage3_workload import DeepSpeedStage3 23 | from workload_generator.generate_megatron_workload import MegatronWorkload 24 | from workload_generator.generate_collective_test import Collective_Test 25 | from workload_applyer import WorkloadApplyer 26 | from utils.utils import * 27 | 28 | if __name__ == "__main__": 29 | args = get_args() 30 | if not hasattr(args, "backend"): 31 | args.backend = "nccl" 32 | torch.distributed.init_process_group(backend=args.backend) 33 | args.world_size = torch.distributed.get_world_size() 34 | args.rank = torch.distributed.get_rank() 35 | if args.frame == "Megatron": 36 | model = MegatronModel(args) 37 | workload_generator = MegatronWorkload(args, model) 38 | elif args.frame == "DeepSpeed": 39 | model = DeepspeedForCausalLM(args) 40 | if args.stage == 1: 41 | workload_generator = DeepSpeedStage1(args, model) 42 | elif args.stage == 2: 43 | workload_generator = DeepSpeedStage2(args, model) 44 | elif args.stage == 3: 45 | workload_generator = DeepSpeedStage3(args, model) 46 | elif args.frame == "collective_test": 47 | workload_generator = Collective_Test(args, None) 48 | workload = workload_generator() 49 | if args.aiob_enable and args.frame == "Megatron": 50 | 51 | params = model.parameters() 52 | args.model_param = sum(p.numel() for p in params) 53 | args.activation_memory = 0 54 | for sub_module in model.child_modules(): 55 | if hasattr(sub_module, "activation_memory"): 56 | args.activation_memory += sub_module.activation_memory() 57 | print("model_param:", args.model_param) 58 | if args.comp_filepath == None: 59 | local_rank = torch.distributed.get_rank() % torch.cuda.device_count() 60 | if local_rank == 0: 61 | filepath = get_comp_out(args) 62 | else: 63 | filepath = get_aiob_path(args) 64 | torch.distributed.barrier() 65 | compute_cache = extract_averages(filepath,args) 66 | else: 67 | print("comp_filepath:", args.comp_filepath) 68 | compute_cache = extract_averages(args.comp_filepath,args) 69 | workload = Comp_with_aiob(workload, compute_cache) 70 | if torch.distributed.get_rank() == 0: 71 | filename = f"{workload_generator.name}_{args.model_name}_sp_{args.enable_sequence_parallel}_iteration_{args.epoch_num}_computationEnable_{args.computation_enable}_{args.world_size}n.csv" 72 | workload.dump(filename) 73 | if not args.workload_only : 74 | applyer = WorkloadApplyer(workload=workload, args=args) 75 | cpu_time = applyer.apply_workload() 76 | if torch.distributed.get_rank() == 0: 77 | bench_logger.analyze_comm_log() 78 | if args.frame != "collective_test": 79 | bench_logger.analyze_comm_time() 80 | csv_filename = bench_logger.dump_log(filename) 81 | if args.enable_visual: 82 | try: 83 | from visualize.generate import visualize_output 84 | visualize_output(csv_filename,False) 85 | except ImportError: 86 | print("visualize_output is not available because required library is not found") 87 | 88 | print( 89 | f"total time for {args.frame} and {args.epoch_num} iterations is {cpu_time:.4f} s" 90 | ) 91 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/core/__init__.py -------------------------------------------------------------------------------- /core/grouped_gemm_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | 3 | try: 4 | import grouped_gemm 5 | except ImportError: 6 | grouped_gemm = None 7 | 8 | 9 | def grouped_gemm_is_available(): 10 | return grouped_gemm is not None 11 | 12 | 13 | def assert_grouped_gemm_is_available(): 14 | assert grouped_gemm_is_available(), ( 15 | "Grouped GEMM is not available. Please run " 16 | "`pip install git+https://github.com/fanshiqing/grouped_gemm@v1.0`." 17 | ) 18 | 19 | 20 | ops = grouped_gemm.ops if grouped_gemm_is_available() else None 21 | -------------------------------------------------------------------------------- /download/AICB_v1.0.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/download/AICB_v1.0.deb -------------------------------------------------------------------------------- /images/detail_log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/images/detail_log.png -------------------------------------------------------------------------------- /images/readme_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/images/readme_01.png -------------------------------------------------------------------------------- /images/result_log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/images/result_log.png -------------------------------------------------------------------------------- /images/simai_dingtalk.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/images/simai_dingtalk.jpg -------------------------------------------------------------------------------- /images/simai_wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/images/simai_wechat.jpg -------------------------------------------------------------------------------- /images/time_log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/images/time_log.png -------------------------------------------------------------------------------- /images/tutorial_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/images/tutorial_1.png -------------------------------------------------------------------------------- /images/tutorial_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/images/tutorial_2.png -------------------------------------------------------------------------------- /images/tutorial_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/images/tutorial_3.png -------------------------------------------------------------------------------- /images/tutorial_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/images/tutorial_4.png -------------------------------------------------------------------------------- /images/tutorial_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/images/tutorial_5.png -------------------------------------------------------------------------------- /images/tutorial_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/images/tutorial_6.png -------------------------------------------------------------------------------- /images/tutorial_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/images/tutorial_7.png -------------------------------------------------------------------------------- /log_analyzer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/log_analyzer/__init__.py -------------------------------------------------------------------------------- /log_analyzer/analyze_res_csv.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from log_analyzer.utils import convert_msg_to_size, convert_size_to_msg 3 | import sys 4 | 5 | def analyze_csv(file_path): 6 | df = pd.read_csv(file_path) 7 | 8 | df = df.dropna(subset=['busbw']) 9 | 10 | df['busbw'] = pd.to_numeric(df['busbw'], errors='coerce') 11 | 12 | df = df.dropna(subset=['busbw']) 13 | 14 | def exclude_min(group): 15 | if len(group) > 1: 16 | group = group.sort_values(by='busbw') 17 | return group.iloc[2:] 18 | return group 19 | 20 | df_excluded_min = df.groupby(['comm_type', 'comm_group', 'msg_size']).apply(exclude_min).reset_index(drop=True) 21 | grouped = df_excluded_min.groupby(['comm_type', 'comm_group', 'msg_size']).agg( 22 | busbw_mean=('busbw', 'mean'), 23 | busbw_max=('busbw', 'max'), 24 | busbw_min=('busbw', 'min'), 25 | busbw_std=('busbw', 'std'), 26 | occurrence_count=('busbw', 'size') 27 | ).reset_index() 28 | grouped['msg_size'] = grouped['msg_size'].apply(convert_size_to_msg) 29 | return grouped 30 | 31 | 32 | if __name__ == '__main__': 33 | if len(sys.argv) < 2: 34 | print("Usage: python -m log_analyzer.analyze_res_csv ") 35 | sys.exit(1) 36 | grouped = analyze_csv(sys.argv[1]) 37 | print(grouped) 38 | 39 | -------------------------------------------------------------------------------- /log_analyzer/ds_comm_log_analyzer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | # /usr/bin/python3 15 | from utils.utils import CommType, CommGroup 16 | from log_analyzer.utils import convert_msg_to_size, convert_size_to_msg 17 | from utils.benchmark_logger import BenchLogger 18 | from log_analyzer.log import LogItem, Log 19 | 20 | COMM_OP = "comm op" 21 | CALLER_FUNC = "Caller Func" 22 | TIME_MS = "time (ms)" 23 | MSG_SIZE = "msg size" 24 | LOG_STARTER = "[rank 0]" 25 | WORLD_SIZE = 16 26 | TP_SIZE = 4 27 | DP_SIZE = 4 28 | # LOG_STARTER = "[INFO] " 29 | 30 | 31 | def clean_s(s): 32 | return s.strip("[]\n\t ") 33 | 34 | 35 | def string2comm_type(s): 36 | if "all_gather" in s: 37 | return CommType.all_gather 38 | if "reduce_scatter" in s: 39 | return CommType.reduce_scatter 40 | if "all_reduce" in s: 41 | return CommType.all_reduce 42 | if "broadcast" in s: 43 | return CommType.broadcast 44 | if "barrier" in s: 45 | return CommType.barrier 46 | if "reduce" in s: 47 | return CommType.reduce 48 | print(f"WARNING cannot convert {s} to CommType") 49 | return CommType.epoch_end 50 | 51 | 52 | def parse_ds_log_item(line): 53 | index = line.lower().find(LOG_STARTER) 54 | if index == -1: 55 | return None 56 | item_list = line[index + len(LOG_STARTER) :].split("|") 57 | item = {} 58 | for raw_item in item_list: 59 | if "epoch" in raw_item: 60 | split_text = raw_item.split() 61 | numbers = [word for word in split_text if word.isdigit()] 62 | item["epoch_num"] = int(numbers[0]) 63 | continue 64 | if "micro_step" in raw_item: 65 | split_text = raw_item.split() 66 | numbers = [word for word in split_text if word.replace(".", "").isdigit()] 67 | item["iter_time"] = float(numbers[0]) 68 | continue 69 | if ":" not in raw_item: 70 | continue 71 | key, value = raw_item.split(":") 72 | key, value = clean_s(key), clean_s(value) 73 | if key == COMM_OP: 74 | item["comm_type"] = string2comm_type(value) 75 | elif key == MSG_SIZE or MSG_SIZE in key: 76 | item["msg_size"] = convert_msg_to_size(value) 77 | elif key == CALLER_FUNC: 78 | item["stage"] = value 79 | elif key == TIME_MS or TIME_MS in key: 80 | item["elapsed_time"] = float(value) 81 | if key == "group": 82 | group = eval(value) 83 | if len(group) == WORLD_SIZE: 84 | item["group"] = CommGroup.all 85 | elif len(group) == TP_SIZE: 86 | item["group"] = CommGroup.tp_group 87 | elif len(group) == DP_SIZE: 88 | item["group"] = CommGroup.dp_group 89 | elif "algbw" in key: 90 | item["algbw"] = float(value) 91 | elif "busbw" in key: 92 | item["busbw"] = float(value) 93 | else: 94 | try: 95 | item[key] = float(value) 96 | except: 97 | item[key] = value 98 | return item 99 | 100 | 101 | def parse_ds_comm_log(filename): 102 | comm_log = Log() 103 | with open(filename, "r") as f: 104 | lines = f.read().split("\n") 105 | for line in lines: 106 | if "After initializing ZeRO optimizer" in line: 107 | comm_log.add_comm_log(LogItem(comm_type=CommType.epoch_end)) 108 | continue 109 | elif "microstep" in line: 110 | comm_log.add_comm_log(LogItem(comm_type=CommType.epoch_end)) 111 | continue 112 | log = parse_ds_log_item(line) 113 | if log is None: 114 | continue 115 | if "comm_type" in log: 116 | log_item = LogItem( 117 | comm_type=log["comm_type"], 118 | comm_group=log.get("group", CommGroup.dp_group), 119 | msg_size=log["msg_size"], 120 | ) 121 | log_item._elapsed_time = log.get("elapsed_time", -1) 122 | log_item.algbw, log_item.busbw = log.get("algbw", -1), log.get( 123 | "busbw", -1 124 | ) 125 | comm_log.add_comm_log(log_item) 126 | return comm_log 127 | 128 | 129 | if __name__ == "__main__": 130 | import sys 131 | 132 | filename = sys.argv[1] 133 | comm_log = parse_ds_comm_log(filename) 134 | comm_log.analyze() 135 | -------------------------------------------------------------------------------- /log_analyzer/log.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import os,math 15 | import pickle 16 | import csv 17 | import dataclasses 18 | import numpy as np 19 | from typing import Union, Dict, List 20 | from utils.utils import CommType, CommGroup 21 | from log_analyzer.utils import convert_size_to_msg, calc_bw_log 22 | import copy 23 | 24 | @dataclasses.dataclass 25 | class LogItem: 26 | comm_type: CommType = dataclasses.field(default=None) 27 | comm_group: CommGroup = dataclasses.field(default=None) 28 | comm_group_size: int = dataclasses.field(default=None) 29 | msg_size: float = dataclasses.field(default=0) 30 | 31 | stage: str = dataclasses.field(default="") 32 | dst: int = dataclasses.field(default=None) 33 | src: int = dataclasses.field(default=None) 34 | additional: str = dataclasses.field(default="") 35 | 36 | _elapsed_time: float = dataclasses.field(default=None) 37 | algbw: float = dataclasses.field(default=None) 38 | busbw: float = dataclasses.field(default=None) 39 | count: float = dataclasses.field(default=1) 40 | 41 | @property 42 | def elapsed_time(self) -> float: 43 | return self._elapsed_time 44 | 45 | @elapsed_time.setter 46 | def elapsed_time(self, elapsed_time): 47 | self._elapsed_time = elapsed_time 48 | self.algbw, self.busbw = calc_bw_log( 49 | self.comm_type, self.msg_size, elapsed_time, self.comm_group_size 50 | ) 51 | 52 | def is_epoch_end(self): 53 | return self.comm_type == CommType.epoch_end 54 | 55 | def is_workload(self): 56 | return self.elapsed_time is None 57 | 58 | def view_as_ds_log(self): 59 | log_str = f"[RANK 0] comm op: {self.comm_type} | comm group: {self.comm_group}" 60 | log_str += " | time (ms): {:.2f}".format(self.elapsed_time) 61 | if self.comm_type == CommType.computation or self.additional == 'overlap': 62 | log_str += " | msg size: " + '0' 63 | log_str += " | algbw (GB): " + '0' 64 | log_str += " | busbw (GB): " + '0' 65 | else: 66 | log_str += " | msg size: " + convert_size_to_msg(self.msg_size) 67 | log_str += " | algbw (GB): {:.2f} ".format(self.algbw) 68 | log_str += " | busbw (GB): {:.2f} ".format(self.busbw) 69 | return log_str 70 | 71 | def csv_header(self): 72 | return ",".join([k for k in self.__dict__.keys()]) 73 | 74 | def view_as_csv_line(self): 75 | return ",".join([str(getattr(self, k)) for k in self.__dict__.keys()]) 76 | 77 | def __str__(self): 78 | if self.is_workload(): 79 | return "None" 80 | return "None" 81 | 82 | 83 | def _print_stage_log(stage_name: str, stage_count: int, comm_type_info: Dict, primary_key: List[str], agg_key: List[str], performance_key: List[str], busbw_key: List[str]): 84 | header = f"{'Comm_Type':<15} {'Comm_Group':<12} {'Message_Size':<12} {'Count':<12} {'Avg_Elapsed_Time ± Std ':<24} {'Avg_BusBw ± Std':<24}\n" 85 | separator = "-" * len(header) + "\n" 86 | log_str = separator + header + separator 87 | 88 | for pkey in sorted(comm_type_info.keys()): 89 | row_str = "" 90 | values = {} 91 | for i, pkey_name in enumerate(primary_key): 92 | value = pkey[i] if pkey_name != "msg_size" else convert_size_to_msg(pkey[i]) 93 | values[pkey_name] = value 94 | for key in agg_key: 95 | value = comm_type_info[pkey][key] 96 | value = convert_size_to_msg(value) if key == "msg_size" else f"{value:.2f}" 97 | values[key] = value 98 | for key in performance_key: 99 | performance_value_list = sorted(comm_type_info[pkey][key]) 100 | values[f'avg_{key}'] = f"{np.mean(performance_value_list):.2f}±{np.std(performance_value_list):.2f}" 101 | values[f'min_{key}'] = f"{performance_value_list[0]:.2f}" 102 | values[f'max_{key}'] = f"{performance_value_list[-1]:.2f}" 103 | 104 | for key in busbw_key: 105 | busbw_value_list = sorted(comm_type_info[pkey][key]) 106 | values[f'avg_{key}'] = f"{np.mean(busbw_value_list):.2f}±{np.std(busbw_value_list):.2f}" 107 | 108 | row_str += f"{values['comm_type']:<15} {values['comm_group']:<12} {values['msg_size']:<12} {values['count']:<16} {values['avg__elapsed_time']:<24} {values['avg_busbw']:<18}\n" 109 | log_str += row_str 110 | 111 | return log_str 112 | 113 | 114 | def _analyze_stage_log(comm_log: List[Dict], stage: str, comm_info: Dict[str, Dict]): 115 | def __update_info( 116 | info_dict, 117 | log, 118 | primary_key: List[str], 119 | agg_key: List[str], 120 | performance_key: List[str], 121 | busbw_key: List[str], 122 | ): 123 | primary_key = tuple(log[key] for key in primary_key) 124 | if primary_key not in info_dict: 125 | info_dict[primary_key] = dict((key, 0) for key in agg_key) 126 | info_dict[primary_key].update(dict((key, []) for key in performance_key)) 127 | info_dict[primary_key].update(dict((key, []) for key in busbw_key)) 128 | for key in agg_key: 129 | info_dict[primary_key][key] += log[key] 130 | for key in performance_key: 131 | info_dict[primary_key][key].append(log[key]) 132 | for key in busbw_key: 133 | info_dict[primary_key][key].append(log[key]) 134 | 135 | if stage not in comm_info: 136 | comm_info[stage] = { 137 | "count": 0, 138 | "comm_type_info": {}, 139 | "detailed_comm_type_info": {}, 140 | } 141 | comm_info[stage]["count"] += 1 142 | # key: comm_type, value: count, time_ms 143 | comm_type_info = comm_info[stage]["comm_type_info"] 144 | # key: comm_type, msg_size, value: count, time_ms 145 | detailed_comm_type_info = comm_info[stage]["detailed_comm_type_info"] 146 | for log in comm_log: 147 | if log.comm_type != CommType.computation: 148 | __update_info( 149 | comm_type_info, 150 | log.__dict__, 151 | ["comm_type", "comm_group"], 152 | ["count", "msg_size"], 153 | ["_elapsed_time"], 154 | ["busbw"], 155 | ) 156 | __update_info( 157 | detailed_comm_type_info, 158 | log.__dict__, 159 | ["comm_type", "comm_group", "msg_size"], 160 | ["count"], 161 | ["_elapsed_time"], 162 | ["busbw"], 163 | ) 164 | 165 | 166 | class Log: 167 | def __init__(self) -> None: 168 | self.comm_logs = [] 169 | self.comm_log_each_epoch = [[]] 170 | self.epoch_times = [] 171 | 172 | def add_comm_log(self, comm_log: LogItem): 173 | if ( 174 | comm_log.is_epoch_end() 175 | and len(self.comm_logs) > 0 176 | and not self.comm_logs[-1].is_epoch_end() 177 | ): 178 | self.comm_logs.append(comm_log) 179 | self.comm_log_each_epoch.append([]) 180 | self.epoch_times.append(comm_log.elapsed_time) 181 | return 182 | self.comm_logs.append(comm_log) 183 | self.comm_log_each_epoch[-1].append(comm_log) 184 | 185 | def analyze(self, print_fn=print): 186 | comm_info: Dict[str, Dict] = {} 187 | _analyze_stage_log(self.comm_log_each_epoch[0], "init", comm_info) 188 | for e_log in self.comm_log_each_epoch[1:]: 189 | _analyze_stage_log(e_log, "train", comm_info) 190 | for stage in comm_info.keys(): 191 | if stage != "init": 192 | stage_count = comm_info[stage]["count"] 193 | comm_type_info = comm_info[stage]["comm_type_info"] 194 | detailed_comm_type_info = comm_info[stage]["detailed_comm_type_info"] 195 | 196 | log_str = _print_stage_log(stage, stage_count, detailed_comm_type_info, ["comm_type", "comm_group", "msg_size"], ["count"], ["_elapsed_time"], ["busbw"]) 197 | print_fn(f"\n\tDetailed comm info for AICB {stage} stage\n{log_str}") 198 | return comm_info 199 | 200 | def dump(self, filename): 201 | default_comm_folder_path = "results/comm_logs/" 202 | if not os.path.exists(default_comm_folder_path): 203 | os.makedirs(default_comm_folder_path, exist_ok=True) 204 | if "." in filename: 205 | filename = filename.split(".")[0] 206 | filename = os.path.join("results/comm_logs/", filename) 207 | csv_filename = filename + "_log.csv" 208 | with open(csv_filename, "w") as f: 209 | f.write(self.comm_logs[0].csv_header() + "\n") 210 | for log_item in self.comm_logs: 211 | log_item_write = copy.deepcopy(log_item) 212 | if(log_item_write.comm_type == CommType.computation): 213 | msg_size_str = "("+' '.join(str(shape).replace(',', '') for shape in log_item_write.msg_size)+")" 214 | log_item_write.msg_size = msg_size_str 215 | f.write(log_item_write.view_as_csv_line() + "\n") 216 | del log_item_write 217 | return csv_filename 218 | 219 | @staticmethod 220 | def load(filename): 221 | filename = filename.split(".") 222 | filename[-1] = "pkl" 223 | filename = ".".join(filename) 224 | return pickle.load(open(filename, "rb")) 225 | 226 | def _get_elapsed_time(self): 227 | return self.epoch_times 228 | 229 | def analyze_time(self, print_fn=print): 230 | self.epoch_times.pop(0) 231 | max_val = max(self.epoch_times) 232 | min_val = min(self.epoch_times) 233 | mean_val = sum(self.epoch_times) / len(self.epoch_times) 234 | 235 | variance = sum((x - mean_val) ** 2 for x in self.epoch_times) / len( 236 | self.epoch_times 237 | ) 238 | variance = math.sqrt(variance) 239 | 240 | sorted_list = sorted(self.epoch_times) 241 | p90_val = sorted_list[int(len(sorted_list) * 0.9)] 242 | p99_val = sorted_list[int(len(sorted_list) * 0.99)] 243 | header = f"{'Init time':<18} {'Max iteration time':<20} {'Min iteration time':<20} {'Avg iteration time':<20} {'P90 iteration time ':<20} {'Iteration time Std ':<20}\n" 244 | separator = "-" * len(header) + "\n" 245 | log_str = separator + header + separator 246 | iteration_result = f"{self.epoch_times[0]:<18.2f} {max_val:<20.2f} {min_val:<20.2f} {mean_val:<20.2f} {p90_val:<20.2f} {variance:<20.2f}\n" 247 | log_str += iteration_result 248 | print_fn(f"\n\tDetailed info for AICB iteration time\n{log_str}") 249 | 250 | 251 | class Workload: 252 | def __init__(self) -> None: 253 | self.workload = [] 254 | 255 | def append(self, log_item: Union[LogItem, Dict]): 256 | if isinstance(log_item, LogItem): 257 | self.workload.append(log_item) 258 | return 259 | if "stage" not in log_item: 260 | log_item["stage"] = log_item["operation"] if "operation" in log_item else "" 261 | if "comm_group" not in log_item: 262 | assert ( 263 | log_item["comm_type"] == CommType.computation 264 | ), "comm_group is required for non-computation comm_type" 265 | log_item["comm_group"] = CommGroup.all 266 | self.workload.append( 267 | LogItem( 268 | comm_type=log_item["comm_type"], 269 | comm_group=log_item["comm_group"], 270 | comm_group_size=log_item["comm_group_size"], 271 | msg_size=log_item["msg_size"], 272 | stage=log_item["stage"], 273 | src=log_item.get("src", None), 274 | dst=log_item.get("dst", None), 275 | additional=log_item.get("additional", None), 276 | ) 277 | ) 278 | 279 | def extend(self, new_workload): 280 | self.workload.extend(new_workload.workload) 281 | 282 | def dump(self, filename): 283 | folder_path = os.path.dirname(filename) 284 | if folder_path and not os.path.exists(folder_path): 285 | os.makedirs(folder_path) 286 | default_folder_path = "results/mocked_workload/" 287 | if not os.path.exists(default_folder_path): 288 | os.makedirs(default_folder_path, exist_ok=True) 289 | if "." in filename: 290 | filename = os.path.basename(filename).split(".")[0] 291 | filename = os.path.join("results/mocked_workload/", filename) 292 | csv_filename = filename + "_workload.csv" 293 | with open(csv_filename, "w") as f: 294 | f.write(self.workload[0].csv_header() + "\n") 295 | for log_item in self.workload: 296 | log_item_write = copy.deepcopy(log_item) 297 | if(log_item_write.comm_type == CommType.computation): 298 | msg_size_str = "("+' '.join(str(shape).replace(',', '') for shape in log_item_write.msg_size)+")" 299 | log_item_write.msg_size = msg_size_str 300 | f.write(log_item_write.view_as_csv_line() + "\n") 301 | del log_item_write 302 | print(f"Workload file generated:{csv_filename}") 303 | 304 | @staticmethod 305 | def load(filename): 306 | filename = filename.split(".") 307 | filename[-1] = "pkl" 308 | filename = ".".join(filename) 309 | workload, args = pickle.load(open(filename, "rb")) 310 | return workload, args 311 | -------------------------------------------------------------------------------- /log_analyzer/plot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | from log_analyzer.ds_comm_log_analyzer import parse_ds_comm_log 17 | from utils.benchmark_logger import BenchLogger 18 | from typing import Dict, List 19 | from log_analyzer.utils import convert_size_to_msg 20 | 21 | 22 | def log_boxplot(detailed_comm_info: Dict): 23 | MAX_ITEMS, COLS = 5, 2 24 | comm_type2msg_size2time_cost = {} 25 | for comm_type, comm_group, msg_size in sorted(detailed_comm_info.keys()): 26 | if (comm_type, comm_group) not in comm_type2msg_size2time_cost: 27 | comm_type2msg_size2time_cost[(comm_type, comm_group)] = {} 28 | elasped_time = np.array( 29 | detailed_comm_info[(comm_type, comm_group, msg_size)]["_elapsed_time"] 30 | ) 31 | comm_type2msg_size2time_cost[(comm_type, comm_group)][ 32 | msg_size 33 | ] = elasped_time # [elasped_time < 3000] 34 | fig_num = sum( 35 | [ 36 | (len(comm_info.keys()) + MAX_ITEMS - 1) // MAX_ITEMS 37 | for comm_info in comm_type2msg_size2time_cost.values() 38 | ] 39 | ) 40 | 41 | fig_rows, fig_idx = (fig_num + COLS - 1) // COLS, 0 42 | fig, axes = plt.subplots(nrows=fig_rows, ncols=COLS, figsize=(8, 6)) 43 | fig.tight_layout() 44 | fig.suptitle("for deepspeed Zero3 llama 13B") 45 | for (comm_type, comm_group), comm_info in comm_type2msg_size2time_cost.items(): 46 | values, labels = list(comm_info.values()), [ 47 | convert_size_to_msg(msg) for msg in comm_info.keys() 48 | ] 49 | for j in range(0, len(values), MAX_ITEMS): 50 | ax = axes[fig_idx // COLS][fig_idx % COLS] 51 | fig_idx += 1 52 | ax.set_title("%s %s msg info" % (comm_type.value, comm_group.value)) 53 | ax.boxplot( 54 | values[j : j + MAX_ITEMS], 55 | labels=labels[j : j + MAX_ITEMS], 56 | flierprops=dict( 57 | marker="o", markerfacecolor="black", markersize=2, linestyle="none" 58 | ), 59 | ) 60 | for k in range(j, min(j + MAX_ITEMS, len(values))): 61 | ax.text( 62 | x=k - j + 1, 63 | y=np.max(values[k]) * 1.01, 64 | s=len(values[k]), 65 | horizontalalignment="center", 66 | size="x-small", 67 | color="r", 68 | weight="semibold", 69 | ) 70 | plt.show() 71 | 72 | 73 | def log_time_plotter(epoch_times: List[float]): 74 | plt.plot(epoch_times) 75 | plt.show() 76 | 77 | 78 | if __name__ == "__main__": 79 | filename = "/Users/yikaizhu/alicode/models-perf/deepspeed_baichuan_exp/results/baichuan_13B_zero3/55n_915_comm_log.txt" 80 | filename = "/Users/yikaizhu/Desktop/AIBC_clean.txt" 81 | comm_log = parse_ds_comm_log(filename) 82 | comm_info = comm_log.analyze() 83 | if "train" in comm_info: 84 | log_boxplot(comm_info["train"]["detailed_comm_type_info"]) 85 | else: 86 | log_boxplot(comm_info["init"]["detailed_comm_type_info"]) 87 | -------------------------------------------------------------------------------- /log_analyzer/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import math 15 | from utils.utils import CommGroup, CommType, get_args 16 | 17 | 18 | def convert_size_to_msg(size_bytes): 19 | if size_bytes == 0: 20 | return "0 B" 21 | size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") 22 | i = int(math.floor(math.log(size_bytes, 1024))) 23 | p = math.pow(1024, i) 24 | s = round(size_bytes / p, 2) 25 | return "%s %s" % (s, size_name[i]) 26 | 27 | 28 | def convert_msg_to_size(msg): 29 | if msg == "0B": 30 | return 0 31 | try: 32 | num, name = msg.split(" ") 33 | except: 34 | print(f"cannot convert msg into int") 35 | return 0 36 | num, name = float(num), name.strip() 37 | size_name = ["B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"] 38 | if name not in size_name: 39 | return None 40 | p = math.pow(1024, size_name.index(name)) 41 | return num * p 42 | 43 | 44 | def calc_bw_log(comm_type: CommType, size, duration,group_size): # size: Bytes; duration: ms 45 | n = group_size if group_size else 1 46 | duration /= 1000 47 | if comm_type in [CommType.all_gather, CommType.reduce_scatter]: 48 | # size *= n 49 | tput = size / duration 50 | busbw = (size / duration) * ((n - 1) / n) 51 | elif comm_type == CommType.all_reduce: 52 | tput = size / duration 53 | busbw = (size / duration) * (2 * (n - 1) / n) 54 | elif comm_type in [CommType.barrier, CommType.computation]: 55 | return 0, 0 56 | else: # [CommType.broadcast, CommType.reduce, "gather", "scatter", "isend", "irecv"] 57 | tput = size / duration 58 | busbw = tput 59 | tput /= 1024*1024*1024 60 | busbw /= 1024*1024*1024 61 | tput = round(tput, 2) 62 | busbw = round(busbw, 2) 63 | return tput, busbw 64 | -------------------------------------------------------------------------------- /run_suites.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | import sys 14 | import subprocess 15 | import os 16 | import configparser 17 | import argparse 18 | 19 | running_command = {} 20 | default_config = { 21 | "deepspeed": { 22 | "llama7b_zero2": 0, 23 | "llama65b_zero3": 0, 24 | }, 25 | "megatron": { 26 | "llama_7B": 0, 27 | "gpt_13B_sp": 0, 28 | "gpt_175B_tp":0, 29 | "gpt_175B": 0, 30 | "gpt_22B": 0, 31 | "llama_405B": 0, 32 | "Mixtral_8*7B": 0, 33 | }, 34 | "aiob" : { #aicb workload suites with computation 35 | "llama_7B_aiob": 0, 36 | "gpt_13B_sp_aiob": 0, 37 | "gpt_175B_aiob": 0, 38 | "gpt_22B_aiob": 0, 39 | "gpt_175B_tp_aiob": 0, 40 | "llama_405B_aiob": 0, 41 | "Mixtral_8*7B_aiob": 0, 42 | "llama7B_zero2_aiob": 0, 43 | "llama65B_zero3_aiob": 0, 44 | }, 45 | "coll_comm_check": {"all_reduce": 0, "all_gather": 0, "muti_all_reduce": 0}, 46 | } 47 | 48 | 49 | def get_params(): 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("--output", help="output directory", default="./results") 52 | 53 | 54 | def read_config(config): 55 | ds_conf = config["deepspeed"] 56 | megatron_conf = config["megatron"] 57 | cc_conf = config["coll_comm_check"] 58 | aiob_conf = config["aiob"] 59 | if int(ds_conf["llama7b_zero2"]): 60 | running_command["deepspeed2_llama13b"] = ( 61 | f"bash scripts/deepspeed_llama.sh --zero_stage 2 -m 13 --epoch_num 10 " 62 | ) 63 | if int(ds_conf["llama65b_zero3"]): 64 | running_command["deepspeed3_llama65b"] = ( 65 | f"bash scripts/deepspeed_llama.sh --zero_stage 3 -m 65 --epoch_num 10 \ 66 | --reduce_bucket_size 1000000000 --allgather_bucket_size 500000000 \ 67 | --param_persistence_threshold 1000000" 68 | ) 69 | if int(megatron_conf["llama_7B"]): 70 | running_command["megatron_llama7B"] = ( 71 | f"bash scripts/megatron_gpt.sh -m 7 --tensor_model_parallel_size 1 --epoch_num 10 --seq_length 4096" 72 | ) 73 | if int(megatron_conf["gpt_13B_sp"]): 74 | running_command["megatron_gpt13b_sp"] = ( 75 | f"bash scripts/megatron_gpt.sh -m 13 --tensor_model_parallel_size 2 --epoch_num 10 --sp" 76 | ) 77 | if int(megatron_conf["gpt_175B"]): 78 | running_command["megatron_gpt175B"] = ( 79 | f"bash scripts/megatron_gpt.sh -m 175 --tensor_model_parallel_size 8 --epoch_num 10 --pipeline_model_parallel 2 --sp" 80 | ) 81 | if int(megatron_conf["gpt_175B_tp"]): 82 | running_command["megatron_gpt175B_tp"] = ( 83 | f"bash scripts/megatron_gpt.sh -m 175 --tensor_model_parallel_size 8 --epoch_num 10 --pipeline_model_parallel 2" 84 | ) 85 | if int(megatron_conf["gpt_22B"]): 86 | running_command["megatron_gpt_22B"] = ( 87 | f"bash scripts/megatron_gpt.sh -m 22 --tensor_model_parallel_size 4 --epoch_num 10 --sp" 88 | ) 89 | if int(megatron_conf["Mixtral_8*7B"]): 90 | running_command["megatron_moe"] = ( 91 | f"bash scripts/megatron_gpt.sh -m moe --tensor_model_parallel_size 2 --epoch_num 10 --sp --ep 4 --num_experts 16 --topk 4 " 92 | ) 93 | if int(megatron_conf["llama_405B"]): 94 | running_command["megatron_llama_405B"] = ( 95 | f"bash scripts/megatron_gpt.sh -m 405 --tensor_model_parallel_size 8 --epoch_num 10 --sp --seq_length 8192" 96 | ) 97 | if int(aiob_conf["llama_7B_aiob"]): 98 | running_command["megatron_llama7b_aiob"] = ( 99 | f"bash scripts/megatron_gpt.sh -m 7 --epoch_num 10 --aiob_enable " 100 | ) 101 | if int(aiob_conf["gpt_13B_sp_aiob"]): 102 | running_command["megatron_gpt13b_sp_aiob"] = ( 103 | f"bash scripts/megatron_gpt.sh -m 13 --tensor_model_parallel_size 4 --epoch_num 10 --aiob_enable --sp" 104 | ) 105 | if int(aiob_conf["gpt_175B_aiob"]): 106 | running_command["megatron_gpt175B_aiob"] = ( 107 | f"bash scripts/megatron_gpt.sh -m 175 --tensor_model_parallel_size 8 --epoch_num 10 --aiob_enable --pipeline_model_parallel 2 --sp" 108 | ) 109 | if int(aiob_conf["gpt_175B_tp_aiob"]): 110 | running_command["megatron_gpt175B_tp_aiob"] = ( 111 | f"bash scripts/megatron_gpt.sh -m 175 --tensor_model_parallel_size 8 --epoch_num 10 --aiob_enable --pipeline_model_parallel 2 " 112 | ) 113 | if int(aiob_conf["gpt_22B_aiob"]): 114 | running_command["megatron_gpt_22B_aiob"] = ( 115 | f"bash scripts/megatron_gpt.sh -m 22 --tensor_model_parallel_size 4 --epoch_num 10 --aiob_enable --sp" 116 | ) 117 | if int(aiob_conf["Mixtral_8*7B_aiob"]): 118 | running_command["megatron_moe_aiob"] = ( 119 | f"bash scripts/megatron_gpt.sh -m moe --tensor_model_parallel_size 2 --epoch_num 10 --sp --ep 4 --num_experts 16 --topk 4 --aiob_enable " 120 | ) 121 | if int(aiob_conf["llama_405B_aiob"]): 122 | running_command["megatron_llama_405B"] = ( 123 | f"bash scripts/megatron_gpt.sh -m 405 --tensor_model_parallel_size 8 --epoch_num 10 --sp --seq_length 8192 --aiob_enable " 124 | ) 125 | if int(aiob_conf["llama7B_zero2_aiob"]): 126 | running_command["deepspeed2_llama7b_aiob"] = ( 127 | f"bash scripts/deepspeed_llama.sh --zero_stage 2 -m 7 --epoch_num 10 --aiob_enable " 128 | ) 129 | if int(aiob_conf["llama65B_zero3_aiob"]): 130 | running_command["deepspeed3_llama65b_aiob"] = ( 131 | f"bash scripts/deepspeed_llama.sh --zero_stage 3 -m 65 --epoch_num 10 \ 132 | --reduce_bucket_size 1000000000 --allgather_bucket_size 500000000 \ 133 | --param_persistence_threshold 1000000 --aiob_enable" 134 | ) 135 | if int(cc_conf["all_reduce"]): 136 | running_command["all_reduce_check"] = ( 137 | f"bash scripts/coll_comm_check.sh --iter_num 100 --test_comm all_reduce --model_name all_reduce" 138 | ) 139 | if int(cc_conf["all_gather"]): 140 | running_command["all_gather_check"] = ( 141 | f"bash scripts/coll_comm_check.sh --iter_num 100 --test_comm all_gather --model_name all_gather" 142 | ) 143 | if int(cc_conf["muti_all_reduce"]): 144 | running_command["muti_all_reduce_check"] = ( 145 | f"bash scripts/coll_comm_check.sh --iter_num 100 --test_comm all_reduce --model_name muti_all_reduce --muti_all_reduce_enable 1" 146 | ) 147 | 148 | 149 | 150 | if __name__ == "__main__": 151 | # config = configparser.ConfigParser() 152 | # config.read('config.ini') 153 | read_config(config=default_config) 154 | result = {} 155 | print(running_command) 156 | for name, command in running_command.items(): 157 | result_dir = "./results/log/" 158 | if not os.path.isdir(result_dir): 159 | os.makedirs(result_dir) 160 | output_file = f"./results/log/{name}.txt" 161 | 162 | command += f" 2>&1 | tee {output_file}" 163 | print(name) 164 | ret = subprocess.run(command, shell=True, text=True) 165 | 166 | # if ret.returncode != 0: 167 | # print(f"ERROR when running {name}: {command}") 168 | # print( 169 | # f"return state is {ret.returncode}, got err{ret.stderr}, get output{ret.stdout}" 170 | # ) 171 | # exit(-1) 172 | # command_out = ret.stdout 173 | # print(command_out) 174 | -------------------------------------------------------------------------------- /scripts/coll_comm_check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | 5 | begin_size=4096 6 | end_size=8589934592 7 | epoch_num=1 8 | iter_num=500 9 | test_comm=all_reduce 10 | frame=collective_test 11 | model_name=all_reduce 12 | multi_all_reduce_enable=0 13 | 14 | usage() { 15 | echo "Usage: $0 [options] 16 | options: 17 | --iter_num num of iterations: $iter_num 18 | --begin_size start message size of test: $begin_size 19 | --end_size end message size of test: $end_size 20 | --test_comm collective communication type: $test_comm 21 | --multi_all_reduce_enable enable muti all_reduce opration: $multi_all_reduce_enable 22 | -h, --help" 1>&2; exit 1; 23 | } 24 | 25 | while [ $# -gt 0 ] 26 | do 27 | case $1 in 28 | --model_name|--model-name) 29 | model_name=$2; shift;; 30 | --iter-num|--iter_num) 31 | iter_num=$2 ; shift;; 32 | --begin_size|--begin-size) 33 | begin_size=$2 ; shift;; 34 | --end_size|--end-size) 35 | end_size=$2 ; shift;; 36 | --test_comm|--test-comm) 37 | test_comm=$2 ; shift;; 38 | --multi_all_reduce_enable|--muti-all-reduce-enable) 39 | multi_all_reduce_enable=$2 ; shift;; 40 | -h|--help) 41 | usage ;; 42 | (*) 43 | break;; 44 | esac 45 | # Fetch next argument as 1st 46 | shift 47 | done 48 | 49 | 50 | script="./aicb.py" 51 | 52 | if [ "$multi_all_reduce_enable" -eq 0 ]; then 53 | echo "torchrun \ 54 | --nnodes $WORLD_SIZE \ 55 | --node_rank $RANK \ 56 | --nproc_per_node gpu \ 57 | --master_addr $MASTER_ADDR \ 58 | --master_port $MASTER_PORT \ 59 | $script --iter_num=$iter_num --world_size=$((WORLD_SIZE*8))\ 60 | --begin_size=$begin_size --end_size=$end_size --test_comm=$test_comm --model_name=$model_name\ 61 | --frame=standard_check --multi_all_reduce_enable=$multi_all_reduce_enable" 62 | else 63 | echo "torchrun \ 64 | --nnodes $WORLD_SIZE \ 65 | --node_rank $RANK \ 66 | --nproc_per_node gpu \ 67 | --master_addr $MASTER_ADDR \ 68 | --master_port $MASTER_PORT \ 69 | $script --iter_num=$iter_num --world_size=$((WORLD_SIZE*8))\ 70 | --begin_size=$begin_size --end_size=$end_size --test_comm=$test_comm --model_name=$model_name\ 71 | --frame=standard_check --multi_all_reduce_enable=$multi_all_reduce_enable --pipeline_model_parallel=$WORLD_SIZE" 72 | fi 73 | 74 | if [ "$multi_all_reduce_enable" -eq 0 ]; then 75 | torchrun \ 76 | --nnodes $WORLD_SIZE \ 77 | --node_rank $RANK \ 78 | --nproc_per_node gpu \ 79 | --master_addr $MASTER_ADDR \ 80 | --master_port $MASTER_PORT \ 81 | $script --iter_num=$iter_num --world_size=$((WORLD_SIZE*8))\ 82 | --begin_size=$begin_size --end_size=$end_size --test_comm=$test_comm --model_name=$model_name\ 83 | --frame=collective_test --multi_all_reduce_enable=$multi_all_reduce_enable 84 | else 85 | torchrun \ 86 | --nnodes $WORLD_SIZE \ 87 | --node_rank $RANK \ 88 | --nproc_per_node gpu \ 89 | --master_addr $MASTER_ADDR \ 90 | --master_port $MASTER_PORT \ 91 | $script --iter_num=$iter_num --world_size=$((WORLD_SIZE*8))\ 92 | --begin_size=$begin_size --end_size=$end_size --test_comm=$test_comm --model_name=$model_name\ 93 | --frame=collective_test --multi_all_reduce_enable=$multi_all_reduce_enable --pipeline_model_parallel=$WORLD_SIZE 94 | fi -------------------------------------------------------------------------------- /scripts/deepspeed_llama.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | set -x 5 | : ${WORLD_SIZE:=1} 6 | : ${RANK:=0} 7 | : ${MASTER_ADDR:="localhost"} 8 | : ${MASTER_PORT:=29500} 9 | model_name=llama_7b 10 | zero_stage=3 11 | model_size=7 12 | num_layers=32 13 | epoch_num=10 14 | num_attention_heads=32 15 | hidden_size=4096 16 | ffn_hidden_size=11008 17 | reduce_bucket_size=500000000 18 | allgather_bucket_size=500000000 19 | prefetch_bucket_size=1000000000 20 | max_live_parameters=1000000000 21 | param_persistence_threshold=100000 22 | seq_len=2048 23 | batch_size=4 24 | contiguous_gradients= 25 | aiob_enable= 26 | enable_visual= 27 | workload_only= 28 | 29 | usage() { 30 | echo "Usage: $0 [options] 31 | options: 32 | --model_name model_name: $model_name 33 | --zero_stage zero_stage: $zero_stage 34 | --epoch_num num of iterations: $epoch_num 35 | --batch_size micro batch_size: $batch_size 36 | --enable_visual enable visual html output files 37 | --workload_only generate workload only 38 | -m, --model-size llama model size.(7/13/30/65): $model_size 39 | --reduce-bucket-size size of reduce bucket: $reduce_bucket_size 40 | --allgather-bucket-size size of all_gather bucket(only used in stage1,2): $reduce_bucket_size 41 | --prefetch-bucket-size size of all_gather prefetch bucket(only used in stage3): $prefetch_bucket_size 42 | --max-live-parameters max size of params that have been all_gather(only used in stage3): $max_live_parameters 43 | --param-persistence-threshold threshold of param that is all-gather before forward(only used in stage3): $param_persistence_threshold 44 | --seq-len seq-len: $seq_len 45 | --contiguous-gradients use reduce instead of all_reduce (only used in stage2) 46 | -h, --help" 1>&2; exit 1; 47 | } 48 | 49 | while [ $# -gt 0 ] 50 | do 51 | echo "Processing argument: $1" 52 | case $1 in 53 | --model_name|--model-name) 54 | model_name=$2 ; shift;; 55 | --stage|--zero-stage|--zero_stage) 56 | zero_stage=$2 ; shift;; 57 | --epoch-num|--epoch_num) 58 | epoch_num=$2 ; shift;; 59 | --batch-size|--micro_batch|--batch_size) 60 | batch_size=$2 ; shift;; 61 | -m|--model-size) 62 | model_size=$2 ; shift;; 63 | --reduce-bucket-size|--reduce_bucket_size) 64 | reduce_bucket_size=$2 ; shift;; 65 | --param-persistence-threshold|--param_persistence_threshold) 66 | param_persistence_threshold=$2 ; shift;; 67 | --max-live-parameters|--max_live_parameters) 68 | prefetch_bucket_size=$2 ; shift;; 69 | --allgather-bucket-size|--allgather_bucket_size) 70 | allgather_bucket_size=$2 ; shift;; 71 | --seq-len|--seq_len) 72 | seq_len=$2 ; shift;; 73 | --aiob_enable) 74 | aiob_enable=--aiob_enable;; 75 | --enable_visual) 76 | enable_visual=--enable_visual;; 77 | --workload_only) 78 | workload_only=--workload_only;; 79 | --contiguous-gradients|--contiguous_gradients) 80 | contiguous_gradients=--contiguous_gradients; shift;; 81 | -h|--help) 82 | usage ;; 83 | (*) 84 | break;; 85 | esac 86 | # Fetch next argument as 1st 87 | shift 88 | done 89 | 90 | case $model_size in 91 | 13) 92 | model_name=llama_13b hidden_size=5120; ffn_hidden_size=13824; num_layers=40; num_attention_heads=40; shift;; 93 | 30) 94 | model_name=llama_30b hidden_size=6656; ffn_hidden_size=17920; num_layers=60; num_attention_heads=52; shift;; 95 | 65) 96 | model_name=llama_65b hidden_size=8192; ffn_hidden_size=22016; num_layers=80; num_attention_heads=64; shift;; 97 | 7) 98 | ;; 99 | (*) 100 | echo "only suport model size 7b, 13b, 30b, 65b, got $model_size";; 101 | esac 102 | 103 | script="./aicb.py" 104 | 105 | torchrun \ 106 | --nnodes ${WORLD_SIZE} \ 107 | --node_rank $RANK \ 108 | --nproc_per_node gpu \ 109 | --master_addr $MASTER_ADDR \ 110 | --master_port $MASTER_PORT \ 111 | $script --frame=DeepSpeed --model_name=$model_name --stage=$zero_stage --world_size=$((WORLD_SIZE*8)) \ 112 | --micro_batch=$batch_size --global_batch=$((WORLD_SIZE*8*batch_size)) --epoch_num=$epoch_num \ 113 | --num_layers=$num_layers --hidden_size=$hidden_size --ffn_hidden_size=$ffn_hidden_size --num_attention_heads=$num_attention_heads \ 114 | --reduce_bucket_size=$reduce_bucket_size --allgather_bucket_size=$allgather_bucket_size --seq_len=$seq_len \ 115 | --max_live_parameters=$max_live_parameters --param_persistence_threshold=$param_persistence_threshold $contiguous_gradients $aiob_enable $enable_visual $workload_only -------------------------------------------------------------------------------- /scripts/megatron_gpt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | : ${WORLD_SIZE:=1} 5 | : ${RANK:=0} 6 | : ${MASTER_ADDR:="localhost"} 7 | : ${MASTER_PORT:=29500} 8 | NUM_GPUS=$(nvidia-smi -L | wc -l) # Get the number of GPUs on a single node 9 | model_size=13 10 | num_layers=40 11 | num_attention_heads=40 12 | hidden_size=5120 13 | seq_length=2048 14 | micro_batch=1 15 | epoch_num=1 16 | tensor_model_parallel_size=8 17 | pipeline_model_parallel=1 18 | vocab_size=50257 19 | model_name=gpt_13b 20 | ga_num=2 21 | sp_enable= 22 | frame=Megatron 23 | aiob_enable= 24 | max_position_embeddings=4096 25 | num_experts=1 26 | moe_enable= 27 | enable_visual= 28 | workload_only= 29 | usage() { 30 | echo "Usage: \$0 [options] 31 | options: 32 | --frame Communication framework: $frame 33 | --world_size World size (number of nodes): $WORLD_SIZE 34 | --tensor_model_parallel_size Tensor parallelism size: $tensor_model_parallel_size 35 | --pipeline_model_parallel Pipeline parallelism size: $pipeline_model_parallel 36 | --global_batch Global batch size: $global_batch 37 | --micro_batch Micro batch size: $micro_batch 38 | --num_layers Number of layers: $num_layers 39 | --seq_length Sequence length: $seq_length 40 | --hidden_size Hidden size: $hidden_size 41 | --epoch_num Number of epochs: $epoch_num 42 | --num_attention_heads Number of attention heads: $num_attention_heads 43 | --aiob_enable Enable AIOB: $aiob_enable 44 | --enable_visual Enable Visualization $enable_visual 45 | --workload_only generate workload only 46 | --use_flash_attn Use flash attention: $use_flash_attn 47 | --swiglu Use SWIGLU: $swiglu 48 | --ffn_hidden_size FFN hidden size: $ffn_hidden_size 49 | --comp_filepath Computation file path: $comp_filepath 50 | --model_name Model name: $model_name 51 | -m, --model_size model size, defaults to $model_size (possible values: 175, 22, 13, 7) 52 | --max_position_embeddings Max position embeddings: $max_position_embeddings 53 | --nnodes Number of nodes: $WORLD_SIZE 54 | --node_rank Rank of the node: $RANK 55 | --nproc_per_node Number of GPUs per node: $NUM_GPUS 56 | --master_addr Master address: $MASTER_ADDR 57 | --master_port Master port: $MASTER_PORT 58 | --me_enable enable moe 59 | --moe_router_topk Number of experts to route to for each token. 60 | --expert_model_parallel_size Degree of expert model parallelism 61 | --num_experts Number of experts in the MoE model. 62 | --moe_grouped_gemm apply grouped gemm 63 | -h, --help Display this help and exit"1>&2; exit 1; 64 | } 65 | while [ $# -gt 0 ] 66 | do 67 | echo "Processing argument: $1" 68 | case $1 in 69 | --frame) 70 | frame=$2; shift;; 71 | --world_size) 72 | world_size=$2; shift;; 73 | --tensor_model_parallel_size|tp_num) 74 | tensor_model_parallel_size=$2; shift;; 75 | --pipeline_model_parallel|pp_num) 76 | pipeline_model_parallel=$2; shift;; 77 | --global_batch) 78 | global_batch=$2; shift;; 79 | --micro_batch) 80 | micro_batch=$2; shift;; 81 | --num_layers) 82 | num_layers=$2; shift;; 83 | --seq_length) 84 | seq_length=$2; shift;; 85 | --hidden_size) 86 | hidden_size=$2; shift;; 87 | --epoch_num) 88 | epoch_num=$2; shift;; 89 | --num_attention_heads) 90 | num_attention_heads=$2; shift;; 91 | --aiob_enable) 92 | aiob_enable=--aiob_enable;; 93 | --enable_visual) 94 | enable_visual=--enable_visual;; 95 | --workload_only) 96 | workload_only=--workload_only;; 97 | --use_flash_attn) 98 | use_flash_attn=--use_flash_attn;; 99 | --swiglu) 100 | swiglu=--swiglu;; 101 | --ffn_hidden_size) 102 | ffn_hidden_size=$2; shift;; 103 | --sp|--sp-enable|--enable_sequence_parallel) 104 | sp_enable=--enable_sequence_parallel;; 105 | --comp_filepath) 106 | comp_filepath=$2; shift;; 107 | -m|--model_size) 108 | model_size=$2; shift;; 109 | --moe_enable) 110 | moe_enable=--moe_enable;; 111 | --moe_router_topk|--topk) 112 | moe_router_topk=$2; shift;; 113 | --num_experts|--experts) 114 | num_experts=$2; shift;; 115 | --expert_model_parallel_size|--ep) 116 | expert_model_parallel_size=$2; shift;; 117 | --grouped_gemm|--moe_grouped_gemm) 118 | grouped_gemm=--moe_grouped_gemm;; 119 | --nnodes) 120 | WORLD_SIZE=$2;shift;; 121 | --node_rank) 122 | RANK=$2;shift;; 123 | --nproc_per_node) 124 | NUM_GPUS=$2;shift;; 125 | --master_addr) 126 | MASTER_ADDR=$2;shift;; 127 | --master_port) 128 | MASTER_PORT=$2;shift;; 129 | -h|--help) 130 | usage ;; 131 | (*) 132 | break;; 133 | esac 134 | 135 | shift 136 | done 137 | 138 | case $model_size in 139 | 175) 140 | model_name=gpt_175B 141 | num_layers=96 142 | hidden_size=12288 143 | num_attention_heads=96 144 | tensor_model_parallel_size=8 145 | ;; 146 | 22) 147 | model_name=gpt_22B 148 | num_layers=48 149 | hidden_size=6144 150 | num_attention_heads=64 151 | tensor_model_parallel_size=8 152 | ;; 153 | 13) 154 | model_name=gpt_13B 155 | num_layers=40 156 | hidden_size=5120 157 | num_attention_heads=40 158 | ;; 159 | 7) 160 | model_name=gpt_7B 161 | num_layers=36 162 | hidden_size=4096 163 | num_attention_heads=32 164 | ;; 165 | 405) 166 | model_name=llama_405B 167 | num_layers=128 168 | hidden_size=16384 169 | ffn_hidden_size=53248 170 | num_attention_heads=128 171 | tensor_model_parallel_size=8 172 | pipeline_model_parallel=16 173 | ;; 174 | 65) 175 | model_name=llama_65B 176 | num_layers=80 177 | hidden_size=8192 178 | ffn_hidden_size=28672 179 | num_attention_heads=64 180 | tensor_model_parallel_size=8 181 | pipeline_model_parallel=2 182 | ;; 183 | moe) 184 | model_name=Mixtral_8*7B 185 | num_layers=32 186 | hidden_size=4096 187 | num_attention_heads=32 188 | ffn_hidden_size=14336 189 | tensor_model_parallel_size=2 190 | moe_enable=--moe_enable 191 | grouped_gemm=--moe_grouped_gemm 192 | ;; 193 | (*) 194 | echo "Only support model size 405,175,22,13,7 or moe; using default size 13" 195 | model_name=gpt_13B 196 | num_layers=40 197 | hidden_size=5120 198 | num_attention_heads=40 199 | ;; 200 | esac 201 | 202 | dp_num=$((world_size/tensor_model_parallel_size/pipeline_model_parallel)) 203 | global_batch=$((ga_num*dp_num*micro_batch)) 204 | if [ $workload_only ]; then 205 | script="python -m workload_generator.generate_megatron_workload" 206 | else 207 | script="./aicb.py" 208 | fi 209 | 210 | cmd="$script \ 211 | --frame=$frame \ 212 | --model_name=$model_name \ 213 | --world_size=$(($WORLD_SIZE * $NUM_GPUS)) \ 214 | --tensor_model_parallel_size=$tensor_model_parallel_size \ 215 | --micro_batch=$micro_batch \ 216 | --global_batch=$global_batch \ 217 | --epoch_num=$epoch_num \ 218 | --num_layers=$num_layers \ 219 | --hidden_size=$hidden_size \ 220 | --num_attention_heads=$num_attention_heads \ 221 | --seq_length=$seq_length \ 222 | --vocab_size=$vocab_size \ 223 | --pipeline_model_parallel=$pipeline_model_parallel \ 224 | --use-distributed-optimizer \ 225 | --max_position_embeddings=$max_position_embeddings \ 226 | ${aiob_enable} \ 227 | ${enable_visual} \ 228 | ${workload_only} \ 229 | ${sp_enable} \ 230 | ${use_flash_attn} \ 231 | ${swiglu} \ 232 | ${ffn_hidden_size:+--ffn_hidden_size=$ffn_hidden_size} \ 233 | ${comp_filepath:+--comp_filepath=$comp_filepath} \ 234 | ${moe_enable} \ 235 | ${moe_router_topk:+--moe_router_topk=$moe_router_topk} \ 236 | ${num_experts:+--num_experts=$num_experts} \ 237 | ${expert_model_parallel_size:+--expert_model_parallel_size=$expert_model_parallel_size} \ 238 | ${grouped_gemm}" 239 | echo $cmd 240 | 241 | if [ $workload_only ]; then 242 | $cmd 243 | else 244 | torchrun \ 245 | --nnodes $WORLD_SIZE \ 246 | --node_rank $RANK \ 247 | --nproc_per_node $NUM_GPUS \ 248 | --master_addr $MASTER_ADDR \ 249 | --master_port $MASTER_PORT \ 250 | $cmd 251 | fi 252 | -------------------------------------------------------------------------------- /scripts/megatron_workload_with_aiob.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | frame=Megatron 5 | world_size=32 6 | tensor_model_parallel_size=8 7 | pipeline_model_parallel=1 8 | global_batch=1024 9 | micro_batch=1 10 | num_layers=40 11 | seq_length=4096 12 | hidden_size=5120 13 | epoch_num=1 14 | num_attention_heads=40 15 | aiob_enable= 16 | use_flash_attn= 17 | swiglu= 18 | sp_enable= 19 | ffn_hidden_size= 20 | comp_filepath= 21 | model_size=13 22 | max_position_embeddings=4096 23 | vocab_size=50257 24 | num_experts=1 25 | moe_enable= 26 | recompute_activations= 27 | gpu_type=None 28 | usage() { 29 | echo "Usage: \$0 [options] 30 | options: 31 | --frame communication framework, defaults to $frame 32 | --world_size world size, defaults to $world_size 33 | --tensor_model_parallel_size tensor parallelism size, defaults to $tensor_model_parallel_size 34 | --pipeline_model_parallel pipeline parallelism size, defaults to $pipeline_model_parallel 35 | --global_batch global batch size, defaults to $global_batch 36 | --micro_batch micro batch size, defaults to $micro_batch 37 | --num_layers number of layers, defaults to $num_layers 38 | --seq_length sequence length, defaults to $seq_length 39 | --hidden_size hidden size, defaults to $hidden_size 40 | --epoch_num number of epochs, defaults to $epoch_num 41 | --use_distributed_optimizer use distributed optimizer 42 | --num_attention_heads number of attention heads, defaults to $num_attention_heads 43 | --aiob_enable enable AIOB 44 | --use_flash_attn use flash attention 45 | --swiglu use swiglu 46 | --ffn_hidden_size FFN hidden size 47 | --comp_filepath computation file path 48 | --max_position_embeddings max position embeddings, defaults to $max_position_embeddings 49 | -m, --model_size model size, defaults to $model_size (possible values: 175, 22, 13, 7, moe) 50 | --moe_enable enable moe 51 | --moe_router_topk Number of experts to route to for each token. 52 | --expert_model_parallel_size Degree of expert model parallelism 53 | --num_experts Number of experts in the MoE model. 54 | --moe_grouped_gemm apply grouped gemm 55 | -h, --help display this help and exit" 1>&2; exit 1; 56 | } 57 | 58 | 59 | while [ $# -gt 0 ] 60 | do 61 | 62 | case $1 in 63 | --gpu_type) 64 | gpu_type=$2; shift;; 65 | --frame) 66 | frame=$2; shift;; 67 | --world_size) 68 | world_size=$2; shift;; 69 | --tensor_model_parallel_size|--tp) 70 | tensor_model_parallel_size=$2; shift;; 71 | --pipeline_model_parallel|--pp) 72 | pipeline_model_parallel=$2; shift;; 73 | --global_batch) 74 | global_batch=$2; shift;; 75 | --micro_batch) 76 | micro_batch=$2; shift;; 77 | --num_layers) 78 | num_layers=$2; shift;; 79 | --seq_length) 80 | seq_length=$2; shift;; 81 | --hidden_size) 82 | hidden_size=$2; shift;; 83 | --epoch_num) 84 | epoch_num=$2; shift;; 85 | --num_attention_heads) 86 | num_attention_heads=$2; shift;; 87 | --aiob_enable|--aiob) 88 | aiob_enable=--aiob_enable;; 89 | --use_flash_attn|--flash_attn) 90 | use_flash_attn=--use_flash_attn;; 91 | --swiglu) 92 | swiglu=--swiglu;; 93 | --ffn_hidden_size) 94 | ffn_hidden_size=$2; shift;; 95 | --sp|--sp-enable) 96 | sp_enable=--enable_sequence_parallel;; 97 | --comp_filepath) 98 | comp_filepath=$2; shift;; 99 | -m|--model_size) 100 | model_size=$2; shift;; 101 | --max_position_embeddings) 102 | max_position_embeddings=$2; shift;; 103 | --moe_enable) 104 | moe_enable=--moe_enable;; 105 | --moe_router_topk|--topk) 106 | moe_router_topk=$2; shift;; 107 | --num_experts|--experts) 108 | num_experts=$2; shift;; 109 | --expert_model_parallel_size|--ep) 110 | expert_model_parallel_size=$2; shift;; 111 | --grouped_gemm|--moe_grouped_gemm) 112 | grouped_gemm=--moe_grouped_gemm;; 113 | --recompute_activations|--recompute) 114 | recompute_activations=--recompute_activations;; 115 | -h|--help) 116 | usage;; 117 | (*) 118 | break;; 119 | esac 120 | shift 121 | done 122 | 123 | 124 | case $model_size in 125 | 175) 126 | model_name=gpt_175B 127 | num_layers=96 128 | hidden_size=12288 129 | num_attention_heads=96 130 | tensor_model_parallel_size=8 131 | ;; 132 | 22) 133 | model_name=gpt_22B 134 | num_layers=48 135 | hidden_size=6144 136 | num_attention_heads=64 137 | tensor_model_parallel_size=8 138 | ;; 139 | 13) 140 | model_name=gpt_13B 141 | num_layers=40 142 | hidden_size=5120 143 | num_attention_heads=40 144 | ;; 145 | 7) 146 | model_name=gpt_7B 147 | num_layers=36 148 | hidden_size=4096 149 | num_attention_heads=32 150 | tensor_model_parallel_size=4 151 | ;; 152 | 405) 153 | model_name=llama_405B 154 | num_layers=128 155 | hidden_size=16384 156 | ffn_hidden_size=53248 157 | num_attention_heads=128 158 | ;; 159 | moe) 160 | model_name=Mixtral_8*7B 161 | num_layers=32 162 | hidden_size=4096 163 | num_attention_heads=32 164 | ffn_hidden_size=14336 165 | tensor_model_parallel_size=4 166 | moe_enable=--moe_enable 167 | grouped_gemm=--moe_grouped_gemm 168 | ;; 169 | (*) 170 | echo "Only support model size 175, 22,13 or 7; using default size 13" 171 | model_name=gpt_13B 172 | num_layers=40 173 | hidden_size=5120 174 | num_attention_heads=40 175 | ;; 176 | esac 177 | 178 | 179 | cmd="python -m workload_generator.AIOB_simAI_workload_generator \ 180 | --gpu_type=$gpu_type \ 181 | --frame=$frame \ 182 | --world_size=$world_size \ 183 | --tensor_model_parallel_size=$tensor_model_parallel_size \ 184 | --pipeline_model_parallel=$pipeline_model_parallel \ 185 | --global_batch=$global_batch \ 186 | --micro_batch=$micro_batch \ 187 | --num_layers=$num_layers \ 188 | --seq_length=$seq_length \ 189 | --hidden_size=$hidden_size \ 190 | --epoch_num=$epoch_num \ 191 | --num_attention_heads=$num_attention_heads \ 192 | --model_name=$model_name \ 193 | --max_position_embeddings=$max_position_embeddings \ 194 | --vocab_size=$vocab_size \ 195 | --use-distributed-optimizer 196 | ${aiob_enable} \ 197 | ${use_flash_attn} \ 198 | ${swiglu} \ 199 | ${sp_enable} \ 200 | ${recompute_activations} \ 201 | ${ffn_hidden_size:+--ffn_hidden_size=$ffn_hidden_size} \ 202 | ${comp_filepath:+--comp_filepath=$comp_filepath} \ 203 | ${moe_enable} \ 204 | ${moe_router_topk:+--moe_router_topk=$moe_router_topk} \ 205 | ${num_experts:+--num_experts=$num_experts} \ 206 | ${expert_model_parallel_size:+--expert_model_parallel_size=$expert_model_parallel_size} \ 207 | ${grouped_gemm} " \ 208 | 209 | echo $cmd 210 | 211 | 212 | $cmd 213 | -------------------------------------------------------------------------------- /scripts/run_in_cluster.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | """ 3 | Usage [{filename}]: 4 | 1. Change IMAGE_NAME from DUMMY to your real image name 5 | 2. Change IPLIST from DUMMY to your real /path/to/iplist (absolute path) 6 | 3. Change AICB_DIR from DUMMY to your real /path/to/aicb (absolute path) 7 | 4. Change the settings in run_suites.py to select the workload you want 8 | 5. Copy iplist and aicb to all participating servers at /path/to/iplist and /path/to/aicb, e.g., using `pscp` command like `pscp.pssh -h iplist iplist /path/to/iplist` and `pscp.pssh -h iplist -r aicb /path/to/aicb` 9 | 6. Run simulation on all participating servers, e.g., using `pssh` command like `pssh -i -h /path/to/iplist -o out -e err -t 0 "cd /path/to/aicb && python scripts/run_in_cluster.py"` 10 | """ 11 | 12 | import subprocess 13 | import os 14 | import re 15 | import sys 16 | 17 | filename = os.path.basename(__file__) 18 | __doc__ = __doc__.format(filename=filename) 19 | 20 | 21 | def get_local_ip(): 22 | output = os.popen("ifconfig").read().strip() 23 | pattern = r"inet (\d+.\d+.\d+.\d+) " 24 | return re.findall(pattern, output) 25 | 26 | 27 | def get_world_id_list(filename): 28 | with open(filename, "r") as f: 29 | return f.read().strip().split("\n") 30 | 31 | 32 | def get_docker_env_rank(filename): 33 | ip_list = get_world_id_list(filename) 34 | local_ip = get_local_ip() 35 | for ip in local_ip: 36 | if ip in ip_list: 37 | return len(ip_list), ip_list.index(ip), ip_list[0], 12345 38 | return -1, -1, -1, -1 39 | 40 | 41 | IPLIST = "DUMMY_IPLIST" # Change it to /path/to/iplist, e.g., /root/iplist 42 | AICB_DIR = "DUMMY_AICB_DIR" # Change it to /path/to/aicb, e.g., /root/aicb 43 | IMAGE_NAME = "DUMMY_IMAGE_NAME" # Change it to your docker image name, e.g., nvcr.io/nvidia/pytorch:xx.xx-py3 44 | 45 | if IPLIST == "DUMMY_IPLIST" or AICB_DIR == "DUMMY_AICB_DIR" or IMAGE_NAME == "DUMMY_IMAGE_NAME": 46 | sys.stderr.write(__doc__) 47 | sys.exit(1) 48 | 49 | WORLD_SIZE, RANK, MASTER_ADDR, MASTER_PORT = get_docker_env_rank(IPLIST) 50 | AICB_DIR_base = os.path.basename(AICB_DIR) 51 | command = f"""docker run --name aicb_test --gpus all --privileged \ 52 | --ulimit memlock=-1 --ulimit stack=67108864 \ 53 | --init -i --shm-size=4g --network=host --rm \ 54 | -e WORLD_SIZE={WORLD_SIZE} \ 55 | -e RANK={RANK} \ 56 | -e MASTER_ADDR={MASTER_ADDR} \ 57 | -e MASTER_PORT={MASTER_PORT} \ 58 | -v {AICB_DIR}:/workspace/{AICB_DIR_base} \ 59 | {IMAGE_NAME} /bin/sh -c 'cd /workspace/{AICB_DIR_base} && pwd && python run_suites.py' 60 | """ # Change the settings in run_suites.py to select the workload you want 61 | 62 | ret = subprocess.run(command, shell=True) 63 | print(ret) 64 | -------------------------------------------------------------------------------- /utils/benchmark_logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import sys 15 | import torch 16 | import logging 17 | from utils.timer import Timer 18 | from log_analyzer.log import Log, LogItem 19 | 20 | 21 | class LoggerFactory: 22 | 23 | @staticmethod 24 | def create_logger(name=None, level=logging.INFO): 25 | """create a logger 26 | 27 | Args: 28 | name (str): name of the logger 29 | level: level of logger 30 | 31 | Raises: 32 | ValueError is name is None 33 | """ 34 | 35 | if name is None: 36 | raise ValueError("name for logger cannot be None") 37 | 38 | formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] %(message)s") 39 | 40 | logger_ = logging.getLogger(name) 41 | logger_.setLevel(level) 42 | logger_.propagate = False 43 | ch = logging.StreamHandler(stream=sys.stdout) 44 | ch.setLevel(level) 45 | ch.setFormatter(formatter) 46 | logger_.addHandler(ch) 47 | return logger_ 48 | 49 | 50 | logger = LoggerFactory.create_logger(name="LLM_Comm_Benchmark", level=logging.INFO) 51 | 52 | 53 | class BenchLogger: 54 | def __init__(self): 55 | self.comm_log = Log() 56 | self.enable = True 57 | self.timer = Timer() 58 | self.epoch_timer = Timer(use_host_timer=True) 59 | self.epoch = 0 60 | self.epoch_timer.start() 61 | 62 | def log_timing(self, name): 63 | def decorator(func): 64 | def wrapper(*args, **kwargs): 65 | self.timer.start() 66 | result = func(*args, **kwargs) 67 | elapsed_time_ms = self.timer.stop() 68 | 69 | log_item = next((item for item in args if isinstance(item, LogItem))) 70 | if log_item.additional == 'overlap': 71 | log_item.elapsed_time = 0 72 | else: 73 | log_item.elapsed_time = elapsed_time_ms 74 | self.comm_log.add_comm_log(log_item) 75 | if torch.distributed.get_rank() == 0: 76 | logger.info(log_item.view_as_ds_log()) 77 | return result 78 | 79 | return wrapper 80 | 81 | return decorator 82 | 83 | def end_epoch(self, log_item): 84 | torch.cuda.synchronize() 85 | elapsed_time_ms = self.epoch_timer.stop() 86 | if torch.distributed.get_rank() == 0: 87 | logger.info( 88 | f"[RANK 0] --------epoch {self.epoch} | micro_step time {elapsed_time_ms:.2f} ---------\n" 89 | ) 90 | log_item.elapsed_time = elapsed_time_ms 91 | self.comm_log.add_comm_log(log_item) 92 | self.epoch += 1 93 | self.epoch_timer.start() 94 | 95 | def dump_log(self, filename): 96 | csv_filename = self.comm_log.dump(filename) 97 | return csv_filename 98 | 99 | def analyze_comm_log(self, print_fn=logger.info): 100 | return self.comm_log.analyze(print_fn) 101 | 102 | def analyze_comm_time(self, print_fn=logger.info): 103 | return self.comm_log.analyze_time(print_fn) 104 | 105 | 106 | bench_logger = BenchLogger() 107 | -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import time 15 | import torch 16 | 17 | 18 | class CudaEventTimer(object): 19 | """Borrowed from Deepspeed""" 20 | def __init__(self, start_event: torch.cuda.Event, end_event: torch.cuda.Event): 21 | self.start_event = start_event 22 | self.end_event = end_event 23 | 24 | def get_elapsed_msec(self): 25 | #torch.cuda.current_stream().wait_event(self.end_event) 26 | self.end_event.synchronize() 27 | return self.start_event.elapsed_time(self.end_event) 28 | 29 | 30 | class Timer: 31 | def __init__(self, use_host_timer=False): 32 | self.started_ = False 33 | self.use_host_timer = use_host_timer 34 | self.start_event = None 35 | self.start_time = 0.0 36 | 37 | def start(self): 38 | """Start the timer.""" 39 | assert not self.started_, f"timer has already been started" 40 | if self.use_host_timer: 41 | self.start_time = time.time() 42 | else: 43 | self.start_event = torch.cuda.Event(enable_timing=True) 44 | self.start_event.record() 45 | self.started_ = True 46 | 47 | def stop(self): 48 | """Stop the timer.""" 49 | assert self.started_, "timer is not started" 50 | self.started_ = False 51 | if self.use_host_timer: 52 | end_time = time.time() 53 | return (end_time - self.start_time) * 1000 54 | else: 55 | end_event = torch.cuda.Event(enable_timing=True) 56 | end_event.record() 57 | event_timer = CudaEventTimer(self.start_event, end_event) 58 | self.start_event = None 59 | return event_timer.get_elapsed_msec() 60 | -------------------------------------------------------------------------------- /visualize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/visualize/__init__.py -------------------------------------------------------------------------------- /visualize/example.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Data visualization report 6 | 7 | 50 | 51 | 52 |

Data visualization report

53 | 54 |
55 |

Communication type pie chart

56 |
57 |
58 | 59 |
60 |

Communication type scatter

61 |
62 |
63 | 64 |
65 |

message size CDF

66 |
67 |
68 | 69 |
70 |

Communication group scatter

71 |
72 |
73 | 74 |
75 |

Computation Communication Pattern

76 |
77 | {% for i in range(iteration_count) %} 78 | 79 | {% endfor %} 80 |
81 |
82 |
83 |
84 | 85 |
86 |

Ratio Pie

87 |
88 |
89 | 90 | 143 | 144 | -------------------------------------------------------------------------------- /workload/Workload_spec_v1.1.csv: -------------------------------------------------------------------------------- 1 | id Name Parameter_size Hidden_size Num_of_layers Attention_heads Sequence_length FFN_hidden_size Name World_size TP DP PP SP Zero_level expert parallel number Expert num TopK group_gemm reduce_bucket_size allgather_bucket_size prefetch_bucket_size max_live_parameters param_persistence_threshold 2 | 1 LLaMA_7B 7B 4096 32 32 2048 16384 Megatron 128 1 world_size/(PP*TP) 1 - - - - - - - - - - - 3 | 2 GPT_13B 13B 5120 40 32 2048 20480 Megatron 128 2 world_size/(PP*TP) 1 enable - - - - - - - - - - 4 | 3 GPT_22B 22B 6144 48 64 2048 24576 Megatron 128 4 world_size/(PP*TP) 1 - - - - - - - - - - - 5 | 4 LLaMA_65B 65B 8192 80 64 4096 28672 Megatron 128 8 world_size/(PP*TP) 2 enable - - - - - - - - - - 6 | 5 GPT_175B 175B 12288 96 96 2048 49152 Megatron 128 8 world_size/(PP*TP) 8 enable - - - - - - - - - - 7 | 6 GPT_175B 175B 12288 96 96 2048 49152 Megatron 128 8 world_size/(PP*TP) 8 disable - - - - - - - - - - 8 | 7 Llama3 405B 405B 16384 126 128 8192 53248 Megetron 128 8 world_size/(PP*TP) 16 enable - - - - - - - - 9 | 8 LLaMA_7B 7B 4096 32 32 4096 11008 Deepspeed 128 1 world_size 1 - 2 - - - - 1.00E+09 1.00E+09 - - - 10 | 9 LLaMA_65B 65B 8192 80 64 4096 28672 Deepspeed 128 1 world_size 1 - 3 - - - - 1.00E+09 - 1.00E+09 6.00E+08 1.00E+06 11 | 10 Mistral_8*7B 56B 4096 32 32 2048 14336 Megatron 128 2 world_size/(PP*TP) 1 enable - 8 8 2 true - - - - - -------------------------------------------------------------------------------- /workload/aiob_inputs/Example.txt: -------------------------------------------------------------------------------- 1 | train_iter:10 2 | Emb: 3 | time_gpu_max: 857 4 | time_gpu_min: 782 5 | time_gpu_avg: 799 6 | layernorm: 7 | time_gpu_max: 84 8 | time_gpu_min: 43 9 | time_gpu_avg: 70 10 | atten_qkv: 11 | time_gpu_max: 1255 12 | time_gpu_min: 862 13 | time_gpu_avg: 889 14 | atten_flash: 15 | time_gpu_max: 786 16 | time_gpu_min: 460 17 | time_gpu_avg: 512 18 | atten_linear: 19 | time_gpu_max: 939 20 | time_gpu_min: 333 21 | time_gpu_avg: 349 22 | layernorm2: 23 | time_gpu_max: 116 24 | time_gpu_min: 67 25 | time_gpu_avg: 71 26 | mlp_linear_1: 27 | time_gpu_max: 1876 28 | time_gpu_min: 1372 29 | time_gpu_avg: 1408 30 | mlp_gelu: 31 | time_gpu_max: 560 32 | time_gpu_min: 246 33 | time_gpu_avg: 257 34 | mlp_linear_2: 35 | time_gpu_max: 1183 36 | time_gpu_min: 720 37 | time_gpu_avg: 742 38 | layernorm_post: 39 | time_gpu_max: 73 40 | time_gpu_min: 68 41 | time_gpu_avg: 70 42 | logit_time: 43 | time_gpu_max: 15476 44 | time_gpu_min: 12435 45 | time_gpu_avg: 12824 46 | param_time: 47 | time_gpu_max: 23746 48 | time_gpu_min: 16647 49 | time_gpu_avg: 17374 50 | -------------------------------------------------------------------------------- /workload/physical/model_workload/G13B-M1-C01_GPT13B_megatron_tp8_pp1_mbs1.csv: -------------------------------------------------------------------------------- 1 | [rank0 ~ rank127] 2 | comm_type,comm_group,comm_group_size,msg_size,stage,dst,src,additional,_elapsed_time,algbw,busbw,count 3 | CommType.all_reduce,CommGroup.dp_group,16,8,init.model_setup,None,None,,None,None,None,1 4 | CommType.all_reduce,CommGroup.dp_group,16,8,init.model_setup,None,None,,None,None,None,1 5 | CommType.all_reduce,CommGroup.dp_group,16,8,init.model_setup,None,None,,None,None,None,1 6 | CommType.all_reduce,CommGroup.dp_group,16,8,init.model_setup,None,None,,None,None,None,1 7 | CommType.all_gather,CommGroup.dp_group,16,32,init.model_setup,None,None,,None,None,None,1 8 | CommType.broadcast,CommGroup.tp_group,8,24,init.model_setup,None,0,,None,None,None,1 9 | CommType.all_gather,CommGroup.dp_group,16,64,init.model_setup,None,None,,None,None,None,1 10 | CommType.epoch_end,None,None,0,,None,None,,None,None,None,1 11 | CommType.broadcast,CommGroup.tp_group,8,40,forward_step,None,0,,None,None,None,1 12 | CommType.broadcast,CommGroup.tp_group,8,17408,forward_step,None,0,,None,None,None,1 13 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronEmbedding,None,None,,None,None,None,1 14 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 15 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 16 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 17 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 18 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 19 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 20 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 21 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 22 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 23 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 24 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 25 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 26 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 27 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 28 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 29 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 30 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 31 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 32 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 33 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 34 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 35 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 36 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 37 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 38 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 39 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 40 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 41 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 42 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 43 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 44 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 45 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 46 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 47 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 48 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 49 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 50 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 51 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 52 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 53 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 54 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 55 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 56 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 57 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 58 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 59 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 60 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 61 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 62 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 63 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 64 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 65 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 66 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 67 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 68 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 69 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 70 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 71 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 72 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 73 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 74 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 75 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 76 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 77 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 78 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 79 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 80 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 81 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 82 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 83 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 84 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 85 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 86 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 87 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 88 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 89 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 90 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 91 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 92 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 93 | CommType.all_reduce,CommGroup.tp_group,8,20971520,forward.MegatronRowLinear,None,None,,None,None,None,1 94 | CommType.all_reduce,CommGroup.tp_group,8,8192,forward_step._VocabParallelCrossEntropy,None,None,,None,None,None,1 95 | CommType.all_reduce,CommGroup.tp_group,8,8192,forward_step._VocabParallelCrossEntropy,None,None,,None,None,None,1 96 | CommType.all_reduce,CommGroup.tp_group,8,8192,forward_step._VocabParallelCrossEntropy,None,None,,None,None,None,1 97 | CommType.all_reduce,CommGroup.dp_group,16,4,forward_step.average_losses_across_data_parallel_group,None,None,,None,None,None,1 98 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 99 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 100 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 101 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 102 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 103 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 104 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 105 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 106 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 107 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 108 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 109 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 110 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 111 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 112 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 113 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 114 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 115 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 116 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 117 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 118 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 119 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 120 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 121 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 122 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 123 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 124 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 125 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 126 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 127 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 128 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 129 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 130 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 131 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 132 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 133 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 134 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 135 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 136 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 137 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 138 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 139 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 140 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 141 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 142 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 143 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 144 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 145 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 146 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 147 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 148 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 149 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 150 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 151 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 152 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 153 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 154 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 155 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 156 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 157 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 158 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 159 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 160 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 161 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 162 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 163 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 164 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 165 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 166 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 167 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 168 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 169 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 170 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 171 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 172 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 173 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 174 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 175 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 176 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 177 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronColumnLinear,None,None,,None,None,None,1 178 | CommType.all_reduce,CommGroup.tp_group,8,20971520,backward.MegatronEmbedding,None,None,,None,None,None,1 179 | CommType.reduce_scatter,CommGroup.dp_group,16,6422958080,step,None,None,,None,None,None,1 180 | CommType.all_gather,CommGroup.dp_group,16,3211479040,step,None,None,,None,None,None,1 181 | CommType.all_reduce,CommGroup.tp_group,8,0,step._allreduce_layernorm_grads,None,None,,None,None,None,1 182 | CommType.all_reduce,CommGroup.tp_group,8,4,step.check_for_nan,None,None,,None,None,None,1 183 | CommType.epoch_end,None,None,0,,None,None,,None,None,None,1 184 | -------------------------------------------------------------------------------- /workload/physical/model_workload/L7B-M1-C04_Llama7B_megatron_tp2_pp1_mbs1.csv: -------------------------------------------------------------------------------- 1 | [rank0 ~ rank127] 2 | comm_type,comm_group,comm_group_size,msg_size,stage,dst,src,additional,_elapsed_time,algbw,busbw,count 3 | CommType.all_reduce,CommGroup.dp_group,64,8,init.model_setup,None,None,,None,None,None,1 4 | CommType.all_reduce,CommGroup.dp_group,64,8,init.model_setup,None,None,,None,None,None,1 5 | CommType.all_reduce,CommGroup.dp_group,64,8,init.model_setup,None,None,,None,None,None,1 6 | CommType.all_reduce,CommGroup.dp_group,64,8,init.model_setup,None,None,,None,None,None,1 7 | CommType.all_gather,CommGroup.dp_group,64,32,init.model_setup,None,None,,None,None,None,1 8 | CommType.broadcast,CommGroup.tp_group,2,24,init.model_setup,None,0,,None,None,None,1 9 | CommType.all_gather,CommGroup.dp_group,64,64,init.model_setup,None,None,,None,None,None,1 10 | CommType.epoch_end,None,None,0,,None,None,,None,None,None,1 11 | CommType.broadcast,CommGroup.tp_group,2,40,forward_step,None,0,,None,None,None,1 12 | CommType.broadcast,CommGroup.tp_group,2,17408,forward_step,None,0,,None,None,None,1 13 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronEmbedding,None,None,,None,None,None,1 14 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 15 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 16 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 17 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 18 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 19 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 20 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 21 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 22 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 23 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 24 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 25 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 26 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 27 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 28 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 29 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 30 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 31 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 32 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 33 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 34 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 35 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 36 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 37 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 38 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 39 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 40 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 41 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 42 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 43 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 44 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 45 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 46 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 47 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 48 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 49 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 50 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 51 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 52 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 53 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 54 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 55 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 56 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 57 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 58 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 59 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 60 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 61 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 62 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 63 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 64 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 65 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 66 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 67 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 68 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 69 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 70 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 71 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 72 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 73 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 74 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 75 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 76 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 77 | CommType.all_reduce,CommGroup.tp_group,2,16777216,forward.MegatronRowLinear,None,None,,None,None,None,1 78 | CommType.all_reduce,CommGroup.tp_group,2,8192,forward_step._VocabParallelCrossEntropy,None,None,,None,None,None,1 79 | CommType.all_reduce,CommGroup.tp_group,2,8192,forward_step._VocabParallelCrossEntropy,None,None,,None,None,None,1 80 | CommType.all_reduce,CommGroup.tp_group,2,8192,forward_step._VocabParallelCrossEntropy,None,None,,None,None,None,1 81 | CommType.all_reduce,CommGroup.dp_group,64,4,forward_step.average_losses_across_data_parallel_group,None,None,,None,None,None,1 82 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 83 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 84 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 85 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 86 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 87 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 88 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 89 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 90 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 91 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 92 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 93 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 94 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 95 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 96 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 97 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 98 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 99 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 100 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 101 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 102 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 103 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 104 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 105 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 106 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 107 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 108 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 109 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 110 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 111 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 112 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 113 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 114 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 115 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 116 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 117 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 118 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 119 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 120 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 121 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 122 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 123 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 124 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 125 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 126 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 127 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 128 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 129 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 130 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 131 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 132 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 133 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 134 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 135 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 136 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 137 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 138 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 139 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 140 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 141 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 142 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 143 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 144 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 145 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronColumnLinear,None,None,,None,None,None,1 146 | CommType.all_reduce,CommGroup.tp_group,2,16777216,backward.MegatronEmbedding,None,None,,None,None,None,1 147 | CommType.reduce_scatter,CommGroup.dp_group,64,10366697472,step,None,None,,None,None,None,1 148 | CommType.all_gather,CommGroup.dp_group,64,5183348736,step,None,None,,None,None,None,1 149 | CommType.all_reduce,CommGroup.tp_group,2,0,step._allreduce_layernorm_grads,None,None,,None,None,None,1 150 | CommType.all_reduce,CommGroup.tp_group,2,4,step.check_for_nan,None,None,,None,None,None,1 151 | CommType.epoch_end,None,None,0,,None,None,,None,None,None,1 152 | -------------------------------------------------------------------------------- /workload/simAI/micro_test/all_gather.txt: -------------------------------------------------------------------------------- 1 | MICRO 2 | 19 3 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 4096 1 4 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 8192 1 5 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 16384 1 6 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 32768 1 7 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 65536 1 8 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 131072 1 9 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 262144 1 10 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 524288 1 11 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 1048576 1 12 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 2097152 1 13 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 4194304 1 14 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 8388608 1 15 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 16777216 1 16 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 33554432 1 17 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 67108864 1 18 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 134217728 1 19 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 268435456 1 20 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 536870912 1 21 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLGATHER 1073741824 1 22 | -------------------------------------------------------------------------------- /workload/simAI/micro_test/all_reduce.txt: -------------------------------------------------------------------------------- 1 | MICRO 2 | 19 3 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 4096 1 4 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 8192 1 5 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 16384 1 6 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 32768 1 7 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 65536 1 8 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 131072 1 9 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 262144 1 10 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 524288 1 11 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 1048576 1 12 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 2097152 1 13 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 4194304 1 14 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 8388608 1 15 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 16777216 1 16 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 33554432 1 17 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 67108864 1 18 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 134217728 1 19 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 268435456 1 20 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 536870912 1 21 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 1073741824 1 22 | -------------------------------------------------------------------------------- /workload/simAI/micro_test/all_to_all.txt: -------------------------------------------------------------------------------- 1 | MICRO 2 | 19 3 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 4096 1 4 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 8192 1 5 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 16384 1 6 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 32768 1 7 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 65536 1 8 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 131072 1 9 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 262144 1 10 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 524288 1 11 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 1048576 1 12 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 2097152 1 13 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 4194304 1 14 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 8388608 1 15 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 16777216 1 16 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 33554432 1 17 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 67108864 1 18 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 134217728 1 19 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 268435456 1 20 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 536870912 1 21 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLTOALL 1073741824 1 22 | -------------------------------------------------------------------------------- /workload/simAI/micro_test/muti_all_reduce.txt: -------------------------------------------------------------------------------- 1 | HYBRID_TRANSFORMER_FWD_IN_BCKWD model_parallel_NPU_group: 8 checkpoints: 0 checkpoint_initiates: 0 2 | 19 3 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 4096 1 4 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 8192 1 5 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 16384 1 6 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 32768 1 7 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 65536 1 8 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 131072 1 9 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 262144 1 10 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 524288 1 11 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 1048576 1 12 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 2097152 1 13 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 4194304 1 14 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 8388608 1 15 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 16777216 1 16 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 33554432 1 17 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 67108864 1 18 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 134217728 1 19 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 268435456 1 20 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 536870912 1 21 | micro_test -1 1 NONE 0 1 NONE 0 1 ALLREDUCE 1073741824 1 22 | -------------------------------------------------------------------------------- /workload/simAI/model_workload/G13B-M1-C01_GPT13B_megatron_tp8_pp1_mbs1_A100.txt: -------------------------------------------------------------------------------- 1 | HYBRID_TRANSFORMER_FWD_IN_BCKWD model_parallel_NPU_group: 8 checkpoints: 0 checkpoint_initiates: 0 2 | 92 3 | norm -1 0 BROADCAST 16384 1 NONE 0 1 NONE 0 100 4 | grad_norm -1 14565000 ALLGATHER 3248619520 13391000 NONE 0 1 REDUCESCATTER 6497239040 100 5 | layernorm -1 1 NONE 0 1 ALLREDUCE 3248619520 1 NONE 0 100 6 | embedding_layer -1 1 ALLREDUCE 20971520 1 ALLREDUCE 20971520 1 NONE 0 100 7 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 8 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 9 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 10 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 11 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 12 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 13 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 14 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 15 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 16 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 17 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 18 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 19 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 20 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 21 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 22 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 23 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 24 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 25 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 26 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 27 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 28 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 29 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 30 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 31 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 32 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 33 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 34 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 35 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 36 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 37 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 38 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 39 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 40 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 41 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 42 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 43 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 44 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 45 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 46 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 47 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 48 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 49 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 50 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 51 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 52 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 53 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 54 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 55 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 56 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 57 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 58 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 59 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 60 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 61 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 62 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 63 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 64 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 65 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 66 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 67 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 68 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 69 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 70 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 71 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 72 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 73 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 74 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 75 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 76 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 77 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 78 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 79 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 80 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 81 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 82 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 83 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 84 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 85 | attention_layer -1 752000 ALLREDUCE 20971520 752000 ALLREDUCE 20971520 752000 NONE 0 100 86 | mlp_layer -1 750000 ALLREDUCE 20971520 750000 ALLREDUCE 20971520 750000 NONE 0 100 87 | embedding_norm -1 1 ALLREDUCE 327680000 1 NONE 0 1 NONE 0 100 88 | cross_entropy1 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 89 | cross_entropy2 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 90 | cross_entropy3 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 91 | optimizer1 -1 1 ALLREDUCE 4 1 NONE 0 1 NONE 0 100 92 | optimizer2 -1 1 ALLREDUCE 4 1 NONE 0 1 NONE 0 100 93 | optimizer3 -1 1 ALLREDUCE 4 1 NONE 0 1 NONE 0 100 94 | optimizer4 -1 1 ALLREDUCE 4 1 NONE 0 1 NONE 0 100 95 | -------------------------------------------------------------------------------- /workload/simAI/model_workload/G13B-M1-C02_GPT13B_megatron_tp8_pp1_mbs1_sp_A100.txt: -------------------------------------------------------------------------------- 1 | HYBRID_TRANSFORMER_FWD_IN_BCKWD model_parallel_NPU_group: 8 checkpoints: 0 checkpoint_initiates: 0 2 | 172 3 | norm -1 0 BROADCAST 16384 1 NONE 0 1 NONE 0 100 4 | grad_norm -1 14565000 ALLGATHER 3248619520 13391000 NONE 0 1 REDUCESCATTER 6497239040 100 5 | layernorm -1 1 NONE 0 1 ALLREDUCE 3248619520 1 NONE 0 100 6 | embedding_layer -1 1 ALLREDUCE 20971520 1 ALLREDUCE 20971520 1 NONE 0 100 7 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 8 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 9 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 10 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 11 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 12 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 13 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 14 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 15 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 16 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 17 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 18 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 19 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 20 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 21 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 22 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 23 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 24 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 25 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 26 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 27 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 28 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 29 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 30 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 31 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 32 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 33 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 34 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 35 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 36 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 37 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 38 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 39 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 40 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 41 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 42 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 43 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 44 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 45 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 46 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 47 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 48 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 49 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 50 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 51 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 52 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 53 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 54 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 55 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 56 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 57 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 58 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 59 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 60 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 61 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 62 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 63 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 64 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 65 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 66 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 67 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 68 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 69 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 70 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 71 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 72 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 73 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 74 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 75 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 76 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 77 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 78 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 79 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 80 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 81 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 82 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 83 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 84 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 85 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 86 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 87 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 88 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 89 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 90 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 91 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 92 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 93 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 94 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 95 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 96 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 97 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 98 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 99 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 100 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 101 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 102 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 103 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 104 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 105 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 106 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 107 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 108 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 109 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 110 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 111 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 112 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 113 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 114 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 115 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 116 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 117 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 118 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 119 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 120 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 121 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 122 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 123 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 124 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 125 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 126 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 127 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 128 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 129 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 130 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 131 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 132 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 133 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 134 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 135 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 136 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 137 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 138 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 139 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 140 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 141 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 142 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 143 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 144 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 145 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 146 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 147 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 148 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 149 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 150 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 151 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 152 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 153 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 154 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 155 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 156 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 157 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 158 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 159 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 160 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 161 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 162 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 163 | attention_column_layer -1 375000 ALLGATHER 20971520 375000 REDUCESCATTER 20971520 375000 NONE 0 100 164 | attention_row_layer -1 376000 REDUCESCATTER 20971520 376000 ALLGATHER 20971520 376000 NONE 0 100 165 | mlp_column_layer -1 923000 ALLGATHER 20971520 923000 REDUCESCATTER 20971520 923000 NONE 0 100 166 | mlp_row_layer -1 1125000 REDUCESCATTER 20971520 1125000 ALLGATHER 20971520 1125000 NONE 0 100 167 | embedding_norm -1 1 ALLREDUCE 327680000 1 NONE 0 1 NONE 0 100 168 | cross_entropy1 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 169 | cross_entropy2 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 170 | cross_entropy3 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 171 | optimizer1 -1 1 ALLREDUCE 4 1 NONE 0 1 NONE 0 100 172 | optimizer2 -1 1 ALLREDUCE 4 1 NONE 0 1 NONE 0 100 173 | optimizer3 -1 1 ALLREDUCE 4 1 NONE 0 1 NONE 0 100 174 | optimizer4 -1 1 ALLREDUCE 4 1 NONE 0 1 NONE 0 100 175 | -------------------------------------------------------------------------------- /workload/simAI/model_workload/G175B-M1-C03_GPT175B_megatron_tp8_pp1_mbs1_A100.txt: -------------------------------------------------------------------------------- 1 | HYBRID_TRANSFORMER_FWD_IN_BCKWD model_parallel_NPU_group: 8 checkpoints: 0 checkpoint_initiates: 0 2 | 33 3 | grad_gather -1 1 NONE 0 13848000 NONE 0 1 ALLGATHER 1 100 4 | grad_param -1 1 NONE 0 13740000 NONE 0 1 REDUCESCATTER 11175321600 100 5 | layernorm -1 1 NONE 0 1 ALLREDUCE 5587660800 1 NONE 0 100 6 | embedding_grads -1 1 NONE 0 1 ALLREDUCE 150994944 1 NONE 0 100 7 | embedding_layer -1 523895 ALLREDUCE 150994944 1 NONE 0 1 NONE 0 100 8 | attention_norm -1 2571000 ALLREDUCE 50331648 2454000 NONE 0 2454000 NONE 0 100 9 | mlp_norm -1 1584000 ALLREDUCE 50331648 2081000 NONE 0 2081000 NONE 0 100 10 | attention_norm -1 2571000 ALLREDUCE 50331648 2454000 NONE 0 2454000 NONE 0 100 11 | mlp_norm -1 1584000 ALLREDUCE 50331648 2081000 NONE 0 2081000 NONE 0 100 12 | attention_norm -1 2571000 ALLREDUCE 50331648 2454000 NONE 0 2454000 NONE 0 100 13 | mlp_norm -1 1584000 ALLREDUCE 50331648 2081000 NONE 0 2081000 NONE 0 100 14 | attention_norm -1 2571000 ALLREDUCE 50331648 2454000 NONE 0 2454000 NONE 0 100 15 | mlp_norm -1 1584000 ALLREDUCE 50331648 2081000 NONE 0 2081000 NONE 0 100 16 | attention_norm -1 2571000 ALLREDUCE 50331648 2454000 NONE 0 2454000 NONE 0 100 17 | mlp_norm -1 1584000 ALLREDUCE 50331648 2081000 NONE 0 2081000 NONE 0 100 18 | attention_norm -1 2571000 ALLREDUCE 50331648 2454000 NONE 0 2454000 NONE 0 100 19 | mlp_norm -1 1584000 ALLREDUCE 50331648 2081000 NONE 0 2081000 NONE 0 100 20 | attention_norm -1 2571000 ALLREDUCE 50331648 2454000 NONE 0 2454000 NONE 0 100 21 | mlp_norm -1 1584000 ALLREDUCE 50331648 2081000 NONE 0 2081000 NONE 0 100 22 | attention_norm -1 2571000 ALLREDUCE 50331648 2454000 NONE 0 2454000 NONE 0 100 23 | mlp_norm -1 1584000 ALLREDUCE 50331648 2081000 NONE 0 2081000 NONE 0 100 24 | attention_norm -1 2571000 ALLREDUCE 50331648 2454000 NONE 0 2454000 NONE 0 100 25 | mlp_norm -1 1584000 ALLREDUCE 50331648 2081000 NONE 0 2081000 NONE 0 100 26 | attention_norm -1 2571000 ALLREDUCE 50331648 2454000 NONE 0 2454000 NONE 0 100 27 | mlp_norm -1 1584000 ALLREDUCE 50331648 2081000 NONE 0 2081000 NONE 0 100 28 | attention_norm -1 2571000 ALLREDUCE 50331648 2454000 NONE 0 2454000 NONE 0 100 29 | mlp_norm -1 1584000 ALLREDUCE 50331648 2081000 NONE 0 2081000 NONE 0 100 30 | attention_norm -1 2571000 ALLREDUCE 50331648 2454000 NONE 0 2454000 NONE 0 100 31 | mlp_norm -1 1584000 ALLREDUCE 50331648 2081000 NONE 0 2081000 NONE 0 100 32 | embedding_norm -1 1 ALLREDUCE 786432000 1 NONE 0 1 NONE 0 100 33 | cross_entropy1 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 34 | cross_entropy2 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 35 | cross_entropy3 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 -------------------------------------------------------------------------------- /workload/simAI/model_workload/L65B-M1-C05_Llama65B_megatron_tp8_pp1_mbs1_A100.txt: -------------------------------------------------------------------------------- 1 | HYBRID_TRANSFORMER_FWD_IN_BCKWD model_parallel_NPU_group: 8 checkpoints: 0 checkpoint_initiates: 0 2 | 839 3 | grad_gather -1 1 NONE 0 4284000 NONE 0 1 ALLGATHER 1 100 4 | grad_param -1 1 NONE 0 7300000 NONE 0 1 REDUCESCATTER 6236471296 100 5 | layernorm -1 1 NONE 0 1 ALLREDUCE 2315808768 1 NONE 0 100 6 | embedding_grads -1 1 NONE 0 1 ALLREDUCE 65536000 1 NONE 0 100 7 | embedding_layer -1 382000 ALLREDUCE 65536000 1 NONE 0 1 NONE 0 100 8 | attention_norm -1 1711000 ALLREDUCE 33554432 1500000 NONE 0 1500000 NONE 0 100 9 | mlp_norm -1 862000 ALLREDUCE 33554432 1200000 NONE 0 1200000 NONE 0 100 10 | attention_norm -1 1711000 ALLREDUCE 33554432 1500000 NONE 0 1500000 NONE 0 100 11 | mlp_norm -1 862000 ALLREDUCE 33554432 1200000 NONE 0 1200000 NONE 0 100 12 | attention_norm -1 1711000 ALLREDUCE 33554432 1500000 NONE 0 1500000 NONE 0 100 13 | mlp_norm -1 862000 ALLREDUCE 33554432 1200000 NONE 0 1200000 NONE 0 100 14 | attention_norm -1 1711000 ALLREDUCE 33554432 1500000 NONE 0 1500000 NONE 0 100 15 | mlp_norm -1 862000 ALLREDUCE 33554432 1200000 NONE 0 1200000 NONE 0 100 16 | attention_norm -1 1711000 ALLREDUCE 33554432 1500000 NONE 0 1500000 NONE 0 100 17 | mlp_norm -1 862000 ALLREDUCE 33554432 1200000 NONE 0 1200000 NONE 0 100 18 | attention_norm -1 1711000 ALLREDUCE 33554432 1500000 NONE 0 1500000 NONE 0 100 19 | mlp_norm -1 862000 ALLREDUCE 33554432 1200000 NONE 0 1200000 NONE 0 100 20 | attention_norm -1 1711000 ALLREDUCE 33554432 1500000 NONE 0 1500000 NONE 0 100 21 | mlp_norm -1 862000 ALLREDUCE 33554432 1200000 NONE 0 1200000 NONE 0 100 22 | attention_norm -1 1711000 ALLREDUCE 33554432 1500000 NONE 0 1500000 NONE 0 100 23 | mlp_norm -1 862000 ALLREDUCE 33554432 1200000 NONE 0 1200000 NONE 0 100 24 | attention_norm -1 1711000 ALLREDUCE 33554432 1500000 NONE 0 1500000 NONE 0 100 25 | mlp_norm -1 862000 ALLREDUCE 33554432 1200000 NONE 0 1200000 NONE 0 100 26 | attention_norm -1 1711000 ALLREDUCE 33554432 1500000 NONE 0 1500000 NONE 0 100 27 | mlp_norm -1 862000 ALLREDUCE 33554432 1200000 NONE 0 1200000 NONE 0 100 28 | attention_norm -1 1711000 ALLREDUCE 33554432 1500000 NONE 0 1500000 NONE 0 100 29 | mlp_norm -1 862000 ALLREDUCE 33554432 1200000 NONE 0 1200000 NONE 0 100 30 | attention_norm -1 1711000 ALLREDUCE 33554432 1500000 NONE 0 1500000 NONE 0 100 31 | mlp_norm -1 862000 ALLREDUCE 33554432 1200000 NONE 0 1200000 NONE 0 100 32 | embedding_norm -1 1 ALLREDUCE 524288000 1 NONE 0 1 NONE 0 100 33 | cross_entropy1 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 34 | cross_entropy2 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 35 | cross_entropy3 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 -------------------------------------------------------------------------------- /workload/simAI/model_workload/L65B_D1_C08_Llama65B_deepspeed_zero3_A100.txt: -------------------------------------------------------------------------------- 1 | HYBRID_TRANSFORMER_FWD_IN_BCKWD model_parallel_NPU_group: 1 checkpoints: 0 checkpoint_initiates: 0 2 | 12 3 | Llama_layer1 -1 14356000 NONE 0 4756000 ALLGATHER 357468 1 REDUCESCATTER 1889785610 100 4 | Llama_layer2 -1 10948000 NONE 0 19502000 ALLGATHER 119162 1 NONE 0 100 5 | Llama_layer3 -1 7168000 NONE 0 13844000 ALLGATHER 317843 1 NONE 0 100 6 | Llama_layer4 -1 17976000 NONE 0 14664000 ALLGATHER 317843 1 NONE 0 100 7 | Llama_layer5 -1 14752000 NONE 0 4607000 ALLGATHER 317843 1 NONE 0 100 8 | Llama_layer6 -1 8000000 NONE 0 46148000 ALLGATHER 357468 1 NONE 0 100 9 | Llama_layer7 -1 7168000 NONE 0 20196000 ALLGATHER 119162 1 NONE 0 100 10 | Llama_layer8 -1 1760000 NONE 0 13844000 ALLGATHER 317843 1 NONE 0 100 11 | Llama_layer9 -1 15236000 NONE 0 15236000 ALLGATHER 317843 1 NONE 0 100 12 | Llama_layer10 -1 7008000 NONE 0 16036000 ALLGATHER 317843 1 NONE 0 100 13 | Llama_layer11 -1 16036000 NONE 0 4607000 ALLGATHER 317843 1 NONE 0 100 14 | Llama_layer12 -1 7064500 NONE 0 21750000 ALLGATHER 317843 1 NONE 0 100 -------------------------------------------------------------------------------- /workload/simAI/model_workload/L7B-D1-C02_Llama7B_deepspeed_zero3_A100.txt: -------------------------------------------------------------------------------- 1 | HYBRID_TRANSFORMER_FWD_IN_BCKWD model_parallel_NPU_group: 1 checkpoints: 0 checkpoint_initiates: 0 2 | 12 3 | Llama_layer1 -1 14356000 NONE 0 4756000 ALLGATHER 357468 1 REDUCESCATTER 1889785610 100 4 | Llama_layer2 -1 10948000 NONE 0 19502000 ALLGATHER 119162 1 NONE 0 100 5 | Llama_layer3 -1 7168000 NONE 0 13844000 ALLGATHER 317843 1 NONE 0 100 6 | Llama_layer4 -1 17976000 NONE 0 14664000 ALLGATHER 317843 1 NONE 0 100 7 | Llama_layer5 -1 14752000 NONE 0 4607000 ALLGATHER 317843 1 NONE 0 100 8 | Llama_layer6 -1 8000000 NONE 0 46148000 ALLGATHER 357468 1 NONE 0 100 9 | Llama_layer7 -1 7168000 NONE 0 20196000 ALLGATHER 119162 1 NONE 0 100 10 | Llama_layer8 -1 1760000 NONE 0 13844000 ALLGATHER 317843 1 NONE 0 100 11 | Llama_layer9 -1 15236000 NONE 0 15236000 ALLGATHER 317843 1 NONE 0 100 12 | Llama_layer10 -1 7008000 NONE 0 16036000 ALLGATHER 317843 1 NONE 0 100 13 | Llama_layer11 -1 16036000 NONE 0 4607000 ALLGATHER 317843 1 NONE 0 100 14 | Llama_layer12 -1 7064500 NONE 0 21750000 ALLGATHER 317843 1 NONE 0 100 -------------------------------------------------------------------------------- /workload/simAI/model_workload/L7B-M1-C04_Llama7B_megatron_tp2_pp1_mbs1_A100.txt: -------------------------------------------------------------------------------- 1 | HYBRID_TRANSFORMER_FWD_IN_BCKWD model_parallel_NPU_group: 2 checkpoints: 0 checkpoint_initiates: 0 2 | 76 3 | norm -1 0 BROADCAST 16384 1 NONE 0 1 NONE 0 100 4 | grad_norm -1 1 ALLGATHER 6754926592 1 NONE 0 1 REDUCESCATTER 13509853184 100 5 | layernorm -1 1 NONE 0 1 ALLREDUCE 6754926592 1 NONE 0 100 6 | embedding_layer -1 1 ALLREDUCE 16777216 1 ALLREDUCE 16777216 1 NONE 0 100 7 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 8 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 9 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 10 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 11 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 12 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 13 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 14 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 15 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 16 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 17 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 18 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 19 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 20 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 21 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 22 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 23 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 24 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 25 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 26 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 27 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 28 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 29 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 30 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 31 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 32 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 33 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 34 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 35 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 36 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 37 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 38 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 39 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 40 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 41 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 42 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 43 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 44 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 45 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 46 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 47 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 48 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 49 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 50 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 51 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 52 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 53 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 54 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 55 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 56 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 57 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 58 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 59 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 60 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 61 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 62 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 63 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 64 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 65 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 66 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 67 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 68 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 69 | attention_layer -1 539986 ALLREDUCE 16777216 592000 ALLREDUCE 16777216 592000 NONE 0 100 70 | mlp_layer -1 486330 ALLREDUCE 16777216 529500 ALLREDUCE 16777216 529500 NONE 0 100 71 | embedding_norm -1 1 ALLREDUCE 262144000 1 NONE 0 1 NONE 0 100 72 | cross_entropy1 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 73 | cross_entropy2 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 74 | cross_entropy3 -1 1 ALLREDUCE 8192 1 NONE 0 1 NONE 0 100 75 | optimizer1 -1 1 ALLREDUCE 4 1 NONE 0 1 NONE 0 100 76 | optimizer2 -1 1 ALLREDUCE 4 1 NONE 0 1 NONE 0 100 77 | optimizer3 -1 1 ALLREDUCE 4 1 NONE 0 1 NONE 0 100 78 | optimizer4 -1 1 ALLREDUCE 4 1 NONE 0 1 NONE 0 100 79 | -------------------------------------------------------------------------------- /workload_applyer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | import torch 14 | import sys 15 | import math 16 | import time 17 | from utils.utils import WorkloadWriter, CommGroup, CommType, ReduceOp 18 | from utils.benchmark_logger import bench_logger 19 | import utils.utils as utils 20 | 21 | 22 | class WorkloadApplyer: 23 | def __init__(self, workload=None, args=None, filename=None) -> None: 24 | if workload is None or args is None: 25 | assert ( 26 | filename is None 27 | ), f"you should either pass workload,args or filename to init WorkloadApplyer" 28 | workload, args = WorkloadWriter.load_workload(filename) 29 | # if not hasattr(args, "backend"): 30 | # args.backend = "nccl" 31 | # torch.distributed.init_process_group(backend=args.backend) 32 | self.args = args 33 | world_size = torch.distributed.get_world_size() 34 | # args.rank = torch.distributed.get_rank() 35 | if args.world_size != world_size: 36 | print( 37 | f"WARNNING: world_size is {args.world_size} when generating workload, but now world size is {world_size}" 38 | ) 39 | args.world_size = torch.distributed.get_world_size() 40 | device_count = torch.cuda.device_count() 41 | self.device = args.rank % device_count 42 | torch.cuda.set_device(self.device) 43 | self.device = torch.cuda.current_device() 44 | self.comm_group_info, self.pp_global_rank_info = ( 45 | self._generate_dp_tp_pp_ep_groups() 46 | ) 47 | self.workload = workload 48 | self.comm_type_function = { 49 | CommType.barrier: self._apply_barrier, 50 | CommType.broadcast: self._apply_broadcast, 51 | CommType.reduce: self._apply_reduce, 52 | CommType.all_reduce: self._apply_all_reduce, 53 | CommType.all_gather: self._apply_all_gather, 54 | CommType.reduce_scatter: self._apply_reduce_scatter, 55 | CommType.isend: self._apply_p2pcommunication, 56 | CommType.irecv: self._apply_p2pcommunication, 57 | CommType.all_gather_into_tensor: self._apply_all_gather, 58 | CommType.reduce_scatter_tensor: self._apply_reduce_scatter, 59 | CommType.computation: self._apply_computation, 60 | CommType.all_to_all: self._apply_all_to_all, 61 | CommType.epoch_end: bench_logger.end_epoch, 62 | 63 | } 64 | 65 | cal_tuple_num = lambda t: math.prod(t[0]) + math.prod(t[1]) 66 | max_msg_size = max( 67 | [ 68 | ( 69 | item.msg_size 70 | if isinstance(item.msg_size, int) 71 | else cal_tuple_num(item.msg_size) 72 | ) 73 | for item in self.workload.workload 74 | ] 75 | ) 76 | self.gemm_cache = {} 77 | self.computation_aiob = False 78 | if args.aiob_enable and args.frame == "Megatron": 79 | self.computation_aiob = True 80 | 81 | self.skip_computation = False 82 | self.always_apply_gemm = False 83 | self.gemm_iters = 1 if self.always_apply_gemm else 50 84 | self.buffer = torch.empty( 85 | (max_msg_size,), dtype=torch.bfloat16, device=self.device 86 | ) 87 | def _generate_dp_tp_pp_ep_groups(self): 88 | """Borrow from Megatron-LM""" 89 | all_data_parallel_group_ranks = [] 90 | world_size = self.args.world_size 91 | rank = torch.distributed.get_rank() 92 | self.rank = rank 93 | tensor_model_parallel_size, pipeline_model_parallel_size, data_parallel_size,expert_model_parallel_size = ( 94 | self.args.tensor_model_parallel_size, 95 | self.args.pipeline_model_parallel, 96 | self.args.dp_num, 97 | self.args.expert_model_parallel_size, 98 | ) 99 | rank_generator = utils.RankGenerator( 100 | tp=tensor_model_parallel_size, 101 | ep=expert_model_parallel_size, 102 | dp=data_parallel_size, 103 | pp=pipeline_model_parallel_size, 104 | cp=self.args.context_parallel_size, 105 | order='tp-cp-ep-dp-pp', 106 | ) 107 | for ranks in rank_generator.get_ranks('ep', independent_ep=True): 108 | group = torch.distributed.new_group( 109 | ranks 110 | ) 111 | if rank in ranks: 112 | ep_group = group 113 | for ranks in rank_generator.get_ranks('tp'): 114 | group = torch.distributed.new_group( 115 | ranks 116 | ) 117 | if rank in ranks: 118 | tp_group = group 119 | for ranks in rank_generator.get_ranks('pp'): 120 | group = torch.distributed.new_group( 121 | ranks 122 | ) 123 | if rank in ranks: 124 | pp_group = group 125 | pp_global_rank = ranks 126 | # Setup embedding group (to exchange gradients between 127 | # first and last stages). 128 | # if len(ranks) > 1: 129 | # embedding_ranks = [ranks[0], ranks[-1]] 130 | # position_embedding_ranks = [ranks[0]] 131 | # if self.args.pipeline_model_parallel_split_rank is not None: 132 | # if ranks[self.args.pipeline_model_parallel_split_rank] not in embedding_ranks: 133 | # embedding_ranks = [ 134 | # ranks[0], 135 | # ranks[self.args.pipeline_model_parallel_split_rank], 136 | # ranks[-1], 137 | # ] 138 | # if ranks[self.args.pipeline_model_parallel_split_rank] not in position_embedding_ranks: 139 | # position_embedding_ranks = [ranks[0], ranks[self.args.pipeline_model_parallel_split_rank]] 140 | # else: 141 | # embedding_ranks = ranks 142 | # position_embedding_ranks = ranks 143 | 144 | # group = torch.distributed.new_group( 145 | # embedding_ranks 146 | # ) 147 | # if rank in embedding_ranks: 148 | # _EMBEDDING_GROUP = group 149 | # if rank in ranks: 150 | # _EMBEDDING_GLOBAL_RANKS = embedding_ranks 151 | 152 | # group = torch.distributed.new_group( 153 | # position_embedding_ranks, 154 | 155 | # ) 156 | # if rank in position_embedding_ranks: 157 | # _POSITION_EMBEDDING_GROUP = group 158 | # if rank in ranks: 159 | # _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks 160 | for ranks in rank_generator.get_ranks('dp'): 161 | group = torch.distributed.new_group( 162 | ranks 163 | ) 164 | if rank in ranks: 165 | dp_group = group 166 | for ranks in rank_generator.get_ranks('tp-ep', independent_ep=True): 167 | group = torch.distributed.new_group( 168 | ranks 169 | ) 170 | if rank in ranks: 171 | ep_tp_group = group 172 | for ranks in rank_generator.get_ranks('dp', independent_ep=True): 173 | group = torch.distributed.new_group( 174 | ranks 175 | ) 176 | if rank in ranks: 177 | ep_dp_group = group 178 | return { 179 | CommGroup.tp_group: tp_group, 180 | CommGroup.dp_group: dp_group, 181 | CommGroup.pp_group: pp_group, 182 | CommGroup.ep_group: ep_group, 183 | CommGroup.ep_tp_group: ep_tp_group, 184 | CommGroup.ep_dp_group: ep_dp_group, 185 | }, pp_global_rank 186 | 187 | def _get_pipeline_parallel_size(self): 188 | group = self.comm_group_info["pp_group"] 189 | pp_group_size = torch.distributed.get_world_size(group) 190 | return pp_group_size 191 | 192 | def _get_pipeline_parallel_rank(self): 193 | group = self.comm_group_info["pp_group"] 194 | pp_rank = torch.distributed.get_rank(group) 195 | return pp_rank 196 | 197 | def _get_pipeline_prev_rank(self): 198 | rank_in_pipeline = self._get_pipeline_parallel_rank() 199 | world_size = self._get_pipeline_parallel_size() 200 | return self.pp_global_rank_info[(rank_in_pipeline - 1) % world_size] 201 | 202 | def _get_pipeline_next_rank(self): 203 | rank_in_pipeline = self._get_pipeline_parallel_rank() 204 | world_size = self._get_pipeline_parallel_size() 205 | return self.pp_global_rank_info[(rank_in_pipeline + 1) % world_size] 206 | 207 | @bench_logger.log_timing("comm") 208 | def _apply_p2pcommunication(self, item): 209 | ops = [] 210 | tensor = torch.narrow(self.buffer, 0, 0, item.msg_size // 2) 211 | if item.additional == "send_prev": 212 | if self._get_pipeline_parallel_rank() != 0: 213 | send_prev_op = torch.distributed.P2POp( 214 | torch.distributed.isend, tensor, self._get_pipeline_prev_rank() 215 | ) 216 | ops.append(send_prev_op) 217 | else: 218 | pass 219 | if item.additional == "send_next": 220 | if self._get_pipeline_parallel_rank() != self.args.pipeline_model_parallel - 1: 221 | send_next_op = torch.distributed.P2POp( 222 | torch.distributed.isend, tensor, self._get_pipeline_next_rank() 223 | ) 224 | ops.append(send_next_op) 225 | else: 226 | pass 227 | if item.additional == "recv_prev": 228 | if self._get_pipeline_parallel_rank() != 0: 229 | tensor_recv_prev = torch.empty( 230 | item.msg_size // 2, dtype=torch.bfloat16, device=self.device 231 | ) 232 | recv_prev_op = torch.distributed.P2POp( 233 | torch.distributed.irecv, 234 | tensor_recv_prev, 235 | self._get_pipeline_prev_rank(), 236 | ) 237 | ops.append(recv_prev_op) 238 | else: 239 | pass 240 | if item.additional == "recv_next": 241 | if self._get_pipeline_parallel_rank() != self.args.pipeline_model_parallel - 1: 242 | tensor_recv_next = torch.empty( 243 | item.msg_size // 2, dtype=torch.bfloat16, device=self.device 244 | ) 245 | recv_next_op = torch.distributed.P2POp( 246 | torch.distributed.irecv, 247 | tensor_recv_next, 248 | self._get_pipeline_next_rank(), 249 | ) 250 | ops.append(recv_next_op) 251 | else: 252 | pass 253 | if len(ops) > 0: 254 | reqs = torch.distributed.batch_isend_irecv(ops) 255 | for req in reqs: 256 | req.wait() 257 | 258 | torch.cuda.synchronize() 259 | 260 | def _apply_barrier(self, item): 261 | torch.distributed.barrier() 262 | 263 | @bench_logger.log_timing("comm") 264 | def _apply_broadcast(self, item): 265 | tensor = torch.narrow(self.buffer, 0, 0, item.msg_size // 2) 266 | group = self.comm_group_info[item.comm_group] 267 | src = torch.distributed.get_global_rank(group, 0) 268 | return torch.distributed.broadcast( 269 | tensor=tensor, src=src, group=group, async_op=False 270 | ) 271 | 272 | @bench_logger.log_timing("comm") 273 | def _apply_reduce(self, item): 274 | tensor = torch.narrow(self.buffer, 0, 0, item.msg_size // 2) 275 | group = self.comm_group_info[item.comm_group] 276 | dst = item.dst 277 | return torch.distributed.reduce( 278 | tensor=tensor, 279 | dst=dst, 280 | op=torch.distributed.ReduceOp.SUM, 281 | group=group, 282 | async_op=False, 283 | ) 284 | 285 | @bench_logger.log_timing("comm") 286 | def _apply_all_reduce(self, item): 287 | tensor = torch.narrow(self.buffer, 0, 0, item.msg_size // 2) 288 | group = self.comm_group_info[item.comm_group] 289 | return torch.distributed.all_reduce( 290 | tensor=tensor, 291 | op=torch.distributed.ReduceOp.SUM, 292 | group=group, 293 | async_op=False, 294 | ) 295 | 296 | @bench_logger.log_timing("comm") 297 | def _apply_all_gather(self, item): 298 | group = self.comm_group_info[item.comm_group] 299 | num_elements = item.msg_size // 2 300 | padding_size = ( 301 | (group.size() - num_elements % group.size()) 302 | if num_elements % group.size() 303 | else 0 304 | ) 305 | num_elements = num_elements + padding_size 306 | output_tensor = torch.narrow(self.buffer, 0, 0, num_elements) 307 | input_tensor_size = output_tensor.numel() // group.size() 308 | group_rank = torch.distributed.get_group_rank(group, self.rank) 309 | input_tensor = torch.narrow( 310 | output_tensor, 0, group_rank * input_tensor_size, input_tensor_size 311 | ) 312 | return torch.distributed.all_gather_into_tensor( 313 | output_tensor, input_tensor, group=group, async_op=False 314 | ) 315 | @bench_logger.log_timing("comm") 316 | def _overlap(self, item): 317 | item.additional = 'overlap' 318 | 319 | @bench_logger.log_timing("comm") 320 | def _apply_reduce_scatter(self, item): 321 | group = self.comm_group_info[item.comm_group] 322 | num_elements = item.msg_size // 2 323 | padding_size = ( 324 | (group.size() - num_elements % group.size()) 325 | if num_elements % group.size() 326 | else 0 327 | ) 328 | num_elements = num_elements + padding_size 329 | input_tensor = torch.narrow(self.buffer, 0, 0, num_elements) 330 | group = self.comm_group_info[item.comm_group] 331 | output_tensor_size = input_tensor.numel() // group.size() 332 | group_rank = torch.distributed.get_group_rank(group, self.rank) 333 | output_tensor = torch.narrow( 334 | input_tensor, 0, group_rank * output_tensor_size, output_tensor_size 335 | ) 336 | return torch.distributed.reduce_scatter_tensor( 337 | output_tensor, input_tensor, group=group, async_op=False 338 | ) 339 | 340 | @bench_logger.log_timing("comm") 341 | def _apply_all_to_all(self, item): 342 | group = self.comm_group_info[item.comm_group] 343 | num_elements = item.msg_size // 2 344 | input_tensor = torch.narrow(self.buffer, 0, 0, num_elements) 345 | # output_tensor = torch.narrow(self.buffer, 0, 0 , num_elements) 346 | output_tensor = torch.empty( 347 | num_elements * group.size(), 348 | dtype=self.buffer.dtype, 349 | device=self.buffer.device, 350 | ) 351 | return torch.distributed.all_to_all_single( 352 | output_tensor, input_tensor, group=group 353 | ) 354 | 355 | @bench_logger.log_timing("comp") 356 | def _apply_computation(self, item): 357 | if self.skip_computation: 358 | return 359 | if self.computation_aiob: 360 | time.sleep(item._elapsed_time/ 1e9) 361 | else: 362 | # item.msg_size = 1 363 | input_shape1, input_shape2 = item.msg_size 364 | A, B = torch.rand(input_shape1, device=self.device), torch.rand( 365 | input_shape2, device=self.device 366 | ) 367 | torch.matmul(A, B) 368 | return 369 | 370 | def apply_workload(self): 371 | torch.cuda.synchronize(self.device) 372 | start = time.perf_counter() 373 | key = "backward" 374 | for item in self.workload.workload: 375 | if ( 376 | self.computation_aiob 377 | and item.comm_type == CommType.all_reduce 378 | and key in item.stage 379 | ): 380 | comm_func = self.comm_type_function[item.comm_type] 381 | # comm_func = self._overlap() 382 | # comm_func(item) 383 | else: 384 | comm_func = self.comm_type_function[item.comm_type] 385 | comm_func(item) 386 | torch.cuda.synchronize(self.device) 387 | end = time.perf_counter() 388 | return end - start 389 | 390 | 391 | if __name__ == "__main__": 392 | filename = "results/model_workload/local_deepspeed_stage3.csv" 393 | applyer = WorkloadApplyer(filename=filename) 394 | applyer.apply_workload() 395 | # timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") 396 | if torch.distributed.get_rank() == 0: 397 | bench_logger.analyze_comm_log(bench_logger.comm_log) 398 | -------------------------------------------------------------------------------- /workload_generator/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /workload_generator/analysis_pytorch_trace.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import json 15 | from utils.utils import CommGroup, CommType 16 | from log_analyzer.log import LogItem 17 | from workload_generator.mocked_model.MockedModel import MockedModel 18 | from workload_generator.workload_generator import WorkloadGenerator 19 | 20 | comm_node = {} 21 | workload = [] 22 | 23 | 24 | class Pytorch_trace_analyer(WorkloadGenerator): 25 | def __init__(self, args, model, filename): 26 | super().__init__(args, model) 27 | self.name = "pytorch_trace" 28 | self.filename = filename 29 | 30 | def init(self): 31 | pass 32 | 33 | def string2comm_type(self, s): 34 | if "all_gather" in s or "_all_gather_base" in s or "_allgather_base" in s: 35 | return CommType.all_gather 36 | if "reduce_scatter" in s or "_reduce_scatter_base" in s: 37 | return CommType.reduce_scatter 38 | if "all_reduce" in s: 39 | return CommType.all_reduce 40 | if "broadcast" in s: 41 | return CommType.broadcast 42 | if "barrier" in s: 43 | return CommType.barrier 44 | if "reduce" in s: 45 | return CommType.reduce 46 | else: 47 | print(f"can not convert {s} to any comm type") 48 | exit(0) 49 | 50 | def step(self): 51 | item = LogItem() 52 | with open(self.filename) as f: 53 | data = json.load(f) 54 | 55 | nodes_list = data["nodes"] 56 | for node in nodes_list: 57 | if node["name"].startswith("nccl:"): 58 | name = node["name"].split(":")[1] 59 | comm_type = self.string2comm_type(name) 60 | item.comm_type = comm_type 61 | # TODO: set group in default dp_group , get group in trace info 62 | item.comm_group = CommGroup.dp_group 63 | input = node.get("inputs") 64 | item.msg_size = input[0][3] 65 | item.item_size = input[0][4] 66 | self.workload.append(item) 67 | 68 | 69 | # if __name__ == "__main__": 70 | # a = parse_pytorch_trace_log("llama7b_zero3_rank8.json") 71 | # json_parse("test_json.json") 72 | # panda_parse("test_json.json") 73 | -------------------------------------------------------------------------------- /workload_generator/generate_collective_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | from utils.utils import CommType, CommGroup, get_params, WorkloadWriter 15 | from log_analyzer.log import LogItem, Workload 16 | from workload_generator.workload_generator import WorkloadGenerator 17 | 18 | 19 | class Collective_Test(WorkloadGenerator): 20 | def __init__(self, args, model): 21 | super().__init__(args, model) 22 | self.args = args 23 | self.name = "collective_test" 24 | 25 | def init(self): 26 | iter_num = self.args.iter_num 27 | for i in range(iter_num): 28 | # for warmup 29 | self.workload.append( 30 | LogItem( 31 | comm_type=CommType.get_comm_type(self.args.test_comm), 32 | comm_group=CommGroup.dp_group, 33 | comm_group_size=self.args.dp_num, 34 | msg_size=self.args.begin_size, 35 | stage="warmup", 36 | ) 37 | ) 38 | 39 | def step(self): 40 | test_comm = CommType.get_comm_type(self.args.test_comm) 41 | begin_size = self.args.begin_size 42 | end_size = self.args.end_size 43 | curr_size = begin_size 44 | iter_num = self.args.iter_num 45 | multi_all_reduce_enable = self.args.multi_all_reduce_enable 46 | 47 | while curr_size <= end_size: 48 | # self.workload.append(LogItem(comm_type=CommType.epoch_end)) 49 | if not multi_all_reduce_enable: 50 | for i in range(iter_num): 51 | self.workload.append( 52 | LogItem( 53 | comm_type=test_comm, 54 | comm_group=CommGroup.dp_group, 55 | comm_group_size=self.args.dp_num, 56 | msg_size=curr_size, 57 | stage="test_step", 58 | ) 59 | ) 60 | curr_size *= 2 61 | else: 62 | for i in range(iter_num): 63 | self.workload.append( 64 | LogItem( 65 | comm_type=test_comm, 66 | comm_group=CommGroup.pp_group, 67 | comm_group_size=self.args.pipeline_model_parallel, 68 | msg_size=curr_size, 69 | stage="test_step", 70 | ) 71 | ) 72 | curr_size *= 2 73 | 74 | 75 | if __name__ == "__main__": 76 | args = get_params() 77 | workload_generator = Collective_Test(args, None) 78 | workload = workload_generator() 79 | filename = "multi_all_reduce.csv" 80 | workload.dump(filename) 81 | -------------------------------------------------------------------------------- /workload_generator/generate_deepspeed_stage1_2_workload.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | """example of running zero1/2 on llama-13B 15 | python -m workload_generator.deepspeed_stage1_workload \ 16 | --stage=1 --world_size=624 --global_batch=624 \ 17 | --num_layers=40 --epoch_num=2 --hidden_size=5120 --ffn_hidden_size=13696 \ 18 | --reduce_bucket_size=2000000000 --allgather_bucket_size=2000000000 19 | 20 | python -m workload_generator.deepspeed_stage1_2_workload \ 21 | --stage=2 --world_size=256 --global_batch=1024 --num_attention_heads=40 \ 22 | --num_layers=40 --epoch_num=100 --hidden_size=5120 --ffn_hidden_size=13824 \ 23 | --reduce_bucket_size=26214400 --allgather_bucket_size=500000000 --contiguous_gradients 24 | """ 25 | 26 | import math 27 | from workload_generator.mocked_model.MockedDeepspeed import DeepspeedForCausalLM 28 | from workload_generator.mocked_model.MockedModel import MockedModel 29 | from workload_generator.workload_generator import WorkloadGenerator 30 | from utils.utils import CommGroup, CommType, get_params, WorkloadWriter 31 | from log_analyzer.log import LogItem 32 | 33 | 34 | class DeepSpeedStage1(WorkloadGenerator): 35 | """workload generator for deepspeed engine setup 36 | mock comm behavior of DeepSpeedEngine.__init__ 37 | """ 38 | 39 | def __init__(self, args, model) -> None: 40 | super().__init__(args, model) 41 | self.name = "deepspeed_stage1" 42 | self.compute_enable = args.computation_enable 43 | self.batch_size = args.micro_batch 44 | self.seq_len = args.seq_length 45 | self.reduce_bucket, self.num_in_reduce_bucket, self.max_reduce_bucket_size = ( 46 | [], 47 | 0, 48 | args.reduce_bucket_size, 49 | ) 50 | self.allgather_bucket_size = args.allgather_bucket_size 51 | self.amp_enabled = args.amp_enabled 52 | self.dp_world_size = args.dp_num 53 | self.elem_size = 2 54 | self.all_params = list(self.model.parameters()) 55 | 56 | def init(self): 57 | if not self.amp_enabled: 58 | for param in self.model.parameters(): 59 | self.workload.append( 60 | LogItem( 61 | comm_type=CommType.broadcast, 62 | comm_group=CommGroup.dp_group, 63 | comm_group_size=self.dp_world_size, 64 | msg_size=param.msg_size(), 65 | stage="init", 66 | ) 67 | ) 68 | 69 | self.workload.append( 70 | LogItem( 71 | comm_type=CommType.barrier, 72 | comm_group=CommGroup.all, 73 | comm_group_size=self.dp_world_size, 74 | msg_size=param.msg_size(), 75 | stage="init.__init__", 76 | ) 77 | ) 78 | 79 | def forward(self): 80 | if self.compute_enable: 81 | self.all_params = list(self.model.parameters()) 82 | for param in self.all_params: 83 | if param.get_shape()[-1] != 1: 84 | self.workload.append( 85 | LogItem( 86 | comm_type=CommType.computation, 87 | msg_size=( 88 | (self.batch_size, self.seq_len, param.get_shape()[0]), 89 | (param.get_shape()[0], param.get_shape()[1]), 90 | ), 91 | stage="forward.computation", 92 | ) 93 | ) 94 | 95 | def _reduce_ipg_grads(self): 96 | self.workload.append( 97 | LogItem( 98 | comm_type=CommType.all_reduce, 99 | comm_group=CommGroup.dp_group, 100 | comm_group_size=self.dp_world_size, 101 | msg_size=self.num_in_reduce_bucket * self.elem_size, 102 | stage=f"{self.current_op}.allreduce_bucket", 103 | ) 104 | ) 105 | self.reduce_bucket, self.num_in_reduce_bucket = [], 0 106 | 107 | def backward(self): 108 | self.current_op = "backward" 109 | for param in self.all_params[::-1]: 110 | if param.numel() + self.num_in_reduce_bucket > self.max_reduce_bucket_size: 111 | self._reduce_ipg_grads() 112 | self.reduce_bucket.append(param) 113 | self.num_in_reduce_bucket += param.numel() 114 | if self.compute_enable: 115 | if param.get_shape()[-1] != 1: 116 | self.workload.append( 117 | LogItem( 118 | comm_type=CommType.computation, 119 | msg_size=( 120 | (self.batch_size, self.seq_len, param.get_shape()[0]), 121 | (param.get_shape()[0], param.get_shape()[1]), 122 | ), 123 | stage=f"{self.current_op}.computation", 124 | ) 125 | ) 126 | self.workload.append( 127 | LogItem( 128 | comm_type=CommType.computation, 129 | msg_size=( 130 | (param.get_shape()[0], self.batch_size * self.seq_len), 131 | (self.batch_size * self.seq_len, param.get_shape()[1]), 132 | ), 133 | stage=f"{self.current_op}.computation", 134 | ) 135 | ) 136 | 137 | def step(self): 138 | self.current_op = "step" 139 | self._reduce_ipg_grads() 140 | 141 | self.workload.append( 142 | LogItem( 143 | comm_type=CommType.all_reduce, 144 | comm_group=CommGroup.dp_group, 145 | comm_group_size=self.dp_world_size, 146 | msg_size=1, 147 | stage=f"{self.current_op}.has_overflow", 148 | ) 149 | ) 150 | num_params = sum([param.numel() for param in self.model.parameters()]) 151 | num_shards = max(num_params // self.allgather_bucket_size, 1) 152 | shard_size = num_params // num_shards 153 | 154 | for i in range(num_shards): 155 | num_elements = ( 156 | num_params - i * shard_size if i == (num_shards - 1) else shard_size 157 | ) 158 | padding_size = ( 159 | (self.dp_world_size - num_elements % self.dp_world_size) 160 | if num_elements % self.dp_world_size 161 | else 0 162 | ) 163 | num_elements = num_elements + padding_size 164 | self.workload.append( 165 | LogItem( 166 | comm_type=CommType.all_gather, 167 | comm_group=CommGroup.dp_group, 168 | comm_group_size=self.dp_world_size, 169 | msg_size=num_elements * self.elem_size, 170 | stage=f"{self.current_op}.all_gather_dp_groups", 171 | ) 172 | ) 173 | 174 | 175 | class DeepSpeedStage2(DeepSpeedStage1): 176 | def __init__(self, args, model) -> None: 177 | super().__init__(args, model) 178 | self.name = "deepspeed_stage2" 179 | 180 | self.param_range_map = self.build_model_gbuf_param_range_map( 181 | model, self.dp_world_size 182 | ) 183 | 184 | def build_model_gbuf_param_range_map(self, model: MockedModel, dp_world_size: int): 185 | gbuf_size = sum([param.numel() for param in model.parameters()]) 186 | 187 | gbuf_partition_size = int(math.ceil(gbuf_size / dp_world_size)) 188 | gbuf_world_all_ranges = [] 189 | for r in range(dp_world_size): 190 | gbuf_world_start = r * gbuf_partition_size 191 | gbuf_world_end = min(gbuf_size, gbuf_world_start + gbuf_partition_size) 192 | gbuf_world_all_ranges.append((gbuf_world_start, gbuf_world_end)) 193 | 194 | start_idx, r = 0, 0 195 | gbuf_world_start, gbuf_world_end = gbuf_world_all_ranges[r] 196 | # record each param should be reduced to which rank(s) 197 | # param_id: int -> List[(rank: int, param_start_idx: int, param_end_idx: int)] 198 | param_range_map = {} 199 | for param in self.all_params: 200 | # current param in [start_idx, end_idx) range of gbuf 201 | param_id = id(param) 202 | param_range_map[param_id] = [] 203 | end_idx = start_idx + param.numel() 204 | 205 | # current rank is in change of [gbuf_world_start, gbuf_world_end) of gbuf 206 | param_start_idx = start_idx 207 | # if current rank cannot fully cover this param, move to next rank 208 | while gbuf_world_end < end_idx: 209 | param_range_map[param_id].append((r, param_start_idx, gbuf_world_end)) 210 | param_start_idx = gbuf_world_end 211 | r += 1 212 | gbuf_world_start, gbuf_world_end = gbuf_world_all_ranges[r] 213 | param_range_map[param_id].append((r, param_start_idx, end_idx)) 214 | 215 | # for next param 216 | start_idx = end_idx 217 | return param_range_map 218 | 219 | def _reduce_ipg_grads(self): 220 | if not self.args.contiguous_gradients: 221 | super()._reduce_ipg_grads() 222 | return 223 | 224 | rank_start_end_idx = [[-1, -1, -1]] 225 | for param in self.reduce_bucket[::-1]: 226 | for rank, start_idx, end_idx in self.param_range_map[id(param)]: 227 | if rank == rank_start_end_idx[-1][0]: 228 | if rank_start_end_idx[-1][-1] != start_idx: 229 | print(f"WARNNING {rank_start_end_idx[-1]} - {start_idx}") 230 | rank_start_end_idx[-1][-1] = end_idx 231 | else: 232 | rank_start_end_idx.append([rank, start_idx, end_idx]) 233 | 234 | for rank, start_idx, end_idx in rank_start_end_idx[1:]: 235 | self.workload.append( 236 | LogItem( 237 | comm_type=CommType.reduce, 238 | comm_group=CommGroup.dp_group, 239 | msg_size=(end_idx - start_idx) * self.elem_size, 240 | comm_group_size=self.dp_world_size, 241 | stage=f"{self.current_op}.average_tensor", 242 | dst=rank, 243 | ) 244 | ) 245 | self.reduce_bucket, self.num_in_reduce_bucket = [], 0 246 | 247 | 248 | if __name__ == "__main__": 249 | args = get_params() 250 | print(args.__dict__) 251 | model = DeepspeedForCausalLM(args) 252 | if args.stage == 1: 253 | workload_generator = DeepSpeedStage1(args, model) 254 | filename = f"{workload_generator.name}_{args.model_name}_sp_{args.enable_sequence_parallel}_iteration_{args.epoch_num}_computationEnable_{args.computation_enable}_{args.world_size}n.csv" 255 | else: 256 | workload_generator = DeepSpeedStage2(args, model) 257 | filename = f"{workload_generator.name}_{args.model_name}_sp_{args.enable_sequence_parallel}_iteration_{args.epoch_num}_computationEnable_{args.computation_enable}_{args.world_size}n.csv" 258 | workload = workload_generator() 259 | workload.dump(filename) 260 | if args.enable_visual: 261 | try: 262 | from visualize.generate import visualize_output 263 | base_name = filename.split(".")[0] 264 | visualize_output(f"./results/mocked_workload/{base_name}_workload.csv",True) 265 | except ImportError: 266 | print("visualize_output is not available because required library is not found") 267 | # WorkloadWriter.write_workload(workload, args, filename) 268 | -------------------------------------------------------------------------------- /workload_generator/generate_deepspeed_stage3_workload.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | """ 15 | python -m workload_generator.generate_deepspeed_stage3_workload \ 16 | --stage=3 --world_size=256 --global_batch=1024 --num_attention_heads=40 \ 17 | --num_layers=40 --epoch_num=100 --hidden_size=5120 --ffn_hidden_size=13824 \ 18 | --reduce_bucket_size=26214400 --allgather_bucket_size=500000000 --contiguous_gradients 19 | """ 20 | 21 | from workload_generator.mocked_model.MockedDeepspeed import DeepspeedForCausalLM 22 | from workload_generator.mocked_model.MockedModel import MockedModel 23 | from workload_generator.workload_generator import WorkloadGenerator 24 | from utils.utils import CommGroup, CommType, get_params, WorkloadWriter 25 | from collections import deque, defaultdict 26 | from log_analyzer.log import LogItem 27 | 28 | 29 | class DeepSpeedStage3(WorkloadGenerator): 30 | """workload generator for deepspeed engine setup 31 | mock comm behavior of DeepSpeedEngine.__init__ 32 | """ 33 | 34 | def __init__(self, args, model) -> None: 35 | super().__init__(args, model) 36 | self.name = "deepspeed_stage3" 37 | self.amp_enabled = args.amp_enabled 38 | self.dp_world_size = args.dp_num 39 | self.batch_size = args.micro_batch 40 | self.seq_len = args.seq_length 41 | self.compute_enable = args.computation_enable 42 | self.reduce_bucket, self.reduce_bucket_size = 0, args.reduce_bucket_size 43 | self.prefetch_bucket_size = args.prefetch_bucket_size 44 | self.max_live_parameters, self.current_live_parameters = ( 45 | args.max_live_parameters, 46 | 0, 47 | ) 48 | self.stage, self._param_queue, self.all_params = ( 49 | "init", 50 | deque(), 51 | list(self.model.parameters()), 52 | ) 53 | self.__param_order = [ 54 | (param, step_id) 55 | for step_id, param in enumerate(self.all_params + self.all_params[::-1]) 56 | ] 57 | self.__most_recent_step_id_param_fetched_for = defaultdict(lambda: -1) 58 | self._mark_persistent_parameters( 59 | args.param_persistence_threshold, args.model_persistence_threshold 60 | ) 61 | 62 | def _mark_persistent_parameters(self, param_threshold, model_threshold): 63 | self.persistent_params = [] 64 | total_persistent_parameters = 0 65 | count = 0 66 | for param in self.model.parameters(): 67 | param.id = count # is also the step id 68 | count += 1 69 | param.ds_persist = False 70 | param.has_been_allgather = False 71 | if param.numel() + total_persistent_parameters > model_threshold: 72 | continue 73 | if param.numel() <= param_threshold: 74 | param.ds_persist = True 75 | self.persistent_params.append(param) 76 | total_persistent_parameters += param.numel() 77 | 78 | def init(self): 79 | if not self.amp_enabled: 80 | for param in self.model.parameters(): 81 | self.workload.append( 82 | LogItem( 83 | comm_type=CommType.broadcast, 84 | comm_group=CommGroup.dp_group, 85 | comm_group_size=self.dp_world_size, 86 | msg_size=param.msg_size(), 87 | stage="init._broadcast_model", 88 | src=0, 89 | ) 90 | ) 91 | 92 | self.workload.append( 93 | LogItem( 94 | comm_type=CommType.barrier, 95 | comm_group=CommGroup.all, 96 | comm_group_size=self.dp_world_size, 97 | msg_size=0, 98 | stage="init._create_fp16_partitions_with_defragmentation", 99 | ) 100 | ) 101 | 102 | for _ in range(2): 103 | self.workload.append( 104 | LogItem( 105 | comm_type=CommType.barrier, 106 | comm_group=CommGroup.all, 107 | comm_group_size=self.dp_world_size, 108 | msg_size=0, 109 | stage="init._setup_for_real_optimizer", 110 | ) 111 | ) 112 | 113 | for param in self.model.parameters(): 114 | self.workload.append( 115 | LogItem( 116 | comm_type=CommType.all_gather, 117 | comm_group=CommGroup.dp_group, 118 | comm_group_size=self.dp_world_size, 119 | msg_size=param.msg_size(), 120 | stage="init._allgather_params", 121 | ) 122 | ) 123 | 124 | def _compute_for_param(self, param): 125 | if self.stage == "forward": 126 | if param.get_shape()[-1] != 1: 127 | self.workload.append( 128 | LogItem( 129 | comm_type=CommType.computation, 130 | msg_size=( 131 | (self.batch_size, self.seq_len, param.get_shape()[0]), 132 | (param.get_shape()[0], param.get_shape()[1]), 133 | ), 134 | stage=f"{self.stage}.computation", 135 | ) 136 | ) 137 | if self.stage == "backward": 138 | # input grad 139 | if param.get_shape()[-1] != 1: 140 | self.workload.append( 141 | LogItem( 142 | comm_type=CommType.computation, 143 | msg_size=( 144 | (self.batch_size, self.seq_len, param.get_shape()[0]), 145 | (param.get_shape()[0], param.get_shape()[1]), 146 | ), 147 | stage=f"{self.stage}.computation", 148 | ) 149 | ) 150 | 151 | # weight grad 152 | self.workload.append( 153 | LogItem( 154 | comm_type=CommType.computation, 155 | msg_size=( 156 | (param.get_shape()[0], self.batch_size * self.seq_len), 157 | (self.batch_size * self.seq_len, param.get_shape()[1]), 158 | ), 159 | ) 160 | ) 161 | 162 | def _gather_param_directly(self, param): 163 | if not param.has_been_allgather: 164 | self.workload.append( 165 | LogItem( 166 | comm_type=CommType.all_gather, 167 | comm_group=CommGroup.dp_group, 168 | comm_group_size=self.dp_world_size, 169 | msg_size=param.msg_size(), 170 | stage=f"{self.stage}.allgather_fn", 171 | ) 172 | ) 173 | param.has_been_allgather = True 174 | self.current_live_parameters += param.numel() 175 | if self.compute_enable: 176 | self._compute_for_param(param) 177 | 178 | def _gather_param_prefetch(self, param, step_id): 179 | prefetch_bucket, prefetch_bucket_size = [], 0 180 | if not param.has_been_allgather: 181 | prefetch_bucket.append(param) 182 | prefetch_bucket_size += param.numel() 183 | future_param, future_step_id = self._param_queue.popleft() 184 | if future_param != param: 185 | print( 186 | f"WARNING: expected {param.__dict__, step_id} but got {future_param.__dict__, future_step_id}" 187 | ) 188 | param.has_been_allgather = True 189 | self.current_live_parameters += param.numel() 190 | 191 | while ( 192 | self._param_queue 193 | and prefetch_bucket_size < self.prefetch_bucket_size 194 | and self.current_live_parameters < self.max_live_parameters 195 | ): 196 | future_param, step_id = self._param_queue.popleft() 197 | self.__most_recent_step_id_param_fetched_for[future_param.id] = max( 198 | step_id, self.__most_recent_step_id_param_fetched_for[future_param.id] 199 | ) 200 | if future_param.has_been_allgather: 201 | continue 202 | prefetch_bucket.append(future_param) 203 | future_param.has_been_allgather = True 204 | self.current_live_parameters += future_param.numel() 205 | prefetch_bucket_size += future_param.numel() 206 | 207 | if prefetch_bucket: 208 | self.workload.append( 209 | LogItem( 210 | comm_type=CommType.all_gather, 211 | comm_group=CommGroup.dp_group, 212 | comm_group_size=self.dp_world_size, 213 | msg_size=sum(param.msg_size() for param in prefetch_bucket), 214 | stage=f"{self.stage}.allgather_fn", 215 | ) 216 | ) 217 | if self.compute_enable: 218 | for param in prefetch_bucket: 219 | self._compute_for_param(param) 220 | 221 | def _partition_param(self, param, step_id): 222 | if len(self._param_queue) == 0: 223 | # 这里会错误的释放一些ds_persist的参数,但是不影响整体的模拟 224 | param.has_been_allgather = False 225 | self.current_live_parameters -= param.numel() 226 | return 227 | if param.ds_persist: 228 | return 229 | # 这里说明之后马上还会用到这个param 230 | if self.__most_recent_step_id_param_fetched_for[param.id] > step_id: 231 | return 232 | param.has_been_allgather = False 233 | self.current_live_parameters -= param.numel() 234 | 235 | def forward(self): 236 | self.stage = "forward" 237 | for i, param in enumerate(self.all_params): 238 | if len(self._param_queue) == 0: 239 | self._gather_param_directly(param) 240 | else: 241 | self._gather_param_prefetch(param, i) 242 | self._partition_param(param, i) 243 | 244 | def _reduce_param_with_bucket(self, param): 245 | if param.numel() + self.reduce_bucket > self.reduce_bucket_size: 246 | self.workload.append( 247 | LogItem( 248 | comm_type=CommType.reduce_scatter, 249 | comm_group=CommGroup.dp_group, 250 | comm_group_size=self.dp_world_size, 251 | msg_size=self.reduce_bucket * param.elem_size(), 252 | stage=f"{self.stage}.reduce_scatter_fn", 253 | ) 254 | ) 255 | self.reduce_bucket = param.numel() 256 | else: 257 | self.reduce_bucket += param.numel() 258 | 259 | def backward(self): 260 | self.stage = "backward" 261 | for i, param in enumerate(self.all_params[::-1]): 262 | if len(self._param_queue) == 0: 263 | self._gather_param_directly(param) 264 | else: 265 | self._gather_param_prefetch(param, i) 266 | self._partition_param(param, i + len(self.all_params)) 267 | self._reduce_param_with_bucket(param) 268 | self._param_queue = deque(self.__param_order) 269 | self.__most_recent_step_id_param_fetched_for = defaultdict(lambda: -1) 270 | 271 | def step(self): 272 | self.stage = "step" 273 | self.workload.append( 274 | LogItem( 275 | comm_type=CommType.reduce_scatter, 276 | comm_group=CommGroup.dp_group, 277 | comm_group_size=self.dp_world_size, 278 | msg_size=self.reduce_bucket * 2, 279 | stage=f"{self.stage}.reduce_scatter_fn", 280 | ) 281 | ) 282 | self.reduce_bucket = 0 283 | 284 | self.workload.append( 285 | LogItem( 286 | comm_type=CommType.all_reduce, 287 | comm_group=CommGroup.dp_group, 288 | comm_group_size=self.dp_world_size, 289 | msg_size=1, 290 | stage=f"{self.stage}.has_overflow", 291 | ) 292 | ) 293 | self.workload.append( 294 | LogItem( 295 | comm_type=CommType.all_reduce, 296 | comm_group=CommGroup.dp_group, 297 | comm_group_size=self.dp_world_size, 298 | msg_size=8, 299 | stage=f"{self.stage}.get_grad_norm_direct", 300 | ) 301 | ) 302 | 303 | for param in self.model.parameters(): 304 | param.has_been_allgather = False 305 | self.current_live_parameters = 0 306 | 307 | for param in self.persistent_params: 308 | self._gather_param_directly(param) 309 | 310 | 311 | if __name__ == "__main__": 312 | args = get_params() 313 | model = DeepspeedForCausalLM(args) 314 | workload_generator = DeepSpeedStage3(args, model) 315 | workload = workload_generator() 316 | filename = f"{workload_generator.name}_{args.model_name}_sp_{args.enable_sequence_parallel}_iteration_{args.epoch_num}_computationEnable_{args.computation_enable}_{args.world_size}n.csv" 317 | workload.dump(filename) 318 | if args.enable_visual: 319 | try: 320 | from visualize.generate import visualize_output 321 | base_name = filename.split(".")[0] 322 | visualize_output(f"./results/mocked_workload/{base_name}_workload.csv",True) 323 | except ImportError: 324 | print("visualize_output is not available because required library is not found") 325 | # WorkloadWriter.write_workload(workload, args, filename) 326 | -------------------------------------------------------------------------------- /workload_generator/generate_ds_trace_replay_workload.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import re 15 | from utils.utils import CommType, CommGroup, WorkloadWriter, get_params 16 | 17 | 18 | class TraceParser: 19 | def __init__(self, input_file): 20 | self.input_file = input_file 21 | self.comm_workload = [] 22 | 23 | def prase_trace(self): 24 | 25 | pattern = r"comm op: (.*?) \|.*?msg size: (.*?) \|.*?algbw \(Gbps\): (.*?) " 26 | 27 | with open(self.input_file, "r") as file: 28 | for line in file: 29 | match = re.search(pattern, line) 30 | if match: 31 | op = match.group(1) 32 | comm_op = CommType.get_comm_type(op) 33 | msg_size = match.group(2) 34 | self.comm_workload.append( 35 | { 36 | "operation": "trace", 37 | "comm_type": comm_op, 38 | "msg_size": msg_size, 39 | "comm_group": CommGroup.dp_group, 40 | "bw": match.group(3), 41 | } 42 | ) 43 | 44 | def get_trace_workload(self): 45 | return self.comm_workload 46 | 47 | 48 | if __name__ == "__main__": 49 | args = get_params() 50 | output_file = "model_workload/deepspeed_trace.csv" 51 | paser = TraceParser( 52 | "llama-7b-ga8-seq2048-bs3_dlcfw77d07c87pho-master-0_2023-07-05 21_48_38.txt" 53 | ) 54 | paser.prase_trace() 55 | workload = paser.get_trace_workload() 56 | WorkloadWriter().write_workload(workload, args, output_file) 57 | -------------------------------------------------------------------------------- /workload_generator/mocked_model/MockedDeepspeed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | from typing import List 15 | from workload_generator.mocked_model.MockedModel import MockedModel, Linear 16 | 17 | 18 | class DeepspeedMLP(MockedModel): 19 | def __init__(self, hidden_size: int, ffn_hidden_size: int): 20 | self.gate_proj = Linear(hidden_size, ffn_hidden_size) 21 | self.down_proj = Linear(ffn_hidden_size, hidden_size) 22 | self.up_proj = Linear(hidden_size, ffn_hidden_size) 23 | 24 | 25 | class DeepspeedAttention(MockedModel): 26 | def __init__(self, config): 27 | self.hidden_size = config.hidden_size 28 | self.num_heads = config.num_attention_heads 29 | self.head_dim = self.hidden_size // self.num_heads 30 | self.max_position_embeddings = config.max_position_embeddings 31 | self.q_proj = Linear(self.hidden_size, self.num_heads * self.head_dim) 32 | self.k_proj = Linear(self.hidden_size, self.num_heads * self.head_dim) 33 | self.v_proj = Linear(self.hidden_size, self.num_heads * self.head_dim) 34 | self.o_proj = Linear(self.num_heads * self.head_dim, self.hidden_size) 35 | # self.rotary_emb = Linear(self.head_dim, self.max_position_embeddings) 36 | 37 | 38 | class DeepspeedDecoderLayer(MockedModel): 39 | def __init__(self, config): 40 | self.input_layernorm = Linear(config.hidden_size, 1) 41 | self.self_attn = DeepspeedAttention(config=config) 42 | self.post_attention_layernorm = Linear(config.hidden_size, 1) 43 | self.mlp = DeepspeedMLP(config.hidden_size, config.ffn_hidden_size) 44 | 45 | 46 | class DeepspeedModel(MockedModel): 47 | def __init__(self, config): 48 | self.embed_tokens = Linear(config.vocab_size, config.hidden_size) 49 | self.layers = [DeepspeedDecoderLayer(config) for _ in range(config.num_layers)] 50 | self.norm = Linear(config.hidden_size, 1) 51 | 52 | 53 | class DeepspeedForCausalLM(MockedModel): 54 | def __init__(self, config): 55 | super().__init__() 56 | self.model = DeepspeedModel(config) 57 | self.lm_head = Linear(config.hidden_size, config.vocab_size) 58 | -------------------------------------------------------------------------------- /workload_generator/mocked_model/MockedModel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import math 15 | from typing import List, Tuple 16 | 17 | 18 | class MockedParam: 19 | def __init__(self, shape: Tuple, elem_size=2, name=None) -> None: 20 | self.shape = shape 21 | self._numel = math.prod(shape) 22 | self._elem_size = elem_size 23 | self.name = name if name is not None else "Unknown" 24 | 25 | def numel(self): 26 | return self._numel 27 | 28 | def elem_size(self): 29 | return self._elem_size 30 | 31 | def msg_size(self): 32 | return self._numel * self._elem_size 33 | 34 | def get_shape(self): 35 | return self.shape 36 | 37 | # def name(self): 38 | # return self.param_name 39 | 40 | 41 | def _unpack_params(value: object) -> List[MockedParam]: 42 | if isinstance(value, MockedParam): 43 | return [value] 44 | elif isinstance(value, MockedModel): 45 | return value.parameters() 46 | elif isinstance(value, dict): 47 | params = [] 48 | for k, v in value.items(): 49 | params += _unpack_params(v) 50 | return params 51 | elif isinstance(value, (list, tuple)): 52 | params = [] 53 | for v in value: 54 | params += _unpack_params(v) 55 | return params 56 | else: 57 | return [] 58 | 59 | 60 | def _child_modules(value: object) -> List["MockedModel"]: 61 | if isinstance(value, MockedModel): 62 | modules = [value] 63 | modules.extend(_child_modules(value.__dict__)) 64 | return modules 65 | elif isinstance(value, dict): 66 | modules = [] 67 | for k, v in value.items(): 68 | modules += _child_modules(v) 69 | return modules 70 | elif isinstance(value, (list, tuple)): 71 | modules = [] 72 | for v in value: 73 | modules += _child_modules(v) 74 | return modules 75 | else: 76 | return [] 77 | 78 | 79 | class MockedModel: 80 | def __init__(self) -> None: 81 | self._pre_forward_hook = [] 82 | self._post_forward_hook = [] 83 | self._pre_backward_hook = [] 84 | self._post_backward_hook = [] 85 | 86 | def parameters(self) -> List[MockedParam]: 87 | return _unpack_params(self.__dict__) 88 | 89 | def child_modules(self) -> List["MockedModel"]: 90 | return _child_modules(self.__dict__) 91 | 92 | def register_forward_pre_hook(self, fn): 93 | self._pre_forward_hook.append(fn) 94 | 95 | def register_backward_pre_hook(self, fn): 96 | self._pre_backward_hook.append(fn) 97 | 98 | def register_forward_post_hook(self, fn): 99 | self._post_forward_hook.append(fn) 100 | 101 | def register_backward_post_hook(self, fn): 102 | self._post_backward_hook.append(fn) 103 | 104 | 105 | class Linear(MockedModel): # alias for LlamaRMSNorm, Embedding, LlamaRotaryEmbedding 106 | def __init__(self, in_feature, out_feature): 107 | self.weight = MockedParam((in_feature, out_feature)) 108 | -------------------------------------------------------------------------------- /workload_generator/mocked_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/aicb/9ac2f0ecc233996e8da3f34b2ae4c15f81543134/workload_generator/mocked_model/__init__.py -------------------------------------------------------------------------------- /workload_generator/workload_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, Alibaba Group; 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | from workload_generator.mocked_model.MockedModel import MockedModel 15 | from utils.utils import CommGroup, CommType 16 | from log_analyzer.log import Workload, LogItem 17 | 18 | 19 | class WorkloadGenerator: 20 | # generator = WorkloadGenerator 21 | def __init__(self, args, model: MockedModel) -> None: 22 | self.name = "workload_generator" 23 | self.args = args 24 | self.model = model 25 | self.workload = Workload() 26 | self.epoch = 0 27 | 28 | def __call__(self): 29 | args = self.args 30 | self.workload = Workload() 31 | self.init() 32 | self.workload.append(LogItem(comm_type=CommType.epoch_end)) 33 | for i in range(args.epoch_num): 34 | if args.pipeline_model_parallel > 1 and args.frame != "collective_test": 35 | self.with_pipeline_forward_backward() 36 | self.step() 37 | else: 38 | for _ in range(args.num_microbatches): 39 | self.forward() 40 | self.backward() 41 | self.step() 42 | self.workload.append(LogItem(comm_type=CommType.epoch_end)) 43 | return self.workload 44 | 45 | def forward(self): 46 | pass 47 | 48 | def backward(self): 49 | pass 50 | 51 | def step(self): 52 | pass 53 | 54 | def with_pipeline_forward_backward(self): 55 | pass 56 | --------------------------------------------------------------------------------