├── .gitignore
├── README.md
├── __init__.py
├── cli_benchmark.py
├── cli_perf_visual.py
├── cli_structure_analyzer.py
├── figures
├── Qwen3-32B_a100-sxm-80gb_flops_vs_seq_len.png
├── Qwen3-32B_a100-sxm-80gb_interactive.html
├── Qwen3-32B_a100-sxm-80gb_latency_vs_seq_len.png
├── Qwen3-32B_a100-sxm-80gb_memory_vs_seq_len.png
├── Qwen3-32B_a100-sxm-80gb_overview.png
├── Qwen3-32B_a100-sxm-80gb_throughput_vs_seq_len.png
├── grpah_decode_llama2-70B_tp4_bs16_seqlen1024_genlen128.png
├── grpah_prefill_llama2-70B_tp4_bs16_seqlen1024_genlen128.png
└── roofline_analysis_optimized.png
├── images
├── flops_decode_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
├── flops_prefill_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
├── grpah_decode_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
├── grpah_prefill_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
├── latency_decode_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
├── latency_prefill_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
└── params_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
├── llm_counts
├── benchmark_analyzer.py
├── configs
│ ├── gpu_configs.json
│ ├── gpu_perf.ini
│ └── model_configs.json
├── count_flops.py
├── count_latency.py
├── count_memory.py
├── count_params.py
├── layer_graph_visualizer.py
├── roofline_model.py
└── utils
│ ├── __pycache__
│ ├── config.cpython-310.pyc
│ ├── config.cpython-311.pyc
│ ├── config.cpython-312.pyc
│ ├── constants.cpython-310.pyc
│ ├── constants.cpython-311.pyc
│ ├── constants.cpython-312.pyc
│ ├── utils.cpython-310.pyc
│ └── utils.cpython-311.pyc
│ ├── config.py
│ ├── constants.py
│ ├── utils.py
│ └── visualizer.py
└── test_torch_info.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | tmp
3 | .ruff_cache
4 | **/__pycache__/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # llm_profiler
2 |
3 | llm theoretical performance analysis tools and support params, flops, memory and latency analysis.
4 |
5 | ## 主要功能
6 |
7 | - 支持 qwen2.5、qwen3 dense 系列模型。
8 | - 支持张量并行推理模式。
9 | - 支持 `A100`、`V100`、`T4` 等硬件以及主流 decoder-only 的自回归模型,可自行在配置文件中增加。
10 | - 支持分析性能瓶颈,不同 `layer` 是 `memory bound` 还是 `compute bound`,以及 `kv_cache` 的性能瓶颈。
11 | - 支持输出每层和整个模型的参数量、计算量,内存和 `latency`。
12 | - 推理时支持预填充和解码阶段分别计算内存和 latency、以及理论支持的最大 `bs` 等等。
13 | - 支持设置计算效率、内存读取效率(不同推理框架可能不一样,这个设置好后,可推测输出实际值)。
14 | - 推理性能理论分析结果的格式化输出。
15 |
16 | ## 如何使用
17 |
18 | 使用方法,直接调用 `llm_profiler/llm_profiler.py` 文件中函数 `llm_profile()` 函数并输入相关参数即可。
19 |
20 | ```python
21 | def llm_profile(model_name="llama-13b",
22 | gpu_name: str = "v100-sxm-32gb",
23 | bytes_per_param: int = BYTES_FP16,
24 | bs: int = 1,
25 | seq_len: int = 522,
26 | generate_len=1526,
27 | ds_zero: int = 0,
28 | dp_size: int = 1,
29 | tp_size: int = 1,
30 | pp_size: int = 1,
31 | sp_size: int = 1,
32 | layernorm_dtype_bytes: int = BYTES_FP16,
33 | kv_cache_bytes: int = BYTES_FP16,
34 | flops_efficiency: float = FLOPS_EFFICIENCY,
35 | hbm_memory_efficiency: float = HBM_MEMORY_EFFICIENCY,
36 | intra_node_memory_efficiency=INTRA_NODE_MEMORY_EFFICIENCY,
37 | inter_node_memory_efficiency=INTER_NODE_MEMORY_EFFICIENCY,
38 | mode: str = "inference",
39 | ) -> dict:
40 |
41 | """format print dicts of the total floating-point operations, MACs, parameters and latency of a llm.
42 |
43 | Args:
44 | model_name (str, optional): model name to query the pre-defined `model_configs.json`. Defaults to "llama-13b".
45 | gpu_name (str, optional): gpu name to query the pre-defined `model_configs.json`. Defaults to "v100-sxm2-32gb".
46 | bs (int, optional): _description_. Defaults to 1.
47 | seq_len (int, optional): batch size per GPU.. Defaults to 522.
48 | generate_len (int, optional): The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. Defaults to 1526.
49 | dp_size (int, optional): data parallelism size. Defaults to 1.
50 | tp_size (int, optional): tensor parallelism size. Defaults to 1.
51 | pp_size (int, optional): pipeline parallelism size. Defaults to 1.
52 | sp_size (int, optional): sequence parallelism size. Defaults to 1.
53 | speed up decoding. Defaults to True.
54 | layernorm_dtype_bytes (int, optional): number of bytes in the data type for the layernorm activations.. Defaults to BYTES_FP16.
55 | kv_cache_bytes (int, optional): number of bytes in the data type for the kv_cache. Defaults to None.
56 | flops_efficiency (float, optional): flops efficiency, ranging from 0 to 1. Defaults to None.
57 | hbm_memory_efficiency (float, optional): GPU HBM memory efficiency, ranging from 0 to 1. Defaults to HBM_MEMORY_EFFICIENCY.
58 | intra_node_memory_efficiency (_type_, optional): intra-node memory efficiency, ranging from 0 to 1.. Defaults to INTRA_NODE_MEMORY_EFFICIENCY.
59 | inter_node_memory_efficiency (_type_, optional): inter-node memory efficiency, ranging from 0 to 1.. Defaults to INTER_NODE_MEMORY_EFFICIENCY.
60 |
61 | Returns:
62 | None: format print some summary dictionary of the inference analysis
63 | """
64 | ```
65 |
66 | `llama2-70` 模型,tp_size = 8 和 bs = 20,输出示例信息如下所示:
67 |
68 | ```bash
69 | -------------------------- LLM main infer config --------------------------
70 | { 'inference_config': { 'model_name': 'llama2-70b',
71 | 'num_attention_heads': 64,
72 | 'num_kv_heads': 8,
73 | 'head_dim': 128,
74 | 'hidden_size': 8192,
75 | 'intermediate_size': 28672,
76 | 'vocab_size': 32000,
77 | 'max_seq_len': 4096,
78 | 'bs': 32,
79 | 'seq_len': 1024,
80 | 'tp_size': 8,
81 | 'pp_size': 1,
82 | 'generate_len': 128},
83 | 'gpu_config': { 'name': 'a100-sxm-40gb',
84 | 'memory_GPU_in_GB': '40 GB',
85 | 'gpu_hbm_bandwidth': '1555 GB/s',
86 | 'gpu_intra_node_bandwidth': '600 GB/s',
87 | 'gpu_fp16_TFLOPS': '312 TFLOPS'}}
88 |
89 | -------------------------- LLM infer performance analysis --------------------------
90 | { 'weight_memory_per_gpu': '17.18 GB',
91 | 'consume_memory_per_gpu': '20.57 GB',
92 | 'prefill_flops': '4574.25 T',
93 | 'decode_flops_per_step': '4.38 T',
94 | 'TTFT': 2.7060724961666294,
95 | 'TTOT': 0.040541745771914876,
96 | 'kv_cache_latency': '959.04 us',
97 | 'total_infer_latency': '7.9 s',
98 | 'support_max_batch_total_tokens': 240249}
99 |
100 | ---------------------------- LLM Params per_layer analysis ----------------------------
101 | { 'qkvo_proj': '150.99 M',
102 | 'mlp': '704.64 M',
103 | 'rmsnorm': '16.38 K',
104 | 'input_embedding': '262.14 M',
105 | 'output_embedding': '262.14 M'}
106 | {'params_model': '68.71 G'}
107 |
108 | ---------------------------- LLM Prefill Flops per_layer analysis ----------------------------
109 | { 'attention_kernel': '1.1 T',
110 | 'qkvo_proj': '9.9 T',
111 | 'mlp': '46.18 T',
112 | 'rmsnorm': '4.29 G',
113 | 'positional_embedding': '536.87 M',
114 | 'input_embedding': '0'}
115 | {'prefill flops_model': '4574.25 T'}
116 |
117 | ---------------------------- LLM Memory analysis (Prefill) ----------------------------
118 | { 'weight_memory_per_gpu': '17.18 GB',
119 | 'prefill_max_bs': '388B',
120 | 'prefill_act_per_gpu': '1.88 GB'}
121 |
122 | ---------------------------- LLM Memory analysis (Decode) ----------------------------
123 | { 'decode_act_per_gpu': '1.88 GB',
124 | 'kv_cache_memory_per_gpu': '1.51 GB',
125 | 'consume_memory_per_gpu': '20.57 GB',
126 | 'decode_max_bs': '215.0B',
127 | 'max_batch_total_tokens': '240.25 KB'}
128 |
129 | ---------------------------- LLM Latency analysis (Prefill) ----------------------------
130 | { 'prefill_qkvo_proj': '352.41 ms',
131 | 'prefill_attn_kernel': '131.39 ms',
132 | 'prefill_mlp': '1.64 s',
133 | 'prefill_rmsnorm': '61.38 ms',
134 | 'prefill_tp_comm': '501.08 ms',
135 | 'prefill_kv_cache_rw': '959.04 us',
136 | 'prefill_latency': '2.71 s'}
137 |
138 | ---------------------------- LLM Latency analysis (Decode) ----------------------------
139 | { 'decode_qkvo_proj': '6.5 ms',
140 | 'decode_attn_kernel': '2.56 ms',
141 | 'decode_mlp': '30.26 ms',
142 | 'decode_rmsnorm': '64.62 us',
143 | 'decode_tp_comm': '640.0 us',
144 | 'decode_kv_cache_rw': '121.75 us',
145 | 'kv_cache_latency': '959.04 us',
146 | 'decode_latency': '40.54 ms'}
147 | ```
148 |
149 | ## 模型结构可视化
150 |
151 | llama2-70b 模型,A100-SXM40GB,tp_size = 8 和 bs = 20,prefill 阶段:
152 |
153 |
154 |

155 |
156 |
157 | llama2-70b 模型,A100-SXM40GB,tp_size = 8 和 bs = 20, decode 阶段:
158 |
159 |
160 |

161 |
162 |
163 | ## 模型参数量、计算量、latency 分布
164 |
165 | llama2-70b 模型,A100-SXM40GB,tp_size = 8 和 bs = 20,参数量统计分布:
166 |
167 |
168 |

169 |
170 |
171 | llama2-70b 模型,A100-SXM40GB,tp_size = 8 和 bs = 20,prefill 阶段计算量统计分布:
172 |
173 |
174 |

175 |
176 |
177 | llama2-70b 模型,A100-SXM40GB,tp_size = 8 和 bs = 20,generate_len = 128, decode 阶段计算量统计分布:
178 |
179 |
180 |

181 |
182 |
183 | llama2-70b 模型,A100-SXM40GB,tp_size = 8 和 bs = 20,prefill 阶段 latency 统计分布:
184 |
185 |
186 |

187 |
188 |
189 | llama2-70b 模型,A100-SXM40GB,tp_size = 8 和 bs = 20,decode 阶段 latency 统计分布:
190 |
191 |
192 |

