├── tests
├── __init__.py
├── extra_arch.yaml
└── test_upmem_pytorch_simulator.py
├── src
└── upmem_llm_framework
│ ├── __init__.py
│ ├── utils.py
│ ├── sim_architectures.py
│ ├── architectures_schema.json
│ ├── options.py
│ ├── pytorch_upmem_layers.py
│ ├── sim_architectures.yaml
│ ├── simulator.py
│ ├── base_architecture.py
│ └── profiler.py
├── .pylintrc
├── .vale.ini
├── .editorconfig
├── MANIFEST.in
├── .github
└── workflows
│ └── pytest.yml
├── LICENSE
├── examples
├── nn_example.py
└── hf_example.py
├── .gitignore
├── pyproject.toml
├── scripts_simulation
└── simulations_llama2_7B.py
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Unit test package for upmem_pytorch_simulator."""
2 |
--------------------------------------------------------------------------------
/src/upmem_llm_framework/__init__.py:
--------------------------------------------------------------------------------
1 | from upmem_llm_framework.options import initialize_profiling_options
2 | from upmem_llm_framework.pytorch_upmem_layers import profiler_init, profiler_start, profiler_end
3 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MESSAGES CONTROL]
2 |
3 | disable=
4 | missing-function-docstring,
5 | missing-class-docstring,
6 | missing-module-docstring,
7 |
8 |
9 | [BASIC]
10 |
11 | good-names=sum_energy_mJ
12 | good-names-rgxs=^.*_(GBs|mJ)$
13 |
--------------------------------------------------------------------------------
/.vale.ini:
--------------------------------------------------------------------------------
1 | StylesPath = styles
2 |
3 | MinAlertLevel = suggestion
4 |
5 | Packages = Google, proselint, write-good, alex, Readability
6 |
7 | Vocab = Upmem
8 |
9 | [*.md]
10 | BasedOnStyles = Vale, Google, proselint, write-good, alex, Readability
11 |
--------------------------------------------------------------------------------
/src/upmem_llm_framework/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2014-2024 - UPMEM
3 | #
4 |
5 |
6 | def add_dictionaries(dict1, dict2):
7 | return {
8 | key: dict1.get(key, 0) + dict2.get(key, 0) for key in set(dict1) | set(dict2)
9 | }
10 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.bat]
14 | indent_style = tab
15 | end_of_line = crlf
16 |
17 | [LICENSE]
18 | insert_final_newline = false
19 |
20 | [Makefile]
21 | indent_style = tab
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include AUTHORS.rst
2 | include CONTRIBUTING.rst
3 | include HISTORY.rst
4 | include LICENSE
5 | include README.rst
6 |
7 | recursive-include tests *
8 | recursive-include examples *
9 |
10 | recursive-exclude * __pycache__
11 | recursive-exclude * *.py[co]
12 |
13 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif
14 |
15 | include src/upmem_llm_framework/architectures_schema.json
16 | include src/upmem_llm_framework/sim_architectures.yaml
17 |
--------------------------------------------------------------------------------
/.github/workflows/pytest.yml:
--------------------------------------------------------------------------------
1 | name: "pytest"
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches: [main]
7 | pull_request:
8 | branches: [main]
9 |
10 | jobs:
11 | test:
12 | runs-on: ubuntu-22.04
13 |
14 | steps:
15 | - name: Checkout
16 | uses: actions/checkout@v4
17 |
18 | - name: Set up Python
19 | uses: actions/setup-python@v5
20 | with:
21 | python-version: '3.10'
22 |
23 | - name: Install package
24 | run: pip install --verbose .[dev]
25 |
26 | - name: Run tests
27 | run: pytest
28 |
--------------------------------------------------------------------------------
/tests/extra_arch.yaml:
--------------------------------------------------------------------------------
1 | # yaml-language-server: $schema=../src/upmem_llm_framework/architectures_schema.json
2 |
3 | PIM_AI_4chip_duplicated:
4 | host_to_device_bw_GBs: 12.8
5 | host_to_device_pj_per_bit: 80 # 20 * 4
6 | device_to_host_bw_GBs: 51.2 # 12.8 * 4
7 | device_to_host_pj_per_bit: 20
8 | mem_bw_GBs: 409.6 # 102.4 * 4
9 | mem_pj_per_bit: 0.95
10 | tflops: 20 # 5 * 4
11 | tflops_int4: 128 # 32 * 4
12 | pj_per_tflop: 0.4e+12
13 | softmax_ns_per_element: 0.1 # 0.4 / 4
14 | SiLU_ns_per_element: 0.15 # 0.6 / 4
15 | RMSNorm_ns_per_element: 0.275 # 1.04 / 4
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 UPMEM
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/examples/nn_example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import typer
3 |
4 | import upmem_llm_framework as upmem_layers
5 |
6 | app = typer.Typer(callback=upmem_layers.initialize_profiling_options)
7 |
8 |
9 | class TinyModel(torch.nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 | self.linear1 = torch.nn.Linear(100, 200)
14 | self.activation = torch.nn.ReLU()
15 | self.linear2 = torch.nn.Linear(200, 10)
16 | # self.softmax = torch.nn.Softmax(dim = 0)
17 |
18 | def forward(self, x):
19 | x = self.linear1(x)
20 | x = self.activation(x)
21 | x = self.linear2(x)
22 | x = torch.nn.functional.softmax(x, dim=0)
23 | return x
24 |
25 | @app.command()
26 | def profile():
27 | upmem_layers.profiler_init()
28 |
29 | tinymodel = TinyModel()
30 |
31 | print("The model:")
32 | print(tinymodel)
33 |
34 | my_tensor = torch.rand(100)
35 |
36 | layer_mapping = {
37 | "linear1": "PIM-AI-1chip",
38 | "linear2": "PIM-AI-1chip",
39 | }
40 |
41 | upmem_layers.profiler_start(layer_mapping, last_layer="linear2")
42 | prediction = tinymodel.forward(my_tensor)
43 | upmem_layers.profiler_end()
44 | print(prediction)
45 |
46 | if __name__ == "__main__":
47 | app()
48 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 |
58 | # Flask stuff:
59 | instance/
60 | .webassets-cache
61 |
62 | # Scrapy stuff:
63 | .scrapy
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 |
68 | # PyBuilder
69 | target/
70 |
71 | # Jupyter Notebook
72 | .ipynb_checkpoints
73 |
74 | # pyenv
75 | .python-version
76 |
77 | # celery beat schedule file
78 | celerybeat-schedule
79 |
80 | # SageMath parsed files
81 | *.sage.py
82 |
83 | # dotenv
84 | .env
85 |
86 | # virtualenv
87 | .venv
88 | venv/
89 | ENV/
90 |
91 | # Spyder project settings
92 | .spyderproject
93 | .spyproject
94 |
95 | # Rope project settings
96 | .ropeproject
97 |
98 | # mkdocs documentation
99 | /site
100 |
101 | # mypy
102 | .mypy_cache/
103 |
104 | # IDE settings
105 | .vscode/
106 | .idea/
107 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "upmem_llm_framework"
7 | version = "0.0.1"
8 | authors = [
9 | { name = "Cristobal Ortega", email = "cortega@upmem.com" },
10 | { name = "Sylvan Brocard", email = "sbrocard@upmem.com" },
11 | ]
12 | license = "MIT"
13 | license-files = ["LICENSE"]
14 | keywords = [
15 | "upmem",
16 | "llm",
17 | "transformer",
18 | "pytorch",
19 | "profiling",
20 | "accelerator",
21 | ]
22 | description = """\
23 | UPMEM LLM Framework allows profiling PyTorch layers and functions\
24 | and simulate those layers/functions with a given hardware profile.\
25 | """
26 | readme = "README.md"
27 | dependencies = [
28 | "torch==2.4.1",
29 | "transformers==4.44.2",
30 | "jsonschema==4.23.0",
31 | "typer==0.12.5",
32 | ]
33 | requires-python = ">=3.10"
34 | classifiers = [
35 | 'Development Status :: 2 - Pre-Alpha',
36 | 'Intended Audience :: Developers',
37 | 'Natural Language :: English',
38 | 'Programming Language :: Python :: 3.10',
39 | ]
40 |
41 | [project.optional-dependencies]
42 | dev = ["pytest==8.3.3"]
43 |
44 | [project.urls]
45 | Homepage = "https://upmem.com"
46 | Repository = "https://github.com/upmem/upmem_llm_framework"
47 |
48 | [tool.setuptools.package-data]
49 | include = [
50 | "src/upmem_llm_framework/architectures_schema.json",
51 | "src/upmem_llm_framework/sim_architectures.yaml",
52 | ]
53 |
54 | [tool.ruff.lint]
55 | select = ["ALL"]
56 | ignore = ["COM812"]
57 |
58 | [tool.ruff]
59 | line-length = 100
60 |
61 | [tool.ruff.lint.extend-per-file-ignores]
62 | "tests/**/*.py" = [
63 | "S101", # asserts allowed in tests
64 | "T201", # print statements allowed in tests
65 | ]
66 |
--------------------------------------------------------------------------------
/examples/hf_example.py:
--------------------------------------------------------------------------------
1 | # import time
2 | import transformers
3 | import typer
4 | from typing_extensions import Annotated
5 |
6 | import upmem_llm_framework as upmem_layers
7 |
8 | app = typer.Typer(callback=upmem_layers.initialize_profiling_options)
9 |
10 |
11 | @app.command()
12 | def profile(
13 | hf_token: Annotated[
14 | str, typer.Argument(envvar="hf_token", help="Hugging Face API token")
15 | ]
16 | ):
17 | upmem_layers.profiler_init()
18 |
19 | model = transformers.AutoModelForCausalLM.from_pretrained(
20 | "meta-llama/Llama-2-7b-chat-hf", token=hf_token
21 | )
22 | tokenizer = transformers.AutoTokenizer.from_pretrained(
23 | "meta-llama/Llama-2-7b-chat-hf", token=hf_token
24 | )
25 |
26 | layer_mapping = {
27 | "input_layernorm": "PIM-AI-1chip",
28 | "q_proj": "PIM-AI-1chip",
29 | "k_proj": "PIM-AI-1chip",
30 | "rotary_emb": "PIM-AI-1chip",
31 | "v_proj": "PIM-AI-1chip",
32 | "o_proj": "PIM-AI-1chip",
33 | "output_layernorm": "PIM-AI-1chip",
34 | "gate_proj": "PIM-AI-1chip",
35 | "up_proj": "PIM-AI-1chip",
36 | "down_proj": "PIM-AI-1chip",
37 | "norm": "PIM-AI-1chip",
38 | "lm_head": "PIM-AI-1chip",
39 | }
40 |
41 | prompt = "How to prepare coffee?"
42 |
43 | inputs = tokenizer(prompt, return_tensors="pt", return_token_type_ids=False)
44 |
45 | print(inputs.data["input_ids"][0].shape)
46 | model.eval() # Put model in evaluation / inference mode
47 |
48 | # print (model)
49 |
50 | upmem_layers.profiler_start(layer_mapping)
51 | # In case we want to time the original execution (comment out profiler_start)
52 | # start = time.time_ns()
53 | gen_tokens = model.generate(
54 | inputs.input_ids, do_sample=True, temperature=0.9, min_length=64, max_length=64
55 | )
56 | # print ( (time.time_ns() - start)/1e6)
57 | upmem_layers.profiler_end()
58 |
59 | gen_text = tokenizer.batch_decode(
60 | gen_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False
61 | )[0]
62 | print(gen_text)
63 |
64 |
65 | if __name__ == "__main__":
66 | app()
67 |
--------------------------------------------------------------------------------
/src/upmem_llm_framework/sim_architectures.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2014-2024 - UPMEM
3 | # UPMEM S.A.S France property - UPMEM confidential information covered by NDA
4 | # For UPMEM partner internal use only - no modification allowed without permission of UPMEM
5 | #
6 | # This file implements multiple hardware architectures to be simulated.
7 | # All architecture inherit from the Base_architecture class.
8 | # If an architecture has optimizations for a given operation defined in Base_architecture,
9 | # define them here
10 |
11 | import json
12 | from functools import cache
13 | from importlib.resources import as_file, files
14 | from pathlib import Path
15 | from typing import Dict
16 |
17 | import yaml
18 | from jsonschema import validate
19 |
20 | from upmem_llm_framework.options import options
21 |
22 | def read_architecture_file(file: Path, schema: Dict) -> Dict:
23 | with open(file, "r", encoding="UTF-8") as f:
24 | architectures = yaml.safe_load(f)
25 | validate(architectures, schema)
26 | return architectures
27 |
28 |
29 | @cache
30 | def read_architectures() -> Dict:
31 | """
32 | Read the architectures from the sim_architectures.yaml file
33 | :return: a dictionary containing the architectures
34 | """
35 | with as_file(files("upmem_llm_framework")) as resources_dir:
36 | with open(
37 | resources_dir / "architectures_schema.json", "r", encoding="UTF-8"
38 | ) as f:
39 | schema = json.load(f)
40 |
41 | architectures = read_architecture_file(
42 | resources_dir / "sim_architectures.yaml", schema
43 | )
44 |
45 | if options.extra_archs:
46 | extra_architectures = read_architecture_file(
47 | options.extra_archs, schema
48 | )
49 | architectures.update(extra_architectures)
50 |
51 | return architectures
52 |
53 |
54 | @cache
55 | def get_spec(name: str) -> Dict:
56 | """
57 | Get an architecture object corresponding to the given name
58 | :param name: the name of the architecture
59 | :return: an object corresponding to the architecture
60 | """
61 | architectures = read_architectures()
62 |
63 | architecture_spec = architectures.get(name)
64 | if architecture_spec is None:
65 | raise ValueError(f"Architecture {name} not found in sim_architectures.yaml")
66 |
67 | return architecture_spec
68 |
--------------------------------------------------------------------------------
/src/upmem_llm_framework/architectures_schema.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": "object",
3 | "additionalProperties": {
4 | "type": "object",
5 | "properties": {
6 | "host_to_device_bw_GBs": {
7 | "type": "number",
8 | "description": "Host to device bandwidth in GB/s"
9 | },
10 | "host_to_device_pj_per_bit": {
11 | "type": "number",
12 | "description": "Host to device energy per bit in pJ"
13 | },
14 | "device_to_host_bw_GBs": {
15 | "type": "number",
16 | "description": "Device to host bandwidth in GB/s"
17 | },
18 | "device_to_host_pj_per_bit": {
19 | "type": "number",
20 | "description": "Device to host energy per bit in pJ"
21 | },
22 | "mem_bw_GBs": {
23 | "type": "number",
24 | "description": "Memory bandwidth in GB/s"
25 | },
26 | "mem_pj_per_bit": {
27 | "type": "number",
28 | "description": "Memory energy per bit in pJ"
29 | },
30 | "tflops": {
31 | "type": "number",
32 | "description": "TFLOPS"
33 | },
34 | "pj_per_tflop": {
35 | "type": "number",
36 | "description": "Energy per TFLOP in pJ"
37 | },
38 | "tflops_int4": {
39 | "type": "number",
40 | "description": "TFLOPS for INT4 (optional)"
41 | },
42 | "softmax_ns_per_element": {
43 | "type": "number",
44 | "description": "Softmax latency per element in ns (optional)"
45 | },
46 | "SiLU_ns_per_element": {
47 | "type": "number",
48 | "description": "SiLU latency per element in ns (optional)"
49 | },
50 | "RMSNorm_ns_per_element": {
51 | "type": "number",
52 | "description": "RMSNorm latency per element in ns (optional)"
53 | }
54 | },
55 | "required": [
56 | "host_to_device_bw_GBs",
57 | "host_to_device_pj_per_bit",
58 | "device_to_host_bw_GBs",
59 | "device_to_host_pj_per_bit",
60 | "mem_bw_GBs",
61 | "mem_pj_per_bit",
62 | "tflops",
63 | "pj_per_tflop"
64 | ],
65 | "additionalProperties": false
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/tests/test_upmem_pytorch_simulator.py:
--------------------------------------------------------------------------------
1 | """Tests for `upmem_llm_framework` package."""
2 |
3 | import warnings
4 |
5 | import typer
6 | from transformers import AutoModelForCausalLM, AutoTokenizer
7 | from typer.testing import CliRunner
8 |
9 | import upmem_llm_framework as upmem_layers
10 |
11 | runner = CliRunner()
12 | app = typer.Typer(callback=upmem_layers.initialize_profiling_options)
13 |
14 | gen_length = 64
15 |
16 | layer_mapping = {
17 | "LlamaRMSNorm": "PIM-AI-1chip,t",
18 | "q_proj": "PIM-AI-4chip-duplicated,t",
19 | "k_proj": "PIM-AI-4chip",
20 | "rotatory_emb": "PIM-AI-4chip",
21 | "v_proj": "PIM-AI-4chip",
22 | "o_proj": "PIM-AI-4chip,t",
23 | "output_layernorm": "PIM-AI-1chip,t",
24 | "gate_proj": "PIM-AI-4chip,t",
25 | "up_proj": "PIM-AI-4chip,t",
26 | "down_proj": "PIM-AI-4chip,t",
27 | "norm": "PIM-AI-1chip,t",
28 | "lm_head": "PIM-AI-4chip,t",
29 | }
30 | layer_attn_ctxt = "q_proj"
31 |
32 | ignored_warning = (
33 | "`pad_token_id` should be positive but got -1. This will cause errors when batch "
34 | "generating, if there is padding. Please set `pad_token_id` explicitly by "
35 | "`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, "
36 | "and ensure your `input_ids` input does not have negative values."
37 | )
38 |
39 |
40 | @app.command("profile")
41 | def run_tiny_llama_model_with_profiler():
42 | """Run the tiny LLaMA model with the profiler."""
43 |
44 | # Initialize the profiler
45 | upmem_layers.profiler_init()
46 | # Load the tiny LLaMA model and tokenizer
47 | tokenizer = AutoTokenizer.from_pretrained(
48 | "hf-internal-testing/tiny-random-LlamaForCausalLM"
49 | )
50 | with warnings.catch_warnings():
51 | warnings.filterwarnings("ignore", ignored_warning)
52 | model = AutoModelForCausalLM.from_pretrained(
53 | "hf-internal-testing/tiny-random-LlamaForCausalLM"
54 | )
55 | model.generation_config.pad_token_id = tokenizer.eos_token_id
56 |
57 | # Prepare input data
58 | input_text = "Hello, world!"
59 | input_token = tokenizer.encode(
60 | input_text, return_tensors="pt", return_token_type_ids=False
61 | )
62 | input_ids = {
63 | "input_ids": input_token,
64 | "attention_mask": input_token.new_ones(input_token.shape),
65 | }
66 |
67 | # Run the profiler
68 | upmem_layers.profiler_start(layer_mapping, layer_attn_ctxt=layer_attn_ctxt)
69 | outputs = model.generate(
70 | **input_ids,
71 | do_sample=True,
72 | temperature=0.9,
73 | min_length=gen_length,
74 | max_length=gen_length,
75 | )
76 | upmem_layers.profiler_end()
77 |
78 | # Assert the profiler results
79 | assert outputs is not None
80 | assert outputs.shape == (1, gen_length)
81 |
82 |
83 | def test_tiny_llama_model_with_profiler():
84 | """Test the tiny LLaMA model with the profiler."""
85 | result = runner.invoke(
86 | app,
87 | [
88 | "--simulation",
89 | "--report-layers",
90 | "--report-functions",
91 | "--print-log",
92 | "--print-log-summary",
93 | "--extra-archs=tests/extra_arch.yaml",
94 | "profile",
95 | ],
96 | )
97 | print(result.stdout)
98 | assert result.exit_code == 0
99 | assert "##### UPMEM PROFILER OUTPUT #####" in result.stdout
100 | assert "##### Generation Execution summary #####" in result.stdout
101 | assert "##### All (SUM and GEN) Execution summary #####" in result.stdout
102 |
--------------------------------------------------------------------------------
/src/upmem_llm_framework/options.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from enum import Enum
3 | from pathlib import Path
4 | from typing import Optional
5 |
6 | import typer
7 | from typing_extensions import Annotated
8 |
9 |
10 | @dataclass
11 | class Options:
12 | report_layers: bool = False
13 | report_functions: bool = False
14 | print_log: bool = False
15 | print_log_summary: bool = False
16 | simulation: bool = False
17 | sim_compute: bool = False
18 | sim_data_type: str = "bfloat16"
19 | sim_num_key_value_heads: int = -1
20 | sim_sliding_window: int = -1
21 | sim_verbose: bool = False
22 | extra_archs: Optional[Path] = None
23 |
24 |
25 | options = Options()
26 |
27 | class DataType(str, Enum):
28 | int4 = "int4"
29 | int8 = "int8"
30 | float16 = "float16"
31 | bfloat16 = "bfloat16"
32 | float32 = "float32"
33 |
34 |
35 | def initialize_profiling_options(
36 | report_layers: Annotated[
37 | bool,
38 | typer.Option(
39 | help="Enable reporting metrics for all executed layers at the end of the forward pass."
40 | ),
41 | ] = False,
42 | report_functions: Annotated[
43 | bool,
44 | typer.Option(
45 | help="Enable reporting metrics for all executed functions at the end of the forward "
46 | + "pass."
47 | ),
48 | ] = False,
49 | print_log: Annotated[
50 | bool,
51 | typer.Option(
52 | help="Print a trace of the execution of layers and functions.",
53 | ),
54 | ] = False,
55 | print_log_summary: Annotated[
56 | bool,
57 | typer.Option(
58 | help="Print a detailed summary of each layer and function executed. For summarization, "
59 | + "generation, and both.",
60 | ),
61 | ] = False,
62 | simulation: Annotated[
63 | bool,
64 | typer.Option(
65 | help="Enable simulation according to the layer mapping defined",
66 | ),
67 | ] = False,
68 | sim_compute: Annotated[
69 | bool,
70 | typer.Option(
71 | help="Simulate compute intensive operations. Note that some operations are still "
72 | + "performed due to constraints in inputs/outputs of other layer/functions. "
73 | + "CAUTION: Output tokens will be affected",
74 | ),
75 | ] = False,
76 | sim_data_type: Annotated[
77 | DataType,
78 | typer.Option(
79 | help="Set the datatype for weights and inputs.",
80 | ),
81 | ] = DataType.bfloat16,
82 | sim_num_key_value_heads: Annotated[
83 | int,
84 | typer.Option(
85 | help="When using GQA, this value is used to simulate fetching the correct KV caches.",
86 | ),
87 | ] = -1,
88 | sim_sliding_window: Annotated[
89 | int,
90 | typer.Option(
91 | help="When set, a sliding window is simulated according to this value. Note that the "
92 | + "real underlying execution will run according to the model parameter.",
93 | ),
94 | ] = -1,
95 | sim_verbose: Annotated[
96 | bool,
97 | typer.Option(
98 | help="Set a verbose mode for simulation",
99 | ),
100 | ] = False,
101 | extra_archs: Annotated[
102 | Optional[Path],
103 | typer.Option(
104 | help="Path to a yaml file containing extra architectures to be used in simulation",
105 | ),
106 | ] = None,
107 | ):
108 | options.report_layers = report_layers
109 | options.report_functions = report_functions
110 | options.print_log = print_log
111 | options.print_log_summary = print_log_summary
112 | options.simulation = simulation
113 | options.sim_compute = sim_compute
114 | options.sim_data_type = sim_data_type
115 | options.sim_num_key_value_heads = sim_num_key_value_heads
116 | options.sim_sliding_window = sim_sliding_window
117 | options.sim_verbose = sim_verbose
118 | options.extra_archs = extra_archs
119 |
--------------------------------------------------------------------------------
/scripts_simulation/simulations_llama2_7B.py:
--------------------------------------------------------------------------------
1 | # import time
2 | from typing import Optional
3 |
4 | import torch
5 | import transformers
6 | import typer
7 | from typing_extensions import Annotated
8 |
9 | import upmem_llm_framework as upmem_layers
10 |
11 | app = typer.Typer(callback=upmem_layers.initialize_profiling_options)
12 |
13 |
14 | @app.command()
15 | def profile(
16 | hf_token: Annotated[
17 | str, typer.Argument(envvar="hf_token", help="Hugging Face API token")
18 | ],
19 | device: Annotated[str, typer.Option(help="Device to simulate for")] = "mixed",
20 | in_tokens: Annotated[int, typer.Option(help="Number of input tokens")] = 64,
21 | out_tokens: Annotated[int, typer.Option(help="Number of output tokens")] = 128,
22 | bs: Annotated[int, typer.Option(help="Batch size")] = 1,
23 | ):
24 | upmem_layers.profiler_init()
25 |
26 | print("Simulating with device...", device)
27 | print("in:", in_tokens, "out:", out_tokens)
28 |
29 | model = transformers.AutoModelForCausalLM.from_pretrained(
30 | "meta-llama/Llama-2-7b-chat-hf", token=hf_token
31 | )
32 | tokenizer = transformers.AutoTokenizer.from_pretrained(
33 | "meta-llama/Llama-2-7b-chat-hf", token=hf_token
34 | )
35 | layer_mapping = (
36 | {
37 | "LlamaRMSNorm": "PIM-AI-1chip,t",
38 | "q_proj": "PIM-AI-4chip,t",
39 | "k_proj": "PIM-AI-4chip",
40 | "rotatory_emb": "PIM-AI-4chip",
41 | "v_proj": "PIM-AI-4chip",
42 | "o_proj": "PIM-AI-4chip,t",
43 | "output_layernorm": "PIM-AI-1chip,t",
44 | "gate_proj": "PIM-AI-4chip,t",
45 | "up_proj": "PIM-AI-4chip,t",
46 | "down_proj": "PIM-AI-4chip,t",
47 | "norm": "PIM-AI-1chip,t",
48 | "lm_head": "PIM-AI-4chip,t",
49 | }
50 | if device == "mixed"
51 | else {
52 | "LlamaRMSNorm": device,
53 | "q_proj": device,
54 | "k_proj": device,
55 | "rotatory_emb": device,
56 | "v_proj": device,
57 | "o_proj": device,
58 | "output_layernorm": device,
59 | "gate_proj": device,
60 | "up_proj": device,
61 | "down_proj": device,
62 | "norm": device,
63 | "lm_head": device,
64 | }
65 | )
66 | layer_attn_ctxt = "q_proj"
67 |
68 | print("Batch 1")
69 | prompt = "placeholder"
70 | prompt_batch = [prompt] * bs
71 | input_ids = tokenizer(
72 | prompt_batch, return_tensors="pt", return_token_type_ids=False
73 | )
74 | input_ids["input_ids"] = torch.randint(100, [bs, in_tokens])
75 | input_ids["attention_mask"] = torch.ones([bs, in_tokens], dtype=torch.int)
76 | print(input_ids.data["input_ids"][0].shape)
77 |
78 | model.eval()
79 | print(model)
80 | upmem_layers.profiler_start(layer_mapping, layer_attn_ctxt=layer_attn_ctxt)
81 | # start = time.time_ns()
82 | gen_tokens = model.generate(
83 | **input_ids,
84 | do_sample=True,
85 | temperature=0.9,
86 | min_length=out_tokens,
87 | max_length=out_tokens,
88 | )
89 | # print ( (time.time_ns() - start)/1e6)
90 | upmem_layers.profiler_end()
91 |
92 | gen_text = tokenizer.batch_decode(gen_tokens)
93 | print(gen_text)
94 |
95 | raise typer.Exit()
96 |
97 | print("Batch 10")
98 | prompt_batch = [prompt] * 10
99 | input_ids = tokenizer(prompt_batch, return_tensors="pt")
100 | print(input_ids.data["input_ids"][0].shape)
101 |
102 | upmem_layers.profiler_start(layer_mapping, layer_attn_ctxt=layer_attn_ctxt)
103 | gen_tokens = model.generate(
104 | **input_ids, do_sample=True, temperature=0.9, min_length=128, max_length=128
105 | )
106 | upmem_layers.profiler_end()
107 |
108 | print("Batch 30")
109 | prompt_batch = [prompt] * 30
110 | input_ids = tokenizer(prompt_batch, return_tensors="pt")
111 | print(input_ids.data["input_ids"][0].shape)
112 |
113 | upmem_layers.profiler_start(layer_mapping, layer_attn_ctxt=layer_attn_ctxt)
114 | gen_tokens = model.generate(
115 | **input_ids, do_sample=True, temperature=0.9, min_length=128, max_length=128
116 | )
117 | upmem_layers.profiler_end()
118 |
119 | print("Batch 40")
120 | prompt_batch = [prompt] * 40
121 | input_ids = tokenizer(prompt_batch, return_tensors="pt")
122 | print(input_ids.data["input_ids"][0].shape)
123 |
124 | upmem_layers.profiler_start(layer_mapping, layer_attn_ctxt=layer_attn_ctxt)
125 | gen_tokens = model.generate(
126 | **input_ids, do_sample=True, temperature=0.9, min_length=128, max_length=128
127 | )
128 | upmem_layers.profiler_end()
129 |
130 | print("Batch 200")
131 | prompt_batch = [prompt] * 200
132 | input_ids = tokenizer(prompt_batch, return_tensors="pt")
133 | print(input_ids.data["input_ids"][0].shape)
134 |
135 | upmem_layers.profiler_start(layer_mapping, layer_attn_ctxt=layer_attn_ctxt)
136 | gen_tokens = model.generate(
137 | **input_ids, do_sample=True, temperature=0.9, min_length=128, max_length=128
138 | )
139 | upmem_layers.profiler_end()
140 |
141 | # Batching (from https://lukesalamone.github.io/posts/what-are-attention-masks/)
142 | # tokenizer.padding_side = "left"
143 | # tokenizer.pad_token = tokenizer.eos_token
144 | #
145 | # sentences = ["It will rain in the",
146 | # "I want to eat a big bowl of",
147 | # "My dog is"]
148 | # inputs = tokenizer(sentences, return_tensors="pt", padding=True)
149 |
150 | # gen_text = tokenizer.batch_decode(gen_tokens)[0]
151 | gen_text = tokenizer.batch_decode(gen_tokens)
152 | print(gen_text)
153 |
154 | ## torch profiler snippet
155 | # with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
156 | # with record_function("forward"):
157 | # gen_text = tokenizer.batch_decode(gen_tokens)[0]
158 | # #model(inputs)
159 | #
160 | #
161 | # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=100))
162 | #
163 | # print ("----- Group by input shape")
164 | # print(prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10))
165 | #
166 | # prof.export_chrome_trace("trace.json")
167 |
168 |
169 | if __name__ == "__main__":
170 | app()
171 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | UPMEM LLM framework for profiling / simulation
2 | ==============================================
3 |
4 | [](https://github.com/upmem/upmem_llm_framework/actions/workflows/pytest.yml)
5 |
6 | This library allows
7 |
8 | 1. Profiling PyTorch neural networks on a CPU,
9 | 2. Simulating the execution of the neural network in a target hardware
10 | accelerator.
11 |
12 | Usage
13 | -----
14 |
15 | 1. Import the `upmem_llm_framework` library and
16 | [`typer`](https://typer.tiangolo.com/). Create a `typer` app to handle the user
17 | input for the profiler.
18 |
19 | ```python
20 | # file: my_profiler.py
21 | import typer
22 |
23 | import upmem_llm_framework as upmem_layers
24 |
25 | app = typer.Typer(callback=upmem_layers.initialize_profiling_options)
26 | ```
27 |
28 | 2. Define your main function and add the desired user input.
29 | Initialize the library before creating or importing the neural network:
30 |
31 | ```python
32 | @app.command()
33 | def profile(my_input: str):
34 | upmem_layers.profiler_init()
35 | # Create or import the neural network
36 | model = ...
37 | # Define the input tensor
38 | myTensor = ...
39 | ```
40 |
41 | 3. Call the profiler when doing a forward pass / inference:
42 |
43 | ```python
44 | upmem_layers.profiler_start()
45 | prediction = model.forward(myTensor)
46 | upmem_layers.profiler_end()
47 | ```
48 |
49 | 4. Call the app:
50 |
51 | ```python
52 | if __name__ == "__main__":
53 | app()
54 | ```
55 |
56 | 5. See the available options:
57 |
58 | ```bash
59 | python my_profiler.py --help
60 | ```
61 |
62 | 6. Run the app:
63 |
64 | ```bash
65 | python my_profiler.py --some-option profile my_input
66 | ```
67 |
68 | ### Examples
69 |
70 | You can find usage examples with a custom PyTorch model in `nn_example.py` and
71 | with a model from HuggingFace in `hf_example.py`.
72 |
73 |
74 | PyTorch model
75 |
76 | ```bash
77 | python3 nn_example.py profile
78 | ```
79 |
80 | Expected output:
81 |
82 | ```text
83 | Options: Options(report_layers=False, report_functions=False, print_log=False, print_log_summary=False, simulation=False, sim_compute=False, sim_data_type=, sim_num_key_value_heads=-1, sim_sliding_window=-1, sim_verbose=False, extra_archs=None)
84 | The model:
85 | TinyModel(
86 | (linear1): UPM_Linear(in_features=100, out_features=200, bias=True)
87 | (activation): ReLU()
88 | (linear2): UPM_Linear(in_features=200, out_features=10, bias=True)
89 | )
90 | ##### UPMEM PROFILER OUTPUT #####
91 | Total time (SUM + GEN): 0.002975238 s, with data type: bfloat16, batch size: 1
92 | Generated tokens: 0 in 0.002175651 s, with tokens/s: 0.0
93 | Summarization step took: 0.000799587 s, weight in the execution: SUM: 0.268747239716621%, GEN: 0.7312527602833789%
94 | ##### END UPMEM PROFILER OUTPUT #####
95 | tensor([0.0983, 0.0919, 0.1012, 0.0836, 0.0796, 0.1157, 0.1202, 0.0996, 0.0930,
96 | 0.1168], grad_fn=)
97 | ```
98 |
99 |
100 |
101 |
102 | HuggingFace model
103 |
104 | ```bash
105 | python3 hf_example.py profile
106 | ```
107 |
108 | Expected output:
109 |
110 | ```text
111 | Options: Options(report_layers=False, report_functions=False, print_log=False, print_log_summary=False, simulation=False, sim_compute=False, sim_data_type=, sim_num_key_value_heads=-1, sim_sliding_window=-1, sim_verbose=False, extra_archs=None)
112 | Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.03it/s]
113 | torch.Size([6])
114 | ##### UPMEM PROFILER OUTPUT #####
115 | Total time (SUM + GEN): 42.124470486 s, with data type: bfloat16, batch size: 1
116 | Generated tokens: 57 in 41.404485553 s, with tokens/s: 1.3766624373834302
117 | Summarization step took: 0.719984933 s, weight in the execution: SUM: 0.017091845302584535%, GEN: 0.9829081546974154%
118 | ##### END UPMEM PROFILER OUTPUT #####
119 | How to prepare coffee?
120 |
121 | There are many ways to prepare coffee, and the method you choose will depend on your personal preferences and the equipment you have available. Here are some common methods for preparing coffee:
122 |
123 | 1. Drip brewing: This is one of the most common methods of prepar
124 | ```
125 |
126 |
127 |
128 | Profiler
129 | --------
130 |
131 | The profiler records the start time and end time of a computation layer or
132 | function.
133 | Currently, the profiler doesn't track the real power consumption of the CPU.
134 |
135 | The profiler identifies a layer or function by 4 parameters:
136 |
137 | 1. Layer type (f.e. `Linear` module) or function (f.e. `softmax`),
138 | 2. Context when the layer or function is called, meaning the variable name
139 | assigned to the layer or function (f.e. `q_proj = torch.nn.Linear(...)` has a
140 | context of `q_proj`),
141 | 3. the input dimensions of the layer or function,
142 | 4. specifically for layer, a unique ID assigned at layer initialization.
143 |
144 | ### Profiler output
145 |
146 | By default, the profiler reports a summary with execution time, energy (when
147 | simulating), and power consumption (when simulating) at the end of its
148 | execution.
149 |
150 | When simulating, this summary breaks down into the summarization (encoding)
151 | phase and the generation (decoding) phase.
152 |
153 | You can enable the following flags to show more information:
154 |
155 | * `--report-layers`: reports the created layers in the neural network with its
156 | associated parameters
157 | * `--report-functions`: reports the called functions during the forward pass of
158 | the neural network with its associated parameters
159 | * `--print-log`: prints a time-ordered detailed log of each layer and function
160 | executed during the forward pass of the neural network
161 |
162 | Simulation
163 | ----------
164 |
165 | To run a simulation, library users need to provide a dictionary mapping layers
166 | with a device or hardware accelerator.
167 |
168 | This dictionary contains `name_of_layer:device,options` key-value pairs.
169 | The name of the layer corresponds to the context concept introduced before.
170 | The device corresponds to one of the accelerators defined in
171 | `sim_architectures.yaml`.
172 |
173 | Currently supported options:
174 |
175 | * 't' or transfer point: the input of a layer with this option comes from the
176 | CPU, which means that the last device sent its results back to the CPU and the
177 | CPU is sending them back as input to the layer's device.
178 | * 'm' or MoE transfer point: the input of a layer with this option comes from
179 | the CPU but only once since the input is shared across different MoEs.
180 |
181 | For instance, for a neural network composed of 2 Linear layers that execute
182 | sequentially in different chips:
183 |
184 | ```python
185 | layer_mapping = {
186 | "linear1":"PIM-AI-1chip,t",
187 | "linear2":"PIM-AI-1chip,t",
188 | }
189 |
190 | upmem_layers.profiler_start(layer_mapping)
191 | prediction = model.forward(myTensor)
192 | upmem_layers.profiler_end()
193 | ```
194 |
195 | This mapping corresponds to the following scheme
196 |
197 | ```mermaid
198 | graph LR
199 | **CPU** -->|input of *linear1* is sent to **PIM-AI-1chip1** device| PIM-AI-1chip1["`**PIM-AI-1chip1**
200 | Execute *linear1*`"]
201 | PIM-AI-1chip1 -->|output of *linear1* is sent to **CPU**| **CPU**
202 | **CPU** -->|input of *linear2* is sent to **PIM-AI-1chip2** device| PIM-AI-1chip2["`**PIM-AI-1chip2**
203 | Execute *linear2*`"]
204 | ```
205 |
206 | ### Running a simulation
207 |
208 | After specifying the layer mapping, to run a simulation:
209 |
210 | ```bash
211 | python3 hf_example.py --simulation profile
212 | ```
213 |
214 | ### Adding a hardware accelerator
215 |
216 | The file `sim_architectures.yaml` contains hardware accelerator profiles.
217 |
218 | To add a new hardware accelerator profile, create a YAML file with the following
219 | structure:
220 |
221 | ```yaml
222 | # yaml-language-server: $schema=/architectures_schema.json
223 | # (The above line is optional, it will enable autocompletion and validation in editors that support
224 | # the YAML language server)
225 |
226 | My_accelerator:
227 | # * Required parameters:
228 | # - HOST communication
229 | host_to_device_bw_GBs: 22
230 | host_to_device_pj_per_bit: 200
231 | device_to_host_bw_GBs: 88
232 | device_to_host_pj_per_bit: 50
233 | # - Device memory (shared memory like)
234 | mem_bw_GBs: 6553.6
235 | mem_pj_per_bit: 0.95
236 | # - Compute
237 | tflops: 320
238 | pj_per_tflop: 0.4e+12
239 | # * Optional parameters:
240 | softmax_ns_per_element: 6.25e-03
241 | SiLU_ns_per_element: 9.375e-03
242 | RMSNorm_ns_per_element: 1.625e-02
243 |
244 | My_accelerator2:
245 | <...>
246 | ```
247 |
248 | Use the `extra-archs` option to add the new accelerator to the simulation:
249 |
250 | ```bash
251 | python3 simulations_llama2_7B.py --simulation --extra-archs my_archs.yaml profile --device My_accelerator
252 | ```
253 |
254 | *Note:* underscores in device names such as `new_device` convert to hyphens,
255 | resulting in `new-device` in the layer mapping.
256 |
257 | ### Notes on simulation
258 |
259 | This library makes two assumptions to simplify execution modelling across
260 | hardware profiles:
261 |
262 | 1. Ignored interconnection communication latency: it assumes that
263 | intercommunication between devices finishes fast enough that it can overlap with
264 | compute and get hidden.
265 | For instance, when simulating more than one GPU, it doesn't model the required
266 | data exchange between them.
267 | For an AI-PIM device (DIMM), it doesn't model communication within a DIMM.
268 | 2. Devices always reach peak performance. All hardware profiles perform
269 | operations at their peak performance. This is unrealistic in some scenarios.
270 | Adding a performance ratio to model this is left to future work.
271 |
272 | Installation
273 | ------------
274 |
275 | ### Environment setup
276 |
277 | #### Python version
278 |
279 | This library expects Python 3.10.
280 |
281 | If your distribution doesn't provide it, you can use
282 | [`pyenv`](https://github.com/pyenv/pyenv) to install it, or any other Python
283 | version manager:
284 |
285 | ```bash
286 | pyenv install 3.10
287 | pyenv shell 3.10
288 | ```
289 |
290 | Now your shell runs Python 3.10.
291 |
292 | #### Virtual environment
293 |
294 | Preferably, create a virtual environment to install the library:
295 |
296 | ```bash
297 | python -m venv venv
298 | source venv/bin/activate
299 | ```
300 |
301 | This avoids conflicts with other Python libraries in your system.
302 |
303 | ### User installation
304 |
305 | To install the library in your current Python environment:
306 |
307 | ```bash
308 | python -m pip install .
309 | ```
310 |
311 | ### Developer installation
312 |
313 | To install the library for editing in your current Python environment, with the
314 | necessary development dependencies:
315 |
316 | ```bash
317 | python -m pip install -e '.[dev]'
318 | ```
319 |
320 | #### Running tests
321 |
322 | Run the tests with:
323 |
324 | ```bash
325 | python -m pytest
326 | ```
327 |
328 | #### Formatting
329 |
330 | This project uses `black` formatting. Please, make sure to run it before
331 | committing:
332 |
333 | ```bash
334 | python -m black src/upmem_llm_framework/*.py
335 | ```
336 |
--------------------------------------------------------------------------------
/src/upmem_llm_framework/pytorch_upmem_layers.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2014-2024 - UPMEM
3 | # UPMEM S.A.S France property - UPMEM confidential information covered by NDA
4 | # For UPMEM partner internal use only - no modification allowed without permission of UPMEM
5 | #
6 | # This file wraps PyTorch classes and functions into new UPM classes and functions able to
7 | # track the start, inputs, end and outputs of the corresponding function.
8 | # Currently, forward from multiple modules and other minor functions
9 | # (normalizations, softmax, activations, etc.) are tracked and profiled.
10 |
11 | from inspect import getframeinfo, stack
12 | from typing import Tuple
13 |
14 | import torch
15 | import transformers
16 |
17 | from upmem_llm_framework.options import options
18 | from upmem_llm_framework.profiler import UPM_Profiler
19 |
20 | profiler: UPM_Profiler = None
21 | profiling = 0
22 |
23 |
24 | def get_context():
25 | # https://stackoverflow.com/questions/24438976/debugging-get-filename-and-line-number-from-which-a-function-is-called
26 | return getframeinfo(stack()[2][0]).code_context[0].split()[0].replace("self.", "")
27 |
28 |
29 | class UPM_Module(torch.nn.Module):
30 |
31 | def forward(self, x):
32 | x = super().forward(x)
33 | return x
34 |
35 |
36 | class UPM_Linear(torch.nn.Linear):
37 |
38 | def __init__(self, *args, **kwargs):
39 | super().__init__(*args, **kwargs)
40 | profiler.add(self, get_context())
41 |
42 | def forward(self, x):
43 | context = get_context()
44 | profiler.forward_start(x.shape)
45 | if options.sim_compute:
46 | shape = list(x.shape)
47 | shape[-1] = self.out_features
48 | x = torch.zeros(shape)
49 | else:
50 | x = super().forward(x)
51 | profiler.forward_end(x.shape, context, layer_obj=self)
52 | return x
53 |
54 |
55 | class UPM_NonDynamicallyQuantizableLinear(
56 | torch.nn.modules.linear.NonDynamicallyQuantizableLinear
57 | ):
58 | def __init__(self, *args, **kwargs):
59 | super().__init__(*args, **kwargs)
60 | profiler.add(self, get_context())
61 | print("HERE")
62 |
63 | def forward(self, x):
64 | context = get_context()
65 | profiler.forward_start(x.shape)
66 | if options.sim_compute:
67 | shape = list(x.shape)
68 | shape[-1] = self.out_features
69 | x = torch.zeros(shape)
70 | else:
71 | x = super().forward(x)
72 | profiler.forward_end(x.shape, context, layer_obj=self)
73 | return x
74 |
75 |
76 | class UPM_LayerNorm(torch.nn.LayerNorm):
77 |
78 | def __init__(self, *args, **kwargs):
79 | super().__init__(*args, **kwargs)
80 | profiler.add(self, get_context())
81 |
82 | def forward(self, x):
83 | context = get_context()
84 | profiler.forward_start(x.shape)
85 | x = super().forward(x)
86 | profiler.forward_end(x.shape, context, layer_obj=self)
87 | return x
88 |
89 |
90 | class UPM_Embedding(torch.nn.Embedding):
91 |
92 | def __init__(self, *args, **kwargs):
93 | super().__init__(*args, **kwargs)
94 | profiler.add(self, get_context())
95 |
96 | def forward(self, x):
97 | context = get_context()
98 | profiler.forward_start(x.shape)
99 | x = super().forward(x)
100 | profiler.forward_end(x.shape, context, layer_obj=self)
101 | return x
102 |
103 |
104 | class UPM_LlamaRotaryEmbedding(
105 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
106 | ):
107 |
108 | def __init__(self, *args, **kwargs):
109 | super().__init__(*args, **kwargs)
110 | profiler.add(self, get_context())
111 |
112 | def forward(self, x, position_ids) -> Tuple[torch.Tensor, torch.Tensor]:
113 | context = get_context()
114 | shape = x.shape
115 | profiler.forward_start(shape)
116 | x = super().forward(x, position_ids)
117 | profiler.forward_end(shape, context, layer_obj=self)
118 | return x
119 |
120 |
121 | class UPM_LlamaRMSNorm(transformers.models.llama.modeling_llama.LlamaRMSNorm):
122 |
123 | def __init__(self, *args, **kwargs):
124 | super().__init__(*args, **kwargs)
125 | profiler.add(self, get_context())
126 |
127 | def forward(self, hidden_states):
128 | context = get_context()
129 | profiler.forward_start(hidden_states.shape)
130 | hidden_states = super().forward(hidden_states)
131 | profiler.forward_end(hidden_states.shape, context, layer_obj=self)
132 | return hidden_states
133 |
134 |
135 | class UPM_SiLUActivation(torch.nn.SiLU):
136 |
137 | def __init__(self, *args, **kwargs):
138 | super().__init__(*args, **kwargs)
139 | profiler.add(self, get_context())
140 |
141 | def forward(self, x):
142 | context = get_context()
143 | profiler.forward_start(x.shape)
144 | x = super().forward(x)
145 | profiler.forward_end(x.shape, context, layer_obj=self)
146 | return x
147 |
148 |
149 | class UPM_NewGELUActivation(transformers.activations.NewGELUActivation):
150 |
151 | def __init__(self, *args, **kwargs):
152 | super().__init__(*args, **kwargs)
153 | profiler.add(self, get_context())
154 |
155 | def forward(self, x):
156 | context = get_context()
157 | profiler.forward_start(x.shape)
158 | x = super().forward(x)
159 | profiler.forward_end(x.shape, context, layer_obj=self)
160 | return x
161 |
162 |
163 | # Not used in inference
164 | class UPM_Dropout(torch.nn.Dropout):
165 |
166 | def __init__(self, *args, **kwargs):
167 | super().__init__(*args, **kwargs)
168 | profiler.add(self, get_context())
169 |
170 |
171 | class UPM_Conv1d(torch.nn.Conv1d):
172 |
173 | def __init__(self, *args, **kwargs):
174 | super().__init__(*args, **kwargs)
175 | profiler.add(self, get_context())
176 |
177 |
178 | class UPM_Conv2d(torch.nn.Conv2d):
179 | def __init__(self, *args, **kwargs):
180 | super().__init__(*args, **kwargs)
181 | profiler.add(self, get_context())
182 |
183 | def forward(self, x):
184 | context = get_context()
185 | profiler.forward_start(x.shape)
186 | x = super().forward(x)
187 | profiler.forward_end(x.shape, context, layer_obj=self)
188 | return x
189 |
190 |
191 | class UPM_Conv1D(transformers.pytorch_utils.Conv1D):
192 |
193 | def __init__(self, *args, **kwargs):
194 | super().__init__(*args, **kwargs)
195 | profiler.add(self, get_context())
196 |
197 | def forward(self, x):
198 | context = get_context()
199 | profiler.forward_start(x.shape)
200 | x = super().forward(x)
201 | profiler.forward_end(x.shape, context)
202 | return x
203 |
204 |
205 | class UPM_Softmax(torch.nn.Softmax):
206 |
207 | def __init__(self, *args, **kwargs):
208 | super().__init__(*args, **kwargs)
209 | profiler.add(self, get_context())
210 |
211 | def forward(self, x):
212 | context = get_context()
213 | profiler.forward_start(x.shape)
214 | x = super().forward(x)
215 | profiler.forward_end(x.shape, context)
216 | return x
217 |
218 |
219 | class UPM_Tensor(torch.Tensor):
220 |
221 | def transpose(self, dim0, dim1):
222 | print("MyTranpose with input:", self, "dim0", dim0, "dim1", dim1)
223 | super().transpose(dim0, dim1)
224 |
225 |
226 | __pytorch_nn_functional_softmax = torch.nn.functional.softmax
227 |
228 |
229 | # TODO: change logic here to not use stringly types
230 | def UPM_Softmax_functional(input, dim=None, dtype=None):
231 | context = get_context()
232 | profiler.forward_func_start("softmax", context, input.shape)
233 | x = __pytorch_nn_functional_softmax(input, dim=dim, dtype=dtype)
234 | profiler.forward_func_end(__pytorch_nn_functional_softmax, context, x.shape)
235 | return x
236 |
237 |
238 | __pytorch_matmul = torch.matmul
239 |
240 |
241 | # TODO: here too
242 | def UPM_Matmul(input, other, *, out=None):
243 | context = get_context()
244 | profiler.forward_func_start("matmul", context, input.shape)
245 | x = __pytorch_matmul(input, other, out=out)
246 | profiler.forward_func_end(__pytorch_matmul, context, x.shape)
247 | return x
248 |
249 |
250 | __pytorch_scaled_dot_product_attention = (
251 | torch.nn.functional.scaled_dot_product_attention
252 | )
253 |
254 |
255 | # TODO: here too
256 | def UPM_scaled_dot_product_attention(query, key, value, **kwargs):
257 | context = get_context()
258 | # profiler.forward_func_start("scaled_dot_product_attention", context, [query.shape, key.shape, value.shape])
259 | profiler.forward_func_start("scaled_dot_product_attention", context, key.shape)
260 | if options.sim_compute:
261 | q_shape = list(query.shape)
262 | v_shape = list(value.shape)
263 | q_shape[-1] = v_shape[-1]
264 | x = torch.zeros(q_shape)
265 | else:
266 | x = __pytorch_scaled_dot_product_attention(query, key, value, **kwargs)
267 | profiler.forward_func_end(__pytorch_scaled_dot_product_attention, context, x.shape)
268 | return x
269 |
270 |
271 | __pytorch_transpose = torch.transpose
272 |
273 |
274 | def UPM_Transpose(input, dim0, dim1):
275 | print("UPM_Transpose with input", input.shape, "dim0:", dim0, "dim1", dim1)
276 | x = __pytorch_transpose(input, dim0, dim1)
277 | return x
278 |
279 |
280 | def profiler_init():
281 |
282 | print(f"Options: {options}")
283 |
284 | global profiling, profiler
285 | profiling = 1
286 | profiler = UPM_Profiler(options)
287 |
288 | # torch library
289 | torch.nn.Module = UPM_Module
290 | torch.nn.Linear = UPM_Linear
291 | torch.nn.modules.linear.NonDynamicallyQuantizableLinear = (
292 | UPM_NonDynamicallyQuantizableLinear
293 | )
294 | torch.nn.LayerNorm = UPM_LayerNorm
295 | torch.nn.Embedding = UPM_Embedding
296 | torch.nn.Dropout = UPM_Dropout
297 | torch.nn.Conv1d = UPM_Conv1d
298 | torch.nn.Conv2d = UPM_Conv2d
299 | torch.nn.Softmax = UPM_Softmax
300 | torch.nn.functional.softmax = UPM_Softmax_functional
301 | torch.matmul = UPM_Matmul
302 | torch.transpose = UPM_Transpose
303 | torch.nn.functional.scaled_dot_product_attention = UPM_scaled_dot_product_attention
304 | # torch.Tensor = UPM_Tensor
305 |
306 | # transformers library
307 | transformers.pytorch_utils.Conv1D = UPM_Conv1D
308 | transformers.activations.NewGELUActivation = UPM_NewGELUActivation
309 | transformers.activations.ACT2FN["gelu_new"] = (
310 | UPM_NewGELUActivation # classes are hardcoded in ACT2FN
311 | )
312 | transformers.activations.ACT2FN["silu"] = (
313 | UPM_SiLUActivation # classes are hardcoded in ACT2FN
314 | )
315 | transformers.models.llama.modeling_llama.LlamaRMSNorm = UPM_LlamaRMSNorm
316 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = (
317 | UPM_LlamaRotaryEmbedding
318 | )
319 |
320 | transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm = (
321 | UPM_LlamaRMSNorm # miXtral models
322 | )
323 | transformers.models.mistral.modeling_mistral.MistralRMSNorm = (
324 | UPM_LlamaRMSNorm # miStral models
325 | )
326 |
327 |
328 | def profiler_start(
329 | layer_mapping=None,
330 | layer_attn_ctxt="",
331 | last_layer="lm_head",
332 | batch_size=1,
333 | moe_end="",
334 | experts_per_token=2,
335 | ):
336 | profiler.start(
337 | layer_mapping=layer_mapping,
338 | layer_attn_ctxt=layer_attn_ctxt,
339 | last_layer=last_layer,
340 | batch_size=batch_size,
341 | moe_end=moe_end,
342 | experts_per_token=experts_per_token,
343 | )
344 |
345 |
346 | def profiler_end():
347 | profiler.end()
348 |
--------------------------------------------------------------------------------
/src/upmem_llm_framework/sim_architectures.yaml:
--------------------------------------------------------------------------------
1 | # yaml-language-server: $schema=./architectures_schema.json
2 |
3 | DGX100:
4 | host_to_device_bw_GBs: 450
5 | host_to_device_pj_per_bit: 27
6 | device_to_host_bw_GBs: 450
7 | device_to_host_pj_per_bit: 27
8 | mem_bw_GBs: 26800
9 | mem_pj_per_bit: 7
10 | tflops: 7916
11 | pj_per_tflop: 0.5e+12
12 |
13 | V100:
14 | host_to_device_bw_GBs: 64
15 | host_to_device_pj_per_bit: 27
16 | device_to_host_bw_GBs: 64
17 | device_to_host_pj_per_bit: 27
18 | mem_bw_GBs: 900
19 | mem_pj_per_bit: 7
20 | tflops: 112
21 | pj_per_tflop: 0.5e+12
22 |
23 | H100_x4:
24 | host_to_device_bw_GBs: 64
25 | host_to_device_pj_per_bit: 27
26 | device_to_host_bw_GBs: 64
27 | device_to_host_pj_per_bit: 27
28 | mem_bw_GBs: 8000 # 2000 * 4
29 | mem_pj_per_bit: 7
30 | tflops: 3026 # 756.5 * 4
31 | pj_per_tflop: 0.5e+12
32 | softmax_ns_per_element: 5.208e-4 # 0.4 / (16 * 12) / 4
33 | SiLU_ns_per_element: 7.813e-4 # 0.6 / (16 * 12) / 4
34 | RMSNorm_ns_per_element: 1.354e-3 # 1.04 / (16 * 12) / 4
35 |
36 | H100_x5:
37 | host_to_device_bw_GBs: 64
38 | host_to_device_pj_per_bit: 27
39 | device_to_host_bw_GBs: 64
40 | device_to_host_pj_per_bit: 27
41 | mem_bw_GBs: 10000 # 2000 * 5
42 | mem_pj_per_bit: 7
43 | tflops: 3782.5 # 756.5 * 5
44 | pj_per_tflop: 0.5e+12
45 | softmax_ns_per_element: 4.167e-4 # 0.4 / (16 * 12) / 5
46 | SiLU_ns_per_element: 6.25e-4 # 0.6 / (16 * 12) / 5
47 | RMSNorm_ns_per_element: 1.0833e-3 # 1.04 / (16 * 12) / 5
48 |
49 | H100_x8:
50 | host_to_device_bw_GBs: 450
51 | host_to_device_pj_per_bit: 280 # 40 * (8 - 1)
52 | device_to_host_bw_GBs: 450
53 | device_to_host_pj_per_bit: 40
54 | mem_bw_GBs: 16000 # 2000 * 8
55 | mem_pj_per_bit: 7
56 | tflops: 6052 # 756.5 * 8
57 | pj_per_tflop: 0.5e+12
58 | softmax_ns_per_element: 2.604e-4 # 0.4 / (16 * 12) / 8
59 | SiLU_ns_per_element: 3.906e-4 # 0.6 / (16 * 12) / 8
60 | RMSNorm_ns_per_element: 6.771e-4 # 1.04 / (16 * 12) / 8
61 |
62 | H100_x3:
63 | host_to_device_bw_GBs: 64
64 | host_to_device_pj_per_bit: 27
65 | device_to_host_bw_GBs: 64
66 | device_to_host_pj_per_bit: 27
67 | mem_bw_GBs: 6000 # 2000 * 3
68 | mem_pj_per_bit: 7
69 | tflops: 2269.5 # 756.5 * 3
70 | pj_per_tflop: 0.5e+12
71 | # Assuming a H100 is equivalent to 128 AI PIM cores (8 DIMMs) due to server size
72 | softmax_ns_per_element: 1.0417e-03 # 0.4 / (16 * 2 * 4) / 3
73 | SiLU_ns_per_element: 1.5625e-03 # 0.6 / (16 * 2 * 4) / 3
74 | RMSNorm_ns_per_element: 2.7083e-3 # 1.04 / (16 * 2 * 4) / 3
75 |
76 | H100_x2:
77 | host_to_device_bw_GBs: 64
78 | host_to_device_pj_per_bit: 27
79 | device_to_host_bw_GBs: 64
80 | device_to_host_pj_per_bit: 27
81 | mem_bw_GBs: 4000 # 2000 * 2
82 | mem_pj_per_bit: 7
83 | tflops: 1513 # 756.5 * 2
84 | pj_per_tflop: 0.5e+12
85 | softmax_ns_per_element: 1.5625e-3 # 0.4 / (16 * 2 * 4) / 2
86 | SiLU_ns_per_element: 2.34375e-3 # 0.6 / (16 * 2 * 4) / 2
87 | RMSNorm_ns_per_element: 4.0625e-3 # 1.04 / (16 * 2 * 4) / 2
88 |
89 | A800:
90 | host_to_device_bw_GBs: 64
91 | host_to_device_pj_per_bit: 27
92 | device_to_host_bw_GBs: 64
93 | device_to_host_pj_per_bit: 27
94 | mem_bw_GBs: 1500
95 | mem_pj_per_bit: 7
96 | tflops: 312
97 | pj_per_tflop: 0.5e+12
98 | # Assuming a A800 is equivalent to 128 AI PIM cores (8 DIMMs) due to server size
99 | softmax_ns_per_element: 3.125e-3 # 0.4 / (16 * 2 * 4)
100 | SiLU_ns_per_element: 4.6875e-3 # 0.6 / (16 * 2 * 4)
101 | RMSNorm_ns_per_element: 8.125e-3 # 1.04 / (16 * 2 * 4)
102 |
103 | H20:
104 | host_to_device_bw_GBs: 64
105 | host_to_device_pj_per_bit: 27
106 | device_to_host_bw_GBs: 64
107 | device_to_host_pj_per_bit: 27
108 | mem_bw_GBs: 4000
109 | mem_pj_per_bit: 7
110 | tflops: 148
111 | pj_per_tflop: 0.5e+12
112 | # Assuming a H20 is equivalent to 128 AI PIM cores (8 DIMMs) due to server size
113 | softmax_ns_per_element: 3.125e-3 # 0.4 / (16 * 2 * 4)
114 | SiLU_ns_per_element: 4.6875e-3 # 0.6 / (16 * 2 * 4)
115 | RMSNorm_ns_per_element: 8.125e-3 # 1.04 / (16 * 2 * 4)
116 |
117 | H200:
118 | host_to_device_bw_GBs: 64
119 | host_to_device_pj_per_bit: 27
120 | device_to_host_bw_GBs: 64
121 | device_to_host_pj_per_bit: 27
122 | mem_bw_GBs: 2860
123 | mem_pj_per_bit: 7
124 | tflops: 989
125 | pj_per_tflop: 0.5e+12
126 | # Assuming a H100 is equivalent to 128 AI PIM cores (8 DIMMs) due to server size
127 | softmax_ns_per_element: 3.125e-3 # 0.4 / (16 * 2 * 4)
128 | SiLU_ns_per_element: 4.6875e-3 # 0.6 / (16 * 2 * 4)
129 | RMSNorm_ns_per_element: 8.125e-3 # 1.04 / (16 * 2 * 4)
130 |
131 | H100:
132 | host_to_device_bw_GBs: 64
133 | host_to_device_pj_per_bit: 27
134 | device_to_host_bw_GBs: 64
135 | device_to_host_pj_per_bit: 27
136 | mem_bw_GBs: 2000
137 | mem_pj_per_bit: 7
138 | tflops: 756.5
139 | pj_per_tflop: 0.5e+12
140 | # Assuming a H100 is equivalent to 128 AI PIM cores (8 DIMMs) due to server size
141 | softmax_ns_per_element: 3.125e-3 # 0.4 / (16 * 2 * 4)
142 | SiLU_ns_per_element: 4.6875e-3 # 0.6 / (16 * 2 * 4)
143 | RMSNorm_ns_per_element: 8.125e-3 # 1.04 / (16 * 2 * 4)
144 |
145 | A6000:
146 | host_to_device_bw_GBs: 32
147 | host_to_device_pj_per_bit: 35
148 | device_to_host_bw_GBs: 32
149 | device_to_host_pj_per_bit: 35
150 | mem_bw_GBs: 768
151 | mem_pj_per_bit: 15
152 | tflops: 155
153 | pj_per_tflop: 0.5e+12
154 |
155 | A17Pro:
156 | host_to_device_bw_GBs: 51.2
157 | host_to_device_pj_per_bit: 20
158 | device_to_host_bw_GBs: 51.2
159 | device_to_host_pj_per_bit: 20
160 | mem_bw_GBs: 51.2
161 | mem_pj_per_bit: 20
162 | tflops: 17 # GPU (FP16): 4.3, ANE (INT4): 35, ANE (INT8): 17
163 | pj_per_tflop: 0.4e+12
164 |
165 | Dimensity9300:
166 | host_to_device_bw_GBs: 76.8
167 | host_to_device_pj_per_bit: 10
168 | device_to_host_bw_GBs: 76.8
169 | device_to_host_pj_per_bit: 10
170 | mem_bw_GBs: 76.8
171 | mem_pj_per_bit: 10
172 | tflops: 16 # GPU (FP16): 6, APU (INT4): 33, APU (INT8): 16
173 | pj_per_tflop: 0.4e+12
174 |
175 | Snapdragon8gen3:
176 | host_to_device_bw_GBs: 77
177 | host_to_device_pj_per_bit: 10
178 | device_to_host_bw_GBs: 77
179 | device_to_host_pj_per_bit: 10
180 | mem_bw_GBs: 77
181 | mem_pj_per_bit: 10
182 | tflops: 17 # GPU (FP16): 4.73, Hexagon (INT4): 34, Hexagon (INT8): 17
183 | pj_per_tflop: 0.4e+12
184 |
185 | SAM_LPDDR5PIM:
186 | host_to_device_bw_GBs: 12.8
187 | host_to_device_pj_per_bit: 22
188 | device_to_host_bw_GBs: 12.8
189 | device_to_host_pj_per_bit: 22
190 | mem_bw_GBs: 102.4
191 | mem_pj_per_bit: 0.95
192 | tflops: 0.1024
193 | tflops_int4: 0.4096 # 0.1024 * 4
194 | pj_per_tflop: 0.8e+12
195 |
196 | PIM_AI_1chip:
197 | host_to_device_bw_GBs: 12.8
198 | host_to_device_pj_per_bit: 20
199 | device_to_host_bw_GBs: 12.8
200 | device_to_host_pj_per_bit: 20
201 | mem_bw_GBs: 102.4
202 | mem_pj_per_bit: 0.95
203 | tflops: 5
204 | pj_per_tflop: 0.4e+12
205 |
206 | PIM_AI_4chip:
207 | host_to_device_bw_GBs: 12.8
208 | host_to_device_pj_per_bit: 80 # 20 * 4
209 | device_to_host_bw_GBs: 51.2 # 12.8 * 4
210 | device_to_host_pj_per_bit: 20
211 | mem_bw_GBs: 409.6 # 102.4 * 4
212 | mem_pj_per_bit: 0.95
213 | tflops: 20 # 5 * 4
214 | tflops_int4: 128 # 32 * 4
215 | pj_per_tflop: 0.4e+12
216 | softmax_ns_per_element: 0.1 # 0.4 / 4
217 | SiLU_ns_per_element: 0.15 # 0.6 / 4
218 | RMSNorm_ns_per_element: 0.275 # 1.04 / 4
219 |
220 | PIM_AI_1dimm:
221 | host_to_device_bw_GBs: 44
222 | host_to_device_pj_per_bit: 50
223 | device_to_host_bw_GBs: 44
224 | device_to_host_pj_per_bit: 50
225 | mem_bw_GBs: 1638.4 # 102.4 * 16
226 | mem_pj_per_bit: 0.95
227 | tflops: 80 # 5 * 16
228 | pj_per_tflop: 0.4e+12
229 |
230 | PIM_AI_2dimm:
231 | host_to_device_bw_GBs: 22
232 | host_to_device_pj_per_bit: 100 # 50 * 2
233 | device_to_host_bw_GBs: 44
234 | device_to_host_pj_per_bit: 50
235 | mem_bw_GBs: 3276.8 # 102.4 * 16 * 2
236 | mem_pj_per_bit: 0.95
237 | tflops: 160 # 5 * 16 * 2
238 | pj_per_tflop: 0.4e+12
239 | softmax_ns_per_element: 1.25e-02 # 0.4 / (16 * 2)
240 | SiLU_ns_per_element: 1.875e-02 # 0.6 / (16 * 2)
241 | RMSNorm_ns_per_element: 3.25e-02 # 1.04 / (16 * 2)
242 |
243 | PIM_AI_4dimm:
244 | host_to_device_bw_GBs: 22
245 | host_to_device_pj_per_bit: 200 # 50 * 4
246 | device_to_host_bw_GBs: 88 # 44 * 2
247 | device_to_host_pj_per_bit: 50
248 | mem_bw_GBs: 6553.6 # 102.4 * 16 * 4
249 | mem_pj_per_bit: 0.95
250 | tflops: 320 # 5 * 16 * 4
251 | pj_per_tflop: 0.4e+12
252 | softmax_ns_per_element: 6.25e-03 # 0.4 / (16 * 4)
253 | SiLU_ns_per_element: 9.375e-03 # 0.6 / (16 * 4)
254 | RMSNorm_ns_per_element: 1.625e-02 # 1.04 / (16 * 4)
255 |
256 | PIM_AI_16dimm:
257 | host_to_device_bw_GBs: 22
258 | host_to_device_pj_per_bit: 800 # 50 * 16
259 | device_to_host_bw_GBs: 352 # 44 * 8
260 | device_to_host_pj_per_bit: 50
261 | mem_bw_GBs: 26214.4 # 102.4 * 16 * 16
262 | mem_pj_per_bit: 0.95
263 | tflops: 2048 # 8 * 16 * 16
264 | pj_per_tflop: 0.4e+12
265 | softmax_ns_per_element: 1.5625e-03 # 0.4 / (16 * 16)
266 | SiLU_ns_per_element: 2.34375e-03 # 0.6 / (16 * 16)
267 | RMSNorm_ns_per_element: 4.0625e-03 # 1.04 / (16 * 16)
268 |
269 | PIM_AI_8dimm:
270 | host_to_device_bw_GBs: 22
271 | host_to_device_pj_per_bit: 640 # (50 + 2 * 15) * 8
272 | device_to_host_bw_GBs: 176 # 44 * 4
273 | device_to_host_pj_per_bit: 50
274 | mem_bw_GBs: 13107.2 # 102.4 * 16 * 8
275 | mem_pj_per_bit: 0.95
276 | tflops: 1024 # 8 * 16 * 8
277 | pj_per_tflop: 0.5e+12
278 | softmax_ns_per_element: 3.125e-03 # 0.4 / (16 * 8)
279 | SiLU_ns_per_element: 4.6875e-03 # 0.6 / (16 * 8)
280 | RMSNorm_ns_per_element: 8.125e-03 # 1.04 / (16 * 8)
281 |
282 | PIM_AI_10dimm:
283 | host_to_device_bw_GBs: 22
284 | host_to_device_pj_per_bit: 500 # 50 * 10
285 | device_to_host_bw_GBs: 220 # 44 * 5
286 | device_to_host_pj_per_bit: 50
287 | mem_bw_GBs: 16384 # 102.4 * 16 * 10
288 | mem_pj_per_bit: 0.95
289 | tflops: 1280 # 8 * 16 * 10
290 | pj_per_tflop: 0.4e+12
291 | softmax_ns_per_element: 2.5e-03 # 0.4 / (16 * 10)
292 | SiLU_ns_per_element: 3.75e-03 # 0.6 / (16 * 10)
293 | RMSNorm_ns_per_element: 6.5e-03 # 1.04 / (16 * 10)
294 |
295 | PIM_AI_6dimm:
296 | host_to_device_bw_GBs: 22
297 | host_to_device_pj_per_bit: 480 # (50 + 2 * 15) * 6
298 | device_to_host_bw_GBs: 132 # 44 * 3
299 | device_to_host_pj_per_bit: 50
300 | mem_bw_GBs: 9830.4 # 102.4 * 16 * 6
301 | mem_pj_per_bit: 0.95
302 | tflops: 768 # 8 * 16 * 6
303 | pj_per_tflop: 0.5e+12
304 | softmax_ns_per_element: 4.1667e-03 # 0.4 / (16 * 6)
305 | SiLU_ns_per_element: 6.25e-03 # 0.6 / (16 * 6)
306 | RMSNorm_ns_per_element: 1.0833e-02 # 1.04 / (16 * 6)
307 |
308 | CXL_PIM_BC:
309 | # CXL board with:
310 | # 8-lane full duplex PCIe GEN5
311 | # 16 LPDDR controllers, 16 bits, 9.6 GT/s, dual rank (2 devices per IFC)
312 | # A device is a stack of 4 LPDDR-PIM
313 | # 256 GB overall memory (stacking 4 LPDDR-PIM = 8 dies)
314 | # This might be seen as 8x AI PIM DIMM with C2C connection between groups of 4 chips
315 | # Broadcast between LPDDR-PIM is possible
316 | host_to_device_bw_GBs: 19.2 # 8-lane PCIe GEN5, but only one LPDDR5 at a time
317 | host_to_device_pj_per_bit: 50 # crossing PCIe and LPDDR interfaces on both host and device
318 | device_to_host_bw_GBs: 19.2 # 8-lane PCIe GEN5, but only one LPDDR5 at a time
319 | device_to_host_pj_per_bit: 50 # crossing PCIe and LPDDR interfaces on both host and device
320 | mem_bw_GBs: 13107.2 # 102.4 * 16 * 2 * 4
321 | mem_pj_per_bit: 0.95
322 | tflops: 640 # 5 * 16 * 2 *4
323 | pj_per_tflop: 0.4e+12
324 | softmax_ns_per_element: 3.125e-03 # 0.4 / (16 * 2 * 4)
325 | SiLU_ns_per_element: 4.6875e-03 # 0.6 / (16 * 2 * 4)
326 | RMSNorm_ns_per_element: 8.125e-03 # 1.04 / (16 * 2 * 4)
327 |
328 | CXL_PIM_nBC:
329 | # CXL board with:
330 | # 8-lane full duplex PCIe GEN5
331 | # 16 LPDDR controllers, 16 bits, 9.6 GT/s, dual rank (2 devices per IFC)
332 | # A device is a stack of 4 LPDDR-PIM
333 | # 256 GB overall memory (stacking 4 LPDDR-PIM = 8 dies)
334 | # This might be seen as 8x AI PIM DIMM with C2C connection between groups of 4 chips
335 | # Broadcast between LPDDR-PIM is not possible
336 | host_to_device_bw_GBs: 0.6 # 19.2 /32: 8-lane PCIe GEN5, but only one LPDDR5 at a time
337 | host_to_device_pj_per_bit: 1600 # 50 * 32: crossing PCIe and LPDDR interfaces on both host and device
338 | device_to_host_bw_GBs: 19.2 # 8-lane PCIe GEN5, but only one LPDDR5 at a time
339 | device_to_host_pj_per_bit: 50 # crossing PCIe and LPDDR interfaces on both host and device
340 | mem_bw_GBs: 13107.2 # 102.4 * 16 * 2 * 4
341 | mem_pj_per_bit: 0.95
342 | tflops: 640 # 5 * 16 * 2 * 4
343 | pj_per_tflop: 0.4e+12
344 | softmax_ns_per_element: 3.125e-3 # 0.4 / (16 * 2 * 4)
345 | SiLU_ns_per_element: 4.6875e-3 # 0.6 / (16 * 2 * 4)
346 | RMSNorm_ns_per_element: 8.125e-3 # 1.04 / (16 * 2 * 4)
347 |
348 | PIM_AI_12dimm:
349 | host_to_device_bw_GBs: 22
350 | host_to_device_pj_per_bit: 960 # (50 + 2 * 15) * 12
351 | device_to_host_bw_GBs: 264 # 44 * 6
352 | device_to_host_pj_per_bit: 50
353 | mem_bw_GBs: 19660.8 # 102.4 * 16 * 12
354 | mem_pj_per_bit: 0.95
355 | tflops: 1536 # 8 * 16 * 12
356 | pj_per_tflop: 0.5e+12
357 | softmax_ns_per_element: 2.0833e-3 # 0.4 / (16 * 12)
358 | SiLU_ns_per_element: 3.125e-3 # 0.6 / (16 * 12)
359 | RMSNorm_ns_per_element: 5.4167e-03 # 1.04 / (16 * 12)
360 |
361 | PIM_AI_24dimm:
362 | host_to_device_bw_GBs: 22
363 | host_to_device_pj_per_bit: 1920 # (50 + 2 * 15) * 24
364 | device_to_host_bw_GBs: 528 # 44 * 12
365 | device_to_host_pj_per_bit: 50
366 | mem_bw_GBs: 39321.6 # 102.4 * 16 * 24
367 | mem_pj_per_bit: 0.95
368 | tflops: 3072 # 8 * 16 * 24
369 | pj_per_tflop: 0.5e+12
370 | softmax_ns_per_element: 1.0417e-3 # 0.4 / (16 * 24)
371 | SiLU_ns_per_element: 1.5625e-3 # 0.6 / (16 * 24)
372 | RMSNorm_ns_per_element: 2.7083e-3 # 1.04 / (16 * 24)
373 |
--------------------------------------------------------------------------------
/src/upmem_llm_framework/simulator.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2014-2024 - UPMEM
3 | # UPMEM S.A.S France property - UPMEM confidential information covered by NDA
4 | # For UPMEM partner internal use only - no modification allowed without permission of UPMEM
5 | #
6 | # This file implements multiple entry points called by the profiler to simulate the underlying
7 | # hardware.
8 |
9 |
10 | import torch
11 |
12 | from upmem_llm_framework.base_architecture import BaseArchitecture
13 | from upmem_llm_framework.sim_architectures import get_spec
14 | from upmem_llm_framework.utils import add_dictionaries
15 |
16 |
17 | class Simulator:
18 |
19 | def __init__(
20 | self,
21 | data_type_bytes=2.0,
22 | sliding_window=-1,
23 | num_key_value_heads=-1,
24 | verbose=False,
25 | ):
26 | self.data_type_bytes = data_type_bytes
27 | self.layer_mapping = {}
28 | self.layer_attn_ctxt = ""
29 | self.use_kv_cache = True
30 | self.sum = True
31 | self.sum_size = 0
32 | self.batch_size = 0
33 | self.sliding_window = sliding_window
34 | self.num_key_value_heads = num_key_value_heads
35 | self.verbose = verbose
36 |
37 | self.moe_already_sent = {}
38 | self.moe_end = ""
39 | self.experts_per_token = 2 # from Mixtral 8x7B
40 |
41 | self.current_device, _, _ = self.name_to_device("HOST")
42 |
43 | def start_gen(self):
44 | self.sum = False
45 |
46 | def map_layers(self, mapping, layer_attn_ctxt="", moe_end="", experts_per_token=2):
47 | self.layer_mapping = mapping
48 | self.layer_attn_ctxt = layer_attn_ctxt
49 | if self.verbose:
50 | print(self.layer_mapping)
51 | print(self.layer_attn_ctxt)
52 | self.moe_end = moe_end
53 | self.experts_per_token = experts_per_token
54 |
55 | def name_to_device(self, full_name: str):
56 |
57 | device_name = full_name.split(",")[0].replace("-", "_")
58 | flags = full_name.split(",")[1] if len(full_name.split(",")) > 1 else ""
59 |
60 | do_transfer = "t" in flags
61 | moe = "m" in flags
62 |
63 | new_device = BaseArchitecture(
64 | data_type_bytes=self.data_type_bytes,
65 | sliding_window=self.sliding_window,
66 | num_key_value_heads=self.num_key_value_heads,
67 | verbose=self.verbose,
68 | )
69 |
70 | if device_name == "HOST":
71 | new_device.name = "HOST"
72 | return new_device, do_transfer, moe
73 |
74 | spec = get_spec(device_name)
75 | new_device.load_spec(device_name, spec)
76 |
77 | # new_device.adjust_for_quantization()
78 |
79 | return new_device, do_transfer, moe
80 |
81 | def simulate_attn(self, input_shape):
82 | batch_size = input_shape[0] if (len(input_shape) > 2) else 1
83 | n_rows = input_shape[1] if (len(input_shape) > 1) else 1
84 | n_columns = input_shape[-1]
85 |
86 | # input (n_rows, n_columns) x Wqkv (n_columns, Wqkv),
87 | # where Wqkv is n_columns*3 (all Wq, Wk, Wv together)
88 | # Qall_heads = (n_rows, n_columns), each head computes n_columns/n_heads
89 | # Kall_heads = (n_rows, n_columns)
90 | # Vall_heads = (n_rows, n_columns)
91 |
92 | # TODO: add transpose time, concat time
93 | compute_time_ns = 0
94 | performance = {}
95 | energy_compute = {}
96 |
97 | kt = torch.Size([n_columns, n_rows])
98 | qkt = torch.Size([batch_size, n_rows, n_rows])
99 | v = torch.Size([n_rows, n_columns])
100 |
101 | if self.use_kv_cache:
102 | if self.sum:
103 | self.sum_size = n_rows # Just to keep it updated
104 | else:
105 | # load KV cache
106 | kv_cache = torch.Size([batch_size, self.sum_size, n_columns * 2])
107 | step_time, step_perf, step_energy = self.current_device.load_data(
108 | kv_cache
109 | )
110 | compute_time_ns += step_time
111 | performance = add_dictionaries(performance, step_perf)
112 | energy_compute = add_dictionaries(energy_compute, step_energy)
113 |
114 | # concat K + 1
115 | kt = torch.Size([n_columns, self.sum_size + 1])
116 | # concat V + 1
117 | v = torch.Size([self.sum_size + 1, n_columns])
118 | qkt = torch.Size([batch_size, n_rows, self.sum_size + 1])
119 |
120 | # Q x Kt
121 | if self.verbose:
122 | print("Computing Q x Kt")
123 | step_time, step_perf, step_energy = self.current_device.compute_ns(
124 | input_shape,
125 | torch.nn.Linear(in_features=n_columns, out_features=n_rows),
126 | kt,
127 | load_input=False,
128 | )
129 | compute_time_ns += step_time
130 | performance = add_dictionaries(performance, step_perf)
131 | energy_compute = add_dictionaries(energy_compute, step_energy)
132 |
133 | # output = V * QKt
134 | if self.verbose:
135 | print("Computing V x QKt")
136 | step_time, step_perf, step_energy = self.current_device.compute_ns(
137 | qkt,
138 | torch.nn.Linear(in_features=n_rows, out_features=n_rows),
139 | v,
140 | load_input=False,
141 | )
142 | compute_time_ns += step_time
143 | performance = add_dictionaries(performance, step_perf)
144 | energy_compute = add_dictionaries(energy_compute, step_energy)
145 |
146 | if self.verbose:
147 | print(
148 | f"Attn ({'SUM.' if self.sum else 'GEN.'}): "
149 | f"Q: {input_shape} Kt: {kt} QKt: {qkt} V: {v}"
150 | )
151 | print(performance)
152 |
153 | return compute_time_ns, performance, energy_compute
154 |
155 | def simulate_end(self, input_shape, generated_tokens=1):
156 |
157 | time_send_ans_to_host = 0
158 | perf_send_ans_to_host = {}
159 | energy_send_ans_to_host = {}
160 | data_send_ans_to_host = {}
161 |
162 | if self.current_device.name != "HOST":
163 | # we asume the new input is what needs to be written back to host from previous layer
164 | (
165 | time_send_ans_to_host,
166 | perf_send_ans_to_host,
167 | energy_send_ans_to_host,
168 | data_send_ans_to_host,
169 | ) = self.current_device.host_transfer(
170 | input_shape, direction="to_host", generated_tokens=generated_tokens
171 | )
172 |
173 | return (
174 | time_send_ans_to_host,
175 | perf_send_ans_to_host,
176 | energy_send_ans_to_host,
177 | data_send_ans_to_host,
178 | )
179 |
180 | def check_moe(self, context):
181 | num_seen = self.moe_already_sent.get(context, 0)
182 |
183 | self.moe_already_sent[context] = num_seen + 1
184 |
185 | return num_seen == 0
186 |
187 | def reset_moe(self):
188 | all_moe_seen = True
189 | for v in self.moe_already_sent.values():
190 | all_moe_seen = all_moe_seen and (v == self.experts_per_token)
191 |
192 | if all_moe_seen:
193 | # Reset dict for next iteration
194 | self.moe_already_sent = {}
195 |
196 | return all_moe_seen
197 |
198 | def check_sync_point(self, context, input_shape):
199 | time_send_ans_to_host = 0
200 | time_send_ans_from_host = 0
201 | perf = {}
202 | energy = {}
203 | moved_data = {}
204 |
205 | new_device = None
206 |
207 | # print(f"Try mapping {context}")
208 | # assume that if the layer is not mapped, it stays in the current device
209 | if context in self.layer_mapping:
210 | # print(f"Mapping {context} to {self.layer_mapping[context]}")
211 | new_device, gather_at_host, moe = self.name_to_device(
212 | self.layer_mapping[context]
213 | )
214 |
215 | # if new_device != current_device --> pay transfer
216 | if (
217 | new_device.name != self.current_device.name
218 | or gather_at_host
219 | or (moe and self.check_moe(context))
220 | ):
221 |
222 | # if HOST is not current device, transfer to HOST the output from the last layer
223 | # we asume the new input is what needs to be written back to host from previous
224 | # layer
225 | if self.current_device.name != "HOST":
226 | time_send_ans_to_host, step_perf, step_energy, step_data = (
227 | self.current_device.host_transfer(
228 | input_shape, direction="to_host"
229 | )
230 | )
231 | perf = add_dictionaries(perf, step_perf)
232 | energy = add_dictionaries(energy, step_energy)
233 | moved_data = add_dictionaries(moved_data, step_data)
234 |
235 | # change current_device = new_device
236 | old_device = self.current_device
237 | self.current_device = new_device
238 |
239 | # then, host writes into the new device
240 | if self.current_device.name != "HOST":
241 | time_send_ans_from_host, step_perf, step_energy, step_data = (
242 | self.current_device.host_transfer(input_shape)
243 | )
244 | perf = add_dictionaries(perf, step_perf)
245 | energy = add_dictionaries(energy, step_energy)
246 | moved_data = add_dictionaries(moved_data, step_data)
247 |
248 | if self.verbose:
249 | print(
250 | f"Changing device from {old_device.name} to {new_device.name} "
251 | f"took {perf} with energy: {energy}"
252 | )
253 |
254 | return time_send_ans_to_host, time_send_ans_from_host, perf, energy, moved_data
255 |
256 | def simulate_layer(self, layer, input_shape, layer_obj, weight_shape, output_shape):
257 |
258 | time_send_ans_to_host = 0
259 | time_send_ans_from_host = 0
260 | compute_time_ns = 0
261 |
262 | perf_transfer = {}
263 | perf_compute = {}
264 |
265 | energy_transfer = {}
266 | energy_compute = {}
267 |
268 | data_transfer = {}
269 |
270 | if self.verbose:
271 | print("Simulating layer:", layer.context, "n_layer:", layer.n_layer)
272 |
273 | (
274 | time_send_ans_to_host,
275 | time_send_ans_from_host,
276 | perf_transfer,
277 | energy_transfer,
278 | data_transfer,
279 | ) = self.check_sync_point(layer.context, input_shape)
280 |
281 | if layer.context == self.moe_end and self.moe_end != "":
282 | if self.reset_moe():
283 | # output_shape expected to be [tokens * batch_size, features]
284 | # Send all experts' output except one, which shall be accounted by the layer mapping
285 | transfer_shape = torch.Size(
286 | [self.experts_per_token - 1, output_shape[0], output_shape[1]]
287 | )
288 | (
289 | time_send_ans_to_host_moe,
290 | perf_transfer_moe,
291 | energy_transfer_moe,
292 | data_transfer_moe,
293 | ) = self.current_device.host_transfer(
294 | transfer_shape, direction="to_host"
295 | )
296 | if self.verbose:
297 | print("Last layer of MoE sends back to HOST: ", transfer_shape)
298 | time_send_ans_to_host += time_send_ans_to_host_moe
299 | perf_transfer = add_dictionaries(perf_transfer, perf_transfer_moe)
300 | energy_transfer = add_dictionaries(energy_transfer, energy_transfer_moe)
301 | data_transfer = add_dictionaries(data_transfer, data_transfer_moe)
302 |
303 | # pay compute
304 | step_time, step_perf, step_energy = self.current_device.compute_ns(
305 | input_shape, layer_obj, weight_shape
306 | )
307 | compute_time_ns += step_time
308 | perf_compute = add_dictionaries(perf_compute, step_perf)
309 | energy_compute = add_dictionaries(energy_compute, step_energy)
310 |
311 | # If self-attention required, simulate
312 | # if (layer.context == self.layer_attn_ctxt):
313 | # step_time, step_perf, step_energy = self.simulate_attn(input_shape, weight_shape)
314 | # compute_time_ns += step_time
315 | # perf_compute = add_dictionaries(perf_compute , step_perf )
316 | # energy_compute = add_dictionaries(energy_compute, step_energy)
317 |
318 | if self.verbose:
319 | print("Time send ans to host (ns):", time_send_ans_to_host)
320 | print("Time send ans from host (ns):", time_send_ans_from_host)
321 | print("Compute time (ns):", compute_time_ns)
322 | print("Energy send ans to/from host (pJ):", energy_transfer)
323 | print("Energy Compute (pJ):", energy_compute)
324 |
325 | # we can pipeline TODO: calculate this
326 | # max (time_send_ans_from_host, compute_time)
327 |
328 | # return simulated stats
329 | total_time = time_send_ans_to_host + time_send_ans_from_host + compute_time_ns
330 |
331 | total_perf = add_dictionaries(perf_transfer, perf_compute)
332 |
333 | total_energy = add_dictionaries(energy_transfer, energy_compute)
334 |
335 | return total_time, total_perf, total_energy, data_transfer
336 |
337 | def simulate_function(self, function, context, input_shape, output_shape):
338 |
339 | function_name = (
340 | function.__name__ if hasattr(function, "__name__") else function.name
341 | )
342 |
343 | if self.verbose:
344 | print(
345 | f"Simulating function: {function_name}, "
346 | f"context: {context}, "
347 | f"input shape: {input_shape}, "
348 | f"output shape: {output_shape}"
349 | )
350 |
351 | (
352 | time_send_ans_to_host,
353 | time_send_ans_from_host,
354 | perf_transfer,
355 | energy_transfer,
356 | data_transfer,
357 | ) = self.check_sync_point(function_name, input_shape)
358 |
359 | compute_time_ns, perf_compute, energy_compute = self._compute_function_metrics(
360 | function, context, input_shape, output_shape
361 | )
362 |
363 | if self.verbose:
364 | print("Time send ans to host (ns):", time_send_ans_to_host)
365 | print("Time send ans from host (ns):", time_send_ans_from_host)
366 | print("Compute time (ns):", compute_time_ns)
367 | print("Energy send ans to/from host (pJ):", energy_transfer)
368 | print("Energy Compute (pJ):", energy_compute)
369 |
370 | # return simulated stats
371 | total_time = time_send_ans_to_host + time_send_ans_from_host + compute_time_ns
372 |
373 | total_perf = add_dictionaries(perf_transfer, perf_compute)
374 | total_energy = add_dictionaries(energy_transfer, energy_compute)
375 |
376 | return total_time, total_perf, total_energy, data_transfer
377 |
378 | def _compute_function_metrics(self, function, context, input_shape, output_shape):
379 | if hasattr(function, "__name__") and function.__name__.endswith("softmax"):
380 | return self.current_device.compute_softmax_ns(input_shape)
381 | if hasattr(function, "name") and function.name.endswith("LlamaRMSNorm"):
382 | return self.current_device.compute_RMSNorm_ns(input_shape, output_shape)
383 | if hasattr(function, "name") and (
384 | function.name.endswith("SiLU") or function.name.endswith("SiLUActivation")
385 | ):
386 | return self.current_device.compute_activation_ns(
387 | input_shape, activation="SiLU"
388 | )
389 | if hasattr(function, "__name__") and function.__name__.endswith("matmul"):
390 | return self.current_device.compute_matmul_ns(
391 | context,
392 | input_shape,
393 | output_shape,
394 | summarization=self.sum,
395 | sum_size=self.sum_size,
396 | )
397 | if hasattr(function, "__name__") and function.__name__.endswith(
398 | "scaled_dot_product_attention"
399 | ):
400 | return self.current_device.compute_scaled_dot_product_ns(
401 | context,
402 | input_shape,
403 | output_shape,
404 | summarization=self.sum,
405 | sum_size=self.sum_size,
406 | )
407 | raise ValueError(
408 | "Unsupported function: "
409 | f"{function.__name__ if hasattr(function, '__name__') else function.name}, "
410 | f"type: {type(function)}, "
411 | f"string: {function}"
412 | )
413 |
--------------------------------------------------------------------------------
/src/upmem_llm_framework/base_architecture.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2014-2024 - UPMEM
3 | # UPMEM S.A.S France property - UPMEM confidential information covered by NDA
4 | # For UPMEM partner internal use only - no modification allowed without permission of UPMEM
5 | #
6 | # This file implements Base_architecture class
7 | # This class contains a default implementation of the following functions:
8 | # - adjust_for_quantization: scales up/down the TFLOPs depending on the quantization choosen
9 | # - get_tflops: returns the TFLOPs required in a MxM
10 | # - get_moved_data_bytes: returns the required bytes to move in order to do an operation
11 | # - load_data: models loading the KV cache
12 | # - host_transfer: simulates a data transfer with host in any direction
13 | # - compute_ns: simulates the computation of a MxM
14 | # - compute_scaled_dot_product_ns: simulates the computation of function scaled_dot_product where,
15 | # usually, attention computation occurs
16 | # - compute_matmul_ns: simulates the computation of a matmul for self-attention
17 | # - compute_activation_ns: simulates an activation layer
18 | # - compute_RMSNorm_ns: simulates a RMSNorm layer
19 | # - compute_softmax_ns: simulates a softmax operation
20 | #
21 | # Note that all simulations returns compute_time_ns, performance_dict, energy_dict:
22 | # - compute_time_ns: the simulated time in ns,
23 | # - performance_dict: dictionary containing the simulated time in ns for each operation simulated,
24 | # - energy_dict: dictionary containing the simulated energy in pJ for each operation simulated.
25 |
26 | import math
27 | from typing import Dict
28 |
29 | import torch
30 |
31 | from upmem_llm_framework.utils import add_dictionaries
32 |
33 |
34 | class BaseArchitecture:
35 |
36 | def __init__(
37 | self,
38 | active_chips=1,
39 | tflops=1,
40 | pj_per_tflop=1,
41 | host_to_device_bw_GBs=1,
42 | device_to_host_bw_GBs=1,
43 | # inter_bw = 1,
44 | memory=1,
45 | mem_bw_GBs=1,
46 | mem_pj_per_bit=1,
47 | data_type_bytes=2.0, # float16
48 | # 3000 cycles per row of 2048 elements --> 1.4 cycles / element
49 | # assuming 1 GHz, 1.5 ns / element, parallelized accross 4 chips -> 0.37
50 | softmax_ns_per_element=0.4, # ns, considering it cycles in 1GHz config
51 | SiLU_ns_per_element=0.6, # ns, softmax * 1.5
52 | # (empiric number based on execution of Llama2-7b)
53 | RMSNorm_ns_per_element=1.1, # ns, softmax * 2.6
54 | # (empiric number based on execution of Llama2-7b)
55 | # 3000 cycles per row of 2048 elements with 5 TFLOPs of computing power
56 | # assuming 1 GHz, 0,000003 s --> 3MOPS per row of 2048 --> 1.5kOPS per element
57 | misc_tflops_per_element=1500 / 1e12,
58 | sliding_window=-1,
59 | num_key_value_heads=-1,
60 | verbose=False,
61 | ):
62 |
63 | self.name = ""
64 | self.active_chips = active_chips
65 | # Compute capabilities
66 | self.tflops = tflops
67 | self.pj_per_tflop = 0.4
68 | # Interface with HOST
69 | self.host_to_device_bw_GBs = host_to_device_bw_GBs
70 | self.device_to_host_bw_GBs = device_to_host_bw_GBs
71 | self.host_to_device_pj_per_bit = 25
72 | self.device_to_host_pj_per_bit = 25
73 | # self.inter_bw = inter_bw
74 |
75 | # Device memory (shared memory like)
76 | self.memory = memory # unused at the moment
77 | self.mem_bw_GBs = mem_bw_GBs
78 | self.mem_pj_per_bit = mem_pj_per_bit
79 | self.pj_per_tflop = pj_per_tflop
80 |
81 | self.data_type_bytes = data_type_bytes
82 |
83 | self.softmax_ns_per_element = softmax_ns_per_element
84 | self.RMSNorm_ns_per_element = RMSNorm_ns_per_element
85 | self.SiLU_ns_per_element = SiLU_ns_per_element
86 |
87 | self.misc_tflops_per_element = misc_tflops_per_element
88 |
89 | self.sliding_window = sliding_window
90 | self.num_key_value_heads = num_key_value_heads
91 |
92 | self.verbose = verbose
93 |
94 | def load_spec(self, name: str, spec: Dict):
95 | """Load accelerator specification from a dictionary"""
96 | self.name = name
97 | for key, value in spec.items():
98 | if key == "tflops_int4":
99 | continue
100 | if not hasattr(self, key):
101 | raise ValueError(
102 | f"Warning: {key} is not a valid attribute for {self.__class__}"
103 | )
104 | setattr(self, key, value)
105 | if "tflops_int4" in spec and self.data_type_bytes == 0.5:
106 | self.tflops = spec["tflops_int4"]
107 |
108 | # Defined TFLOPS are defined for float16,
109 | # assume that performance is doubled if data type is demoted
110 | def adjust_for_quantization(self):
111 | ratio = 2 / self.data_type_bytes # Assume pj_per_tflop corresponds to float16
112 | self.pj_per_tflop = self.pj_per_tflop / ratio
113 | # self.tflops = self.tflops * (2 / self.data_type_bytes)
114 |
115 | # self.softmax_ns_per_element = self.softmax_ns_per_element / ratio
116 | # self.RMSNorm_ns_per_element = self.RMSNorm_ns_per_element / ratio
117 | # self.SiLU_ns_per_element = self.SiLU_ns_per_element / ratio
118 |
119 | def get_tflops_Linear(self, input_shape, weight_shape):
120 | out_features = weight_shape[0]
121 | tflops = input_shape.numel() * out_features * 2 / 1e12
122 |
123 | return tflops
124 |
125 | # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
126 | def get_tflops_Conv2d(self, input_shape, layer, weight_shape):
127 | batch_size = input_shape[-4] if (len(input_shape) > 3) else 1
128 | n_channels = input_shape[-3] if (len(input_shape) > 2) else 1
129 | n_height = input_shape[-2] if (len(input_shape) > 1) else 1
130 | n_width = input_shape[-1]
131 |
132 | stride = layer.stride
133 | padding = layer.padding
134 |
135 | n_height = n_height + 2 * (
136 | padding[0] if isinstance(padding, tuple) else padding
137 | )
138 | n_width = n_width + 2 * (padding[1] if isinstance(padding, tuple) else padding)
139 |
140 | # Example of how many times a kernel is applied depending on stride:
141 | # | 0 | 1 | 2 | 3 | 4 | 5 |
142 | # Stride 1:
143 | # 00001111
144 | # 11112222
145 | # 22223333
146 | # 33334444
147 | # 44445555
148 | # Stride 2:
149 | # 00001111
150 | # 22223333
151 | # 44445555
152 |
153 | width_times = math.ceil(
154 | (n_width - 1) / (stride[1] if isinstance(stride, tuple) else stride)
155 | )
156 | height_times = math.ceil(
157 | (n_width - 1) / (stride[0] if isinstance(stride, tuple) else stride)
158 | )
159 |
160 | # TFLOPS when applying once the kernel
161 | tflops_kernel = (
162 | 2 * batch_size * n_channels * weight_shape[1] * weight_shape[0]
163 | ) / 1e12
164 |
165 | tflops = tflops_kernel * width_times * height_times
166 |
167 | return tflops
168 |
169 | def get_tflops_LayerNorm(self, input_shape):
170 | batch_size = input_shape[-4] if (len(input_shape) > 3) else 1
171 | n_heads = input_shape[-3] if (len(input_shape) > 2) else 1
172 | n_rows = input_shape[-2] if (len(input_shape) > 1) else 1
173 | n_columns = input_shape[-1]
174 |
175 | tflops = (
176 | batch_size * n_heads * n_rows * n_columns * self.misc_tflops_per_element
177 | )
178 |
179 | return tflops
180 |
181 | # TODO: see DeepSpeed implementation
182 | # def _attn_flops_compute
183 |
184 | def get_tflops(self, input_shape, layer, weight_shape):
185 | if issubclass(torch.nn.Conv2d, type(layer)):
186 | tflops = self.get_tflops_Conv2d(input_shape, layer, weight_shape)
187 | elif issubclass(torch.nn.LayerNorm, type(layer)):
188 | tflops = self.get_tflops_LayerNorm(input_shape)
189 | else:
190 | # Treat everything else as Linear
191 | tflops = self.get_tflops_Linear(input_shape, weight_shape)
192 | # print ("get_tflops not defined for layer: ", type(layer))
193 | # sys.exit(-1)
194 |
195 | return tflops
196 |
197 | def get_moved_data_bytes(
198 | self, input_shape, weight_shape, load_input=False, load_weight=True
199 | ):
200 | batch_size = input_shape[-4] if (len(input_shape) > 3) else 1
201 | n_heads = input_shape[-3] if (len(input_shape) > 2) else 1
202 | n_rows = input_shape[-2] if (len(input_shape) > 1) else 1
203 | n_columns = input_shape[-1]
204 |
205 | weight_size = weight_shape[1] * weight_shape[0] if load_weight else 0
206 | input_size = batch_size * n_heads * n_rows * n_columns if load_input else 0
207 | # output_size = batch_size * n_rows * weight_shape[1]
208 | return self.data_type_bytes * (weight_size + input_size) # + output_size)
209 |
210 | # KV cache load
211 | def load_data(self, input_shape):
212 | batch_size = input_shape[-4] if (len(input_shape) > 3) else 1
213 | n_heads = input_shape[-3] if (len(input_shape) > 2) else 1
214 | n_rows = input_shape[-2] if (len(input_shape) > 1) else 1
215 | n_columns = input_shape[-1]
216 |
217 | # data_size_bytes = self.data_type_bytes * (
218 | # batch_size * n_rows * n_heads * n_columns
219 | # )
220 |
221 | # Hardcoded to FP16
222 | data_size_bytes = 2 * (batch_size * n_rows * n_heads * n_columns)
223 | # B / (GB/s) --> s/G --> ns
224 | transfer_time_ns = data_size_bytes / self.mem_bw_GBs
225 |
226 | performance = {"kv_load": transfer_time_ns}
227 | # GB/s * time --> GB * pJ/bit --> J
228 | energy = {"main_mem": data_size_bytes * 8 * self.mem_pj_per_bit}
229 |
230 | if self.verbose:
231 | print(
232 | "Load time for input_shape:",
233 | input_shape,
234 | "=",
235 | transfer_time_ns,
236 | "(ns) with",
237 | energy,
238 | "pj",
239 | performance,
240 | )
241 |
242 | return transfer_time_ns, performance, energy
243 |
244 | def host_transfer(self, input_shape, direction="to_device", generated_tokens=1):
245 | batch_size = input_shape[-4] if (len(input_shape) > 3) else 1
246 | n_heads = input_shape[-3] if (len(input_shape) > 2) else 1
247 | n_rows = input_shape[-2] if (len(input_shape) > 1) else 1
248 | n_columns = input_shape[-1]
249 |
250 | bandwidth = (
251 | self.host_to_device_bw_GBs
252 | if direction == "to_device"
253 | else self.device_to_host_bw_GBs
254 | )
255 |
256 | data_size_bytes = self.data_type_bytes * (
257 | batch_size * n_heads * n_rows * n_columns * generated_tokens
258 | )
259 | # B / (GB/s) --> s / G --> ns
260 | transfer_time_ns = data_size_bytes / bandwidth
261 |
262 | name_op = "host_to_device" if direction == "to_device" else "device_to_host"
263 |
264 | performance = {name_op: transfer_time_ns}
265 | moved_data = {name_op: data_size_bytes}
266 |
267 | ddr_pj_per_bit = (
268 | self.host_to_device_pj_per_bit
269 | if direction == "to_device"
270 | else self.device_to_host_pj_per_bit
271 | )
272 | energy = {name_op: data_size_bytes * 8 * ddr_pj_per_bit}
273 |
274 | if self.verbose:
275 | print(
276 | f"Transfer time for input_shape: {input_shape} = {transfer_time_ns} (ns) "
277 | f"with {energy} pj perf: {performance} data in bytes: {moved_data}"
278 | )
279 |
280 | return transfer_time_ns, performance, energy, moved_data
281 |
282 | def compute_ns(
283 | self, input_shape, layer_obj, weight_shape, load_input=False, load_weight=True
284 | ):
285 | tflops = self.get_tflops(input_shape, layer_obj, weight_shape)
286 |
287 | data_size_bytes = self.get_moved_data_bytes(
288 | input_shape, weight_shape, load_input=load_input, load_weight=load_weight
289 | )
290 |
291 | compute_time_ns = (tflops / self.tflops) * 1e9
292 | transfer_time_ns = data_size_bytes / self.mem_bw_GBs
293 | real_time_ns = max(compute_time_ns, transfer_time_ns)
294 |
295 | performance = {
296 | "compute": compute_time_ns,
297 | "mem_transfer": transfer_time_ns,
298 | }
299 |
300 | energy = {
301 | "compute": tflops * self.pj_per_tflop,
302 | "main_mem": data_size_bytes * 8 * self.mem_pj_per_bit,
303 | }
304 |
305 | if self.verbose:
306 | print(
307 | f"Computing {input_shape} x {weight_shape} with TFLOPS: {tflops} "
308 | f"with {data_size_bytes} bytes"
309 | )
310 | print(
311 | f"takes {real_time_ns} ns with {compute_time_ns} in compute and {transfer_time_ns} "
312 | f"in loading data"
313 | )
314 | print(
315 | f"and consumes {energy['compute']} pJ for compute and {energy['main_mem']} pJ "
316 | "for loading data"
317 | )
318 | print(f"performance: {performance}")
319 |
320 | return real_time_ns, performance, energy
321 |
322 | def compute_scaled_dot_product_ns(
323 | self,
324 | context,
325 | key_shape, # same dimensions as value_shape
326 | output_shape,
327 | use_kv_cache=True,
328 | summarization=False,
329 | sum_size=0,
330 | ):
331 | batch_size = key_shape[-4] if (len(key_shape) > 3) else 1
332 | n_heads = key_shape[-3] if (len(key_shape) > 2) else 1
333 | n_rows = key_shape[-2] if (len(key_shape) > 1) else 1 # already concatenated!
334 | n_columns = key_shape[-1]
335 |
336 | q_rows = (
337 | output_shape[-2] if (len(output_shape) > 2) else 1
338 | ) # 1 when using kv cache in GEN.
339 |
340 | compute_time_ns = 0
341 | load_k_time = 0
342 | load_v_time = 0
343 | performance = {}
344 | energy = {}
345 |
346 | # Load KV cache if GENeration and kv cache is enabled
347 | # Only K is required for next step
348 | if not summarization and use_kv_cache:
349 | if self.sliding_window != -1:
350 | n_rows = self.sliding_window
351 | if self.num_key_value_heads != -1:
352 | n_heads = self.num_key_value_heads
353 |
354 | k_cache = torch.Size([batch_size, n_heads, n_rows, n_columns])
355 | load_k_time, load_k_perf, load_k_energy = self.load_data(k_cache)
356 | performance = add_dictionaries(performance, load_k_perf)
357 | energy = add_dictionaries(energy, load_k_energy)
358 |
359 | # Q x Kt
360 | # (q_rows, embedding) x (embedding, kv_cache_length) = (q_rows, kv_cache_length)
361 | query_shape = torch.Size([batch_size, n_heads, q_rows, n_columns])
362 | kt_shape = torch.Size([batch_size, n_heads, n_columns, n_rows])
363 | # TODO: fix this calculation
364 | step_time, step_perf, step_energy = self.compute_ns(
365 | query_shape,
366 | torch.nn.Linear(in_features=n_columns, out_features=n_rows),
367 | kt_shape,
368 | load_input=False,
369 | load_weight=False,
370 | )
371 | compute_time_ns += max(
372 | load_k_time, step_time
373 | ) # overlap loading K with Q x Kt computation
374 | performance = add_dictionaries(performance, step_perf)
375 | energy = add_dictionaries(energy, step_energy)
376 |
377 | # Load KV cache if GENeration and kv cache is enabled
378 | # Only V is required for next step
379 | if not summarization and use_kv_cache:
380 |
381 | v_cache = torch.Size([batch_size, n_heads, n_rows, n_columns])
382 | load_v_time, load_v_perf, load_v_energy = self.load_data(v_cache)
383 | performance = add_dictionaries(performance, load_v_perf)
384 | energy = add_dictionaries(energy, load_v_energy)
385 |
386 | # QxKT x V
387 | # (q_rows, kv_cache_length) x (kv_cache_length, embedding) = (q_rows, embedding)
388 | qxkt_shape = torch.Size([batch_size, n_heads, q_rows, n_rows])
389 | # TODO: fix this calculation
390 | step_time, step_perf, step_energy = self.compute_ns(
391 | query_shape,
392 | torch.nn.Linear(in_features=n_rows, out_features=n_columns),
393 | key_shape,
394 | load_input=False,
395 | load_weight=False,
396 | )
397 | compute_time_ns += max(
398 | load_v_time, step_time
399 | ) # overlap loading V with QxKt x V computation
400 | performance = add_dictionaries(performance, step_perf)
401 | energy = add_dictionaries(energy, step_energy)
402 |
403 | if self.verbose:
404 | print(
405 | f"Computing scaled_dot_product: {query_shape} x {kt_shape} x {key_shape} "
406 | f"in {compute_time_ns} with {energy}"
407 | )
408 |
409 | return compute_time_ns, performance, energy
410 |
411 | def compute_matmul_ns(
412 | self,
413 | context,
414 | shape_a,
415 | shape_b,
416 | summarization=False,
417 | sum_size=0,
418 | ):
419 | batch_size = shape_a[-4] if (len(shape_a) > 3) else 1
420 | n_heads = shape_a[-3] if (len(shape_a) > 2) else 1
421 | n_rows = shape_a[-2] if (len(shape_a) > 1) else 1
422 | n_columns = shape_a[-1]
423 |
424 | compute_time_ns = 0
425 | performance = {}
426 | energy = {}
427 |
428 | if context == "attn_weights":
429 | if not summarization:
430 | kv_cache = torch.Size([batch_size, n_heads, sum_size, n_columns * 2])
431 | step_time, step_perf, step_energy = self.load_data(kv_cache)
432 | compute_time_ns += step_time
433 | performance = add_dictionaries(performance, step_perf)
434 | energy = add_dictionaries(energy, step_energy)
435 |
436 | step_time, step_perf, step_energy = self.compute_ns(
437 | shape_a,
438 | torch.nn.Linear(in_features=n_rows, out_features=n_columns),
439 | shape_b,
440 | load_input=False,
441 | load_weight=False,
442 | )
443 |
444 | compute_time_ns += step_time
445 | performance = add_dictionaries(performance, step_perf)
446 | energy = add_dictionaries(energy, step_energy)
447 |
448 | if self.verbose:
449 | print(
450 | f"Computing matmul: {shape_a} x {shape_b} in {compute_time_ns} with {energy}"
451 | )
452 |
453 | return compute_time_ns, performance, energy
454 |
455 | def compute_activation_ns(self, data_shape, activation="SiLU"):
456 | batch_size = data_shape[-4] if (len(data_shape) > 3) else 1
457 | n_heads = data_shape[-3] if (len(data_shape) > 2) else 1
458 | n_rows = data_shape[-2] if (len(data_shape) > 1) else 1
459 | n_columns = data_shape[-1]
460 |
461 | activation_ns_per_element = 0
462 | if activation == "SiLU":
463 | activation_ns_per_element = self.SiLU_ns_per_element
464 |
465 | tflops = (
466 | batch_size * n_heads * n_rows * n_columns * self.misc_tflops_per_element
467 | )
468 | compute_time_ns = (
469 | batch_size * n_heads * (n_rows * (activation_ns_per_element * n_columns))
470 | )
471 |
472 | performance = {"compute": compute_time_ns}
473 | energy = {"compute": tflops * self.pj_per_tflop}
474 |
475 | if self.verbose:
476 | print(
477 | f"Computing Activation: {activation} : {data_shape} in {compute_time_ns} "
478 | f"with {energy}"
479 | )
480 |
481 | return compute_time_ns, performance, energy
482 |
483 | def compute_RMSNorm_ns(self, data_shape, dimension):
484 | batch_size = data_shape[-4] if (len(data_shape) > 3) else 1
485 | n_heads = data_shape[-3] if (len(data_shape) > 2) else 1
486 | n_rows = data_shape[-2] if (len(data_shape) > 1) else 1
487 | n_columns = dimension
488 |
489 | tflops = (
490 | batch_size * n_heads * n_rows * n_columns * self.misc_tflops_per_element
491 | )
492 | compute_time_ns = (
493 | batch_size * n_heads * n_rows * (self.RMSNorm_ns_per_element * n_columns)
494 | )
495 |
496 | performance = {"compute": compute_time_ns}
497 | energy = {"compute": tflops * self.pj_per_tflop}
498 |
499 | if self.verbose:
500 | print(
501 | "Computing RMSNorm:", data_shape, "in", compute_time_ns, "with", energy
502 | )
503 |
504 | return compute_time_ns, performance, energy
505 |
506 | def compute_softmax_ns(self, data_shape):
507 | batch_size = data_shape[-4] if (len(data_shape) > 3) else 1
508 | n_heads = data_shape[-3] if (len(data_shape) > 2) else 1
509 | n_rows = data_shape[-2] if (len(data_shape) > 1) else 1
510 | n_columns = data_shape[-1]
511 |
512 | tflops = (
513 | batch_size * n_heads * n_rows * n_columns * self.misc_tflops_per_element
514 | )
515 | compute_time_ns = (
516 | batch_size * n_heads * n_rows * (self.softmax_ns_per_element * n_columns)
517 | )
518 |
519 | performance = {"compute": compute_time_ns}
520 | energy = {"compute": tflops * self.pj_per_tflop}
521 |
522 | if self.verbose:
523 | print(
524 | "Computing softmax:", data_shape, "in", compute_time_ns, "with", energy
525 | )
526 |
527 | return compute_time_ns, performance, energy
528 |
--------------------------------------------------------------------------------
/src/upmem_llm_framework/profiler.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2014-2024 - UPMEM
3 | # UPMEM S.A.S France property - UPMEM confidential information covered by NDA
4 | # For UPMEM partner internal use only - no modification allowed without permission of UPMEM
5 | #
6 | # This file implements all profiling related classes and functions
7 |
8 | import sys
9 | import time
10 | from collections import OrderedDict
11 | from typing import Mapping, Callable, Tuple
12 |
13 | import torch
14 | from torch.nn import Linear, SiLU, LayerNorm, Embedding, Dropout, Softmax, Conv2d
15 | from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding
16 | from transformers.activations import NewGELUActivation
17 | from transformers.pytorch_utils import Conv1D
18 |
19 | from upmem_llm_framework.simulator import Simulator
20 | from upmem_llm_framework.utils import add_dictionaries
21 |
22 |
23 | class layer_profile:
24 | def __init__(self, uniq_id, name, n_layer, context, dim_in, dim_out, obj=None):
25 | self.id = uniq_id
26 | self.name = name
27 | self.n_layer = n_layer
28 | self.context = context
29 | self.dim_in = dim_in
30 | self.dim_out = dim_out
31 | self.exec_time = 0
32 | self.exec_nums = 0
33 | self.energy = {}
34 | self.obj = obj
35 |
36 |
37 | class layer_log:
38 | def __init__(
39 | self,
40 | uniq_id,
41 | name,
42 | context,
43 | summarization,
44 | start_time,
45 | input,
46 | weights,
47 | output,
48 | exec_time_ms,
49 | performance,
50 | energy_pj,
51 | transfer_bytes,
52 | ):
53 | self.id = uniq_id
54 | self.name = name
55 | self.context = context
56 | self.summarization = summarization
57 | self.start_time = start_time
58 | self.input = input
59 | self.weights = weights
60 | self.output = output
61 | self.exec_time_ms = exec_time_ms
62 | self.performance = performance
63 | self.energy = energy_pj
64 | self.transfer_bytes = transfer_bytes
65 |
66 |
67 | class UPM_Profiler:
68 | layer_dimensions: Mapping[type, Callable[[torch.nn.Module], Tuple[int, int]]] = {
69 | Linear: (lambda l: (l.in_features, l.out_features)),
70 | NewGELUActivation: (lambda l: (1, 1)),
71 | SiLU: (lambda l: (1, 1)),
72 | LlamaRMSNorm: (lambda l: (l.weight.size()[0], l.weight.size()[0])),
73 | LlamaRotaryEmbedding: (lambda l: (1, 1)),
74 | LayerNorm: (lambda l: (1, 1)),
75 | Embedding: (lambda l: (1, 1)),
76 | Dropout: (lambda l: (1, 1)),
77 | Softmax: (lambda l: (l.dim, l.dim)),
78 | Conv1D: (lambda l: (l.weight.shape[0], l.weight.shape[1])),
79 | Conv2d: (lambda l: (l.kernel_size[0], l.kernel_size[1])),
80 | }
81 |
82 | functional_layers = {LlamaRMSNorm, SiLU}
83 |
84 | def __init__(self, options):
85 | self.n_layers = 0
86 | self.n_executions = 0
87 | self.layers = {}
88 | self.functions = {}
89 |
90 | self.options = options
91 | self.sim_compute = False
92 | self.sim_sliding_window = -1
93 | self.sim_num_key_value_heads = 8
94 | self.sim_data_type = "float16"
95 | self.sim_data_type_bytes = 2.0
96 | self.simulator = None
97 | self.set_options(options)
98 |
99 | self.start_inference = time.time_ns()
100 | self.inference_time = 0
101 | self.summarization_time = 0
102 | self.sum_perf = {}
103 | self.gen_perf = {}
104 | self.sum_energy = {}
105 | self.gen_energy = {}
106 | self.sum_transfer_bytes = {}
107 | self.gen_transfer_bytes = {}
108 |
109 | self.last_layer = "lm_head"
110 | self.batch_size = 1
111 |
112 | self.layers_start = {}
113 | self.layers_end = {}
114 | self.log = []
115 |
116 | self.forward_input_shape = None
117 | self.forward_time_start = 0
118 | self.forward_time_end = 0
119 |
120 | self.func_input_shape = None
121 | self.start_func = 0
122 | self.end_func = 0
123 |
124 | def set_options(self, options):
125 | self.options = options
126 | # simulation related
127 | self.sim_compute = options.sim_compute
128 | self.sim_sliding_window = options.sim_sliding_window
129 | self.sim_num_key_value_heads = options.sim_num_key_value_heads
130 |
131 | self.sim_data_type = options.sim_data_type
132 |
133 | self.sim_data_type_bytes = {
134 | "int4": 0.5,
135 | "int8": 1.0,
136 | "float16": 2.0,
137 | "bfloat16": 2.0,
138 | "float32": 4.0,
139 | }.get(self.sim_data_type)
140 |
141 | self.simulator = self.create_arch_simulator() if options.simulation else None
142 |
143 | def create_arch_simulator(self):
144 | return Simulator(
145 | data_type_bytes=self.sim_data_type_bytes,
146 | sliding_window=self.sim_sliding_window,
147 | num_key_value_heads=self.sim_num_key_value_heads,
148 | verbose=self.options.sim_verbose,
149 | )
150 |
151 | def print_layers_model(self):
152 | print("##### Layers of Model in order of creation #####")
153 | print(
154 | "Layer ID (creation order), Context, Function, Dimensions (rows x columns), "
155 | "times executed, avg. execution time (ms)"
156 | )
157 | for layer in self.layers.values():
158 | print(
159 | f"{str(layer.id)},"
160 | f"{layer.context},"
161 | f"{layer.name},"
162 | f"({layer.dim_in}x{layer.dim_out}),"
163 | f"{layer.exec_nums},"
164 | f"{layer.exec_time / self.n_executions / 1e6}"
165 | )
166 |
167 | def print_functions_model(self):
168 | print("##### Functions called by the Model in order of calling #####")
169 | print(
170 | "Function name, Context, Dimensions in (columns), Dimensions out (columns), "
171 | "times executed, avg. execution time (ms)"
172 | )
173 | for name, func in self.functions.items():
174 | print(
175 | f"{str(name)},"
176 | f"{func.context},"
177 | f"({func.dim_in}),"
178 | f"({func.dim_out}),"
179 | f"{func.exec_nums},"
180 | f"{func.exec_time / 1e6}"
181 | )
182 |
183 | def print_log_summary(self, show_summarization=False, show_all=False):
184 | phase = "Generation"
185 | if show_summarization:
186 | phase = "Summarization"
187 | if show_all:
188 | phase = "All (SUM and GEN)"
189 | print("#####", phase, "Execution summary #####")
190 | name_ctxt = []
191 | summary_time = OrderedDict()
192 | summary_perf = OrderedDict()
193 | summary_energy = OrderedDict()
194 | summary_transfer_bytes = OrderedDict()
195 | summary_nexec = OrderedDict()
196 | input_shapes = OrderedDict()
197 | weights_shapes = OrderedDict()
198 | output_shapes = OrderedDict()
199 | for log in self.log:
200 | if not show_all and not show_summarization and log.summarization:
201 | continue
202 | if not show_all and show_summarization and not log.summarization:
203 | continue
204 | ctxt = log.name + ":" + log.context
205 | if not ctxt in summary_time.keys():
206 | name_ctxt.append(ctxt)
207 | summary_nexec[ctxt] = 1 + summary_nexec.get(ctxt, 0)
208 | summary_time[ctxt] = log.exec_time_ms + summary_time.get(ctxt, 0)
209 | summary_energy[ctxt] = add_dictionaries(
210 | summary_energy.get(ctxt, {}), log.energy
211 | )
212 | summary_perf[ctxt] = add_dictionaries(
213 | summary_perf.get(ctxt, {}), log.performance
214 | )
215 | summary_transfer_bytes[ctxt] = add_dictionaries(
216 | summary_transfer_bytes.get(ctxt, {}), log.transfer_bytes
217 | )
218 |
219 | input_shapes[ctxt] = "(" + ":".join([str(x) for x in log.input]) + ")"
220 | weights_shapes[ctxt] = "(" + ":".join([str(x) for x in log.weights]) + ")"
221 | output_shapes[ctxt] = "(" + ":".join([str(x) for x in log.output]) + ")"
222 |
223 | executed_times = 1 if show_summarization else (self.n_executions - 1)
224 | print(
225 | "Function: Context: input shape: weights shape: output shape:"
226 | "time(s):H2C(ms):C2H(ms):compute(ms):mem_transfer(ms):kv_load(ms)"
227 | "host_to_device(mJ):device_to_host(mJ):main_mem(mJ):compute(mJ)"
228 | )
229 | for key in name_ctxt:
230 |
231 | perf_values = []
232 | for perf_key in [
233 | "host_to_device",
234 | "device_to_host",
235 | "compute",
236 | "mem_transfer",
237 | "kv_load",
238 | ]:
239 | perf_values.append(
240 | str(summary_perf[key].get(perf_key, 0) / 1e6 / executed_times)
241 | )
242 | perf_string = ":".join(perf_values)
243 |
244 | energy_values = []
245 | for ene_key in ["host_to_device", "device_to_host", "main_mem", "compute"]:
246 | energy_values.append(
247 | str(summary_energy[key].get(ene_key, 0) / 1e6 / executed_times)
248 | )
249 | energy_string = ":".join(energy_values)
250 | print(
251 | key,
252 | input_shapes[key],
253 | weights_shapes[key],
254 | output_shapes[key],
255 | (summary_time[key] / executed_times),
256 | perf_string,
257 | energy_string,
258 | )
259 |
260 | total_time_explained = sum(summary_time.values())
261 | total_percentage_explained = 0
262 | for key in name_ctxt:
263 | print(
264 | key,
265 | "explains",
266 | (summary_time[key] / total_time_explained) * 100,
267 | "% of the total inference time (num. executions:",
268 | summary_nexec[key],
269 | ") average time:",
270 | summary_time[key] / summary_nexec[key],
271 | "(ms)",
272 | )
273 | total_percentage_explained += (
274 | summary_time[key] / total_time_explained
275 | ) * 100
276 | print(
277 | "Profiler captures", total_percentage_explained, "% of the total execution"
278 | )
279 | print("Profiler captures", total_time_explained, "ms of the total execution")
280 | print(summary_time)
281 |
282 | def print_log(self):
283 | print("##### Execution log #####")
284 | print(
285 | "Start time, exec time, Function, Context, input shape, weights shape, output shape"
286 | )
287 | for log in self.log:
288 | input_shape = "(" + ",".join([str(x) for x in log.input]) + ")"
289 | weights_shape = "(" + ",".join([str(x) for x in log.weights]) + ")"
290 | output_shape = "(" + ",".join([str(x) for x in log.output]) + ")"
291 | print(
292 | log.start_time / 1e6,
293 | log.exec_time_ms,
294 | log.id,
295 | log.name,
296 | log.context,
297 | input_shape,
298 | weights_shape,
299 | output_shape,
300 | )
301 |
302 | def update_inference_perf(self, step_perf):
303 | if self.simulator.sum:
304 | for key in step_perf.keys():
305 | self.sum_perf[key] = self.sum_perf.get(key, 0) + step_perf[key]
306 | else:
307 | for key in step_perf.keys():
308 | self.gen_perf[key] = self.gen_perf.get(key, 0) + step_perf[key]
309 |
310 | def update_inference_energy(self, step_energy):
311 | if self.simulator.sum:
312 | for key in step_energy.keys():
313 | self.sum_energy[key] = self.sum_energy.get(key, 0) + step_energy[key]
314 | else:
315 | for key in step_energy.keys():
316 | self.gen_energy[key] = self.gen_energy.get(key, 0) + step_energy[key]
317 |
318 | def update_inference_transfer_bytes(self, step_transfer_bytes):
319 | if self.simulator.sum:
320 | for key in step_transfer_bytes.keys():
321 | self.sum_transfer_bytes[key] = (
322 | self.sum_transfer_bytes.get(key, 0) + step_transfer_bytes[key]
323 | )
324 | else:
325 | for key in step_transfer_bytes.keys():
326 | self.gen_transfer_bytes[key] = (
327 | self.gen_transfer_bytes.get(key, 0) + step_transfer_bytes[key]
328 | )
329 |
330 | def start(
331 | self,
332 | layer_mapping=None,
333 | layer_attn_ctxt="",
334 | last_layer="lm_head",
335 | batch_size=1,
336 | moe_end="",
337 | experts_per_token=2,
338 | ):
339 | self.start_inference = time.time_ns()
340 | self.n_executions = 0
341 | self.inference_time = 0
342 | self.summarization_time = 0
343 | self.sum_perf = {}
344 | self.gen_perf = {}
345 | self.sum_energy = {}
346 | self.gen_energy = {}
347 | self.sum_transfer_bytes = {}
348 | self.gen_transfer_bytes = {}
349 |
350 | self.last_layer = last_layer
351 | self.batch_size = batch_size
352 |
353 | self.layers_start = {}
354 | self.layers_end = {}
355 | self.log = []
356 |
357 | if self.simulator:
358 | self.simulator.map_layers(
359 | layer_mapping,
360 | layer_attn_ctxt=layer_attn_ctxt,
361 | moe_end=moe_end,
362 | experts_per_token=experts_per_token,
363 | )
364 |
365 | def end(self):
366 | if self.simulator:
367 | step_time, step_perf, step_energy, step_transfer_bytes = (
368 | self.simulator.simulate_end(
369 | self.forward_input_shape, generated_tokens=(self.n_executions)
370 | )
371 | )
372 | self.inference_time += step_time
373 | self.update_inference_perf(step_perf)
374 | self.update_inference_energy(step_energy)
375 | self.update_inference_transfer_bytes(step_transfer_bytes)
376 | else:
377 | self.inference_time = time.time_ns() - self.start_inference
378 |
379 | inference_time_sec = self.inference_time / 1e9
380 | sum_energy_mJ = 0
381 | gen_energy_mJ = 0
382 | sum_time_s = self.summarization_time / 1e9
383 | gen_time_s = inference_time_sec - sum_time_s
384 | gen_n_executions = self.n_executions - 1
385 |
386 | print("##### UPMEM PROFILER OUTPUT #####")
387 | print(
388 | f"Total time (SUM + GEN): {inference_time_sec} s, "
389 | f"with data type: {self.sim_data_type}, "
390 | f"batch size: {self.batch_size}"
391 | )
392 | print(
393 | f"Generated tokens: {gen_n_executions * self.batch_size} "
394 | f"in {gen_time_s} s, "
395 | f"with tokens/s: {(gen_n_executions * self.batch_size) / gen_time_s}"
396 | )
397 | print(
398 | f"Summarization step took: {sum_time_s} s, "
399 | f"weight in the execution: SUM: {sum_time_s / inference_time_sec}%, "
400 | f"GEN: {gen_time_s / inference_time_sec}%"
401 | )
402 |
403 | if self.simulator:
404 | print("SUMMARIZATION summary")
405 | for key, transfer in self.sum_transfer_bytes.items():
406 | print(f"Transferred data in {key}: {transfer / 1e6} MB")
407 | for key, energy in self.sum_energy.items():
408 | energy_mj = energy / 1e9
409 | print(f"Energy in {key}: {energy_mj} mJ")
410 | sum_energy_mJ += energy / 1e9
411 | print(f"Energy: {sum_energy_mJ} mJ")
412 | print(f"Power: {sum_energy_mJ / 1e3 / sum_time_s} W")
413 |
414 | if gen_n_executions > 0:
415 | print("GENERATION summary")
416 | for key, transfer in self.gen_transfer_bytes.items():
417 | print(
418 | f"Transferred data in {key}: {transfer / 1e6} MB, "
419 | f"MB/token: {transfer / 1e6 / self.n_executions / self.batch_size}"
420 | )
421 | for key, energy in self.gen_energy.items():
422 | energy_mj = energy / 1e9
423 | print(
424 | f"Energy in {key}: {energy_mj} mJ, "
425 | f"mJ/token: {energy_mj / gen_n_executions / self.batch_size}"
426 | )
427 | gen_energy_mJ += energy_mj
428 | print(
429 | f"Energy: {gen_energy_mJ} mJ, "
430 | f"mJ/token: {gen_energy_mJ / gen_n_executions / self.batch_size}"
431 | )
432 | print(f"Power: {gen_energy_mJ / 1e3 / gen_time_s} W")
433 |
434 | print("Execution time breakdown (ms / %)")
435 | print("SUMMARIZATION phase")
436 | for perf_key in [
437 | "host_to_device",
438 | "device_to_host",
439 | "compute",
440 | "mem_transfer",
441 | "kv_load",
442 | ]:
443 | perf_value = self.sum_perf.get(perf_key, 0)
444 | print(
445 | perf_key, (perf_value / 1e6), "(ms)", perf_value / 1e9 / sum_time_s
446 | )
447 |
448 | if gen_n_executions > 0:
449 | print("GENERATION phase")
450 | for perf_key in [
451 | "host_to_device",
452 | "device_to_host",
453 | "compute",
454 | "mem_transfer",
455 | "kv_load",
456 | ]:
457 | perf_value = self.gen_perf.get(perf_key, 0)
458 | print(
459 | f"{perf_key}: {perf_value / 1e6} ms, {perf_value / 1e9 / gen_time_s}"
460 | )
461 |
462 | if self.options.report_layers:
463 | self.print_layers_model()
464 |
465 | if self.options.report_functions:
466 | self.print_functions_model()
467 |
468 | if self.options.print_log:
469 | self.print_log()
470 |
471 | if self.options.print_log_summary:
472 | self.print_log_summary()
473 | self.print_log_summary(show_summarization=True)
474 | self.print_log_summary(show_all=True)
475 |
476 | print("##### END UPMEM PROFILER OUTPUT #####")
477 |
478 | def add(self, layer, context):
479 | layer_type = type(layer)
480 | dim_in, dim_out = next(
481 | (
482 | dim_func(layer)
483 | for key, dim_func in self.layer_dimensions.items()
484 | if issubclass(layer_type, key)
485 | ),
486 | (None, None),
487 | )
488 |
489 | if dim_in is None or dim_out is None:
490 | print(f"Layer: {layer_type} not supported")
491 | sys.exit()
492 |
493 | name_layer = layer_type.__name__
494 |
495 | self.layers[layer] = layer_profile(
496 | self.n_layers,
497 | name_layer,
498 | self.n_layers,
499 | context,
500 | dim_in,
501 | dim_out,
502 | obj=layer,
503 | )
504 | self.n_layers += 1
505 |
506 | def forward_start(self, input_shape):
507 | self.forward_input_shape = input_shape
508 |
509 | if self.simulator:
510 | self.forward_time_start = self.inference_time
511 | else:
512 | self.forward_time_start = time.time_ns()
513 |
514 | def forward_end(self, output_shape, context, layer_obj=None):
515 | self.forward_time_end = time.time_ns()
516 |
517 | weights_shape = torch.Size(
518 | [self.layers[layer_obj].dim_in, self.layers[layer_obj].dim_out]
519 | )
520 |
521 | cur_exec_time = self.forward_time_end - self.forward_time_start
522 | performance = {}
523 | energy = {}
524 | relative_start_time = self.forward_time_start - self.start_inference
525 |
526 | if self.simulator:
527 | name = self.layers[layer_obj].name
528 | if any(
529 | isinstance(layer_obj, layer_t) for layer_t in self.functional_layers
530 | ):
531 | cur_exec_time, performance, energy, transfer_bytes = (
532 | self.simulator.simulate_function(
533 | self.layers[layer_obj],
534 | context,
535 | output_shape,
536 | self.layers[layer_obj].dim_out,
537 | )
538 | )
539 | else:
540 | cur_exec_time, performance, energy, transfer_bytes = (
541 | self.simulator.simulate_layer(
542 | self.layers[layer_obj],
543 | self.forward_input_shape,
544 | layer_obj,
545 | weights_shape,
546 | output_shape,
547 | )
548 | ) # or context?
549 | self.inference_time += cur_exec_time
550 | self.update_inference_perf(performance)
551 | self.update_inference_energy(energy)
552 | self.update_inference_transfer_bytes(transfer_bytes)
553 |
554 | self.layers[layer_obj].exec_time += cur_exec_time
555 | self.layers[layer_obj].energy = add_dictionaries(
556 | self.layers[layer_obj].energy, energy
557 | )
558 | self.layers_start[layer_obj] = self.forward_time_start
559 | self.layers_end[layer_obj] = self.inference_time
560 | relative_start_time = self.forward_time_start
561 | else:
562 | self.inference_time += cur_exec_time
563 | self.layers[layer_obj].exec_time += cur_exec_time
564 | self.layers_start[layer_obj] = self.forward_time_start
565 | self.layers_end[layer_obj] = self.forward_time_end
566 | transfer_bytes = 0
567 |
568 | self.layers[layer_obj].exec_nums += 1
569 |
570 | summarization_phase = self.simulator.sum if self.simulator else False
571 |
572 | cur_logging = layer_log(
573 | self.layers[layer_obj].id,
574 | self.layers[layer_obj].name,
575 | self.layers[layer_obj].context,
576 | summarization_phase,
577 | relative_start_time / 1e6,
578 | self.forward_input_shape,
579 | weights_shape,
580 | output_shape,
581 | cur_exec_time / 1e6,
582 | performance,
583 | energy,
584 | transfer_bytes,
585 | )
586 |
587 | self.log.append(cur_logging)
588 | if self.layers[layer_obj].context == self.last_layer:
589 | if self.simulator and not self.simulator.sum:
590 | self.simulator.sum_size += 1
591 | if self.n_executions == 0:
592 | if self.simulator:
593 | self.simulator.start_gen()
594 | self.simulator.sum_size = (
595 | output_shape[-2] if (len(output_shape) > 1) else 1
596 | )
597 | self.summarization_time = self.inference_time
598 | self.n_executions += 1
599 | print(f"New token generated ({self.n_executions})", end="\r")
600 |
601 | def forward_func_start(self, name, context, input_shape):
602 | self.func_input_shape = input_shape
603 |
604 | self.start_func = time.time_ns()
605 | if self.simulator:
606 | self.start_func = self.inference_time
607 |
608 | def forward_func_end(self, function, context, output_shape):
609 | self.end_func = time.time_ns()
610 |
611 | func_profile = self.functions.get(
612 | function.__name__,
613 | layer_profile(
614 | 0, function.__name__, 0, context, self.func_input_shape[-1], output_shape[-1]
615 | ),
616 | )
617 |
618 | cur_exec_time = self.end_func - self.start_func
619 | performance = {}
620 | energy = {}
621 | relative_time = self.start_func - self.start_inference
622 |
623 | if self.simulator:
624 | cur_exec_time, performance, energy, transfer_bytes = (
625 | self.simulator.simulate_function(
626 | function, context, self.func_input_shape, output_shape
627 | )
628 | )
629 | self.inference_time += cur_exec_time
630 | self.update_inference_perf(performance)
631 | self.update_inference_energy(energy)
632 | self.update_inference_transfer_bytes(transfer_bytes)
633 | else:
634 | transfer_bytes = 0
635 |
636 | relative_exec_time = self.start_func
637 |
638 | summarization_phase = self.simulator.sum if self.simulator else False
639 | cur_logging = layer_log(
640 | 0, # functions have ID set to 0
641 | function.__name__,
642 | context,
643 | summarization_phase,
644 | relative_exec_time / 1e6,
645 | self.func_input_shape,
646 | output_shape,
647 | output_shape,
648 | cur_exec_time / 1e6,
649 | performance,
650 | energy,
651 | transfer_bytes,
652 | )
653 |
654 | self.log.append(cur_logging)
655 |
656 | func_profile.exec_nums += 1
657 | func_profile.exec_time += cur_exec_time
658 | self.functions[function.__name__] = func_profile
659 |
--------------------------------------------------------------------------------