193 |
194 |
195 | ## 参考链接
196 | - [Transformer 性能分析理论基础](https://github.com/HarleysZhang/dl_note/blob/main/6-llm_note/transformer_basic/Transformer%E6%80%A7%E8%83%BD%E5%88%86%E6%9E%90%E7%90%86%E8%AE%BA%E5%9F%BA%E7%A1%80.md)
197 | - [llm_analysis](https://github.com/cli99/llm-analysis)
198 | - [Transformer Inference Arithmetic](https://kipp.ly/blog/transformer-inference-arithmetic/)
199 | - [LLM-Viewer](https://github.com/hahnyuan/LLM-Viewer.git)
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/__init__.py
--------------------------------------------------------------------------------
/cli_benchmark.py:
--------------------------------------------------------------------------------
1 | # usage: python cli_benchmark.py --model_name Qwen3-32B --gpu_name a100-sxm-80gb --batch_size 16 --generate_len 1024 --tp_size 4
2 | import pandas as pd
3 | import argparse
4 | from llm_counts.utils.config import *
5 | from llm_counts.benchmark_analyzer import llm_profile
6 | from llm_counts.utils.visualizer import SeqLenVisualizer
7 |
8 |
9 | def sweep_seq_len(model_name, gpu_name="h100-sxm-80gb", batch_size=16, generate_len=1024, tp_size=2, seq_len_list=None, **kwargs):
10 | """Profile a model over several sequence lengths and print / plot a table.
11 |
12 | Args:
13 | model_name (str): name of the LLM
14 | gpu_name (str): target GPU
15 | seq_len_list (List[int]): list of sequence lengths to test
16 | **kwargs: forwarded to llm_profile
17 | Returns:
18 | pandas.DataFrame: one row per sequence length with metrics
19 | """
20 | if seq_len_list is None:
21 | seq_len_list = [128, 256, 512, 1024, 1334, 1567, 1678, 2567, 3072,
22 | 4096, 5120, 6144, 8192, 10240, 12288, 16384,
23 | 21472, 24576, 30346, 32768, 33792, 34980, 36790]
24 |
25 | records1 = []
26 | records2 = []
27 | for seq in seq_len_list:
28 | res1, res2 = llm_profile(
29 | model_name=model_name,
30 | gpu_name=gpu_name,
31 | batch_size=batch_size,
32 | seq_len=seq,
33 | generate_len=generate_len,
34 | tp_size=tp_size,
35 | print_flag=False,
36 | visual_flag=False,
37 | )
38 | print("=" * 80)
39 | print(f"model_name: {model_name}, gpu_name: {gpu_name}, tp_size: {tp_size}, batch_size: {batch_size}, seq_len: {seq}, generate_len: {generate_len}")
40 |
41 | records1.append(res1)
42 | records2.append(res2)
43 |
44 | df1 = pd.DataFrame(records1)
45 | print("=" * 80)
46 | print(df1.to_string(index=False))
47 | print("=" * 80)
48 |
49 | df2 = pd.DataFrame(records2)
50 | # Derive throughput in tokens / second for visualisation
51 | if "TTFT" in df2.columns:
52 | df2["throughput_tok_per_second"] = df2["seq_len"] * batch_size / df2["TTFT"].replace(0, float("nan"))
53 | # Visualise the results using *plot_seq_len_sweep*
54 | if kwargs.get("visual_flag", True):
55 | viz = SeqLenVisualizer(df2, model_name, gpu_name, show=True)
56 | viz.visualize()
57 |
58 | return df1
59 |
60 |
61 | def _parse_args():
62 | parser = argparse.ArgumentParser(
63 | description="Sweep sequence lengths, profile an LLM, and generate visualisations."
64 | )
65 | parser.add_argument("--model_name", required=True, help="LLM model name, e.g. Qwen3-32B")
66 | parser.add_argument("--gpu_name", default="h100-sxm-80gb", help="Target GPU name")
67 | parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
68 | parser.add_argument("--generate_len", type=int, default=1024, help="Generation length")
69 | parser.add_argument("--tp_size", type=int, default=2, help="Tensor‑parallel size")
70 | parser.add_argument(
71 | "--seq_lens",
72 | type=int,
73 | nargs="*",
74 | default=None,
75 | help="Space‑separated list of sequence lengths (tokens) to sweep",
76 | )
77 | parser.add_argument(
78 | "--no_visual",
79 | action="store_true",
80 | help="Disable visualisation (figures will not be generated)",
81 | )
82 | return parser.parse_args()
83 |
84 |
85 | if __name__ == "__main__":
86 | args = _parse_args()
87 | sweep_seq_len(
88 | model_name=args.model_name,
89 | gpu_name=args.gpu_name,
90 | batch_size=args.batch_size,
91 | generate_len=args.generate_len,
92 | tp_size=args.tp_size,
93 | seq_len_list=args.seq_lens,
94 | visual_flag=not args.no_visual,
95 | )
96 |
--------------------------------------------------------------------------------
/cli_perf_visual.py:
--------------------------------------------------------------------------------
1 | from llm_counts.utils.config import *
2 | from framework_tools.LLMCounts.llm_counts.benchmark_analyzer import llm_profile
3 | import math
4 |
5 |
6 | ####################################################################################################################
7 | def print_list(list):
8 | """print one-dimensional list
9 |
10 | :param list: List[int]
11 | :return: None
12 | """
13 | for _, x in enumerate(list):
14 | print(x, end="\n")
15 |
16 | ####################################################################################################################
17 | def print_all_llm_analyzer():
18 | model_name_list = [
19 | "llama-7b",
20 | "llama-13b",
21 | "llama-65b",
22 | "llama2-70b",
23 | "internlm-20b",
24 | ]
25 | gpu_name_list = [
26 | "a30-sxm-24gb",
27 | "a40-pcie-48gb",
28 | "a100-sxm-40gb",
29 | "a100-sxm-80gb",
30 | "910b-64gb",
31 | "v100-sxm-32gb",
32 | "t4-pcie-15gb",
33 | ]
34 | tp_nums_list = [1, 2, 4, 8]
35 | tgi_service_dict_list = []
36 | seq_len, generate_len = 1024, 1024
37 |
38 | for model_name in model_name_list:
39 | if model_name in ["llama2-70b", "internlm-20b"]:
40 | seq_len, generate_len = 1024, 1024
41 |
42 | for gpu_name in gpu_name_list:
43 | for tp_size in tp_nums_list:
44 | try:
45 | res1, _ = llm_profile(
46 | model_name=model_name,
47 | gpu_name=gpu_name,
48 | tp_size=tp_size,
49 | seq_len=seq_len,
50 | generate_len=generate_len,
51 | print_flag=False,
52 | visual_flag=False,
53 | )
54 | max_batch_total_tokens = int(res1["max_batch_total_tokens"])
55 | except Exception as e:
56 | print(
57 | f"model_name: {model_name}, gpu_name: {gpu_name}, tp_size: {tp_size}, error: {e}"
58 | )
59 | continue
60 |
61 | tgi_service_dict = {
62 | "model_name": model_name,
63 | "gpu_name": gpu_name,
64 | "tp_size": tp_size,
65 | "max_batch_total_tokens": max_batch_total_tokens,
66 | "max_bs": math.floor(
67 | max_batch_total_tokens / (seq_len + generate_len)
68 | ),
69 | }
70 | tgi_service_dict_list.append(tgi_service_dict)
71 |
72 | print(
73 | "============================ TGI+LightLLM service max_batch_total_tokens params list ======================"
74 | )
75 | print_list(tgi_service_dict_list)
76 |
77 | if __name__ == "__main__":
78 | # llm_profile(model_name="llama-7b", tp_size=1, print_flag=True, visual_flag=True)
79 | llm_profile(model_name="llama2-70b", gpu_name = "a100-sxm-40gb", tp_size=8,
80 | batch_size = 32, seq_len = 1024, generate_len=128,
81 | print_flag=True, visual_flag=True)
82 |
--------------------------------------------------------------------------------
/cli_structure_analyzer.py:
--------------------------------------------------------------------------------
1 | from llm_counts.layer_graph_visualizer import LayerAnalyzer, LayerGraphVisualizer
2 | from llm_counts.utils.utils import *
3 | from llm_counts.utils.config import get_model_and_gpu_config_by_name
4 | import pprint
5 | import argparse
6 |
7 |
8 | def print_format_summary_dict(summary_dict: dict, depth: int) -> str:
9 | """打印时对 params / flops / latency / memory 等进行统一转换显示。"""
10 | for key, value in summary_dict.items():
11 | if "params" in key or "flops" in key:
12 | if not isinstance(value, dict):
13 | summary_dict.update({key: num_to_string(value)})
14 | else:
15 | print_format_summary_dict(
16 | value, get_dict_depth(value) - 1
17 | ) # 递归
18 | if "latency" in key:
19 | if not isinstance(value, dict):
20 | summary_dict.update({key: latency_to_string(value)})
21 | else:
22 | print_format_summary_dict(value, get_dict_depth(value) - 1)
23 | if "memory" in key:
24 | if not isinstance(value, dict):
25 | summary_dict.update({key: f"{num_to_string(value)}B"})
26 | else:
27 | print_format_summary_dict(value, get_dict_depth(value) - 1)
28 | if depth >= 1:
29 | pprint.pprint(summary_dict, indent=4, sort_dicts=False)
30 |
31 | def test_llm_analyzer(
32 | model_name: str = "Qwen/Qwen3-8B",
33 | gpu_name="a100-sxm-80gb",
34 | bs: int = 1,
35 | seq_len: int = 522,
36 | generate_len: int = 1526,
37 | tp_size: int = 1,
38 | ):
39 | model_config, gpu_config = get_model_and_gpu_config_by_name(model_name, gpu_name)
40 | model_type = model_config.model_type
41 | llm_analyzer = LayerAnalyzer(model_config, gpu_config, tp_size=tp_size)
42 | results = llm_analyzer.analyze_model(bs=bs, seq_len=seq_len, generate_len=generate_len)
43 |
44 | # -------------------------- 绘图:模型 graph 图示例 --------------------------
45 | base_filename = f"{model_name.replace('/', '_')}_tp{tp_size}_bs{bs}_seqlen{seq_len}_genlen{generate_len}"
46 | print("base_filename", base_filename)
47 | LayerGraphVisualizer(model_type, results).render(base_filename)
48 | depth = get_dict_depth(results)
49 | # print_format_summary_dict(results, depth)
50 |
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser(
54 | description="Run LayerAnalyzer, print a formatted summary, "
55 | "and generate per‑stage layer‑graph PNGs."
56 | )
57 | parser.add_argument("--model-name", default="Qwen3-32B")
58 | parser.add_argument("--gpu-name", default="a100-sxm-80gb")
59 | parser.add_argument("--bs", type=int, default=16)
60 | parser.add_argument("--seq-len", type=int, default=1024)
61 | parser.add_argument("--generate-len",type=int, default=128)
62 | parser.add_argument("--tp-size", type=int, default=4)
63 | args = parser.parse_args()
64 |
65 | test_llm_analyzer(
66 | model_name=args.model_name,
67 | gpu_name=args.gpu_name,
68 | bs=args.bs,
69 | seq_len=args.seq_len,
70 | generate_len=args.generate_len,
71 | tp_size=args.tp_size,
72 | )
73 |
74 | """"
75 | python cli_structure_analyzer.py \
76 | --model-name llama2-70B \
77 | --gpu-name a100-sxm-80gb \
78 | --bs 16 \
79 | --seq-len 1024 \
80 | --generate-len 128 \
81 | --tp-size 4
82 | """
--------------------------------------------------------------------------------
/figures/Qwen3-32B_a100-sxm-80gb_flops_vs_seq_len.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/figures/Qwen3-32B_a100-sxm-80gb_flops_vs_seq_len.png
--------------------------------------------------------------------------------
/figures/Qwen3-32B_a100-sxm-80gb_interactive.html:
--------------------------------------------------------------------------------
1 | TTFT (s)
TTOT (ms)
Prefill TFLOPs
HBM (GiB)
Throughput (tok/s)
--------------------------------------------------------------------------------
/figures/Qwen3-32B_a100-sxm-80gb_latency_vs_seq_len.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/figures/Qwen3-32B_a100-sxm-80gb_latency_vs_seq_len.png
--------------------------------------------------------------------------------
/figures/Qwen3-32B_a100-sxm-80gb_memory_vs_seq_len.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/figures/Qwen3-32B_a100-sxm-80gb_memory_vs_seq_len.png
--------------------------------------------------------------------------------
/figures/Qwen3-32B_a100-sxm-80gb_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/figures/Qwen3-32B_a100-sxm-80gb_overview.png
--------------------------------------------------------------------------------
/figures/Qwen3-32B_a100-sxm-80gb_throughput_vs_seq_len.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/figures/Qwen3-32B_a100-sxm-80gb_throughput_vs_seq_len.png
--------------------------------------------------------------------------------
/figures/grpah_decode_llama2-70B_tp4_bs16_seqlen1024_genlen128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/figures/grpah_decode_llama2-70B_tp4_bs16_seqlen1024_genlen128.png
--------------------------------------------------------------------------------
/figures/grpah_prefill_llama2-70B_tp4_bs16_seqlen1024_genlen128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/figures/grpah_prefill_llama2-70B_tp4_bs16_seqlen1024_genlen128.png
--------------------------------------------------------------------------------
/figures/roofline_analysis_optimized.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/figures/roofline_analysis_optimized.png
--------------------------------------------------------------------------------
/images/flops_decode_llama2-70b_tp8_bs32_seqlen1024_genlen128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/images/flops_decode_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
--------------------------------------------------------------------------------
/images/flops_prefill_llama2-70b_tp8_bs32_seqlen1024_genlen128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/images/flops_prefill_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
--------------------------------------------------------------------------------
/images/grpah_decode_llama2-70b_tp8_bs32_seqlen1024_genlen128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/images/grpah_decode_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
--------------------------------------------------------------------------------
/images/grpah_prefill_llama2-70b_tp8_bs32_seqlen1024_genlen128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/images/grpah_prefill_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
--------------------------------------------------------------------------------
/images/latency_decode_llama2-70b_tp8_bs32_seqlen1024_genlen128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/images/latency_decode_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
--------------------------------------------------------------------------------
/images/latency_prefill_llama2-70b_tp8_bs32_seqlen1024_genlen128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/images/latency_prefill_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
--------------------------------------------------------------------------------
/images/params_llama2-70b_tp8_bs32_seqlen1024_genlen128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/images/params_llama2-70b_tp8_bs32_seqlen1024_genlen128.png
--------------------------------------------------------------------------------
/llm_counts/benchmark_analyzer.py:
--------------------------------------------------------------------------------
1 | # -*- coding : utf-8 -*-
2 | # author : honggao.zhang
3 | # Create : 2024-7-19
4 | # Update : 2025-01-05
5 | # Version : 0.2.0
6 | # Description : transformer model(llm) profiling tools,
7 | # with latency, memory, flops, and params distribution analysis.
8 |
9 | import logging
10 | import pprint
11 | import matplotlib.pyplot as plt
12 | import argparse
13 | import json
14 | import copy
15 |
16 | from .utils.config import *
17 | from .utils.utils import *
18 |
19 | from .count_flops import CountCausalLMFlops
20 | from .count_params import CountCausalLMParams
21 | from .count_memory import CountCausalLMMemory
22 | from .count_latency import CountCausalLMLatency
23 | from .layer_graph_visualizer import LayerAnalyzer
24 |
25 | logger = logging.getLogger()
26 |
27 |
28 | class LayerAnalyzerVisual(object):
29 | """Measures the latency, memory, number of estimated floating-point operations,
30 | and parameters of each module in a PyTorch model.
31 | """
32 |
33 | def __init__(self, llm_configs: LLMConfigs) -> None:
34 | self.llm_configs = llm_configs
35 | self.model_config = llm_configs.model_config
36 | self.gpu_config = llm_configs.gpu_config
37 | self.inference_config = llm_configs.inference_config
38 | self.parallelism_config = llm_configs.parallelism_config
39 | self.gpu_efficiency_config = llm_configs.gpu_efficiency_config
40 |
41 | self.h = self.model_config.hidden_size
42 | self.l = self.model_config.num_layers
43 | self.V = self.model_config.vocab_size
44 |
45 | self.b = llm_configs.inference_config.bs
46 | self.s = llm_configs.inference_config.seq_len
47 | self.o = llm_configs.inference_config.generate_len
48 | self.bytes_per_param = llm_configs.inference_config.bytes_per_param
49 |
50 | self.tp_size = self.parallelism_config.tp_size
51 | self.pp_size = self.parallelism_config.pp_size
52 | self.num_layers_per_gpu = int(self.l / self.parallelism_config.pp_size)
53 |
54 | self.gpu_memory_in_GB = (
55 | llm_configs.gpu_config.memory_GPU_in_GB * 10**9
56 | )
57 |
58 | self.llm_params = CountCausalLMParams(self.model_config)
59 | self.llm_flops = CountCausalLMFlops(self.model_config)
60 | self.llm_memory = CountCausalLMMemory(llm_configs)
61 | self.llm_latency = CountCausalLMLatency(llm_configs)
62 |
63 | def infer_profile(
64 | self,
65 | bs: int = 1,
66 | seq_len: int = 522,
67 | generate_len: int = 1526,
68 | act_dtype_bytes: int = BYTES_FP16,
69 | kv_cache_bytes: int = BYTES_FP16,
70 | qkvo_weight_dtype_bytes: int = BYTES_FP16,
71 | mlp_weight_dtype_bytes=BYTES_FP16,
72 | flops_efficiency: float = None,
73 | hbm_memory_efficiency: float = HBM_MEMORY_EFFICIENCY,
74 | intra_node_memory_efficiency=INTRA_NODE_MEMORY_EFFICIENCY,
75 | inter_node_memory_efficiency=INTER_NODE_MEMORY_EFFICIENCY,
76 | print_flag=False,
77 | visual_flag=False,
78 | ) -> dict:
79 | """LLM inference analysis given the llm configs and inputs."""
80 |
81 | if self.model_config.max_seq_len is not None:
82 | assert seq_len + generate_len <= self.model_config.max_seq_len, (
83 | f"seq_len {seq_len} + generate_len {generate_len} Exceeding the model max_seq_len {self.model_config.max_seq_len}"
84 | )
85 |
86 | if self.l % self.pp_size != 0:
87 | logger.warning(
88 | "Warning: the number of layers is not divisible by pp_size, please taking the floor!"
89 | )
90 |
91 | infer_config_dict = {
92 | "inference_config": {
93 | "model_name": self.model_config.model_name,
94 | "num_attention_heads": self.model_config.num_heads,
95 | "num_kv_heads": self.model_config.num_kv_heads,
96 | "head_dim": self.model_config.head_dim,
97 | "hidden_size": self.model_config.hidden_size,
98 | "intermediate_size": self.model_config.intermediate_size,
99 | "vocab_size": self.model_config.vocab_size,
100 | "max_seq_len": self.model_config.max_seq_len,
101 | "bs": bs,
102 | "seq_len": seq_len,
103 | "tp_size": self.tp_size,
104 | "pp_size": self.pp_size,
105 | "generate_len": generate_len,
106 | },
107 | "gpu_config": {
108 | "name": self.gpu_config.name,
109 | "memory_GPU_in_GB": f"{self.gpu_config.memory_GPU_in_GB} GB",
110 | "gpu_hbm_bandwidth": f"{self.gpu_config.hbm_bandwidth_in_GB_per_sec} GB/s",
111 | "gpu_intra_node_bandwidth": f"{self.gpu_config.intra_node_bandwidth_in_GB_per_sec} GB/s",
112 | "gpu_fp16_TFLOPS": f"{self.gpu_config.peak_fp16_TFLOPS} TFLOPS",
113 | },
114 | }
115 |
116 | # -------------------------- 1. Params --------------------------
117 | params_per_layer, dict_params_per_layer = (
118 | self.llm_params.count_params_per_layer()
119 | )
120 | num_params_model = self.llm_params.count_params_model()
121 |
122 | # -------------------------- 2. FLOPs ---------------------------
123 | prefill_flops_per_layer, prefill_dict_flops_per_layer = (
124 | self.llm_flops.count_flops_per_layer(bs, seq_len, generate_len)
125 | )
126 | decode_flops_per_layer, decode_dict_flops_per_layer = (
127 | self.llm_flops.count_flops_per_layer(bs, 1, generate_len)
128 | )
129 |
130 | prefill_num_flops_model = self.llm_flops.count_flops_model(bs, seq_len, generate_len)
131 | decode_num_flops_model = self.llm_flops.count_flops_model(bs, 1, generate_len)
132 |
133 | # -------------------------- 3. Memory --------------------------
134 | memory_prefill_summary_dict, memory_decode_summary_dict = (
135 | self.llm_memory.count_memory_per_gpu(
136 | bs,
137 | seq_len,
138 | generate_len,
139 | flash_attn=False,
140 | qkvo_weight_dtype_bytes=qkvo_weight_dtype_bytes,
141 | mlp_weight_dtype_bytes=mlp_weight_dtype_bytes,
142 | kv_cache_bytes=kv_cache_bytes,
143 | )
144 | )
145 |
146 | # -------------------------- 4. Latency -------------------------
147 | prefill_latency_per_layer, prefill_dict_latency_per_layer = (
148 | self.llm_latency.count_latency_per_layer(bs, seq_len, 0)
149 | )
150 | decode_latency_per_layer, decode_dict_latency_per_layer = (
151 | self.llm_latency.count_latency_per_layer(bs, 1, generate_len)
152 | )
153 | prefill_latency_breakdown, decode_latency_breakdown = (
154 | self.llm_latency.count_latency(
155 | bs,
156 | seq_len,
157 | generate_len,
158 | kv_cache_bytes=kv_cache_bytes,
159 | )
160 | )
161 |
162 | infer_result_dict = {
163 | "weight_memory_per_gpu": memory_prefill_summary_dict["weight_memory_per_gpu"],
164 | "consume_memory_per_gpu": memory_decode_summary_dict["consume_memory_per_gpu"],
165 | "prefill_flops": prefill_num_flops_model,
166 | "decode_flops_per_step": decode_num_flops_model,
167 | "TTFT": prefill_latency_breakdown["TTFT"],
168 | "TTOT": decode_latency_breakdown["TTOT"],
169 | "kv_cache_latency": decode_latency_breakdown["kv_cache_latency"],
170 | "total_infer_latency": prefill_latency_breakdown["TTFT"] + decode_latency_breakdown["TTOT"] * generate_len,
171 | "support_max_batch_total_tokens": memory_decode_summary_dict["max_batch_total_tokens"],
172 | }
173 |
174 | # --------------------------- 5. Memory Access ----------------------
175 | if visual_flag:
176 | model_type = self.model_config.model_type
177 | llm_analyzer = LayerAnalyzer(self.model_config, self.gpu_config, tp_size=self.tp_size)
178 | results = llm_analyzer.analyze_model(bs=bs, seq_len=seq_len, generate_len=generate_len)
179 |
180 | # -------------------------- 绘图:模型 graph 图示例 --------------------------
181 | base_path = f"_{self.model_config.model_name}_tp{self.tp_size}_bs{self.b}_seqlen{self.s}_genlen{self.o}.png"
182 | llm_analyzer.create_layer_graph(model_type, results, base_path)
183 | # Formatter.print_format_summary_dict(results, get_dict_depth(results))
184 |
185 | # -------------------------- 绘图:Pie 图示例 --------------------------
186 | prefill_latency_pie_save_path = f"./figures/latency_prefill" + base_path
187 | decode_latency_pie_save_path = f"./figures/latency_decode" + base_path
188 | prefill_flops_pie_save_path = f"./figures/flops_prefill" + base_path
189 | decode_flops_pie_save_path = f"./figures/flops_decode" + base_path
190 | params_pie_save_path = f"./figures/params" + base_path
191 |
192 | pie_tasks = [
193 | (dict_params_per_layer, "Params Distribution", params_pie_save_path),
194 | (prefill_dict_flops_per_layer, "Prefill FLOPS Distribution", prefill_flops_pie_save_path),
195 | (decode_dict_flops_per_layer, "Decode FLOPS Distribution", decode_flops_pie_save_path),
196 | (prefill_dict_latency_per_layer, "Prefill Latency Distribution", prefill_latency_pie_save_path),
197 | (decode_dict_latency_per_layer, "Decode Latency Distribution", decode_latency_pie_save_path),
198 | ]
199 | for data, title, path in pie_tasks:
200 | self.plot_distribution_pie(data, title, path)
201 |
202 | # ------------------------- 6. pretty‑print report --------------------
203 | if print_flag:
204 | self._print_report(
205 | infer_config_dict,
206 | copy.deepcopy(infer_result_dict),
207 | dict_params_per_layer,
208 | num_params_model,
209 | prefill_dict_flops_per_layer,
210 | prefill_num_flops_model,
211 | memory_prefill_summary_dict,
212 | memory_decode_summary_dict,
213 | prefill_latency_breakdown,
214 | decode_latency_breakdown,
215 | )
216 |
217 | return infer_result_dict
218 |
219 | def plot_distribution_pie(
220 | self,
221 | data: dict[str, float],
222 | title: str,
223 | save_path: str,
224 | *,
225 | explode_small_pct: float = 4.0, # explode slices whose pct < this value
226 | label_pct_threshold: float = 0.5, # display "= label_display_threshold else "" for lbl, pct in zip(labels, pct_list)
247 | ]
248 |
249 | # colour palette
250 | cmap = plt.get_cmap("tab20" if len(labels) > 9 else "tab10")
251 | colors = [cmap(i % cmap.N) for i in range(len(labels))]
252 |
253 | # proportional explode: smaller share → larger offset (capped at 0.18)
254 | explode = [
255 | min(0.18, 0.04 + (explode_small_pct - pct) / explode_small_pct * 0.10)
256 | if (pct := 100 * s / total) < explode_small_pct
257 | else 0
258 | for s in sizes
259 | ]
260 |
261 | # formatting tiny percentage
262 | def _autopct(pct: float) -> str:
263 | return (
264 | f"<{label_pct_threshold:.1f}%" if pct < label_pct_threshold else f"{pct:.1f}%"
265 | )
266 |
267 | # high‑dpi for clarity
268 | fig, ax = plt.subplots(figsize=(7, 6), dpi=300)
269 |
270 | wedges, texts, autotexts = ax.pie(
271 | sizes,
272 | labels=labels_display,
273 | labeldistance=1.18,
274 | autopct=_autopct,
275 | pctdistance=0.78,
276 | startangle=140,
277 | colors=colors,
278 | explode=explode,
279 | wedgeprops={"edgecolor": "white", "linewidth": 1.0},
280 | textprops={"fontsize": 10, "color": "black"},
281 | )
282 | # inner % text style
283 | plt.setp(autotexts, size=9, weight="bold", color="white")
284 |
285 | # keep legend for color reference but remove title to save space
286 | ax.legend(
287 | wedges,
288 | labels,
289 | loc="upper center",
290 | bbox_to_anchor=(0.5, -0.14),
291 | ncol=min(len(labels), 5),
292 | fontsize=9,
293 | frameon=False,
294 | )
295 |
296 | ax.axis("equal") # perfect circle
297 |
298 | # Title
299 | fig.suptitle(
300 | title,
301 | fontsize=18,
302 | weight="bold",
303 | y=0.98,
304 | color="#2c3e50",
305 | )
306 |
307 | # tidy layout – adjust bottom for legend
308 | fig.subplots_adjust(left=0.05, right=0.95, top=0.88, bottom=0.25)
309 | fig.savefig(save_path, bbox_inches="tight", pad_inches=0.06, dpi=300)
310 | plt.close(fig)
311 |
312 | # ------------------------- Pretty‑print helpers -------------------- #
313 | def _print_section(self, title, summary_dict, category, extra_totals=None):
314 | """Print a single analysis section with optional totals."""
315 | print(f"\n---------------------------- {title} ----------------------------")
316 | Formatter.print_format_summary_dict(
317 | summary_dict=summary_dict,
318 | depth=get_dict_depth(summary_dict),
319 | category=category,
320 | )
321 | if extra_totals:
322 | pprint.pprint(extra_totals, indent=4, sort_dicts=False)
323 |
324 | def _print_report(
325 | self,
326 | infer_config_dict,
327 | infer_result_dict,
328 | dict_params_per_layer,
329 | num_params_model,
330 | prefill_dict_flops_per_layer,
331 | prefill_num_flops_model,
332 | memory_prefill_summary_dict,
333 | memory_decode_summary_dict,
334 | prefill_latency_breakdown,
335 | decode_latency_breakdown,
336 | ):
337 | """Pretty‑print a full performance report."""
338 | print("\n-------------------------- LLM main infer config --------------------------")
339 | pprint.pprint(infer_config_dict, indent=4, sort_dicts=False)
340 |
341 | print("\n-------------------------- LLM infer performance analysis --------------------------")
342 | Formatter.print_format_summary_dict(
343 | infer_result_dict, get_dict_depth(infer_result_dict)
344 | )
345 |
346 | sections = [
347 | (
348 | "LLM Params per_layer analysis",
349 | dict_params_per_layer,
350 | "params",
351 | {"params_model": num_to_string(num_params_model)},
352 | ),
353 | (
354 | "LLM Prefill Flops per_layer analysis",
355 | prefill_dict_flops_per_layer,
356 | "flops",
357 | {"prefill flops_model": num_to_string(prefill_num_flops_model)},
358 | ),
359 | (
360 | "LLM Memory analysis (Prefill)",
361 | memory_prefill_summary_dict,
362 | "memory",
363 | None,
364 | ),
365 | (
366 | "LLM Memory analysis (Decode)",
367 | memory_decode_summary_dict,
368 | "memory",
369 | None,
370 | ),
371 | (
372 | "LLM Latency analysis (Prefill)",
373 | prefill_latency_breakdown,
374 | "latency",
375 | None,
376 | ),
377 | (
378 | "LLM Latency analysis (Decode)",
379 | decode_latency_breakdown,
380 | "latency",
381 | None,
382 | ),
383 | ]
384 |
385 | for title, summary_dict, category, extra in sections:
386 | self._print_section(title, summary_dict, category, extra)
387 |
388 | def llm_profile(
389 | model_name,
390 | gpu_name: str = "a100-sxm-40gb",
391 | bytes_per_param: int = BYTES_FP16,
392 | batch_size: int = 20,
393 | seq_len: int = 1024,
394 | generate_len=1024,
395 | dp_size: int = 1,
396 | tp_size: int = 8,
397 | pp_size: int = 1,
398 | sp_size: int = 1,
399 | act_dtype_bytes: int = BYTES_FP16,
400 | kv_cache_bytes: int = BYTES_FP16,
401 | flops_efficiency: float = FLOPS_EFFICIENCY,
402 | hbm_memory_efficiency: float = HBM_MEMORY_EFFICIENCY,
403 | intra_node_memory_efficiency=INTRA_NODE_MEMORY_EFFICIENCY,
404 | inter_node_memory_efficiency=INTER_NODE_MEMORY_EFFICIENCY,
405 | print_flag: bool = False,
406 | visual_flag: bool = False,
407 | ) -> dict:
408 | """Returns dict of the total floating-point operations, MACs, parameters and latency of a llm.
409 | It now returns a dictionary containing FLOPs, latency, HBM memory usage and max_batch_total_tokens.
410 |
411 | Args:
412 | model_name (str, optional): model name to query the pre-defined `model_configs.json`.
413 | Defaults to "llama-13b".
414 | gpu_name (str, optional): gpu name to query the pre-defined `model_configs.json`.
415 | Defaults to "v100-sxm2-32gb".
416 | batch_size (int, optional): _description_. Defaults to 1.
417 | seq_len (int, optional): batch size per GPU.. Defaults to 522.
418 | generate_len (int, optional): The maximum numbers of tokens to generate,
419 | ignoring the number of tokens in the prompt. Defaults to 1526.
420 | dp_size (int, optional): data parallelism size. Defaults to 1.
421 | tp_size (int, optional): tensor parallelism size. Defaults to 1.
422 | pp_size (int, optional): pipeline parallelism size. Defaults to 1.
423 | sp_size (int, optional): sequence parallelism size. Defaults to 1.
424 | past last key/values attentions (if applicable to the model) to speed up decoding. Defaults to True.
425 | layernorm_dtype_bytes (int, optional): number of bytes in the data type for the layernorm activations..
426 | Defaults to BYTES_FP16.
427 | kv_cache_bytes (int, optional): number of bytes in the data type for the kv_cache. Defaults to None.
428 | flops_efficiency (float, optional): flops efficiency, ranging from 0 to 1. Defaults to None.
429 | hbm_memory_efficiency (float, optional): GPU HBM memory efficiency, ranging from 0 to 1.
430 | Defaults to HBM_MEMORY_EFFICIENCY.
431 | intra_node_memory_efficiency (_type_, optional): intra-node memory efficiency, ranging from 0 to 1..
432 | Defaults to INTRA_NODE_MEMORY_EFFICIENCY.
433 | inter_node_memory_efficiency (_type_, optional): inter-node memory efficiency, ranging from 0 to 1..
434 | Defaults to INTER_NODE_MEMORY_EFFICIENCY.
435 |
436 | Returns:
437 | dict: a summary dictionary of the inference analysis
438 | """
439 | model_config, gpu_config = get_model_and_gpu_config_by_name(model_name, gpu_name)
440 |
441 | parallelism_config = ParallelismConfig(
442 | tp_size=tp_size, pp_size=pp_size, dp_size=dp_size, sp_size=sp_size
443 | )
444 |
445 | inference_config = InferenceConfig(
446 | bs=batch_size,
447 | seq_len=seq_len,
448 | generate_len=generate_len,
449 | bytes_per_param=bytes_per_param,
450 | act_dtype_bytes=act_dtype_bytes,
451 | kv_cache_bytes=kv_cache_bytes,
452 | )
453 |
454 | gpu_efficiency_config = GPUEfficiencyConfig(
455 | flops_efficiency=flops_efficiency,
456 | hbm_memory_efficiency=hbm_memory_efficiency,
457 | intra_node_memory_efficiency=intra_node_memory_efficiency,
458 | inter_node_memory_efficiency=inter_node_memory_efficiency,
459 | )
460 |
461 | llm_configs = LLMConfigs(
462 | model_config=model_config,
463 | gpu_config=gpu_config,
464 | parallelism_config=parallelism_config,
465 | inference_config=inference_config,
466 | gpu_efficiency_config=gpu_efficiency_config,
467 | )
468 |
469 | profiler = LayerAnalyzerVisual(llm_configs)
470 |
471 | infer_result_dict = profiler.infer_profile(
472 | bs=batch_size,
473 | seq_len=seq_len,
474 | generate_len=generate_len,
475 | act_dtype_bytes=act_dtype_bytes,
476 | flops_efficiency=flops_efficiency,
477 | hbm_memory_efficiency=hbm_memory_efficiency,
478 | print_flag=print_flag,
479 | visual_flag=visual_flag,
480 | )
481 |
482 | # ---------------------------------------------------------------------
483 | # Collect summary metrics (keep raw numbers for downstream maths) #
484 | # ---------------------------------------------------------------------
485 | weight_memory_per_gpu = infer_result_dict.get("weight_memory_per_gpu", None)
486 | consume_memory_per_gpu = infer_result_dict.get("consume_memory_per_gpu", None)
487 | prefill_flops = infer_result_dict.get("prefill_flops", None)
488 |
489 | table_results = {
490 | "seq_len": seq_len,
491 | "generate_len": generate_len,
492 | "prefill_flops": num_to_string(prefill_flops),
493 | "weight_memory_per_gpu": num_to_string(weight_memory_per_gpu),
494 | "consume_memory_per_gpu": num_to_string(consume_memory_per_gpu),
495 | "TTFT": infer_result_dict.get("TTFT", None),
496 | "TTOT": infer_result_dict.get("TTOT", None),
497 | "Total_latency": infer_result_dict.get("total_infer_latency", None),
498 | }
499 | visual_results = {
500 | "seq_len": seq_len,
501 | "generate_len": generate_len,
502 | "prefill_flops": prefill_flops, # raw number
503 | "weight_memory_per_gpu": weight_memory_per_gpu,
504 | "consume_memory_per_gpu": consume_memory_per_gpu, # raw bytes
505 | "TTFT": infer_result_dict.get("TTFT", None),
506 | "TTOT": infer_result_dict.get("TTOT", None),
507 | "Total_latency": infer_result_dict.get("total_infer_latency", None),
508 | }
509 | return table_results, visual_results
510 |
511 |
512 | # ----------------------------- Command‑line interface ----------------------------- #
513 | def _cli():
514 | """Command‑line wrapper for quick profiling."""
515 | parser = argparse.ArgumentParser(description="LLMCounts – quick model inference profiler")
516 | parser.add_argument("--model_name", required=True, help="Model name defined in model_configs.json")
517 | parser.add_argument("--gpu_name", default="a100-sxm-40gb", help="GPU name defined in model_configs.json")
518 | parser.add_argument("--batch_size", type=int, default=1)
519 | parser.add_argument("--seq_len", type=int, default=1024)
520 | parser.add_argument("--generate_len", type=int, default=1024)
521 | parser.add_argument("--tp_size", type=int, default=1)
522 | parser.add_argument("--pp_size", type=int, default=1)
523 | parser.add_argument("--dp_size", type=int, default=1)
524 | parser.add_argument("--sp_size", type=int, default=1)
525 | parser.add_argument("--visual", action="store_true", help="Generate pie‑charts and layer graph")
526 | parser.add_argument("--print", dest="print_flag", action="store_true", help="Pretty‑print verbose breakdown")
527 | parser.add_argument("--json", dest="json_flag", action="store_true", help="Output raw results as JSON")
528 | args = parser.parse_args()
529 |
530 | table_results, visual_results = llm_profile(
531 | model_name=args.model_name,
532 | gpu_name=args.gpu_name,
533 | batch_size=args.batch_size,
534 | seq_len=args.seq_len,
535 | generate_len=args.generate_len,
536 | tp_size=args.tp_size,
537 | pp_size=args.pp_size,
538 | dp_size=args.dp_size,
539 | sp_size=args.sp_size,
540 | print_flag=args.print_flag,
541 | visual_flag=args.visual,
542 | )
543 |
544 | if args.json_flag:
545 | print(json.dumps(visual_results, indent=2))
546 | else:
547 | import pprint
548 | pprint.pprint(table_results, indent=2)
549 |
550 |
551 | if __name__ == "__main__":
552 | _cli()
--------------------------------------------------------------------------------
/llm_counts/configs/gpu_configs.json:
--------------------------------------------------------------------------------
1 | {
2 | "t4-pcie-15gb": {
3 | "name": "t4-pcie-15gb",
4 | "memory_GPU_in_GB": 15,
5 | "hbm_bandwidth_in_GB_per_sec": 300,
6 | "intra_node_bandwidth_in_GB_per_sec": 32,
7 | "peak_fp16_TFLOPS": 65,
8 | "peak_int8_TFLOPS": 130,
9 | "peak_int4_TFLOPS": 260,
10 | "intra_node_min_message_latency": 8e-06
11 | },
12 | "v100-pcie-32gb": {
13 | "name": "v100-pcie-32gb",
14 | "memory_GPU_in_GB": 32,
15 | "hbm_bandwidth_in_GB_per_sec": 900,
16 | "intra_node_bandwidth_in_GB_per_sec": 32,
17 | "inter_node_bandwidth_in_GB_per_sec": 200,
18 | "peak_fp16_TFLOPS": 112,
19 | "peak_int8_TFLOPS": 224,
20 | "peak_int4_TFLOPS": 448,
21 | "intra_node_min_message_latency": 8e-06,
22 | "onchip_buffer": 20480e3
23 | },
24 | "v100-sxm-32gb": {
25 | "name": "v100-sxm-32gb",
26 | "memory_GPU_in_GB": 32,
27 | "hbm_bandwidth_in_GB_per_sec": 900,
28 | "intra_node_bandwidth_in_GB_per_sec": 300,
29 | "inter_node_bandwidth_in_GB_per_sec": 200,
30 | "peak_fp16_TFLOPS": 112,
31 | "peak_int8_TFLOPS": 224,
32 | "peak_int4_TFLOPS": 448,
33 | "intra_node_min_message_latency": 8e-06,
34 | "onchip_buffer": 20480e3
35 | },
36 | "br104p": {
37 | "name": "br104p",
38 | "memory_GPU_in_GB": 32,
39 | "hbm_bandwidth_in_GB_per_sec": 819,
40 | "intra_node_bandwidth_in_GB_per_sec": 192,
41 | "inter_node_bandwidth_in_GB_per_sec": 200,
42 | "peak_fp32_TFLOPS": 256,
43 | "peak_fp16_TFLOPS": 512,
44 | "peak_int8_TFLOPS": 1024,
45 | "intra_node_min_message_latency": 8e-06
46 | },
47 | "a100-pcie-40gb": {
48 | "name": "a100-pcie-40gb",
49 | "memory_GPU_in_GB": 40,
50 | "hbm_bandwidth_in_GB_per_sec": 1555,
51 | "intra_node_bandwidth_in_GB_per_sec": 64,
52 | "inter_node_bandwidth_in_GB_per_sec": 200,
53 | "peak_fp32_TFLOPS": 156,
54 | "peak_fp16_TFLOPS": 312,
55 | "peak_int8_TFLOPS": 624,
56 | "peak_int4_TFLOPS": 1248,
57 | "intra_node_min_message_latency": 8e-06,
58 | "onchip_buffer": 27648e3
59 | },
60 | "a100-sxm-40gb": {
61 | "name": "a100-sxm-40gb",
62 | "memory_GPU_in_GB": 40,
63 | "hbm_bandwidth_in_GB_per_sec": 1555,
64 | "intra_node_bandwidth_in_GB_per_sec": 600,
65 | "inter_node_bandwidth_in_GB_per_sec": 200,
66 | "peak_fp32_TFLOPS": 156,
67 | "peak_fp16_TFLOPS": 312,
68 | "peak_int8_TFLOPS": 624,
69 | "peak_int4_TFLOPS": 1248,
70 | "intra_node_min_message_latency": 8e-06,
71 | "onchip_buffer": 27648e3
72 | },
73 | "a100-pcie-80gb": {
74 | "name": "a100-pcie-80gb",
75 | "memory_GPU_in_GB": 80,
76 | "hbm_bandwidth_in_GB_per_sec": 1935,
77 | "intra_node_bandwidth_in_GB_per_sec": 64,
78 | "inter_node_bandwidth_in_GB_per_sec": 200,
79 | "peak_fp32_TFLOPS": 156,
80 | "peak_fp16_TFLOPS": 312,
81 | "peak_int8_TFLOPS": 624,
82 | "peak_int4_TFLOPS": 1248,
83 | "intra_node_min_message_latency": 8e-06,
84 | "onchip_buffer": 27648e3
85 | },
86 | "a100-sxm-80gb": {
87 | "name": "a100-sxm-80gb",
88 | "memory_GPU_in_GB": 80,
89 | "hbm_bandwidth_in_GB_per_sec": 2039,
90 | "intra_node_bandwidth_in_GB_per_sec": 600,
91 | "inter_node_bandwidth_in_GB_per_sec": 200,
92 | "peak_fp32_TFLOPS": 156,
93 | "peak_fp16_TFLOPS": 312,
94 | "peak_int8_TFLOPS": 624,
95 | "peak_int4_TFLOPS": 1248,
96 | "intra_node_min_message_latency": 8e-06,
97 | "onchip_buffer": 27648e3
98 | },
99 | "910b-64gb": {
100 | "name": "910b-64gb",
101 | "memory_GPU_in_GB": 64,
102 | "hbm_bandwidth_in_GB_per_sec": 460,
103 | "intra_node_bandwidth_in_GB_per_sec": 392,
104 | "inter_node_bandwidth_in_GB_per_sec": 200,
105 | "peak_fp32_TFLOPS": 188,
106 | "peak_fp16_TFLOPS": 376,
107 | "peak_int8_TFLOPS": 752,
108 | "peak_int4_TFLOPS": 1504,
109 | "intra_node_min_message_latency": 9e-06
110 | },
111 | "h100-sxm-80gb": {
112 | "name": "a100-sxm-80gb",
113 | "memory_GPU_in_GB": 80,
114 | "hbm_bandwidth_in_GB_per_sec": 3430,
115 | "intra_node_bandwidth_in_GB_per_sec": 900,
116 | "inter_node_bandwidth_in_GB_per_sec": 400,
117 | "peak_fp32_TFLOPS": 989,
118 | "peak_fp16_TFLOPS": 1979,
119 | "peak_int8_TFLOPS": 3958,
120 | "intra_node_min_message_latency": 8e-06,
121 | "onchip_buffer": 33792e3
122 | },
123 | "h100-pcie-80gb": {
124 | "name": "a100-sxm-80gb",
125 | "memory_GPU_in_GB": 80,
126 | "hbm_bandwidth_in_GB_per_sec": 2048,
127 | "intra_node_bandwidth_in_GB_per_sec": 128,
128 | "inter_node_bandwidth_in_GB_per_sec": 400,
129 | "peak_fp32_TFLOPS": 756,
130 | "peak_fp16_TFLOPS": 1513,
131 | "peak_int8_TFLOPS": 3026,
132 | "intra_node_min_message_latency": 8e-06,
133 | "onchip_buffer": 33792e3
134 | },
135 | "a30-pcie-24gb": {
136 | "name": "a30-pcie-24gb",
137 | "memory_GPU_in_GB": 24,
138 | "hbm_bandwidth_in_GB_per_sec": 933,
139 | "intra_node_bandwidth_in_GB_per_sec": 64,
140 | "inter_node_bandwidth_in_GB_per_sec": 200,
141 | "peak_fp32_TFLOPS": 82,
142 | "peak_fp16_TFLOPS": 165,
143 | "peak_int8_TFLOPS": 330,
144 | "peak_int4_TFLOPS": 661,
145 | "intra_node_min_message_latency": 8e-06
146 | },
147 | "a30-sxm-24gb": {
148 | "name": "a30-sxm-24gb",
149 | "memory_GPU_in_GB": 24,
150 | "hbm_bandwidth_in_GB_per_sec": 933,
151 | "intra_node_bandwidth_in_GB_per_sec": 200,
152 | "inter_node_bandwidth_in_GB_per_sec": 200,
153 | "peak_fp32_TFLOPS": 82,
154 | "peak_fp16_TFLOPS": 165,
155 | "peak_int8_TFLOPS": 330,
156 | "peak_int4_TFLOPS": 661,
157 | "intra_node_min_message_latency": 8e-06
158 | },
159 | "a40-pcie-48gb": {
160 | "name": "a40-pcie-48gb",
161 | "memory_GPU_in_GB": 44.98,
162 | "hbm_bandwidth_in_GB_per_sec": 696,
163 | "intra_node_bandwidth_in_GB_per_sec": 64,
164 | "inter_node_bandwidth_in_GB_per_sec": 200,
165 | "peak_fp32_TFLOPS": 74.8,
166 | "peak_fp16_TFLOPS": 149.7,
167 | "peak_int8_TFLOPS": 299.3,
168 | "peak_int4_TFLOPS": 598.7,
169 | "intra_node_min_message_latency": 8e-06
170 | }
171 | }
--------------------------------------------------------------------------------
/llm_counts/configs/gpu_perf.ini:
--------------------------------------------------------------------------------
1 | [T4]
2 | gpu_memory=16GB
3 | single_precision=8.1TFLOPS
4 | gpu_memory_bandwidth=300GB/s
5 | interconnect_bandwidth=32GB/s
6 | [L4]
7 | gpu_memory=30GB
8 | single_precision=24TFLOPS
9 | gpu_memory_bandwidth=300GB/s
10 | interconnect_bandwidth=64GB/s
11 | [L40]
12 | gpu_memory=48GB
13 | single_precision=90.5TFLOPS
14 | gpu_memory_bandwidth=864GB/s
15 | interconnect_bandwidth=64GB/s
16 | [V100]
17 | gpu_memory=36GB
18 | single_precision=14TFLOPS
19 | gpu_memory_bandwidth=900GB/s
20 | interconnect_bandwidth=32GB/s
21 | [A100]
22 | gpu_memory=80GB
23 | single_precision=19.5TFLOPS
24 | gpu_memory_bandwidth=1935GB/s
25 | interconnect_bandwidth=64GB/s
--------------------------------------------------------------------------------
/llm_counts/configs/model_configs.json:
--------------------------------------------------------------------------------
1 | {
2 | "llama-7B":{
3 | "num_layers": 32,
4 | "num_heads": 32,
5 | "hidden_size": 4096,
6 | "intermediate_size": 11008,
7 | "vocab_size": 32000,
8 | "max_seq_len": 2048,
9 | "model_type": "llama",
10 | "model_name": "llama-7B"
11 | },
12 | "llama-13B":{
13 | "num_layers": 40,
14 | "num_heads": 40,
15 | "hidden_size": 5120,
16 | "intermediate_size": 13824,
17 | "vocab_size": 32000,
18 | "max_seq_len": 2048,
19 | "model_type": "llama",
20 | "model_name": "llama-13B"
21 | },
22 | "llama-30B":{
23 | "num_layers": 60,
24 | "num_heads": 52,
25 | "hidden_size": 6656,
26 | "intermediate_size": 17920,
27 | "vocab_size": 32000,
28 | "max_seq_len": 2048,
29 | "model_type": "llama",
30 | "model_name": "llama-30B"
31 | },
32 | "llama-65B":{
33 | "num_layers": 80,
34 | "num_heads": 64,
35 | "hidden_size": 8192,
36 | "intermediate_size": 22016,
37 | "vocab_size": 32000,
38 | "max_seq_len": 2048,
39 | "model_type": "llama",
40 | "model_name": "llama-65B"
41 | },
42 | "llama2-13B":{
43 | "num_layers": 40,
44 | "num_heads": 40,
45 | "num_kv_heads": 40,
46 | "hidden_size": 5120,
47 | "intermediate_size": 13824,
48 | "vocab_size": 32000,
49 | "max_seq_len": 4096,
50 | "model_type": "llama",
51 | "model_name": "llama2-13B"
52 | },
53 | "llama2-70B":{
54 | "num_layers": 80,
55 | "num_heads": 64,
56 | "num_kv_heads": 8,
57 | "hidden_size": 8192,
58 | "intermediate_size": 28672,
59 | "vocab_size": 32000,
60 | "max_seq_len": 4096,
61 | "model_type": "llama2",
62 | "model_name": "llama2-70B"
63 | },
64 | "internlm-20B": {
65 | "num_layers": 60,
66 | "num_heads": 40,
67 | "num_kv_heads": 40,
68 | "hidden_size": 5120,
69 | "intermediate_size": 13824,
70 | "vocab_size": 103168,
71 | "max_seq_len": 16384,
72 | "model_type": "internlm",
73 | "model_name": "internlm-20B"
74 | },
75 | "internlm2-20b-chat": {
76 | "num_layers": 48,
77 | "num_heads": 48,
78 | "num_kv_heads": 8,
79 | "hidden_size": 6144,
80 | "intermediate_size": 16384,
81 | "vocab_size": 92544,
82 | "max_seq_len": 32768,
83 | "model_type": "internlm2",
84 | "model_name": "internlm2-20b-chat"
85 | },
86 | "Qwen3-8B": {
87 | "num_layers": 36,
88 | "head_dim": 128,
89 | "hidden_size": 4096,
90 | "num_heads": 32,
91 | "num_kv_heads": 8,
92 | "intermediate_size": 12288,
93 | "vocab_size": 151936,
94 | "max_seq_len": 40960,
95 | "model_type": "qwen3",
96 | "model_name": "Qwen3-8B"
97 | },
98 | "Qwen3-32B": {
99 | "num_layers": 64,
100 | "head_dim": 128,
101 | "hidden_size": 5120,
102 | "num_heads": 64,
103 | "num_kv_heads": 8,
104 | "intermediate_size": 25600,
105 | "vocab_size": 151936,
106 | "max_seq_len": 40960,
107 | "model_type": "qwen3",
108 | "model_name": "Qwen3-32B"
109 | }
110 | }
111 |
--------------------------------------------------------------------------------
/llm_counts/count_flops.py:
--------------------------------------------------------------------------------
1 | from .utils.config import ModelConfig
2 |
3 |
4 | class CountCausalLMFlops(object):
5 | """CountCausalLMFlops is a class that counts the number of floating point operations (FLOPs)
6 | for a causal language model (LLM) during the forward passes."""
7 |
8 | def __init__(
9 | self,
10 | model_config: ModelConfig,
11 | ) -> None:
12 | self.model_type = model_config.model_type
13 | self.num_heads = model_config.num_heads
14 | self.num_kv_heads = model_config.num_kv_heads
15 | self.head_dim = model_config.head_dim
16 | self.hidden_size = model_config.hidden_size
17 | self.intermediate_size = model_config.intermediate_size
18 | self.l = model_config.num_layers
19 | self.V = model_config.vocab_size
20 |
21 | def count_flops_per_layer_qkvo_proj(self, bs: int, seq_len: int) -> int:
22 | """Get the number of floating point operations (flops) for the forward
23 | pass of the attention linear layers, given the batch size and sequence length.
24 |
25 | flops_qkvo_proj = flops_q + flops_k + flops_v + flops_output
26 |
27 | Args:
28 | bs (int): batch size
29 | seq_len (int): sequence length
30 | """
31 | q_proj_flops = 2 * bs * seq_len * self.hidden_size * self.num_heads * self.head_dim
32 | k_proj_flops = 2 * bs * seq_len * self.hidden_size * self.num_kv_heads * self.head_dim
33 | v_proj_flops = 2 * bs * seq_len * self.hidden_size * self.num_kv_heads * self.head_dim
34 | o_proj_flops = 2 * bs * seq_len * self.hidden_size * self.num_heads * self.head_dim
35 | qkvo_proj_flops = q_proj_flops + k_proj_flops + v_proj_flops + o_proj_flops
36 |
37 | return qkvo_proj_flops
38 |
39 | def count_flops_per_layer_mlp(self, bs: int, seq_len: int) -> int:
40 | """Count two flops of matrices multiplication(two linear layers in the MLP module.)
41 | eg. llama3.2-1B: self.intermediate_size = 4 * self.hidden_size
42 | eg. flops_mlp(llama3.2-1B) = flops_fc1 + flops_fc2 + flops_fc3
43 | = 2bs(4h^2) + 2bs(4h^2) + 2bs(4h^2) = 24bsh^2
44 | """
45 | flops_gate_proj = 2 * bs * seq_len * self.hidden_size * self.intermediate_size
46 | flops_up_proj = 2 * bs * seq_len * self.hidden_size * self.intermediate_size
47 | flops_down_proj = 2 * bs * seq_len * self.intermediate_size * self.hidden_size
48 |
49 | return flops_gate_proj + flops_up_proj + flops_down_proj
50 |
51 | def count_flops_per_layer_attn_kernel(self, bs: int, seq_len: int, generate_len: int) -> int:
52 | q_norm_flops = bs * 4 * seq_len * self.head_dim
53 | k_norm_flops = q_norm_flops
54 | # e^x / sum(e^x); bs = 1 和 seq_len = 1 时 flops 为 3d-1, 张量中每个元素约执行 3 次操作
55 | softmax_flops = bs * 3 * seq_len * self.num_heads * self.head_dim
56 |
57 | if seq_len != 1:
58 | qk_matmul_flops = bs * 2 * seq_len * seq_len * self.num_heads * self.head_dim
59 | sv_matmul_flops = qk_matmul_flops
60 |
61 | else:
62 | qk_matmul_flops = 2 * self.num_heads * self.head_dim * (seq_len + generate_len)
63 | sv_matmul_flops = qk_matmul_flops
64 |
65 | flops_self_attention_kernel = q_norm_flops + k_norm_flops + qk_matmul_flops + sv_matmul_flops + softmax_flops
66 |
67 | return flops_self_attention_kernel
68 |
69 | def count_flops_per_layer_norm(self, bs: int, seq_len: int) -> int:
70 | """flops of 2 rmsnrom per layer"""
71 | return bs * 4 * seq_len * self.hidden_size
72 |
73 | def count_flops_per_layer(self, bs: int, seq_len: int, generate_len:int) -> tuple:
74 | flops_per_layer_qkvo_proj = self.count_flops_per_layer_qkvo_proj(bs, seq_len)
75 | flops_per_layer_mlp = self.count_flops_per_layer_mlp(bs, seq_len)
76 |
77 | flops_per_layer_attention_kernel = self.count_flops_per_layer_attn_kernel(
78 | bs, seq_len, generate_len,
79 | )
80 | flops_per_layer_rmsnorm = (
81 | self.count_flops_per_layer_norm(bs, seq_len) * 2
82 | ) # atten_rmsnorm and mlp_rmsnorm
83 |
84 | flops_positional_embedding = self.count_flops_positional_embedding(bs, seq_len)
85 |
86 | flops_per_layer = (
87 | flops_per_layer_qkvo_proj
88 | + flops_per_layer_mlp
89 | + flops_per_layer_rmsnorm
90 | + flops_per_layer_attention_kernel
91 | + flops_positional_embedding
92 | )
93 |
94 | dict_flops_per_layer = {
95 | "attention_kernel": flops_per_layer_attention_kernel,
96 | "qkvo_proj": flops_per_layer_qkvo_proj,
97 | "mlp": flops_per_layer_mlp,
98 | "rmsnorm": flops_per_layer_rmsnorm * 2,
99 | "positional_embedding": flops_positional_embedding,
100 | "input_embedding": 0,
101 | }
102 |
103 | return flops_per_layer, dict_flops_per_layer
104 |
105 | def count_flops_positional_embedding(
106 | self,
107 | bs:int,
108 | seq_len:int,
109 | ) -> int:
110 | """flops of output token logits layer"""
111 | return 2 * bs * seq_len * self.hidden_size
112 |
113 | def count_flops_model(self, bs: int, seq_len: int, generate_len: int) -> int:
114 | """Count flops of the forward pass of the transformer model,
115 | given the batch size and sequence length.
116 | """
117 | num_flops_model = self.count_flops_per_layer(bs, seq_len, generate_len)[0] * self.l
118 |
119 | return num_flops_model
120 |
121 | def count_flops_bwd_model(self, bs: int, seq_len: int, generate_len:int) -> int:
122 | """Get the number of floating point operations (flops) for the backward
123 | pass of the entire transformer model, given the batch size and sequence
124 | """
125 | return 2 * self.count_flops_model(bs, seq_len, generate_len)
126 |
--------------------------------------------------------------------------------
/llm_counts/count_latency.py:
--------------------------------------------------------------------------------
1 | from .utils.config import (
2 | LLMConfigs,
3 | get_gpu_hbm_bandwidth,
4 | get_intra_node_bandwidth,
5 | get_TFLOPS_per_gpu,
6 | )
7 | from .utils.constants import *
8 | from .utils.utils import latency_to_string
9 |
10 | from .count_flops import CountCausalLMFlops
11 | from .count_params import CountCausalLMParams
12 | from .count_memory import CountCausalLMMemory
13 |
14 |
15 | class CountCausalLMLatency(object):
16 | """Count latency by roof-line performance model."""
17 |
18 | def __init__(self, llm_configs: LLMConfigs) -> None:
19 | self.model_config = llm_configs.model_config
20 | self.gpu_config = llm_configs.gpu_config
21 | self.inference_config = llm_configs.inference_config
22 | self.parallelism_config = llm_configs.parallelism_config
23 |
24 | self.h = self.model_config.hidden_size
25 | self.l = self.model_config.num_layers
26 | self.V = self.model_config.vocab_size
27 |
28 | self.b = llm_configs.inference_config.bs
29 | self.s = llm_configs.inference_config.seq_len
30 | self.o = llm_configs.inference_config.generate_len
31 | self.bytes_per_param = llm_configs.inference_config.bytes_per_param
32 |
33 | self.tp_size = self.parallelism_config.tp_size
34 | self.pp_size = self.parallelism_config.pp_size
35 | self.num_layers_per_gpu = int(self.l / self.parallelism_config.pp_size)
36 |
37 | self.gpu_hbm_bandwidth, self.onchip_buffer = (
38 | get_gpu_hbm_bandwidth(self.gpu_config, HBM_MEMORY_EFFICIENCY)
39 | )
40 | self.gpu_hbm_bandwidth *= 10**9
41 |
42 | self.gpu_intra_node_bandwidth = (
43 | get_intra_node_bandwidth(self.gpu_config, INTRA_NODE_MEMORY_EFFICIENCY)
44 | * 10**9
45 | ) # intra node bandwidth,GB/s
46 | self.gpu_TFLOPS = (
47 | get_TFLOPS_per_gpu(self.gpu_config, flops_efficiency=FLOPS_EFFICIENCY)
48 | * 10**12
49 | ) # TFLOPS
50 |
51 | self.llm_params = CountCausalLMParams(self.model_config)
52 | self.llm_memory = CountCausalLMMemory(llm_configs)
53 | self.llm_flops = CountCausalLMFlops(self.model_config)
54 |
55 | @staticmethod
56 | def print_kernel_bound_info(stage, memory_latency, compute_latency, ops_type):
57 | """Print the kernel bound information for the given stage."""
58 | if memory_latency > compute_latency:
59 | print(
60 | f"{stage} stage: memory_latency {latency_to_string(memory_latency)} \
61 | > compute_latency {latency_to_string(compute_latency)}, this {ops_type} layer is memory bound!"
62 | )
63 | else:
64 | print(
65 | f"{stage} stage: memory_latency {latency_to_string(memory_latency)} \
66 | <= compute_latency {latency_to_string(compute_latency)}, this {ops_type} layer is compute bound!"
67 | )
68 |
69 | def common_count_latency_for_ops(
70 | self,
71 | bs: int,
72 | seq_len: int,
73 | generate_len: int = 0,
74 | ops_type: str = "qkvo_proj",
75 | stage="decode_",
76 | print_bound: bool = False,
77 | ) -> float:
78 | """Count the latency for the forward layer or model, assuming the compute and memory operations are perfectly overlapped.
79 |
80 | Args:
81 | flops (float): flops of the forward layer or model
82 | memory (float): r/w memory(bytes) of the forward layer or model
83 | tp_size (float): tensor parallelism size
84 | gpu_TFLOPS (float): GPU TFLOPS in T(10^12)FLOPS
85 | gpu_hbm_bandwidth (float): GPU HBM bandwidth in GB/s(10^9)
86 |
87 | Returns:
88 | float: the latency in seconds for the forward pass
89 | """
90 | ops_type = ops_type.lower()
91 |
92 | if ops_type == "qkvo_proj":
93 | flops = (
94 | self.llm_flops.count_flops_per_layer_qkvo_proj(bs, seq_len)
95 | / self.tp_size
96 | )
97 | weight_memory = (
98 | self.llm_params.count_params_per_layer_mha()
99 | * self.bytes_per_param
100 | / self.tp_size
101 | ) * BYTES_FP16
102 | mac = self.llm_memory.count_mac_per_layer_qkvo_proj(bs, seq_len)[1] / self.tp_size
103 |
104 | memory = weight_memory + mac
105 | elif ops_type == "attn_kernel":
106 | flops = (
107 | self.llm_flops.count_flops_per_layer_attn_kernel(bs, seq_len, generate_len)
108 | / self.tp_size
109 | )
110 | weight_memory = 0
111 | mac = self.llm_memory.count_mac_per_layer_attn_kernel(bs, seq_len, generate_len, kv_cache_bytes=BYTES_FP16)[1] / self.tp_size
112 | memory = weight_memory + mac
113 |
114 | elif ops_type == "mlp":
115 | flops = self.llm_flops.count_flops_per_layer_mlp(bs, seq_len) / self.tp_size
116 | weight_memory = (
117 | self.llm_params.count_params_per_layer_mlp()
118 | * self.bytes_per_param
119 | / self.tp_size
120 | ) * BYTES_FP16
121 | mac = (self.llm_memory.count_mac_per_layer_mlp(bs, seq_len)[1] / self.tp_size)
122 | memory = weight_memory + mac
123 |
124 | elif ops_type == "rmsnorm":
125 | # Two RMSNorm operations (pre‑attention & pre‑MLP) share the same
126 | # vector weight, replicated across TP ranks.
127 | weight_memory = 2 * self.llm_params.count_params_per_layer_norm() * BYTES_FP16
128 | flops = self.llm_flops.count_flops_per_layer_norm(bs, seq_len)
129 | mac = self.llm_memory.count_mac_per_layer_norm(bs, seq_len)[1]
130 | memory = weight_memory + mac
131 | else:
132 | raise ValueError(f"Unsupported ops_type: {ops_type}")
133 |
134 | compute_latency = flops / (self.gpu_TFLOPS) # 单位秒
135 | memory_latency = memory / (self.gpu_hbm_bandwidth)
136 |
137 | if print_bound:
138 | self.print_kernel_bound_info(stage, memory_latency, compute_latency, ops_type)
139 |
140 | return max(compute_latency, memory_latency)
141 |
142 | def count_latency_per_layer_tp_comm(self, bs: int, seq_len: int) -> float:
143 | """Count the latency of a single allreduce communication across the
144 | tensor parallel group in the forward pass of a transformer layer.
145 | The latency is the max of the latency for the allreduce and the minimum
146 | message latency through intra-node connect.
147 | """
148 |
149 | if self.tp_size == 1:
150 | return 0
151 |
152 | # 一次 AllReduce 产生的通讯量为 \phi = 2bsh
153 | # Self-Attention 和 MLP 部分的计算各需要进行一次 All-Reduce 操作, 即每层做 2 次 All-Reduce操作
154 | # if tp_size is large enough num_data_per_all_reduce can be 4bsh
155 | num_data_per_all_reduce = (
156 | 6 * bs * seq_len * self.h * (self.tp_size - 1) / (self.tp_size)
157 | )
158 |
159 | latency_per_layer_tp_comm = (
160 | num_data_per_all_reduce
161 | * self.bytes_per_param
162 | / self.gpu_intra_node_bandwidth
163 | )
164 |
165 | # intra_node_min_message_latency: 节点内连接的最小消息延迟
166 | return max(
167 | latency_per_layer_tp_comm,
168 | self.gpu_config.intra_node_min_message_latency,
169 | )
170 |
171 | def count_latency_per_layer(
172 | self,
173 | bs: int,
174 | seq_len: int,
175 | generate_len: int = 0,
176 | flash_attn=False,
177 | kv_cache_bytes: int = BYTES_FP16,
178 | ) -> tuple:
179 | kernel_latency_per_layer = 0.0
180 | dict_latency_per_layer = dict()
181 | ops_list = ["qkvo_proj", "attn_kernel", "mlp", "rmsnorm"]
182 |
183 | for ops_name in ops_list:
184 | kernel_latency = self.common_count_latency_for_ops(
185 | bs, seq_len, generate_len, ops_name,
186 | )
187 | dict_latency_per_layer[ops_name] = kernel_latency
188 | kernel_latency_per_layer += kernel_latency
189 |
190 | latency_per_layer_tp_comm = self.count_latency_per_layer_tp_comm(bs, seq_len)
191 | kv_cache_latency = self.count_latency_kv_cache_per_layer(
192 | bs, seq_len, generate_len, flash_attn, kv_cache_bytes
193 | )
194 |
195 | latency_per_layer = (
196 | kernel_latency_per_layer
197 | + latency_per_layer_tp_comm
198 | + kv_cache_latency
199 | )
200 |
201 | dict_latency_per_layer["tp_comm"] = latency_per_layer_tp_comm
202 | dict_latency_per_layer["kv_cache_rw"] = kv_cache_latency
203 |
204 | return latency_per_layer, dict_latency_per_layer
205 |
206 | def count_latency_input_embedding(self, bs: int, seq_len: int) -> float:
207 | """Get the latency for the forward pass of the input embedding layer,
208 | given the batch size, sequence length, and data type of the embedding
209 | weight.
210 |
211 | Args:
212 | bs (int): batch size
213 | seq_len (int): sequence length
214 |
215 | Returns:
216 | float: the latency in seconds for the forward pass of the input embedding layer
217 | """
218 | memory_latency = (
219 | self.model_config.vocab_size
220 | * self.model_config.hidden_size
221 | * self.bytes_per_param
222 | / (self.gpu_hbm_bandwidth)
223 | )
224 | comm_latency = self.count_latency_per_layer_tp_comm(bs, seq_len)
225 | return memory_latency + comm_latency
226 |
227 | def count_latency_output_embedding(self, bs: int, seq_len: int) -> float:
228 | """Get the latency for the forward pass of the output embedding layer (computing the logits).
229 | The operation is compute bound. With tensor parallelism size > 1,
230 | an allgather communicates `bs * seq_len` elements,
231 | which is ignored here. Refer to https://arxiv.org/abs/1909.08053 for more details.
232 |
233 | Args:
234 | bs (int): batch size
235 | seq_len (int): sequence length
236 | """
237 |
238 | compute_latency = (
239 | 2 * bs * seq_len * self.h * self.V / self.tp_size / self.gpu_TFLOPS
240 | )
241 |
242 | return compute_latency
243 |
244 | def count_latency_kv_cache_per_layer(
245 | self,
246 | bs: int,
247 | seq_len: int,
248 | generate_len: int,
249 | flash_attn: bool = False,
250 | kv_cache_bytes: int = BYTES_FP16,
251 | ) -> tuple:
252 | """Get the latency for the forward pass of the key and value cache in a transformer layer,
253 | given the batch size, sequence length, and whether the key and value cache is used.
254 |
255 | Args:
256 | bs (int): batch size
257 | seq_len (int): sequence length
258 | generate_len (int): number of tokens to generate
259 | """
260 | kv_cache_mac = (
261 | self.llm_memory.count_mac_per_layer_kv_cache(
262 | bs, seq_len, generate_len, flash_attn, kv_cache_bytes
263 | )
264 | / self.tp_size
265 | )
266 |
267 | memory_latency = kv_cache_mac / (self.gpu_hbm_bandwidth)
268 |
269 | return memory_latency
270 |
271 | def count_latency_model(
272 | self,
273 | bs: int,
274 | seq_len: int,
275 | generate_len: int,
276 | flash_attn: bool = False,
277 | kv_cache_bytes: int = BYTES_FP16,
278 | breakdown_prefix: str = "",
279 | ) -> tuple:
280 | latency_per_layer, breakdown_per_layer = self.count_latency_per_layer(
281 | bs,
282 | seq_len,
283 | generate_len,
284 | flash_attn,
285 | kv_cache_bytes,
286 | )
287 | num_layers_per_gpu = self.num_layers_per_gpu
288 |
289 | latency_all_layers = latency_per_layer * self.num_layers_per_gpu
290 | latency_input_embedding = self.count_latency_input_embedding(bs, seq_len)
291 | latency_output_embedding = self.count_latency_output_embedding(bs, seq_len)
292 |
293 | model_latency = (
294 | latency_all_layers + latency_input_embedding + latency_output_embedding
295 | )
296 |
297 | model_latency_breakdown = {
298 | breakdown_prefix + "qkvo_proj": (
299 | breakdown_per_layer["qkvo_proj"] * num_layers_per_gpu
300 | ),
301 | breakdown_prefix + "attn_kernel": (
302 | breakdown_per_layer["attn_kernel"] * num_layers_per_gpu
303 | ),
304 | breakdown_prefix + "mlp": (breakdown_per_layer["mlp"] * num_layers_per_gpu),
305 | breakdown_prefix + "rmsnorm": (
306 | breakdown_per_layer["rmsnorm"] * num_layers_per_gpu
307 | ),
308 | breakdown_prefix + "tp_comm": (
309 | breakdown_per_layer["tp_comm"] * num_layers_per_gpu
310 | ),
311 | breakdown_prefix + "kv_cache_rw": (
312 | breakdown_per_layer["kv_cache_rw"] * num_layers_per_gpu
313 | ),
314 | }
315 |
316 | return model_latency, model_latency_breakdown
317 |
318 | def count_latency(
319 | self,
320 | bs: int,
321 | seq_len: int,
322 | generate_len: int,
323 | flash_attn: bool = False,
324 | kv_cache_bytes: int = BYTES_FP16,
325 | ) -> tuple:
326 | # 1, 预填充阶段
327 | prefill_latency, prefill_latency_breakdown = self.count_latency_model(
328 | bs,
329 | seq_len,
330 | generate_len=0,
331 | flash_attn=flash_attn,
332 | kv_cache_bytes=kv_cache_bytes,
333 | breakdown_prefix="prefill_",
334 | )
335 |
336 | prefill_latency_breakdown.update(
337 | {
338 | "TTFT": prefill_latency,
339 | }
340 | )
341 |
342 | # 2, 解码阶段
343 | kv_cache_latency = self.count_latency_kv_cache_per_layer(
344 | bs, seq_len, generate_len, flash_attn, kv_cache_bytes
345 | ) * self.num_layers_per_gpu
346 |
347 | decode_model_latency, decode_latency_breakdown = self.count_latency_model(
348 | bs,
349 | 1,
350 | generate_len=generate_len,
351 | flash_attn=flash_attn,
352 | kv_cache_bytes=kv_cache_bytes,
353 | breakdown_prefix="decode_",
354 | )
355 |
356 | decode_latency = decode_model_latency + kv_cache_latency
357 |
358 | decode_latency_breakdown.update(
359 | {
360 | "kv_cache_latency": kv_cache_latency,
361 | "TTOT": (decode_latency),
362 | }
363 | )
364 | return prefill_latency_breakdown, decode_latency_breakdown
365 |
--------------------------------------------------------------------------------
/llm_counts/count_memory.py:
--------------------------------------------------------------------------------
1 | from .utils.config import LLMConfigs
2 | from .utils.constants import BYTES_FP16
3 | from .count_params import CountCausalLMParams
4 |
5 | from functools import reduce
6 | import operator as _op
7 |
8 | def _B(*dims):
9 | """Utility: multiply arbitrary dimensions to get a byte count."""
10 | return reduce(_op.mul, dims, 1)
11 |
12 |
13 | class CountCausalLMMemory(object):
14 | """Count memory of the model and layers."""
15 |
16 | def __init__(self, llm_configs: LLMConfigs) -> None:
17 | self.model_config = llm_configs.model_config
18 | self.model_type = self.model_config.model_type
19 | self.hidden_size = self.model_config.hidden_size
20 | self.intermediate_size = self.model_config.intermediate_size
21 |
22 | self.num_heads = self.model_config.num_heads
23 | self.num_kv_heads = self.model_config.num_kv_heads
24 | self.head_dim = self.model_config.head_dim
25 | self.num_layers = self.model_config.num_layers
26 | self.V = self.model_config.vocab_size
27 |
28 | self.bytes_per_param = llm_configs.inference_config.bytes_per_param
29 | self.act_dtype_bytes = BYTES_FP16
30 |
31 | self.tp_size = llm_configs.parallelism_config.tp_size
32 | self.pp_size = llm_configs.parallelism_config.pp_size
33 | self.num_layers_per_gpu = int(self.num_layers / self.pp_size)
34 |
35 | self.gpu_memory_in_GB = llm_configs.gpu_config.memory_GPU_in_GB * 10**9
36 | self.llm_params = CountCausalLMParams(self.model_config)
37 |
38 | def count_memory_weight_per_gpu(self, ):
39 | """Get the memory of the model weights"""
40 | params_model = self.llm_params.count_params_model()
41 | memory_weight_per_gpu = params_model * self.bytes_per_param / self.tp_size
42 |
43 | return memory_weight_per_gpu
44 |
45 | def count_mac_per_layer_attn_kernel(
46 | self,
47 | bs: int,
48 | seq_len,
49 | generate_len: int = 0,
50 | flash_attn: bool = False,
51 | kv_cache_bytes: int = BYTES_FP16,
52 | ):
53 | if self.model_type == "qwen3":
54 | norm_bytes = 2 * (
55 | _B(self.head_dim, BYTES_FP16) # load γ
56 | + 2 * _B(bs, seq_len, self.head_dim, BYTES_FP16) # load + store acts
57 | )
58 | else:
59 | norm_bytes = 0
60 |
61 | if not flash_attn:
62 | if seq_len != 1:
63 | # dim changge: (bs, seq_len, hidden_size) -> (bs, seq_len, num_heads, head_dim)
64 | # (bs, seq_len, num_heads, head_dim) -> (bs, num_heads, seq_len, head_dim)
65 | # qk^t: (bs, num_heads, seq_len, head_dim) * (bs, num_kv_heads, seq_len, head_sim) -> (bs, num_heads, seq_len, seq_len)
66 | # sv: (bs, num_heads, seq_len, seq_len) * (bs, num_kv_heads, seq_len, head_dim) -> (bs, num_heads, seq_len, head_dim)
67 |
68 | load_q_mem = bs * self.num_heads * seq_len * self.head_dim
69 | load_k_mem = bs * self.num_kv_heads * seq_len * self.head_dim
70 | qk_store_mem = bs * self.num_heads * seq_len * seq_len
71 |
72 | load_softmax_mem = qk_store_mem
73 | softmax_store_mem = bs * self.num_heads * seq_len * seq_len
74 |
75 | load_s_mem = softmax_store_mem
76 | load_v_mem = bs * self.num_kv_heads * seq_len * self.head_dim
77 | sv_store_mem = bs * self.num_heads * seq_len * self.head_dim
78 |
79 | self_atten_mac = (load_q_mem + load_k_mem + qk_store_mem
80 | + load_softmax_mem + softmax_store_mem
81 | + load_s_mem + load_v_mem + sv_store_mem)
82 | max_act = max(load_q_mem, load_k_mem, qk_store_mem,
83 | load_softmax_mem, softmax_store_mem,
84 | load_s_mem, load_v_mem, sv_store_mem) * self.act_dtype_bytes
85 | return max_act, self_atten_mac * kv_cache_bytes + norm_bytes
86 |
87 | else:
88 | # dim changge: (bs, 1, hidden_size) -> (bs, 1, num_heads, head_dim)
89 | # (bs, 1, num_heads, head_dim) -> (bs, num_heads, 1, head_dim)
90 | # qk^t: (bs, num_heads, seq_len + generate_len, head_dim) * (bs, num_kv_heads, seq_len + generate_len, head_sim) -> (bs, num_heads, seq_len + generate_len, seq_len + generate_len)
91 | # sv: (bs, num_heads, seq_len + generate_len, seq_len + generate_len) * (bs, num_kv_heads, seq_len + generate_len, head_dim) -> (bs, num_heads, seq_len + generate_len, head_dim)
92 |
93 | load_q_mem = bs * self.num_heads * 1 * self.head_dim
94 | load_k_mem = bs * self.num_kv_heads * (seq_len + generate_len) * self.head_dim
95 | qk_store_mem = bs * self.num_heads * (seq_len + generate_len) * (seq_len + generate_len)
96 |
97 | load_softmax_mem = qk_store_mem
98 | softmax_store_mem = bs * self.num_heads * (seq_len + generate_len) * (seq_len + generate_len)
99 |
100 | load_s_mem = softmax_store_mem
101 | load_v_mem = bs * self.num_kv_heads * (seq_len + generate_len) * self.head_dim
102 | sv_store_mem = bs * self.num_heads * (seq_len + generate_len) * self.head_dim
103 |
104 | max_act = max(load_q_mem, load_k_mem, qk_store_mem,
105 | load_softmax_mem, softmax_store_mem,
106 | load_s_mem, load_v_mem, sv_store_mem) * self.act_dtype_bytes
107 | self_atten_mac = (load_q_mem + load_k_mem + qk_store_mem
108 | + load_softmax_mem + softmax_store_mem
109 | + load_s_mem + load_v_mem + sv_store_mem)
110 |
111 | return max_act, self_atten_mac * kv_cache_bytes + norm_bytes
112 |
113 | def count_mac_per_layer_kv_cache(
114 | self,
115 | bs,
116 | seq_len,
117 | generate_len: int = 0,
118 | flash_attn: bool = False,
119 | kv_cache_bytes: int = BYTES_FP16,
120 | ):
121 | if not flash_attn:
122 | store_k_cache = (
123 | self.num_kv_heads * self.head_dim * bs * seq_len * kv_cache_bytes
124 | )
125 | store_v_cache = (
126 | self.num_kv_heads * self.head_dim * bs * seq_len * kv_cache_bytes
127 | )
128 | if seq_len != 1:
129 | return store_k_cache + store_v_cache
130 | else:
131 | qk_matmul_load_k_cache = (
132 | (seq_len + generate_len)
133 | * self.head_dim
134 | * bs
135 | * self.num_kv_heads
136 | * kv_cache_bytes
137 | )
138 | sv_matmul_load_v_cache = (
139 | (seq_len + generate_len)
140 | * self.head_dim
141 | * bs
142 | * self.num_kv_heads
143 | * kv_cache_bytes
144 | )
145 |
146 | kv_cache_mac = (
147 | qk_matmul_load_k_cache
148 | + sv_matmul_load_v_cache
149 | + store_k_cache
150 | + store_v_cache
151 | )
152 | return kv_cache_mac
153 | else:
154 | # FlashAttention path: compute attention on‑the‑fly; only new K/V cache entries are stored
155 | kv_cache_mac = (
156 | self.num_kv_heads
157 | * self.head_dim
158 | * bs
159 | * seq_len
160 | * 2 # K + V
161 | * kv_cache_bytes
162 | )
163 |
164 | return kv_cache_mac
165 |
166 | def count_mac_per_layer_qkvo_proj(
167 | self,
168 | bs: int,
169 | seq_len: int,
170 | qkvo_weight_dtype_bytes=BYTES_FP16,
171 | ) -> int:
172 | """
173 | Count memory access cost for Q/K/V/O projection layers.
174 | """
175 | atten_linear_layers = {
176 | "q_proj": [self.hidden_size, self.num_heads * self.head_dim],
177 | "k_proj": [self.hidden_size, self.num_kv_heads * self.head_dim],
178 | "v_proj": [self.hidden_size, self.num_kv_heads * self.head_dim],
179 | "out_proj": [self.num_heads * self.head_dim, self.hidden_size],
180 | }
181 |
182 | atten_linear_layers_mac = 0
183 | max_act = 0
184 | for name, (in_ch, out_ch) in atten_linear_layers.items():
185 | is_kv_proj = name in ["k_proj", "v_proj"]
186 | is_normal_proj = not is_kv_proj
187 |
188 | load_weight = in_ch * out_ch
189 | load_act = in_ch * bs * seq_len
190 | store_act = 0 if is_kv_proj else bs * seq_len * out_ch
191 | load_kv_cache = 0
192 | store_kv_cache = 0 if is_normal_proj else out_ch * bs * seq_len
193 |
194 | max_act = max(max_act, load_weight, load_act, store_act, store_kv_cache)
195 |
196 | mac = load_weight + load_act + store_act + load_kv_cache + store_kv_cache
197 | atten_linear_layers_mac += mac
198 |
199 | return max_act * self.act_dtype_bytes, atten_linear_layers_mac * qkvo_weight_dtype_bytes
200 |
201 | def count_mac_per_layer_mlp(
202 | self,
203 | bs: int,
204 | seq_len: int,
205 | mlp_weight_dtype_bytes=BYTES_FP16,
206 | ) -> float:
207 | """The `mlp` acts include the input to the two linear layers.
208 | Refer to https://arxiv.org/abs/2205.05198 for details.
209 | The two linear layers store their inputs with size 2bsh and 8bsh
210 | """
211 | mlp_linear_layers = {
212 | "gate_proj": [self.hidden_size, self.intermediate_size],
213 | "up_proj": [self.hidden_size, self.intermediate_size],
214 | "down_proj": [self.intermediate_size, self.hidden_size],
215 | }
216 |
217 | mlp_linear_layers_mac = 0
218 | max_act = 0
219 | for _, (in_ch, out_ch) in mlp_linear_layers.items():
220 | load_weight = in_ch * out_ch
221 | load_act = in_ch * bs * seq_len
222 | store_act = bs * seq_len * out_ch
223 |
224 | max_act = max(max_act, load_weight, load_act, store_act)
225 | mac = load_weight + load_act + store_act
226 | mlp_linear_layers_mac += mac
227 |
228 | return max_act * self.act_dtype_bytes, mlp_linear_layers_mac * mlp_weight_dtype_bytes
229 |
230 | def count_mac_per_layer_norm(
231 | self,
232 | bs: int,
233 | seq_len: int,
234 | ) -> float:
235 | """Get the memory (in bytes) required to store the acts of a single layernorm in a transformer layer."""
236 | rmsnorm_load_weight = self.hidden_size * self.act_dtype_bytes
237 | rmsnorm_load_act = bs * seq_len * self.hidden_size * self.act_dtype_bytes
238 | rmsnorm_store_act = bs * seq_len * self.hidden_size * self.act_dtype_bytes
239 |
240 | norm_mac_per_gpu = (
241 | rmsnorm_load_weight + rmsnorm_load_act + rmsnorm_store_act
242 | )
243 | max_act = max(rmsnorm_load_weight, rmsnorm_load_act, rmsnorm_store_act) * self.act_dtype_bytes
244 | return max_act, norm_mac_per_gpu
245 |
246 | def count_mac_input_embedding(self, bs: int, seq_len: int) -> float:
247 | input_embedding_load_act = bs * seq_len * self.act_dtype_bytes
248 | input_embedding_store_act = (
249 | bs * seq_len * self.hidden_size * self.act_dtype_bytes
250 | )
251 | input_embedding_mac_per_gpu = (
252 | input_embedding_load_act + input_embedding_store_act
253 | )
254 |
255 | return input_embedding_mac_per_gpu
256 |
257 | def count_memory_kv_cache_per_layer(
258 | self,
259 | bs: int,
260 | seq_len: int,
261 | generate_len: int,
262 | kv_cache_bytes: int = BYTES_FP16,
263 | ) -> float:
264 | """Get the memory (in bytes) required to store the key and value cache
265 | for a transformer layer in inference, given the batch size, sequence
266 | length, act data type, and tensor parallelism size.
267 |
268 | memory_kv_cache = 4blh(s+o) unit is byte
269 | Args:
270 | bs (int): batch size
271 | context_len (int): seq_len + generate_len
272 |
273 | Returns:
274 | float: the memory (in bytes) required to store the key and value cache
275 | for a transformer layer in inference.
276 | """
277 |
278 | # At least on attention head on each tensor-parallel GPU
279 | num_kv_heads_per_gpu = max(self.num_kv_heads, 1)
280 | memory_kv_cache_per_layer = (
281 | bs
282 | * (seq_len + generate_len)
283 | * num_kv_heads_per_gpu
284 | * self.head_dim
285 | * 2
286 | * kv_cache_bytes
287 | )
288 |
289 | return memory_kv_cache_per_layer
290 |
291 | def count_max_act_per_layer(
292 | self,
293 | bs: int,
294 | seq_len_ctx: int,
295 | generate_len: int = 0, # used only for decode stage
296 | *,
297 | stage: str = "prefill", # "prefill" | "decode"
298 | flash_attn: bool = False,
299 | qkvo_weight_dtype_bytes: int = BYTES_FP16,
300 | mlp_weight_dtype_bytes: int = BYTES_FP16,
301 | ) -> float:
302 | assert stage in {"prefill", "decode"}
303 |
304 | # For decode stage each step handles just **one token**.
305 | tokens = 1 if stage == "decode" else seq_len_ctx
306 |
307 | act_per_layer_self_atten, _ = self.count_mac_per_layer_attn_kernel(
308 | bs,
309 | tokens,
310 | generate_len=generate_len,
311 | flash_attn=flash_attn,
312 | kv_cache_bytes=qkvo_weight_dtype_bytes,
313 | )
314 | act_per_layer_qkvo_proj, _ = self.count_mac_per_layer_qkvo_proj(
315 | bs,
316 | tokens,
317 | qkvo_weight_dtype_bytes=qkvo_weight_dtype_bytes,
318 | )
319 | act_per_layer_mlp, _ = self.count_mac_per_layer_mlp(
320 | bs,
321 | tokens,
322 | mlp_weight_dtype_bytes=mlp_weight_dtype_bytes,
323 | )
324 | act_per_layer_rn, _ = self.count_mac_per_layer_norm(bs, tokens)
325 |
326 | act_per_layer = max(act_per_layer_self_atten, act_per_layer_qkvo_proj, act_per_layer_mlp, act_per_layer_rn)
327 |
328 | return act_per_layer
329 |
330 | def count_memory_per_gpu(
331 | self,
332 | bs: int,
333 | seq_len: int,
334 | generate_len: int,
335 | flash_attn: bool = True,
336 | qkvo_weight_dtype_bytes: int = BYTES_FP16,
337 | mlp_weight_dtype_bytes=BYTES_FP16,
338 | kv_cache_bytes: int = BYTES_FP16,
339 | ) -> tuple:
340 | # 1, prefill stage count memory and max_bs
341 | weight_memory_per_gpu = self.count_memory_weight_per_gpu() # count model weights memory
342 | memory_left_per_gpu = self.gpu_memory_in_GB - weight_memory_per_gpu
343 |
344 | # --- 1) PREFILL stage ----------------------------------------- #
345 | prefill_act_bs_1 = self.count_max_act_per_layer(
346 | 1,
347 | seq_len,
348 | generate_len=generate_len,
349 | stage="prefill",
350 | flash_attn=flash_attn,
351 | qkvo_weight_dtype_bytes=qkvo_weight_dtype_bytes,
352 | mlp_weight_dtype_bytes=mlp_weight_dtype_bytes,
353 | ) // self.tp_size
354 |
355 | prefill_max_bs = int(memory_left_per_gpu / prefill_act_bs_1)
356 | prefill_act_per_gpu = bs * prefill_act_bs_1
357 |
358 | # --- 2) DECODE stage ------------------------------------------ #
359 | kv_cache_memory_bs_1_per_gpu = (self.count_memory_kv_cache_per_layer(1, seq_len, generate_len, kv_cache_bytes) * self.num_layers_per_gpu) / self.tp_size
360 | decode_act_bs_1_per_gpu = self.count_max_act_per_layer(
361 | 1,
362 | seq_len,
363 | generate_len=generate_len,
364 | stage="decode",
365 | flash_attn=flash_attn,
366 | qkvo_weight_dtype_bytes=qkvo_weight_dtype_bytes,
367 | mlp_weight_dtype_bytes=mlp_weight_dtype_bytes,
368 | ) // self.tp_size
369 | decode_max_bs = memory_left_per_gpu // (decode_act_bs_1_per_gpu + kv_cache_memory_bs_1_per_gpu)
370 |
371 | kv_cache_memory_per_gpu = bs * kv_cache_memory_bs_1_per_gpu
372 | decode_act_per_gpu = decode_act_bs_1_per_gpu * bs
373 | max_batch_total_tokens = decode_max_bs * (seq_len + generate_len)
374 |
375 | assert bs <= decode_max_bs, (
376 | f"For context length: {seq_len + generate_len}, bs {bs} is too large to fit"
377 | " in GPU memory, decode_max_bs:"
378 | f" {decode_max_bs}"
379 | )
380 |
381 | assert memory_left_per_gpu > (
382 | kv_cache_memory_per_gpu + decode_act_per_gpu
383 | ), (
384 | "kv_cache and act memory with bs ="
385 | f" {bs} is too large to fit in GPU memory"
386 | )
387 |
388 | consume_memory_per_gpu = (
389 | weight_memory_per_gpu + decode_act_per_gpu + kv_cache_memory_per_gpu
390 | )
391 |
392 | # memory summary
393 | memory_prefill_summary_dict = {
394 | "weight_memory_per_gpu": weight_memory_per_gpu,
395 | "prefill_max_bs": prefill_max_bs,
396 | "prefill_act_per_gpu": prefill_act_per_gpu,
397 | }
398 |
399 | memory_decode_summary_dict = {
400 | "decode_act_per_gpu": decode_act_per_gpu,
401 | "kv_cache_memory_per_gpu": kv_cache_memory_per_gpu,
402 | "consume_memory_per_gpu": consume_memory_per_gpu,
403 | "decode_max_bs": decode_max_bs,
404 | "max_batch_total_tokens": int(max_batch_total_tokens * 0.97),
405 | }
406 |
407 | return memory_prefill_summary_dict, memory_decode_summary_dict
408 |
--------------------------------------------------------------------------------
/llm_counts/count_params.py:
--------------------------------------------------------------------------------
1 | from .utils.config import ModelConfig
2 | from .utils.constants import *
3 |
4 |
5 | class CountCausalLMParams(object):
6 | def __init__(self, model_config: ModelConfig) -> None:
7 | self.hidden_size = model_config.hidden_size
8 | self.intermediate_size = model_config.intermediate_size
9 | self.num_layers = model_config.num_layers
10 | self.V = model_config.vocab_size
11 |
12 | self.num_heads = model_config.num_heads
13 | self.num_kv_heads = model_config.num_kv_heads
14 | self.head_dim = model_config.head_dim
15 | self.model_type = model_config.model_type
16 |
17 | def count_params_embedding(self, shared_embedding: bool = True) -> int:
18 | """Get the number of parameters in the embedding layer.
19 | params_te = vocab_size * d_model
20 | Args:
21 | shared_embedding (bool, optional): whether the output embedding
22 | shares weights with the input embedding. Defaults to True.
23 |
24 | Returns:
25 | int: the number of parameters in the embedding layer.
26 | """
27 | num_params_input_embedding = self.V * self.hidden_size
28 | num_params_output_embedding = (
29 | self.V * self.hidden_size if not shared_embedding else 0
30 | )
31 |
32 | return num_params_input_embedding + num_params_output_embedding
33 |
34 | def count_params_per_layer_mha(self) -> int:
35 | """Get the number of parameters per layer in the attention module
36 | which include 4 linear layer: q/k/v/o linear layers.
37 |
38 | Returns:
39 | int: the number of parameters per layer in the attention module(mha)
40 | """
41 | params_qo_proj = 2 * self.hidden_size * self.num_heads * self.head_dim
42 | params_kv_proj = 2 * self.hidden_size * self.num_kv_heads * self.head_dim
43 | return params_qo_proj + params_kv_proj
44 |
45 | def count_params_per_layer_mlp(self) -> int:
46 | """Get the number of parameters in the MLP linear layers, including the
47 | intermediate and output matrices.
48 | params_mlp = params_gate_proj + params_up_proj + params_down_proj
49 | Returns:
50 | int: the number of parameters in the two MLP linear layers
51 | """
52 | params_gate_proj = self.hidden_size * self.intermediate_size
53 | params_up_proj = self.hidden_size * self.intermediate_size
54 | params_down_proj = self.intermediate_size * self.hidden_size
55 | params_mlp = params_gate_proj + params_up_proj + params_down_proj
56 |
57 | return params_mlp
58 |
59 | def count_params_per_layer_norm(self) -> int:
60 | """Get the number of atten_norm and mlp_norm parameters per layer.
61 | """
62 | # q_norm、k_norm、atten_norm、mlp_norm
63 | if self.model_type == "qwen3":
64 | return 2 * self.hidden_size + 2 * self.head_dim
65 | else:
66 | return 2 * self.hidden_size
67 |
68 | def count_params_per_layer(self, norm_ignore=False) -> tuple:
69 | """Get the number of params per layer mainly including the attention and MLP layers.
70 |
71 | params_per_layer = params_mha + params_mlp + params_norm
72 |
73 | """
74 | params_per_layer_mha = self.count_params_per_layer_mha()
75 | params_per_layer_mlp = self.count_params_per_layer_mlp()
76 | params_per_layer_norm = 0 if norm_ignore else self.count_params_per_layer_norm()
77 | params_input_embedding = self.count_params_embedding()
78 |
79 | params_per_layer = (
80 | params_per_layer_mha + params_per_layer_mlp + params_per_layer_norm
81 | )
82 |
83 | dict_params_per_layer = {
84 | "qkvo_proj": params_per_layer_mha,
85 | "mlp": params_per_layer_mlp,
86 | "rmsnorm": params_per_layer_norm,
87 | "input_embedding": params_input_embedding,
88 | "output_embedding": params_input_embedding,
89 | }
90 |
91 | return params_per_layer, dict_params_per_layer
92 |
93 | def count_params_model(self) -> int:
94 | """Get the total number of parameters in the model
95 | including all layers and token embedding layer.
96 | params_model = params_embedding + params_per_layer * num_layers
97 | = V * d_model + 12 * d_model**2 * num_layers
98 | Returns:
99 | int: the total number of parameters in the model
100 | """
101 | params_per_layer, _ = self.count_params_per_layer()
102 | params_model = (
103 | params_per_layer * self.num_layers + self.count_params_embedding()
104 | )
105 |
106 | return params_model
--------------------------------------------------------------------------------
/llm_counts/layer_graph_visualizer.py:
--------------------------------------------------------------------------------
1 | """
2 | cli entry point for LayerAnalyzer, which analyzes the memory access and FLOPs of a model.
3 | Usage:
4 | ```bash
5 | python -m llm_counts.llm_analyzer \
6 | --result-json path/to/results.json \
7 | --model-type qwen3 \
8 | --output my_layer_graph
9 | ```
10 | """
11 | from .utils.constants import BYTES_FP16
12 | from .utils.config import *
13 | from .utils.utils import num_to_string
14 | from .roofline_model import roofline_analysis
15 |
16 |
17 | class LayerAnalyzer(object):
18 | """Count memory access of the model and layers."""
19 |
20 | def __init__(self, model_config, gpu_config, tp_size) -> None:
21 | self.tp_size = tp_size
22 | self.bandwidth, self.onchip_buffer = get_gpu_hbm_bandwidth(gpu_config) # GB/s
23 | self.bandwidth *= 10**9
24 | self.gpu_max_ops = get_TFLOPS_per_gpu(gpu_config, data_type="fp16") * 10**12 # TFLOPs
25 |
26 | self.model_type = model_config.model_type
27 | self.hidden_size = model_config.hidden_size
28 | self.intermediate_size = model_config.intermediate_size
29 | self.num_heads = model_config.num_heads
30 | self.num_kv_heads = model_config.num_kv_heads
31 | self.head_dim = self.hidden_size // self.num_heads
32 |
33 | # attention linear layers
34 | self.linear_layers = {
35 | "q_proj": [self.hidden_size, self.num_heads * self.head_dim],
36 | "k_proj": [self.hidden_size, self.num_kv_heads * self.head_dim],
37 | "v_proj": [self.hidden_size, self.num_kv_heads * self.head_dim],
38 | "out_proj": [self.num_heads * self.head_dim, self.hidden_size],
39 |
40 | "gate_proj": [self.hidden_size, self.intermediate_size],
41 | "up_proj": [self.hidden_size, self.intermediate_size],
42 | "down_proj": [self.intermediate_size, self.hidden_size],
43 | }
44 |
45 | self.results = {"decode": {}, "prefill": {}}
46 |
47 | def _analyze_to_results(
48 | self,
49 | stage,
50 | kernel_name,
51 | flops,
52 | load_weight,
53 | load_act,
54 | store_act,
55 | load_kv_cache,
56 | store_kv_cache,
57 | data_type="fp16"
58 | ):
59 | memory_access = (load_weight + load_act + store_act + load_kv_cache + store_kv_cache)
60 | a_intensity, att_flops, bound = roofline_analysis(self.gpu_max_ops,
61 | self.bandwidth,
62 | flops, memory_access) # Arithmetic Intensity
63 |
64 | self.results[stage][kernel_name] = {
65 | "flops": num_to_string(flops),
66 | "memory_access": f"{num_to_string(memory_access)}B",
67 | "arithmetic_intensity": int(a_intensity),
68 | "att_flops": num_to_string(att_flops),
69 | "bound": bound,
70 | "load_weight": f"{num_to_string(load_weight)}B",
71 | "load_act": num_to_string(load_act),
72 | "store_act": num_to_string(store_act),
73 | "load_kv_cache": num_to_string(load_kv_cache),
74 | "store_kv_cache": num_to_string(store_kv_cache),
75 | }
76 |
77 | return self.results
78 |
79 | def analyze_linear_layers(
80 | self,
81 | bs: int,
82 | seq_len: int,
83 | linear_weight_bytes: int = BYTES_FP16,
84 | act_byte: int = BYTES_FP16,
85 | kv_byte: int = BYTES_FP16,
86 | ):
87 | """
88 | Count and save the FLOPs and memory access of self-attention layers.
89 | This function is used to analyze the self-attention layers in the model.
90 | """
91 | # 1. attention linear layers analysis
92 | for name, (in_ch, out_ch) in self.linear_layers.items():
93 | is_kv_proj = name in ["k_proj", "v_proj"]
94 | is_normal_proj = not is_kv_proj
95 |
96 | self._analyze_to_results(
97 | "prefill",
98 | name,
99 | flops=2 * bs * seq_len * in_ch * out_ch // self.tp_size,
100 | load_weight=in_ch * out_ch * linear_weight_bytes // self.tp_size,
101 | load_act=in_ch * bs * seq_len * act_byte // self.tp_size,
102 | store_act=0 if is_kv_proj else bs * seq_len * out_ch * act_byte // self.tp_size,
103 | load_kv_cache=0,
104 | store_kv_cache=(0 if is_normal_proj else out_ch * bs * seq_len * kv_byte) // self.tp_size
105 | )
106 | self._analyze_to_results(
107 | "decode",
108 | name,
109 | flops=2 * bs * in_ch * out_ch // self.tp_size,
110 | load_weight=in_ch * out_ch * linear_weight_bytes // self.tp_size,
111 | load_act=in_ch * bs * act_byte // self.tp_size,
112 | store_act=0 if is_kv_proj else out_ch * bs * act_byte // self.tp_size,
113 | load_kv_cache=0,
114 | store_kv_cache=(0 if is_normal_proj else out_ch * bs * kv_byte) // self.tp_size,
115 | )
116 |
117 | def analyze_self_atten_kernel(
118 | self,
119 | bs: int,
120 | seq_len: int,
121 | generate_len: int,
122 | num_kv_heads: int,
123 | num_heads: int,
124 | head_dim: int,
125 | flash_attn: bool = False,
126 | act_byte: int = BYTES_FP16,
127 | kv_byte: int = BYTES_FP16,
128 | ):
129 | """
130 | Count and save the FLOPs and memory access of self-attention kernels.
131 | This function is used to analyze the self-attention kernels in the model.
132 | """
133 | hidden_size = num_heads * head_dim
134 | if not flash_attn:
135 | ##########################prefill stage##########################
136 | # 1, qkt kernel analysis
137 | name = "qk_matmul"
138 | load_q_mem = bs * self.num_heads * seq_len * self.head_dim
139 | load_k_mem = bs * self.num_kv_heads * seq_len * self.head_dim
140 | qk_store_mem = bs * self.num_heads * seq_len * seq_len
141 | self._analyze_to_results(
142 | "prefill",
143 | name,
144 | flops=2 * seq_len * seq_len * self.head_dim * bs * self.num_heads,
145 | load_weight=0,
146 | load_act=(load_q_mem + load_k_mem) * act_byte, # load q and k act, shape is [s, h]
147 | store_act=qk_store_mem * act_byte,
148 | load_kv_cache=0,
149 | store_kv_cache=0,
150 | )
151 | # 2, softmax kernel analysis
152 | name = f"softmax"
153 | load_softmax_mem = qk_store_mem
154 | softmax_store_mem = bs * self.num_heads * seq_len * seq_len
155 | self._analyze_to_results(
156 | "prefill",
157 | name,
158 | flops= (bs * num_heads * seq_len * seq_len * 1 * 5),
159 | load_weight=0,
160 | load_act=load_softmax_mem * act_byte,
161 | store_act=softmax_store_mem * act_byte,
162 | load_kv_cache=0,
163 | store_kv_cache=0,
164 | )
165 | # 3, sv kernel analysis
166 | name = "sv_matmul"
167 | load_s_mem = softmax_store_mem
168 | load_v_mem = bs * self.num_kv_heads * seq_len * self.head_dim
169 | sv_store_mem = bs * self.num_heads * seq_len * self.head_dim
170 | self._analyze_to_results(
171 | "prefill",
172 | name,
173 | flops=bs * 2 * seq_len * seq_len * head_dim * num_heads,
174 | load_weight=0,
175 | load_act=load_s_mem * act_byte, # load score(qkt) act, shape is [s, s]
176 | store_act=sv_store_mem * act_byte,
177 | load_kv_cache=load_v_mem,
178 | store_kv_cache=0,
179 | )
180 | ##########################decode stage##########################
181 | name = "qk_matmul"
182 | # load q and k, k is form kv cache
183 | qk_matmul_flops = 2 * self.num_heads * self.head_dim * (seq_len + generate_len)
184 | load_q_mem = bs * self.num_heads * 1 * self.head_dim
185 | load_k_mem = bs * self.num_kv_heads * (seq_len + generate_len) * self.head_dim
186 | qk_store_mem = bs * self.num_heads * (seq_len + generate_len) * (seq_len + generate_len)
187 | self._analyze_to_results(
188 | "decode",
189 | name,
190 | flops=qk_matmul_flops,
191 | load_weight=0,
192 | load_act=load_q_mem * act_byte,
193 | store_act=qk_store_mem * act_byte,
194 | load_kv_cache=load_k_mem * kv_byte,
195 | store_kv_cache=0,
196 | )
197 | # 2, softmax kernel analysis
198 | name = f"softmax"
199 | load_softmax_mem = qk_store_mem
200 | softmax_store_mem = bs * self.num_heads * (seq_len + generate_len) * (seq_len + generate_len)
201 | self._analyze_to_results(
202 | "decode",
203 | name,
204 | flops= (bs * num_heads * seq_len * seq_len * 1 * 5),
205 | load_weight=0,
206 | load_act=load_softmax_mem * act_byte,
207 | store_act=softmax_store_mem * act_byte,
208 | load_kv_cache=0,
209 | store_kv_cache=0,
210 | )
211 | # 3, sv kernel analysis
212 | name = "sv_matmul"
213 | load_s_mem = softmax_store_mem
214 | load_v_mem = bs * self.num_kv_heads * (seq_len + generate_len) * self.head_dim
215 | sv_store_mem = bs * self.num_heads * (seq_len + generate_len) * self.head_dim
216 | self._analyze_to_results(
217 | "decode",
218 | name,
219 | flops=qk_matmul_flops,
220 | load_weight=0,
221 | load_act=load_s_mem * act_byte, # load score(qkt) act, shape is [s, s]
222 | store_act=sv_store_mem * act_byte,
223 | load_kv_cache=load_v_mem,
224 | store_kv_cache=0,
225 | )
226 | else:
227 | name = f"fused_attention" # flash_attn2
228 | qk_matmul_OPs = seq_len * seq_len * head_dim * num_heads * bs * 2
229 | sv_matmul_OPs = seq_len * head_dim * seq_len * num_heads * bs * 2
230 | softmax_OPs = bs * num_heads * seq_len * seq_len * 5
231 |
232 | block_size_r = min(math.ceil(self.onchip_buffer / (kv_byte * head_dim)), head_dim)
233 | n_blocks_r = math.ceil(seq_len / block_size_r)
234 | q_numel = seq_len * head_dim * bs * num_heads * act_byte
235 | o_numel = seq_len * seq_len * bs * num_heads * act_byte
236 |
237 | self._analyze_to_results(
238 | "prefill",
239 | name,
240 | flops=qk_matmul_OPs + sv_matmul_OPs + softmax_OPs,
241 | load_weight=0,
242 | load_act=q_numel,
243 | store_act=o_numel * 2, # initialize O and save O
244 | load_kv_cache=n_blocks_r * (seq_len) * head_dim * bs * num_kv_heads * kv_byte * 2,
245 | store_kv_cache=0,
246 | )
247 |
248 | qk_matmul_OPs = seq_len * head_dim * num_heads * bs * 2
249 | sv_matmul_OPs = 1 * head_dim * seq_len * num_heads * bs * 2
250 | softmax_OPs = bs * num_heads * seq_len * 1 * 5
251 |
252 | n_blocks_r = math.ceil(1 / block_size_r)
253 | q_numel = (1) * head_dim * bs * num_heads * act_byte
254 | o_numel = 1 * seq_len * bs * num_heads * act_byte
255 | self._analyze_to_results(
256 | "decode",
257 | name,
258 | OPs=qk_matmul_OPs + sv_matmul_OPs + softmax_OPs,
259 | load_weight=0,
260 | load_act=q_numel,
261 | store_act=o_numel * 2, # initialize O and save O
262 | load_kv_cache=n_blocks_r * (seq_len) * head_dim * bs * num_kv_heads * kv_byte * 2,
263 | store_kv_cache=0,
264 | )
265 |
266 | if self.model_type == "qwen3":
267 | kernel_names = ["q_norm", "k_norm"]
268 | # qwen3 模型中 rms_norm 计算中使用了一个额外的线性变换
269 | q_norm_flops = bs * 4 * seq_len * self.head_dim
270 | q_norm_load_weight = self.head_dim * BYTES_FP16
271 | q_norm_load_act = bs * seq_len * self.head_dim * BYTES_FP16 # equal k_norm_load_act
272 | q_norm_store_act = bs * seq_len * self.head_dim * BYTES_FP16
273 |
274 | # prefill/decode 阶段
275 | for stage in ["prefill", "decode"]:
276 | if stage == "decode":
277 | q_norm_flops = int(q_norm_flops // seq_len)
278 | q_norm_load_act = int(q_norm_load_act // seq_len)
279 | q_norm_store_act = int(q_norm_store_act // seq_len)
280 |
281 | for _, kernel_name in enumerate(kernel_names):
282 | self._analyze_to_results(
283 | stage,
284 | kernel_name,
285 | flops=q_norm_flops // self.tp_size,
286 | load_weight=q_norm_load_weight // self.tp_size,
287 | load_act=q_norm_load_act // self.tp_size,
288 | store_act=q_norm_store_act // self.tp_size,
289 | load_kv_cache=0,
290 | store_kv_cache=0,
291 | )
292 |
293 | def analyze_other_kernels(
294 | self,
295 | bs: int,
296 | seq_len: int,
297 | act_byte: int = BYTES_FP16,
298 | ):
299 | norm_flops = bs * seq_len * 4 * self.hidden_size # mlp_norm, attn_norm
300 | norm_load_weight = self.hidden_size * BYTES_FP16
301 | norm_load_act = bs * seq_len * self.hidden_size * BYTES_FP16
302 | norm_store_act = bs * seq_len * self.hidden_size * BYTES_FP16
303 |
304 | # silu 和 dot * 都是纯逐元素操作算子
305 | silu_dot_flops = (bs * 4 * seq_len * self.intermediate_size) # 每个张量元素执行 4 次操作
306 | silu_dot_load_act = bs * 2 * seq_len * self.intermediate_size * act_byte
307 | silu_dot_store_act = (bs * 2 * seq_len * self.intermediate_size * act_byte)
308 |
309 | mlp_add_flops = bs * seq_len * self.hidden_size
310 | mlp_add_load_act = bs * seq_len * self.hidden_size * act_byte
311 | mlp_add_store_act = bs * seq_len * self.hidden_size * act_byte
312 |
313 | # other kernels (memory bound)
314 | kernel_names = ["attn_norm", "mlp_norm", "mlp_silu_dot", "attn_add", "mlp_add"]
315 | flops_list = [norm_flops, norm_flops, silu_dot_flops, mlp_add_flops, mlp_add_flops]
316 |
317 | load_act_list = [norm_load_act, norm_load_act, silu_dot_load_act, mlp_add_load_act, mlp_add_load_act,]
318 | store_act_list = [norm_store_act, norm_store_act, silu_dot_store_act, mlp_add_store_act, mlp_add_store_act,]
319 |
320 | # prefill/decode 阶段
321 | for stage in ["prefill", "decode"]:
322 | for i, kernel_name in enumerate(kernel_names):
323 | load_weight = (0 if (kernel_name not in ["attn_norm", "mlp_norm"]) else norm_load_weight)
324 |
325 | load_act = load_act_list[i]
326 | store_act = store_act_list[i]
327 | flops = flops_list[i]
328 |
329 | if stage == "decode":
330 | flops = int(flops // seq_len)
331 | load_act = int(load_act // seq_len)
332 | store_act = int(store_act // seq_len)
333 |
334 | self._analyze_to_results(
335 | stage,
336 | kernel_name,
337 | flops=flops // self.tp_size,
338 | load_weight=load_weight // self.tp_size,
339 | load_act=load_act // self.tp_size,
340 | store_act=store_act // self.tp_size,
341 | load_kv_cache=0,
342 | store_kv_cache=0,
343 | )
344 |
345 | def analyze_model(
346 | self,
347 | bs: int,
348 | seq_len: int,
349 | generate_len: int = 0,
350 | flash_attn: bool = False,
351 | act_byte: int = BYTES_FP16,
352 | kv_byte: int = BYTES_FP16,
353 | ):
354 | """
355 | Analyze the model and save the results.
356 | This function is used to analyze the model and save the results.
357 | """
358 | # 1. analyze linear layers
359 | self.analyze_linear_layers(bs, seq_len, act_byte=act_byte, kv_byte=kv_byte)
360 |
361 | # 2. analyze self attention kernels
362 | self.analyze_self_atten_kernel(
363 | bs, seq_len, generate_len,
364 | num_kv_heads=self.num_kv_heads,
365 | num_heads=self.num_heads,
366 | head_dim=self.head_dim,
367 | flash_attn=flash_attn,
368 | act_byte=act_byte,
369 | kv_byte=kv_byte
370 | )
371 |
372 | # 3. analyze other kernels
373 | self.analyze_other_kernels(
374 | bs, seq_len,
375 | )
376 |
377 | return self.results
378 |
379 |
380 | # ---------------------------------------------------------------------------
381 | # Transformer‑layer graph visualisation
382 | # ---------------------------------------------------------------------------
383 |
384 | _DEPENDENCIES = {
385 | "input": [],
386 | "attn_norm": ["input"],
387 | "q_proj": ["attn_norm"],
388 | "k_proj": ["attn_norm"],
389 | "v_proj": ["attn_norm"],
390 | "qk_matmul": ["q_proj", "k_proj"],
391 | "softmax": ["qk_matmul"],
392 | "sv_matmul": ["softmax", "v_proj"],
393 | "out_proj": ["sv_matmul"],
394 | "attn_add": ["input", "out_proj"],
395 | "mlp_norm": ["attn_add"],
396 | "gate_proj": ["mlp_norm"],
397 | "up_proj": ["mlp_norm"],
398 | "mlp_silu_dot": ["up_proj", "gate_proj"],
399 | "down_proj": ["mlp_silu_dot"],
400 | "mlp_add": ["attn_add", "down_proj"],
401 | "output": ["mlp_add"],
402 | }
403 |
404 | class LayerGraphVisualizer:
405 | """Render a transformer layer’s roofline‑analysis graph as a PNG."""
406 |
407 | def __init__(self, model_type: str, results: dict, shapes: dict = None) -> None:
408 | self.model_type = model_type
409 | self.results = results
410 | if model_type == "qwen3":
411 | # qwen3 模型中有额外的 q_norm 和 k_norm 层
412 | _DEPENDENCIES["q_norm"] = ["q_proj"]
413 | _DEPENDENCIES["k_norm"] = ["k_proj"]
414 | # self.shapes = shapes or {} # optional {kernel: "B×S×C"} mapping
415 |
416 | # --------------------------------------------------------------------- #
417 | # internal helpers
418 | # --------------------------------------------------------------------- #
419 | def _label(self, node: str, kernel_stats: dict) -> str:
420 | """Build a neat multi‑line Graphviz label, optionally with shape info."""
421 | label = f"{node}\nFlops: {kernel_stats['flops']}, Access: {kernel_stats['memory_access']}, \nParams: {kernel_stats.get('load_weight', 0)}, Bound: {kernel_stats.get('bound', 'N/A')}"
422 | return label
423 |
424 | # --------------------------------------------------------------------- #
425 | # public API
426 | # --------------------------------------------------------------------- #
427 | def render(self, base_path: str = "layer_graph") -> None:
428 | """Generate one PNG per stage (prefill / decode) under ./figures/."""
429 | from graphviz import Digraph
430 |
431 | for stage, stage_res in self.results.items():
432 | dot = Digraph(
433 | format="png",
434 | node_attr={"style": "filled", "shape": "box", "fontname": "Arial"},
435 | )
436 |
437 | # Only include nodes and deps relevant for this stage, but always include "input" and "output"
438 | pruned_deps = {
439 | n: [d for d in deps if d in stage_res or d in ("input","output")]
440 | for n, deps in _DEPENDENCIES.items()
441 | if n in stage_res or n in ("input","output")
442 | }
443 |
444 | for node, deps in pruned_deps.items():
445 | color = (
446 | "lightblue" if "proj" in node
447 | else "plum" if "matmul" in node
448 | else "lightcyan"
449 | )
450 | if node in stage_res:
451 | label = self._label(node, stage_res[node])
452 | else:
453 | # default zero stats for input/output
454 | label = (
455 | f"{node}\n"
456 | "Flops: 0, Access: 0\n"
457 | "Params: 0, Bound: N/A"
458 | )
459 | dot.node(node, label=label, fillcolor=color)
460 | for dep in deps:
461 | if dep in pruned_deps:
462 | dot.edge(dep, node)
463 | graph_path = f"./figures/grpah_{stage}_{base_path}"
464 | dot.render(graph_path, cleanup=True)
465 |
466 | # ---------------------------------------------------------------------------
467 | # Command‑line entry‑point
468 | # ---------------------------------------------------------------------------
469 | def _main() -> None:
470 | import argparse, json
471 | from pathlib import Path
472 |
473 | parser = argparse.ArgumentParser(
474 | description="Generate a transformer layer graph (Graphviz PNG) from "
475 | "an LayerAnalyzer result JSON."
476 | )
477 | parser.add_argument("--result-json", type=Path, required=True,
478 | help="Path to the analysis‑result JSON produced by LayerAnalyzer")
479 | parser.add_argument("--model-type", required=True,
480 | help="Model type tag, e.g. 'llama' or 'qwen3'")
481 | parser.add_argument("--output", default="layer_graph",
482 | help="Base filename for the generated PNG(s)")
483 | args = parser.parse_args()
484 |
485 | with args.result_json.open() as fp:
486 | results = json.load(fp)
487 |
488 | LayerGraphVisualizer(args.model_type, results).render(args.output)
489 |
490 | if __name__ == "__main__": # pragma: no cover
491 | _main()
--------------------------------------------------------------------------------
/llm_counts/roofline_model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Sequence, List, Dict
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import argparse
6 |
7 | # 全局常量:最大算术强度和最大 TFLOPS,用于坐标轴范围
8 | MAX_OI = 1400
9 | MAX_TFLOPS = 2000
10 |
11 | @dataclass(frozen=True)
12 | class GPU:
13 | """GPU 硬件参数:FP16 峰值算力(TFLOPS)和带宽(TB/s)。"""
14 | name: str
15 | fp16_tflops: float # TFLOPS
16 | hbm_bw: float # TB/s
17 |
18 | @dataclass(frozen=True)
19 | class ModelConfig:
20 | """模型配置参数"""
21 | name: str
22 | total_flops: float # 总 FLOP (以 TeraFLOP 为单位)
23 | total_bytes: float # 总内存访问 (以 TeraByte 为单位)
24 | color: str # 绘图颜色
25 |
26 | def roofline_analysis(
27 | peak_flops: float,
28 | bandwidth: float,
29 | total_flops: float,
30 | total_mac_bytes: float
31 | ) -> tuple[float, float, str]:
32 | """
33 | Analyzes the roofline model and returns the arithmetic intensity,
34 | attainable FLOPs, and the bounding factor (memory or compute).
35 | """
36 | if total_mac_bytes == 0: # 防止除以零
37 | return 0, peak_flops, "compute"
38 |
39 | ai = total_flops / total_mac_bytes
40 | turning_point = peak_flops / bandwidth
41 |
42 | if ai < turning_point:
43 | return ai, ai * bandwidth, "memory"
44 | else:
45 | return ai, peak_flops, "compute"
46 |
47 | def plot_roofline(
48 | models: Sequence[ModelConfig],
49 | gpus: Sequence[GPU],
50 | output_file: str = "roofline_optimized.png"
51 | ) -> None:
52 | """
53 | 绘制经过优化的、用于比较的 Roofline 曲线。
54 |
55 | 主要优化点:
56 | 1. 使用 Log-Log 坐标轴,符合行业标准。
57 | 2. 采用智能图例管理,避免图例冗长。
58 | 3. 使用 adjust_text 自动防止文本标签重叠。
59 | 4. 优化视觉设计,突出重点信息。
60 | """
61 | fig, ax = plt.subplots(figsize=(14, 10))
62 | plot_colors = ['red', 'blue', 'green', 'orange', 'purple']
63 | # --- 1. 绘制 GPU 屋顶线 (作为背景) ---
64 | # 使用对数坐标轴,范围从 0.1 到 10000
65 | oi_range = np.logspace(-1, 4, 200)
66 | gpu_linestyles = ['-', '--', '-.', ':']
67 |
68 | for i, gpu in enumerate(gpus):
69 | roofline = np.minimum(oi_range * gpu.hbm_bw, gpu.fp16_tflops)
70 | linestyle = gpu_linestyles[i % len(gpu_linestyles)]
71 | # 使用统一的灰色系,不同线型来区分,作为背景不干扰主要数据
72 | ax.plot(
73 | oi_range,
74 | roofline,
75 | linestyle=linestyle,
76 | linewidth=2,
77 | label=f"{gpu.name} Roof (Turn @ {gpu.fp16_tflops / gpu.hbm_bw:.1f})",
78 | color=plot_colors[i % len(plot_colors)],
79 | alpha=0.9
80 | )
81 |
82 | # --- 2. 绘制模型性能点并收集文本标签 ---
83 | text_labels = []
84 |
85 | for model in models:
86 | # **智能图例技巧**: 为每个模型创建一个“虚拟”的图例条目,
87 | # 这样图例中每个模型只显示一次。
88 | ax.scatter([], [], color=model.color, marker='o', s=120, label=f"{model.name}")
89 |
90 | for gpu in gpus:
91 | ai, attainable, bound = roofline_analysis(
92 | gpu.fp16_tflops,
93 | gpu.hbm_bw,
94 | model.total_flops,
95 | model.total_bytes
96 | )
97 |
98 | ax.scatter(
99 | ai,
100 | attainable,
101 | s=120,
102 | marker='o', # 使用统一标记,用颜色区分模型
103 | color=model.color,
104 | edgecolors='black',
105 | zorder=5 # 确保点在最上层
106 | )
107 |
108 | # 准备文本标签,稍后由 adjust_text 统一处理
109 | label_text = f"{gpu.name}\n{attainable:.0f} TFLOPS ({bound[:3]}.)"
110 | text_labels.append(
111 | ax.text(ai, attainable, label_text, fontsize=9, ha='center')
112 | )
113 |
114 | # --- 3. 图表美化与最终处理 ---
115 | # 切换到对数坐标轴
116 | ax.set_xscale('log')
117 | ax.set_yscale('log')
118 |
119 | ax.set_xlabel("Arithmetic Intensity (FLOPs / Bytes) [log scale]", fontsize=12)
120 | ax.set_ylabel("Attainable Performance (TFLOPS) [log scale]", fontsize=12)
121 | ax.set_title("Comparative Roofline Analysis", fontsize=16, fontweight='bold')
122 |
123 | # 使用 'both' 在主次刻度上都显示网格,对 log 尺度很友好
124 | ax.grid(True, which="both", linestyle="--", alpha=0.5)
125 |
126 | # 自动调整坐标轴范围,并留出一些边距
127 | ax.autoscale(True)
128 | ax.set_xlim(left=max(ax.get_xlim()[0], 0.5))
129 | ax.set_ylim(bottom=max(ax.get_ylim()[0], 10))
130 |
131 | # **关键步骤**: 调用 adjust_text 来智能地防止标签重叠
132 | # 它会自动移动标签,并可以用箭头指向原始数据点
133 | from adjustText import adjust_text
134 | adjust_text(
135 | text_labels,
136 | ax=ax,
137 | arrowprops=dict(arrowstyle='->', color='gray', lw=0.5)
138 | )
139 |
140 | # 图例现在很简洁,可以优雅地放在图内
141 | ax.legend(fontsize=10, loc='lower right')
142 |
143 | fig.tight_layout()
144 | plt.savefig(output_file, dpi=300)
145 | print(f"Optimized roofline plot saved to {output_file}")
146 | plt.close(fig)
147 |
148 |
149 | def main():
150 | # 预定义 GPU 配置
151 | GPUS = [
152 | GPU("H100", 989, 3.35),
153 | GPU("A100", 312, 2.039),
154 | GPU("RTX4090", 330, 1.008),
155 | GPU("MI300X", 1150, 5.2),
156 | GPU("L40S", 363, 0.864),
157 | ]
158 |
159 | # 预定义模型配置
160 | MODELS = {
161 | "gpt3": ModelConfig(
162 | "GPT-3 (175B)",
163 | total_flops=314000, # TFLOPs (3.14e14 FLOPs)
164 | total_bytes=1000, # TB (1e15 bytes)
165 | color='red'
166 | ),
167 | "llama2-70b": ModelConfig(
168 | "LLaMA2-70B",
169 | total_flops=70000, # TFLOPs (7e13 FLOPs)
170 | total_bytes=200, # TB (2e14 bytes)
171 | color='blue'
172 | ),
173 | "qwen2.5-3b": ModelConfig(
174 | "Qwen2.5-3B",
175 | total_flops=3000, # TFLOPs (3e12 FLOPs)
176 | total_bytes=10, # TB (1e13 bytes)
177 | color='green'
178 | ),
179 | }
180 |
181 | # 命令行参数解析
182 | parser = argparse.ArgumentParser(description="Roofline Model Analysis Tool")
183 | parser.add_argument("--models", nargs="+", choices=list(MODELS.keys()),
184 | default=["gpt3", "llama2-70b", "qwen2.5-3b"],
185 | help="Models to analyze (default: all)")
186 | parser.add_argument("--gpus", nargs="+",
187 | default=["H100", "A100", "RTX4090"],
188 | help="GPUs to analyze (default: H100, A100, RTX4090)")
189 | parser.add_argument("--output", default="roofline_analysis.png",
190 | help="Output filename (default: roofline_analysis.png)")
191 |
192 | args = parser.parse_args()
193 |
194 | # 获取选中的模型和GPU
195 | selected_models = [MODELS[model] for model in args.models]
196 | selected_gpus = [gpu for gpu in GPUS if gpu.name in args.gpus]
197 |
198 | if not selected_gpus:
199 | print("Error: No valid GPUs selected. Available options:")
200 | for gpu in GPUS:
201 | print(f" - {gpu.name}")
202 | return
203 |
204 | # 生成屋顶线图
205 | plot_roofline(selected_models, selected_gpus, args.output)
206 |
207 | if __name__ == "__main__":
208 | main()
--------------------------------------------------------------------------------
/llm_counts/utils/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/llm_counts/utils/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/llm_counts/utils/__pycache__/config.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/llm_counts/utils/__pycache__/config.cpython-311.pyc
--------------------------------------------------------------------------------
/llm_counts/utils/__pycache__/config.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/llm_counts/utils/__pycache__/config.cpython-312.pyc
--------------------------------------------------------------------------------
/llm_counts/utils/__pycache__/constants.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/llm_counts/utils/__pycache__/constants.cpython-310.pyc
--------------------------------------------------------------------------------
/llm_counts/utils/__pycache__/constants.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/llm_counts/utils/__pycache__/constants.cpython-311.pyc
--------------------------------------------------------------------------------
/llm_counts/utils/__pycache__/constants.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/llm_counts/utils/__pycache__/constants.cpython-312.pyc
--------------------------------------------------------------------------------
/llm_counts/utils/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/llm_counts/utils/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/llm_counts/utils/__pycache__/utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harleyszhang/llm_counts/2cd9bc0dd0aa757000b8bd9d5867e986d42828aa/llm_counts/utils/__pycache__/utils.cpython-311.pyc
--------------------------------------------------------------------------------
/llm_counts/utils/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding : utf-8 -*-
2 | # Description : gpu, model, Parallelism, data, train and inference config definition
3 |
4 | import math, json
5 | from .constants import *
6 | from typing import Optional
7 | from dataclasses import dataclass
8 | from enum import Enum
9 | from functools import total_ordering
10 | from transformers import AutoConfig
11 | import os
12 |
13 |
14 | class ActivationRecomputation(Enum):
15 | NONE = 0
16 | """No activation recomputation; requires the most amount of memory."""
17 |
18 | SELECTIVE = 1
19 | """Selectively checkpoints and recomputes only parts of each transformer
20 | layer that take up a considerable amount of memory but are not
21 | computationally expensive to recompute, i.e. Q K V matrix multiplies,
22 | QK^T matrix multiply, softmax, softmax dropout, and attention over V."""
23 |
24 | FULL = 2
25 | """Full activation recomputation stores the input to EVERY transformer
26 | layer, which is sharded across the tensor parallel group, thus requiring an
27 | extra all-gather (ignored for now) per layer and add communication
28 | overhead; requires the lease amount of memory; requires an extra forward
29 | pass."""
30 |
31 |
32 | @total_ordering
33 | class DSZeRO(Enum):
34 | NONE = 0
35 | """No DeepSPeed ZeRO; requires the most amount of memory."""
36 |
37 | STAGE_1 = 1
38 | """ZeRO stage 1 shards the optimizer states across the data parallel
39 | group."""
40 |
41 | STAGE_2 = 2
42 | """ZeRO stage 2 shards the optimizer states and gradients across the data
43 | parallel group."""
44 |
45 | STAGE_3 = 3
46 | """ZeRO stage 3 shards the optimizer states, gradients, and model weights
47 | across the data parallel group."""
48 |
49 | def __lt__(self, other):
50 | # 炫技写法
51 | if other.__class__ is self.__class__:
52 | return self.value < other.value # Enum 枚举类自动赋值
53 | return NotImplemented
54 |
55 | def __eq__(self, other):
56 | if isinstance(other, DSZeRO):
57 | return self.value == other.value
58 | return NotImplemented
59 |
60 |
61 | @dataclass
62 | class GPUEfficiencyConfig:
63 | flops_efficiency: float = 1.0
64 | hbm_memory_efficiency: float = 1.0
65 | intra_node_memory_efficiency: float = 1.0
66 | inter_node_memory_efficiency: float = 1.0
67 |
68 |
69 | @dataclass
70 | class InferenceConfig:
71 | """Inference configuration dataclass."""
72 |
73 | bs: int = None # batch size
74 | seq_len: int = 522 # input sequence length
75 | generate_len: int = 1526 # number of tokens to generate
76 | context_len: int = None # context length
77 | bytes_per_param: int = BYTES_FP16 # model weight bytes
78 | act_dtype_bytes: int = BYTES_FP16 # activation data type bytes
79 | kv_cache_bytes: int = BYTES_FP16 # key/value cache data type bytes
80 |
81 | def __post_init__(self):
82 | if self.context_len is None:
83 | self.context_len = self.seq_len + self.generate_len
84 |
85 |
86 | @dataclass
87 | class ParallelismConfig:
88 | """dataclass module provides a decorator and functions for automatically adding
89 | generated special methods such as __init__() and __repr__() to user-defined classes.
90 | """
91 |
92 | tp_size: int = (
93 | 1 # tensor parallelism size, Megatron-LM tensor parallelism implementation
94 | )
95 | pp_size: int = (
96 | 1 # pipeline parallelism size, Megatron-LM pipeline parallelism implementation
97 | )
98 | dp_size: int = 1 # data parallelism size, DeepSpeed Zero parallelism implementation
99 | sp_size: int = (
100 | 1 # sequence parallelism size, Megatron-LM sequence parallelism implementation
101 | )
102 |
103 |
104 | @dataclass
105 | class ModelConfig:
106 | num_layers: Optional[int] = None # number of transformer layers (blocks)
107 | num_heads: Optional[int] = None # number of attention heads
108 | head_dim: Optional[int] = None # <— 新增:允许显式传入
109 | hidden_size: Optional[int] = None # hidden dimension
110 | vocab_size: Optional[int] = None # vocabulary size
111 | num_kv_heads: Optional[int] = None
112 | max_seq_len: Optional[int] = None # max sequence length
113 | intermediate_size: Optional[int] = None # hidden dimension of FFN, default to 4 * hidden_size
114 | model_type: str = (
115 | None # model type as tagged on Hugging Face (e.g., gpt2, opt, llama.)
116 | )
117 | model_name: str = (
118 | None # model name as tagged on Hugging Face (e.g., gpt2-xl, opt, llama-13b.)
119 | )
120 |
121 | # -------- post-init 逻辑 -------- #
122 | def __post_init__(self) -> None:
123 | # ① KV-heads 默认 = Q-heads
124 | if self.num_kv_heads is None:
125 | self.num_kv_heads = self.num_heads
126 |
127 | # ② FFN 维度默认 = 4×hidden_size
128 | if self.intermediate_size is None:
129 | self.intermediate_size = self.hidden_size * 4
130 |
131 | # ③ **核心:head_dim 计算**
132 | # 若用户 / HF config 已提供,则直接用;否则按经典公式推断
133 | if self.head_dim is None:
134 | self.head_dim = self.hidden_size // self.num_heads
135 |
136 | # ④ 一致性检查(可选:遇到 MoE/GQA 可放宽)
137 | assert (
138 | self.hidden_size == self.head_dim * self.num_heads
139 | ), (
140 | "hidden_size 与 num_heads×head_dim 不一致;"
141 | "若模型采用变体架构,请显式指定 head_dim"
142 | )
143 |
144 | @classmethod
145 | def from_pretrained(
146 | cls, pretrained_model_name_or_path: str, trust_remote_code: bool = True
147 | ):
148 | """
149 | Load a Hugging Face model configuration and map it to ModelConfig.
150 |
151 | Args:
152 | pretrained_model_name_or_path (str): Path or name of the pretrained model.
153 | trust_remote_code (bool): Whether to trust remote code for custom models.
154 |
155 | Returns:
156 | ModelConfig: An instance of the custom ModelConfig class.
157 | """
158 | # Load the Hugging Face configuration
159 | hf_config = AutoConfig.from_pretrained(
160 | pretrained_model_name_or_path, trust_remote_code=trust_remote_code
161 | )
162 |
163 | # Create a ModelConfig instance by mapping the fields
164 | return cls(
165 | num_layers=hf_config.num_hidden_layers,
166 | num_heads=hf_config.num_attentionum_headss,
167 | hidden_size=hf_config.hidden_size,
168 | vocab_size=hf_config.vocab_size,
169 | num_kv_heads=getattr(hf_config, "num_kv_heads", None),
170 | max_seq_len=hf_config.max_position_embeddings,
171 | intermediate_size=hf_config.intermediate_size,
172 | model_type=hf_config.model_type,
173 | model_name=hf_config.name_or_path,
174 | )
175 |
176 |
177 | @dataclass
178 | class GPUConfig:
179 | # 1, gpu 型号和显存大小
180 | name: str # GPU config name
181 | memory_GPU_in_GB: float # memory per GPU in GB
182 | onchip_buffer: float = None # on-chip buffer size in bytes, e.g., register file size
183 |
184 | # 2, gpu 显存带宽、节点内带宽、节点间带宽
185 | hbm_bandwidth_in_GB_per_sec: float=None # GPU HBM bandwidth in GB/s
186 | intra_node_bandwidth_in_GB_per_sec: float=None # intra node GPU bandwidth in GB/s.(PCIE/NVLINK)
187 | intra_node_min_message_latency: float=None # minimum intra node message latency in seconds
188 | # inter node bandwidth in GB/s, assuming Mellanox 200Gbps HDR Infiniband
189 | inter_node_bandwidth_in_GB_per_sec: float = 200
190 |
191 | # 3, 不同精度的 Tensor core 的计算性能
192 | peak_fp32_TFLOPS: float = None # peak Tensor TFLOPS for FP32
193 | peak_fp16_TFLOPS: float = None # peak Tensor TFLOPS for FP16
194 | peak_int8_TFLOPS: float = None # peak Tensor TFLOPS for INT8
195 | peak_int4_TFLOPS: float = None # peak Tensor TFLOPS for INT4
196 |
197 | FLOPS_EFFICIENCY = 0.9
198 | HBM_MEMORY_EFFICIENCY = 0.9
199 | INTRA_NODE_BANDWIDTH_EFFICIENCY = 0.9
200 |
201 | def __post_init__(self):
202 | """
203 | Post-initialization processing to compute missing values and apply efficiencies.
204 | """
205 | # Ensure FP32 TFLOPS is calculated if missing
206 | if self.peak_fp32_TFLOPS is None and self.peak_fp16_TFLOPS is not None:
207 | self.peak_fp32_TFLOPS = self.peak_fp16_TFLOPS / 2
208 |
209 | # Ensure INT8 and INT4 TFLOPS are calculated if missing
210 | if self.peak_int8_TFLOPS is None and self.peak_fp16_TFLOPS is not None:
211 | self.peak_int8_TFLOPS = 2 * self.peak_fp16_TFLOPS
212 | if self.peak_int4_TFLOPS is None and self.peak_fp16_TFLOPS is not None:
213 | self.peak_int4_TFLOPS = 4 * self.peak_fp16_TFLOPS
214 |
215 | # Apply FLOPS efficiency and round to nearest integer
216 | if self.FLOPS_EFFICIENCY:
217 | self.actual_peak_fp32_TFLOPS = math.ceil(
218 | self.peak_fp32_TFLOPS * self.FLOPS_EFFICIENCY
219 | )
220 | self.actual_peak_fp16_TFLOPS = math.ceil(
221 | self.peak_fp16_TFLOPS * self.FLOPS_EFFICIENCY
222 | )
223 | self.actual_peak_int8_TFLOPS = math.ceil(
224 | self.peak_int8_TFLOPS * self.FLOPS_EFFICIENCY
225 | )
226 | self.actual_peak_int4_TFLOPS = math.ceil(
227 | self.peak_int4_TFLOPS * self.FLOPS_EFFICIENCY
228 | )
229 |
230 |
231 | class LLMConfigs(object):
232 | """LLMConfigs is a dataclass that contains all the configurations for the LLM model."""
233 |
234 | def __init__(
235 | self,
236 | gpu_config: GPUConfig,
237 | model_config: ModelConfig,
238 | parallelism_config: ParallelismConfig = ParallelismConfig(),
239 | inference_config: InferenceConfig = InferenceConfig(),
240 | gpu_efficiency_config: GPUEfficiencyConfig = GPUEfficiencyConfig(),
241 | ) -> None:
242 | self.model_config = model_config
243 | self.gpu_config = gpu_config
244 | self.parallelism_config = parallelism_config
245 | self.inference_config = inference_config # 用户自行指定配置
246 | self.gpu_efficiency_config = gpu_efficiency_config # 用户自行指定配置
247 |
248 |
249 | def get_model_and_gpu_config_by_name(
250 | model_name="llama-13b", gpu_name="v100-pcie-32gb"
251 | ) -> dict:
252 | """Read model and gpu configs from a json file."""
253 | current_dir = os.path.dirname(__file__)
254 | model_config_path = os.path.join(current_dir, "../configs/model_configs.json")
255 | gpu_config_path = os.path.join(current_dir, "../configs/gpu_configs.json")
256 |
257 | with open(model_config_path, "r") as f:
258 | config_json = json.load(f) # 类似于 dict 类型
259 | if model_name in config_json:
260 | print(f"model name {model_name} is found in {model_config_path}")
261 | config_dict = config_json[model_name]
262 | model_config = ModelConfig(**config_dict)
263 | else:
264 | print(
265 | f"model name {model_name} is not found in {model_config_path} so need to apply transformers AutoConfig"
266 | )
267 | # 加载模型配置
268 | model_config = ModelConfig.from_pretrained(model_name, trust_remote_code=True)
269 |
270 | with open(gpu_config_path, "r") as f:
271 | config_json = json.load(f) # 类似于 dict 类型
272 | config_dict = config_json[gpu_name]
273 | assert gpu_name in config_json, (
274 | f"gpu name {gpu_name} not found in {gpu_config_path}"
275 | )
276 | gpu_config = GPUConfig(**config_dict)
277 |
278 | return model_config, gpu_config
279 |
280 |
281 | def get_TFLOPS_per_gpu(
282 | gpu_config: GPUConfig, data_type="fp16", flops_efficiency=FLOPS_EFFICIENCY
283 | ) -> float:
284 | """Get the expected TFLOPS per GPU for the specified data type
285 | configuration/GPU (adjusted by flops_efficiency)
286 |
287 | Returns:
288 | float: TFLOPS per GPU and unit is T.
289 | """
290 | if data_type == "int8":
291 | gemm_TFOPS = gpu_config.peak_int8_TFLOPS
292 | elif data_type == "fp16":
293 | gemm_TFOPS = gpu_config.peak_fp16_TFLOPS
294 | else:
295 | print("weight_bits and activation_bits must be 8, or 16!")
296 |
297 | return gemm_TFOPS * flops_efficiency
298 |
299 |
300 | def get_gpu_hbm_bandwidth(
301 | gpu_config: GPUConfig, hbm_memory_efficiency=HBM_MEMORY_EFFICIENCY
302 | ) -> list:
303 | return gpu_config.hbm_bandwidth_in_GB_per_sec * hbm_memory_efficiency, gpu_config.onchip_buffer
304 |
305 |
306 | def get_intra_node_bandwidth(
307 | gpu_config: GPUConfig, intra_node_memory_efficiency=INTRA_NODE_MEMORY_EFFICIENCY
308 | ) -> float:
309 | return gpu_config.intra_node_bandwidth_in_GB_per_sec * intra_node_memory_efficiency
310 |
311 |
312 | def get_inter_node_bandwidth(
313 | gpu_config: GPUConfig, inter_node_memory_efficiency=INTER_NODE_MEMORY_EFFICIENCY
314 | ) -> float:
315 | return gpu_config.inter_node_bandwidth_in_GB_per_sec * inter_node_memory_efficiency
316 |
--------------------------------------------------------------------------------
/llm_counts/utils/constants.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from functools import total_ordering
3 |
4 | #########################################
5 | ####### llm profiler ############
6 | #########################################
7 |
8 | FLOPS_EFFICIENCY = (
9 | 0.9 # FLOPS efficiency achieved by Megatron-LM is ~0.5 for LLM training
10 | )
11 | HBM_MEMORY_EFFICIENCY = 0.9 # GPU HBM memory efficiency
12 | INTRA_NODE_MEMORY_EFFICIENCY = 0.75 # intra-node (nvlink) memory efficiency
13 | INTER_NODE_MEMORY_EFFICIENCY = 0.9 # inter-node memory efficiency
14 |
15 | NUM_GPUS_PER_NODE = 8 # number of GPUs per node
16 |
17 | TOLERANCE = 0.01 # tolerance for floating point comparisons
18 |
19 | BITS_PER_BYTE = 8 # number of bits in a byte
20 |
21 | BITS_FP32 = 32 # number of bits in FP32 data type
22 | BITS_FP16 = 16 # number of bits in FP16 data type
23 | BITS_INT8 = 8 # number of bits in INT8 data type
24 | BITS_INT4 = 4 # number of bits in INT4 data type
25 |
26 | BYTES_FP32 = BITS_FP32 // BITS_PER_BYTE # number of bytes in FP32 data type
27 | BYTES_FP16 = BITS_FP16 // BITS_PER_BYTE # number of bytes in FP16 data type
28 | BYTES_INT8 = BITS_INT8 // BITS_PER_BYTE # number of bytes in INT8 data type
29 | BYTES_INT4 = BITS_INT4 // BITS_PER_BYTE # number of bytes in INT4 data type
30 |
31 | PRINT_LINE_WIDTH = 100
32 |
33 | GPUS = [1, 2, 4, 8]
34 |
35 |
36 | @total_ordering
37 | class ActivationRecomputation(Enum):
38 | NONE = 0
39 | """No activation recomputation; requires the most amount of memory."""
40 | ATTN_COMPUTE = 1
41 | """Selectively checkpoints the attention computation (
42 | QK^T matrix multiply, softmax, softmax dropout, and attention overV.)
43 | in the attention module of a transformer layer; this part takes up a
44 | considerable amount of memory but are not computationally expensive to recompute"""
45 | ATTN = 2
46 | """Selectively checkpoints the input to the attention module in a transformer layer;
47 | requires an extra forward pass on attention."""
48 | NORM_ATTN_NORM = 3
49 | """Selectively checkpoints the input to the sequence of modules (layernom-attention-layernom)
50 | in a transformer layer; requires an extra forward pass on (layernom-attention-layernom)."""
51 | FULL = 4
52 | """Full activation recomputation stores the input to the transformer layer; requires the least
53 | amount of memory; requires an extra forward pass of the layer."""
54 |
55 | def __lt__(self, other):
56 | if self.__class__ is other.__class__:
57 | return self.value < other.value
58 | return NotImplemented
59 |
--------------------------------------------------------------------------------
/llm_counts/utils/utils.py:
--------------------------------------------------------------------------------
1 | import pprint
2 | from .constants import *
3 |
4 | class Formatter(object):
5 | @classmethod
6 | def format_value(cls, value, category):
7 | """根据类别统一格式化 value."""
8 | if category == "params" or category == "flops":
9 | return num_to_string(value)
10 | elif category == "latency":
11 | return latency_to_string(value)
12 | elif category == "memory":
13 | return f"{num_to_string(value)}B"
14 | return value # 如果没有匹配,返回原值
15 |
16 | @classmethod
17 | def print_format_summary_dict(
18 | self,
19 | summary_dict: dict,
20 | depth: int,
21 | category: str | None = None,
22 | ) -> str:
23 | """
24 | 打印时对 params / flops / latency / memory 等进行统一转换显示。
25 | If *category* is provided, apply that formatting to every leaf value that is
26 | not a nested dict; otherwise fall back to key‑based inference.
27 | """
28 | if category is not None and not isinstance(summary_dict, dict):
29 | # Safety bail‑out (shouldn't happen)
30 | return summary_dict
31 | for key, value in summary_dict.items():
32 | # If category is explicitly provided, ignore key‑name heuristics
33 | explicit_cat = category
34 | if (explicit_cat == "params" or explicit_cat == "flops") or ("params" in key or "flops" in key):
35 | if not isinstance(value, dict):
36 | summary_dict.update({key: num_to_string(value)})
37 | else:
38 | self.print_format_summary_dict(
39 | value, get_dict_depth(value) - 1, category
40 | ) # 递归
41 | if explicit_cat == "latency" or "latency" in key:
42 | if not isinstance(value, dict):
43 | summary_dict.update({key: latency_to_string(value)})
44 | else:
45 | self.print_format_summary_dict(value, get_dict_depth(value) - 1, category)
46 | if explicit_cat == "memory" or "memory" in key:
47 | if not isinstance(value, dict):
48 | summary_dict.update({key: f"{num_to_string(value)}B"})
49 | else:
50 | self.print_format_summary_dict(value, get_dict_depth(value) - 1, category)
51 | if depth >= 1:
52 | pprint.pprint(summary_dict, indent=4, sort_dicts=False)
53 |
54 |
55 | def print_list(list):
56 | """print one-dimensional list
57 |
58 | :param list: List[int]
59 | :return: None
60 | """
61 | for i, x in enumerate(list):
62 | print(x, end="\n")
63 |
64 |
65 | def get_dict_depth(d, depth=0):
66 | if not isinstance(d, dict):
67 | return depth
68 | if not d:
69 | return depth
70 |
71 | return max(get_dict_depth(v, depth + 1) for v in d.values())
72 |
73 |
74 | def latency_to_string(latency_in_s, precision=2, return_type="string"):
75 | if latency_in_s is None:
76 | return "None" if return_type == "string" else None
77 |
78 | day = 24 * 60 * 60
79 | hour = 60 * 60
80 | minute = 60
81 | ms = 1 / 1000
82 | us = 1 / 1000000
83 |
84 | if latency_in_s // day > 0:
85 | value = round(latency_in_s / day, precision)
86 | unit = "days"
87 | elif latency_in_s // hour > 0:
88 | value = round(latency_in_s / hour, precision)
89 | unit = "hours"
90 | elif latency_in_s // minute > 0:
91 | value = round(latency_in_s / minute, precision)
92 | unit = "minutes"
93 | elif latency_in_s > 1:
94 | value = round(latency_in_s, precision)
95 | unit = "s"
96 | elif latency_in_s > ms:
97 | value = round(latency_in_s / ms, precision)
98 | unit = "ms"
99 | else:
100 | value = round(latency_in_s / us, precision)
101 | unit = "us"
102 |
103 | if return_type == "string":
104 | return f"{value} {unit}"
105 | elif return_type == "float":
106 | return value
107 | else:
108 | return (value, unit)
109 |
110 |
111 | def num_to_string(num, precision=2, return_type="string"):
112 | if num is None:
113 | return "None" if return_type == "string" else None
114 |
115 | if num // 10**12 > 0:
116 | value = round(num / 10.0**12, precision)
117 | unit = "T"
118 | elif num // 10**9 > 0:
119 | value = round(num / 10.0**9, precision)
120 | unit = "G"
121 | elif num // 10**6 > 0:
122 | value = round(num / 10.0**6, precision)
123 | unit = "M"
124 | elif num // 10**3 > 0:
125 | value = round(num / 10.0**3, precision)
126 | unit = "K"
127 | else:
128 | value = num
129 | unit = ""
130 |
131 | if return_type == "string":
132 | return f"{value} {unit}".strip()
133 | elif return_type == "float":
134 | return value
135 | else:
136 | return (value, unit)
137 |
138 |
139 | def get_readable_summary_dict(summary_dict: dict, title="Summary") -> str:
140 | log_str = f"\n{title.center(PRINT_LINE_WIDTH, '-')}\n"
141 | for key, value in summary_dict.items():
142 | if "num_tokens" in key or "num_params" in key or "flops" in key:
143 | log_str += f"{key}: {num_to_string(value)}\n"
144 | elif "gpu_hours" == key:
145 | log_str += f"{key}: {int(value)}\n"
146 | elif "memory" in key and "efficiency" not in key:
147 | log_str += f"{key}: {num_to_string(value)}B\n"
148 | elif "latency" in key:
149 | log_str += f"{key}: {latency_to_string(value)}\n"
150 | else:
151 | log_str += f"{key}: {value}\n"
152 | log_str += f"{'-' * PRINT_LINE_WIDTH}\n"
153 | return log_str
154 |
155 |
156 | def within_range(val, target, tolerance):
157 | return abs(val - target) / target < tolerance
158 |
159 |
160 | def average(lst):
161 | if not lst:
162 | return None
163 | return sum(lst) / len(lst)
164 |
165 |
166 | def max_value(lst):
167 | if not lst:
168 | return None
169 | return max(lst)
170 |
--------------------------------------------------------------------------------
/llm_counts/utils/visualizer.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | # ====================================================================== #
4 | # Refactored visualisation as a lightweight class #
5 | # ====================================================================== #
6 | class SeqLenVisualizer:
7 | """Encapsulates all plots & tables for a sequence‑length sweep."""
8 |
9 | def __init__(
10 | self,
11 | df: pd.DataFrame,
12 | model: str,
13 | gpu: str,
14 | *,
15 | out_dir: str = "figures",
16 | flops_unit: str = "TFLOPs",
17 | mem_unit: str = "GiB",
18 | dpi: int = 300,
19 | show: bool = False,
20 | ):
21 | from pathlib import Path
22 | import matplotlib.pyplot as plt
23 |
24 | self.df = df.sort_values("seq_len")
25 | self.model = model
26 | self.gpu = gpu
27 | self.out_dir = Path(out_dir)
28 | self.flops_unit = flops_unit
29 | self.mem_unit = mem_unit
30 | self.dpi = dpi
31 | self.show = show
32 |
33 | self.out_dir.mkdir(parents=True, exist_ok=True)
34 | plt.style.use("seaborn-v0_8-paper")
35 | plt.rcParams.update(
36 | {
37 | "figure.facecolor": "white",
38 | "axes.facecolor": "white",
39 | "axes.grid": True,
40 | "grid.alpha": 0.3,
41 | "grid.linestyle": "--",
42 | "axes.edgecolor": "#cccccc",
43 | "axes.spines.top": False,
44 | "axes.spines.right": False,
45 | "font.size": 12,
46 | "axes.titleweight": "bold",
47 | }
48 | )
49 |
50 | # Pre‑compute unit scaling divisors
51 | _scale = {
52 | "GFLOPs": 1e9,
53 | "TFLOPs": 1e12,
54 | "PFLOPs": 1e15,
55 | "MiB": 2**20,
56 | "GiB": 2**30,
57 | }
58 | self.flops_div = _scale.get(self.flops_unit, 1.0)
59 | self.mem_div = _scale.get(self.mem_unit, 1.0)
60 |
61 | # -------------------------- helpers ------------------------------ #
62 | def _save(self, fig, suffix: str):
63 | fig.savefig(
64 | self.out_dir
65 | / f"{self.model}_{self.gpu}_{suffix}.png",
66 | dpi=self.dpi,
67 | bbox_inches="tight",
68 | )
69 | if self.show:
70 | import webbrowser, os, matplotlib.pyplot as _plt
71 | _plt.show()
72 | import matplotlib.pyplot as plt
73 | plt.close(fig)
74 |
75 | @staticmethod
76 | def _line_scatter(ax, x, y, y_label, title, cmap="viridis_r"):
77 | sc = ax.scatter(
78 | x,
79 | y,
80 | c=y,
81 | cmap=cmap,
82 | s=70,
83 | edgecolor="black",
84 | linewidths=0.4,
85 | )
86 | ax.plot(x, y, linewidth=1.2, alpha=0.75)
87 | ax.set_xlabel("Sequence length (tokens)")
88 | ax.set_ylabel(y_label)
89 | ax.set_title(title)
90 | ax.grid(True, linestyle="--", alpha=0.3)
91 | return sc
92 |
93 |
94 | # ----------------------- public interface ------------------------ #
95 | def visualize(self):
96 | self._metric_figs()
97 | self._latency_fig()
98 | self._composite_fig()
99 | self._interactive_html()
100 | self._print_table()
101 |
102 | # ------------------ individual plot generators ------------------ #
103 | def _metric_figs(self):
104 | import matplotlib.pyplot as plt
105 |
106 | metrics = [
107 | ("flops", "prefill_flops", self.flops_div, f"Prefill {self.flops_unit}"),
108 | ("memory", "consume_memory_per_gpu", self.mem_div, f"HBM ({self.mem_unit})"),
109 | ]
110 | if "throughput_tok_per_second" in self.df.columns:
111 | metrics.append(("throughput", "throughput_tok_per_second", 1.0, "Throughput (tok/s)"))
112 |
113 | for suffix, col, div, label in metrics:
114 | if col not in self.df:
115 | continue
116 | y = (self.df[col] / div) if div != 1.0 else self.df[col]
117 | fig, ax = plt.subplots(figsize=(7, 4), constrained_layout=True)
118 | sc = self._line_scatter(
119 | ax,
120 | self.df["seq_len"],
121 | y,
122 | label,
123 | f"{self.model} on {self.gpu}\n{label} vs seq_len",
124 | )
125 | plt.colorbar(sc, ax=ax, label=label)
126 | self._save(fig, f"{suffix}_vs_seq_len")
127 |
128 | def _latency_fig(self):
129 | import matplotlib.pyplot as plt
130 |
131 | fig, ax1 = plt.subplots(figsize=(7, 4), constrained_layout=True)
132 | if "TTFT" in self.df.columns:
133 | ax1.plot(
134 | self.df["seq_len"],
135 | self.df["TTFT"],
136 | "s-.",
137 | linewidth=1.5,
138 | label="TTFT (s)",
139 | )
140 | ax1.set_ylabel("TTFT (s)")
141 | ax2 = ax1.twinx()
142 | if "TTOT" in self.df.columns:
143 | ax2.plot(
144 | self.df["seq_len"],
145 | self.df["TTOT"] * 1000.0,
146 | "^:",
147 | linewidth=1.5,
148 | color="tab:red",
149 | label="TTOT (ms)",
150 | )
151 | ax2.set_ylabel("TTOT (ms)")
152 | ax1.set_xlabel("Sequence length (tokens)")
153 | ax1.set_title(f"{self.model} on {self.gpu}\nTTFT & TTOT vs seq_len")
154 | handles, labels = [], []
155 | for ax in (ax1, ax2):
156 | h, l = ax.get_legend_handles_labels()
157 | handles += h
158 | labels += l
159 | ax1.legend(handles, labels, loc="upper left")
160 | ax1.grid(True, linestyle="--", alpha=0.3)
161 | self._save(fig, "latency_vs_seq_len")
162 |
163 | def _composite_fig(self):
164 | import matplotlib.pyplot as plt
165 | from matplotlib.gridspec import GridSpec
166 |
167 | mem_norm = self.df["consume_memory_per_gpu"] / self.mem_div
168 | fig = plt.figure(figsize=(10, 10), constrained_layout=True)
169 | gs = GridSpec(2, 2, figure=fig)
170 | axes = {
171 | "FLOPs": fig.add_subplot(gs[0, 0]),
172 | "Latency": fig.add_subplot(gs[0, 1]),
173 | "Memory": fig.add_subplot(gs[1, 0]),
174 | "Throughput": fig.add_subplot(gs[1, 1]),
175 | }
176 |
177 | self._line_scatter(
178 | axes["FLOPs"],
179 | self.df["seq_len"],
180 | self.df["prefill_flops"] / self.flops_div,
181 | f"Prefill {self.flops_unit}",
182 | "FLOPs",
183 | )
184 |
185 | if "TTFT" in self.df.columns:
186 | axes["Latency"].plot(self.df["seq_len"], self.df["TTFT"], "s-.", label="TTFT (s)")
187 | if "TTOT" in self.df.columns:
188 | axes["Latency"].plot(self.df["seq_len"], self.df["TTOT"] * 1000.0, "^:", label="TTOT (ms)")
189 | axes["Latency"].set_title("Latency"); axes["Latency"].legend()
190 |
191 | self._line_scatter(
192 | axes["Memory"],
193 | self.df["seq_len"],
194 | mem_norm,
195 | f"HBM ({self.mem_unit})",
196 | "Memory",
197 | )
198 |
199 | if "throughput_tok_per_second" in self.df.columns:
200 | self._line_scatter(
201 | axes["Throughput"],
202 | self.df["seq_len"],
203 | self.df["throughput_tok_per_second"],
204 | "Throughput (tok/s)",
205 | "Throughput",
206 | )
207 |
208 | fig.suptitle(f"{self.model} on {self.gpu}\nOverview", fontsize=14)
209 | self._save(fig, "overview")
210 |
211 | def _interactive_html(self):
212 | try:
213 | import plotly.graph_objects as go
214 | from plotly.offline import plot as psave
215 | from pathlib import Path
216 | figs = []
217 |
218 | # basic metrics
219 | meta = [
220 | ("TTFT (s)", "TTFT", 1.0),
221 | ("TTOT (ms)", "TTOT", 1000.0),
222 | (f"Prefill {self.flops_unit}", "prefill_flops", self.flops_div),
223 | (f"HBM ({self.mem_unit})", "consume_memory_per_gpu", self.mem_div),
224 | ("Throughput (tok/s)", "throughput_tok_per_second", 1.0),
225 | ]
226 | for name, col, div in meta:
227 | if col not in self.df:
228 | continue
229 | y = self.df[col] * div if div != 1.0 else self.df[col]
230 | f = go.Figure(go.Scatter(x=self.df["seq_len"], y=y, mode="lines+markers"))
231 | f.update_layout(title=f"{name} vs seq_len", template="seaborn")
232 | figs.append((name, f))
233 |
234 | html_path = Path(self.out_dir) / f"{self.model}_{self.gpu}_interactive.html"
235 | with open(html_path, "w") as fhtml:
236 | for i, (title, fig) in enumerate(figs):
237 | fhtml.write(f"{title}
")
238 | fhtml.write(psave(fig, include_plotlyjs="cdn" if i == 0 else False, output_type="div"))
239 | if self.show:
240 | import webbrowser, os
241 | webbrowser.open("file://" + os.path.abspath(html_path))
242 | except ImportError:
243 | print("[INFO] plotly missing – skipped interactive output.")
244 |
245 | def _print_table(self):
246 | summary = self.df.copy()
247 | summary["prefill_flops"] = (summary["prefill_flops"] / self.flops_div).map("{:,.2f}".format)
248 | summary["consume_memory_per_gpu"] = (summary["consume_memory_per_gpu"] / self.mem_div).map("{:,.2f}".format)
249 | if "throughput_tok_per_second" in summary.columns:
250 | summary["throughput_tok_per_second"] = summary["throughput_tok_per_second"].map("{:,.2f}".format)
251 | print("=" * 80)
252 | print(summary.to_string(index=False))
253 | print("=" * 80)
--------------------------------------------------------------------------------
/test_torch_info.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Union, TextIO, Optional
3 | import torch
4 | from transformers import AutoModel, AutoConfig, PreTrainedModel
5 | from accelerate import init_empty_weights
6 |
7 | MODEL_ID = "/home/honggao/llm_weights/Qwen3-8B"
8 |
9 |
10 | def print_empty_model(model_id):
11 | """
12 | Accelerate 提供 init_empty_weights 上下文管理器,令所有 Parameter 和 Buffer
13 | 都放在 meta device,尺寸为 0,因此既 不下载权重 也 不占内存。
14 | """
15 | cfg = AutoConfig.from_pretrained(model_id) # 只拉配置
16 |
17 | with init_empty_weights():
18 | model = AutoModel.from_config(cfg)
19 | print(model)
20 | return model
21 |
22 | def print_transformers_model_summary(
23 | model: PreTrainedModel,
24 | *,
25 | use_torchinfo: bool = False,
26 | input_size: Optional[tuple] = None,
27 | file: Union[str, TextIO, None] = None,
28 | ) -> None:
29 | """
30 | 打印 Hugging Face Transformers 模型结构 + 权重 shape。
31 |
32 | Args:
33 | model (PreTrainedModel): 已加载好的模型实例。
34 | use_torchinfo (bool): 是否调用 torchinfo.summary() 生成额外摘要。
35 | input_size (tuple): 当 use_torchinfo=True 时需提供 (seq_len, ) or (bs, seq_len, ...)。
36 | file: None -> 输出到 stdout;
37 | str -> 输出到指定路径文件;
38 | TextIO -> 已打开的文件句柄。
39 | """
40 | import math
41 |
42 | def _human_readable(num: float, *, base: int = 1000,
43 | units=("", "K", "M", "G", "T", "P"), suffix=""):
44 | """Convert a large number to human‑readable form (e.g. 12.3G)."""
45 | if num == 0:
46 | return f"0{suffix}"
47 | exp = min(int(math.log(num, base)), len(units) - 1)
48 | value = num / (base ** exp)
49 | return f"{value:.2f}{units[exp]}{suffix}"
50 |
51 | def _dump(msg: str = ""):
52 | if fh:
53 | fh.write(msg + "\n")
54 | else:
55 | print(msg)
56 |
57 | # 0) 处理输出目标
58 | fh = open(file, "w") if isinstance(file, str) else file
59 |
60 | # 1) 模型 __repr__
61 | _dump("=" * 60)
62 | _dump("Model architecture (__repr__):")
63 | _dump("=" * 60)
64 | _dump(str(model))
65 |
66 | # 2) 权重 shape
67 | _dump("\n" + "=" * 60)
68 | _dump("Parameter shapes (name -> shape, #elements):")
69 | _dump("=" * 60)
70 |
71 | # Token count estimation for FLOPs (default = 1 token if unknown)
72 | tokens = 1
73 | if input_size is not None:
74 | # Accept (seq_len,), (bs, seq_len) or any shape where last dim is seq_len
75 | if len(input_size) == 1:
76 | tokens = input_size[0]
77 | else:
78 | tokens = input_size[0] * input_size[-1]
79 |
80 | total_params = 0
81 | total_flops = 0
82 | total_mem_bytes = 0
83 | for name, param in model.named_parameters():
84 | numel = param.numel()
85 | total_params += numel
86 |
87 | # ---- Estimate per‑parameter FLOPs ----
88 | if param.dim() == 2: # typical (out, in) weight matrix
89 | flops = 2 * param.shape[0] * param.shape[1] * tokens
90 | elif param.dim() == 1: # bias / norm weight
91 | flops = param.shape[0] * tokens
92 | else:
93 | flops = numel # fallback crude estimate
94 | total_flops += flops
95 |
96 | # ---- Memory access cost (parameter bytes only) ----
97 | mem_bytes = numel * param.element_size()
98 | total_mem_bytes += mem_bytes
99 |
100 | # ---- Pretty print ----
101 | flops_str = _human_readable(flops, suffix="F")
102 | mem_str = _human_readable(mem_bytes, base=1024, units=("B","KB","MB","GB","TB","PB"))
103 | _dump(f"{name:<60} {str(tuple(param.shape)):<20} {numel:,} | {flops_str:<8} | {mem_str}")
104 |
105 | _dump(f"\nTotal parameters: {total_params:,}")
106 | _dump(f"Estimated forward FLOPs: {_human_readable(total_flops, suffix='F')}")
107 | _dump(f"Parameter memory: {_human_readable(total_mem_bytes, base=1024, units=('B','KB','MB','GB','TB','PB'))}")
108 |
109 | # 3) 可选 torchinfo 摘要
110 | if use_torchinfo:
111 | try:
112 | from torchinfo import summary # pip install torchinfo
113 | assert input_size is not None, "`input_size` must be provided when use_torchinfo=True"
114 | info = summary(
115 | model,
116 | input_size=input_size,
117 | depth=3,
118 | col_names=("kernel_size", "output_size", "num_params", "mult_adds"),
119 | dtypes=[torch.long], # 对 NLP 模型输入通常是 int64 token id
120 | )
121 | _dump("\n" + "=" * 60)
122 | _dump("torchinfo summary():")
123 | _dump("=" * 60)
124 | _dump(str(info))
125 | except ImportError:
126 | _dump("torchinfo 未安装,跳过摘要。pip install torchinfo 获取更丰富视图。")
127 |
128 | if isinstance(file, str): # 自动关闭文件
129 | fh.close()
130 |
131 | from torchviz import make_dot # pip install torchviz graphviz
132 | def save_model_graph(
133 | model,
134 | input_example: torch.Tensor,
135 | file_name: str = "model_graph.svg"
136 | ) -> None:
137 | """
138 | 利用 torchviz 生成前向图;input_example 必须能直接送入 model。
139 | """
140 | model.eval()
141 | y = model(input_example)
142 | dot = make_dot(y, params=dict(model.named_parameters()))
143 | dot.format = file_name.split(".")[-1] # 自动根据后缀决定 svg/png
144 | dot.render(file_name, cleanup=True)
145 | print(f"✅ Graph saved to {file_name}")
146 |
147 | if __name__ == "__main__":
148 | # model = AutoModel.from_pretrained(MODEL_ID)
149 | model = print_empty_model(MODEL_ID)
150 | input_example = torch.randint(0, 1000, (2, 2048)) # 随机输入
151 | print_transformers_model_summary(
152 | model=model,
153 | use_torchinfo=True,
154 | input_size=(2, 2048),
155 | file="qwen3_8b_structure.txt"
156 | )
--------------------------------------------------------------------------------