├── tests ├── __init__.py ├── extra_arch.yaml └── test_upmem_pytorch_simulator.py ├── src └── upmem_llm_framework │ ├── __init__.py │ ├── utils.py │ ├── sim_architectures.py │ ├── architectures_schema.json │ ├── options.py │ ├── pytorch_upmem_layers.py │ ├── sim_architectures.yaml │ ├── simulator.py │ ├── base_architecture.py │ └── profiler.py ├── .pylintrc ├── .vale.ini ├── .editorconfig ├── MANIFEST.in ├── .github └── workflows │ └── pytest.yml ├── LICENSE ├── examples ├── nn_example.py └── hf_example.py ├── .gitignore ├── pyproject.toml ├── scripts_simulation └── simulations_llama2_7B.py └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit test package for upmem_pytorch_simulator.""" 2 | -------------------------------------------------------------------------------- /src/upmem_llm_framework/__init__.py: -------------------------------------------------------------------------------- 1 | from upmem_llm_framework.options import initialize_profiling_options 2 | from upmem_llm_framework.pytorch_upmem_layers import profiler_init, profiler_start, profiler_end 3 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | 3 | disable= 4 | missing-function-docstring, 5 | missing-class-docstring, 6 | missing-module-docstring, 7 | 8 | 9 | [BASIC] 10 | 11 | good-names=sum_energy_mJ 12 | good-names-rgxs=^.*_(GBs|mJ)$ 13 | -------------------------------------------------------------------------------- /.vale.ini: -------------------------------------------------------------------------------- 1 | StylesPath = styles 2 | 3 | MinAlertLevel = suggestion 4 | 5 | Packages = Google, proselint, write-good, alex, Readability 6 | 7 | Vocab = Upmem 8 | 9 | [*.md] 10 | BasedOnStyles = Vale, Google, proselint, write-good, alex, Readability 11 | -------------------------------------------------------------------------------- /src/upmem_llm_framework/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2014-2024 - UPMEM 3 | # 4 | 5 | 6 | def add_dictionaries(dict1, dict2): 7 | return { 8 | key: dict1.get(key, 0) + dict2.get(key, 0) for key in set(dict1) | set(dict2) 9 | } 10 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include AUTHORS.rst 2 | include CONTRIBUTING.rst 3 | include HISTORY.rst 4 | include LICENSE 5 | include README.rst 6 | 7 | recursive-include tests * 8 | recursive-include examples * 9 | 10 | recursive-exclude * __pycache__ 11 | recursive-exclude * *.py[co] 12 | 13 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif 14 | 15 | include src/upmem_llm_framework/architectures_schema.json 16 | include src/upmem_llm_framework/sim_architectures.yaml 17 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: "pytest" 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [main] 7 | pull_request: 8 | branches: [main] 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-22.04 13 | 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v4 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: '3.10' 22 | 23 | - name: Install package 24 | run: pip install --verbose .[dev] 25 | 26 | - name: Run tests 27 | run: pytest 28 | -------------------------------------------------------------------------------- /tests/extra_arch.yaml: -------------------------------------------------------------------------------- 1 | # yaml-language-server: $schema=../src/upmem_llm_framework/architectures_schema.json 2 | 3 | PIM_AI_4chip_duplicated: 4 | host_to_device_bw_GBs: 12.8 5 | host_to_device_pj_per_bit: 80 # 20 * 4 6 | device_to_host_bw_GBs: 51.2 # 12.8 * 4 7 | device_to_host_pj_per_bit: 20 8 | mem_bw_GBs: 409.6 # 102.4 * 4 9 | mem_pj_per_bit: 0.95 10 | tflops: 20 # 5 * 4 11 | tflops_int4: 128 # 32 * 4 12 | pj_per_tflop: 0.4e+12 13 | softmax_ns_per_element: 0.1 # 0.4 / 4 14 | SiLU_ns_per_element: 0.15 # 0.6 / 4 15 | RMSNorm_ns_per_element: 0.275 # 1.04 / 4 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 UPMEM 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/nn_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import typer 3 | 4 | import upmem_llm_framework as upmem_layers 5 | 6 | app = typer.Typer(callback=upmem_layers.initialize_profiling_options) 7 | 8 | 9 | class TinyModel(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | self.linear1 = torch.nn.Linear(100, 200) 14 | self.activation = torch.nn.ReLU() 15 | self.linear2 = torch.nn.Linear(200, 10) 16 | # self.softmax = torch.nn.Softmax(dim = 0) 17 | 18 | def forward(self, x): 19 | x = self.linear1(x) 20 | x = self.activation(x) 21 | x = self.linear2(x) 22 | x = torch.nn.functional.softmax(x, dim=0) 23 | return x 24 | 25 | @app.command() 26 | def profile(): 27 | upmem_layers.profiler_init() 28 | 29 | tinymodel = TinyModel() 30 | 31 | print("The model:") 32 | print(tinymodel) 33 | 34 | my_tensor = torch.rand(100) 35 | 36 | layer_mapping = { 37 | "linear1": "PIM-AI-1chip", 38 | "linear2": "PIM-AI-1chip", 39 | } 40 | 41 | upmem_layers.profiler_start(layer_mapping, last_layer="linear2") 42 | prediction = tinymodel.forward(my_tensor) 43 | upmem_layers.profiler_end() 44 | print(prediction) 45 | 46 | if __name__ == "__main__": 47 | app() 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # IDE settings 105 | .vscode/ 106 | .idea/ 107 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "upmem_llm_framework" 7 | version = "0.0.1" 8 | authors = [ 9 | { name = "Cristobal Ortega", email = "cortega@upmem.com" }, 10 | { name = "Sylvan Brocard", email = "sbrocard@upmem.com" }, 11 | ] 12 | license = "MIT" 13 | license-files = ["LICENSE"] 14 | keywords = [ 15 | "upmem", 16 | "llm", 17 | "transformer", 18 | "pytorch", 19 | "profiling", 20 | "accelerator", 21 | ] 22 | description = """\ 23 | UPMEM LLM Framework allows profiling PyTorch layers and functions\ 24 | and simulate those layers/functions with a given hardware profile.\ 25 | """ 26 | readme = "README.md" 27 | dependencies = [ 28 | "torch==2.4.1", 29 | "transformers==4.44.2", 30 | "jsonschema==4.23.0", 31 | "typer==0.12.5", 32 | ] 33 | requires-python = ">=3.10" 34 | classifiers = [ 35 | 'Development Status :: 2 - Pre-Alpha', 36 | 'Intended Audience :: Developers', 37 | 'Natural Language :: English', 38 | 'Programming Language :: Python :: 3.10', 39 | ] 40 | 41 | [project.optional-dependencies] 42 | dev = ["pytest==8.3.3"] 43 | 44 | [project.urls] 45 | Homepage = "https://upmem.com" 46 | Repository = "https://github.com/upmem/upmem_llm_framework" 47 | 48 | [tool.setuptools.package-data] 49 | include = [ 50 | "src/upmem_llm_framework/architectures_schema.json", 51 | "src/upmem_llm_framework/sim_architectures.yaml", 52 | ] 53 | 54 | [tool.ruff.lint] 55 | select = ["ALL"] 56 | ignore = ["COM812"] 57 | 58 | [tool.ruff] 59 | line-length = 100 60 | 61 | [tool.ruff.lint.extend-per-file-ignores] 62 | "tests/**/*.py" = [ 63 | "S101", # asserts allowed in tests 64 | "T201", # print statements allowed in tests 65 | ] 66 | -------------------------------------------------------------------------------- /examples/hf_example.py: -------------------------------------------------------------------------------- 1 | # import time 2 | import transformers 3 | import typer 4 | from typing_extensions import Annotated 5 | 6 | import upmem_llm_framework as upmem_layers 7 | 8 | app = typer.Typer(callback=upmem_layers.initialize_profiling_options) 9 | 10 | 11 | @app.command() 12 | def profile( 13 | hf_token: Annotated[ 14 | str, typer.Argument(envvar="hf_token", help="Hugging Face API token") 15 | ] 16 | ): 17 | upmem_layers.profiler_init() 18 | 19 | model = transformers.AutoModelForCausalLM.from_pretrained( 20 | "meta-llama/Llama-2-7b-chat-hf", token=hf_token 21 | ) 22 | tokenizer = transformers.AutoTokenizer.from_pretrained( 23 | "meta-llama/Llama-2-7b-chat-hf", token=hf_token 24 | ) 25 | 26 | layer_mapping = { 27 | "input_layernorm": "PIM-AI-1chip", 28 | "q_proj": "PIM-AI-1chip", 29 | "k_proj": "PIM-AI-1chip", 30 | "rotary_emb": "PIM-AI-1chip", 31 | "v_proj": "PIM-AI-1chip", 32 | "o_proj": "PIM-AI-1chip", 33 | "output_layernorm": "PIM-AI-1chip", 34 | "gate_proj": "PIM-AI-1chip", 35 | "up_proj": "PIM-AI-1chip", 36 | "down_proj": "PIM-AI-1chip", 37 | "norm": "PIM-AI-1chip", 38 | "lm_head": "PIM-AI-1chip", 39 | } 40 | 41 | prompt = "How to prepare coffee?" 42 | 43 | inputs = tokenizer(prompt, return_tensors="pt", return_token_type_ids=False) 44 | 45 | print(inputs.data["input_ids"][0].shape) 46 | model.eval() # Put model in evaluation / inference mode 47 | 48 | # print (model) 49 | 50 | upmem_layers.profiler_start(layer_mapping) 51 | # In case we want to time the original execution (comment out profiler_start) 52 | # start = time.time_ns() 53 | gen_tokens = model.generate( 54 | inputs.input_ids, do_sample=True, temperature=0.9, min_length=64, max_length=64 55 | ) 56 | # print ( (time.time_ns() - start)/1e6) 57 | upmem_layers.profiler_end() 58 | 59 | gen_text = tokenizer.batch_decode( 60 | gen_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False 61 | )[0] 62 | print(gen_text) 63 | 64 | 65 | if __name__ == "__main__": 66 | app() 67 | -------------------------------------------------------------------------------- /src/upmem_llm_framework/sim_architectures.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2014-2024 - UPMEM 3 | # UPMEM S.A.S France property - UPMEM confidential information covered by NDA 4 | # For UPMEM partner internal use only - no modification allowed without permission of UPMEM 5 | # 6 | # This file implements multiple hardware architectures to be simulated. 7 | # All architecture inherit from the Base_architecture class. 8 | # If an architecture has optimizations for a given operation defined in Base_architecture, 9 | # define them here 10 | 11 | import json 12 | from functools import cache 13 | from importlib.resources import as_file, files 14 | from pathlib import Path 15 | from typing import Dict 16 | 17 | import yaml 18 | from jsonschema import validate 19 | 20 | from upmem_llm_framework.options import options 21 | 22 | def read_architecture_file(file: Path, schema: Dict) -> Dict: 23 | with open(file, "r", encoding="UTF-8") as f: 24 | architectures = yaml.safe_load(f) 25 | validate(architectures, schema) 26 | return architectures 27 | 28 | 29 | @cache 30 | def read_architectures() -> Dict: 31 | """ 32 | Read the architectures from the sim_architectures.yaml file 33 | :return: a dictionary containing the architectures 34 | """ 35 | with as_file(files("upmem_llm_framework")) as resources_dir: 36 | with open( 37 | resources_dir / "architectures_schema.json", "r", encoding="UTF-8" 38 | ) as f: 39 | schema = json.load(f) 40 | 41 | architectures = read_architecture_file( 42 | resources_dir / "sim_architectures.yaml", schema 43 | ) 44 | 45 | if options.extra_archs: 46 | extra_architectures = read_architecture_file( 47 | options.extra_archs, schema 48 | ) 49 | architectures.update(extra_architectures) 50 | 51 | return architectures 52 | 53 | 54 | @cache 55 | def get_spec(name: str) -> Dict: 56 | """ 57 | Get an architecture object corresponding to the given name 58 | :param name: the name of the architecture 59 | :return: an object corresponding to the architecture 60 | """ 61 | architectures = read_architectures() 62 | 63 | architecture_spec = architectures.get(name) 64 | if architecture_spec is None: 65 | raise ValueError(f"Architecture {name} not found in sim_architectures.yaml") 66 | 67 | return architecture_spec 68 | -------------------------------------------------------------------------------- /src/upmem_llm_framework/architectures_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "object", 3 | "additionalProperties": { 4 | "type": "object", 5 | "properties": { 6 | "host_to_device_bw_GBs": { 7 | "type": "number", 8 | "description": "Host to device bandwidth in GB/s" 9 | }, 10 | "host_to_device_pj_per_bit": { 11 | "type": "number", 12 | "description": "Host to device energy per bit in pJ" 13 | }, 14 | "device_to_host_bw_GBs": { 15 | "type": "number", 16 | "description": "Device to host bandwidth in GB/s" 17 | }, 18 | "device_to_host_pj_per_bit": { 19 | "type": "number", 20 | "description": "Device to host energy per bit in pJ" 21 | }, 22 | "mem_bw_GBs": { 23 | "type": "number", 24 | "description": "Memory bandwidth in GB/s" 25 | }, 26 | "mem_pj_per_bit": { 27 | "type": "number", 28 | "description": "Memory energy per bit in pJ" 29 | }, 30 | "tflops": { 31 | "type": "number", 32 | "description": "TFLOPS" 33 | }, 34 | "pj_per_tflop": { 35 | "type": "number", 36 | "description": "Energy per TFLOP in pJ" 37 | }, 38 | "tflops_int4": { 39 | "type": "number", 40 | "description": "TFLOPS for INT4 (optional)" 41 | }, 42 | "softmax_ns_per_element": { 43 | "type": "number", 44 | "description": "Softmax latency per element in ns (optional)" 45 | }, 46 | "SiLU_ns_per_element": { 47 | "type": "number", 48 | "description": "SiLU latency per element in ns (optional)" 49 | }, 50 | "RMSNorm_ns_per_element": { 51 | "type": "number", 52 | "description": "RMSNorm latency per element in ns (optional)" 53 | } 54 | }, 55 | "required": [ 56 | "host_to_device_bw_GBs", 57 | "host_to_device_pj_per_bit", 58 | "device_to_host_bw_GBs", 59 | "device_to_host_pj_per_bit", 60 | "mem_bw_GBs", 61 | "mem_pj_per_bit", 62 | "tflops", 63 | "pj_per_tflop" 64 | ], 65 | "additionalProperties": false 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /tests/test_upmem_pytorch_simulator.py: -------------------------------------------------------------------------------- 1 | """Tests for `upmem_llm_framework` package.""" 2 | 3 | import warnings 4 | 5 | import typer 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from typer.testing import CliRunner 8 | 9 | import upmem_llm_framework as upmem_layers 10 | 11 | runner = CliRunner() 12 | app = typer.Typer(callback=upmem_layers.initialize_profiling_options) 13 | 14 | gen_length = 64 15 | 16 | layer_mapping = { 17 | "LlamaRMSNorm": "PIM-AI-1chip,t", 18 | "q_proj": "PIM-AI-4chip-duplicated,t", 19 | "k_proj": "PIM-AI-4chip", 20 | "rotatory_emb": "PIM-AI-4chip", 21 | "v_proj": "PIM-AI-4chip", 22 | "o_proj": "PIM-AI-4chip,t", 23 | "output_layernorm": "PIM-AI-1chip,t", 24 | "gate_proj": "PIM-AI-4chip,t", 25 | "up_proj": "PIM-AI-4chip,t", 26 | "down_proj": "PIM-AI-4chip,t", 27 | "norm": "PIM-AI-1chip,t", 28 | "lm_head": "PIM-AI-4chip,t", 29 | } 30 | layer_attn_ctxt = "q_proj" 31 | 32 | ignored_warning = ( 33 | "`pad_token_id` should be positive but got -1. This will cause errors when batch " 34 | "generating, if there is padding. Please set `pad_token_id` explicitly by " 35 | "`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, " 36 | "and ensure your `input_ids` input does not have negative values." 37 | ) 38 | 39 | 40 | @app.command("profile") 41 | def run_tiny_llama_model_with_profiler(): 42 | """Run the tiny LLaMA model with the profiler.""" 43 | 44 | # Initialize the profiler 45 | upmem_layers.profiler_init() 46 | # Load the tiny LLaMA model and tokenizer 47 | tokenizer = AutoTokenizer.from_pretrained( 48 | "hf-internal-testing/tiny-random-LlamaForCausalLM" 49 | ) 50 | with warnings.catch_warnings(): 51 | warnings.filterwarnings("ignore", ignored_warning) 52 | model = AutoModelForCausalLM.from_pretrained( 53 | "hf-internal-testing/tiny-random-LlamaForCausalLM" 54 | ) 55 | model.generation_config.pad_token_id = tokenizer.eos_token_id 56 | 57 | # Prepare input data 58 | input_text = "Hello, world!" 59 | input_token = tokenizer.encode( 60 | input_text, return_tensors="pt", return_token_type_ids=False 61 | ) 62 | input_ids = { 63 | "input_ids": input_token, 64 | "attention_mask": input_token.new_ones(input_token.shape), 65 | } 66 | 67 | # Run the profiler 68 | upmem_layers.profiler_start(layer_mapping, layer_attn_ctxt=layer_attn_ctxt) 69 | outputs = model.generate( 70 | **input_ids, 71 | do_sample=True, 72 | temperature=0.9, 73 | min_length=gen_length, 74 | max_length=gen_length, 75 | ) 76 | upmem_layers.profiler_end() 77 | 78 | # Assert the profiler results 79 | assert outputs is not None 80 | assert outputs.shape == (1, gen_length) 81 | 82 | 83 | def test_tiny_llama_model_with_profiler(): 84 | """Test the tiny LLaMA model with the profiler.""" 85 | result = runner.invoke( 86 | app, 87 | [ 88 | "--simulation", 89 | "--report-layers", 90 | "--report-functions", 91 | "--print-log", 92 | "--print-log-summary", 93 | "--extra-archs=tests/extra_arch.yaml", 94 | "profile", 95 | ], 96 | ) 97 | print(result.stdout) 98 | assert result.exit_code == 0 99 | assert "##### UPMEM PROFILER OUTPUT #####" in result.stdout 100 | assert "##### Generation Execution summary #####" in result.stdout 101 | assert "##### All (SUM and GEN) Execution summary #####" in result.stdout 102 | -------------------------------------------------------------------------------- /src/upmem_llm_framework/options.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | import typer 7 | from typing_extensions import Annotated 8 | 9 | 10 | @dataclass 11 | class Options: 12 | report_layers: bool = False 13 | report_functions: bool = False 14 | print_log: bool = False 15 | print_log_summary: bool = False 16 | simulation: bool = False 17 | sim_compute: bool = False 18 | sim_data_type: str = "bfloat16" 19 | sim_num_key_value_heads: int = -1 20 | sim_sliding_window: int = -1 21 | sim_verbose: bool = False 22 | extra_archs: Optional[Path] = None 23 | 24 | 25 | options = Options() 26 | 27 | class DataType(str, Enum): 28 | int4 = "int4" 29 | int8 = "int8" 30 | float16 = "float16" 31 | bfloat16 = "bfloat16" 32 | float32 = "float32" 33 | 34 | 35 | def initialize_profiling_options( 36 | report_layers: Annotated[ 37 | bool, 38 | typer.Option( 39 | help="Enable reporting metrics for all executed layers at the end of the forward pass." 40 | ), 41 | ] = False, 42 | report_functions: Annotated[ 43 | bool, 44 | typer.Option( 45 | help="Enable reporting metrics for all executed functions at the end of the forward " 46 | + "pass." 47 | ), 48 | ] = False, 49 | print_log: Annotated[ 50 | bool, 51 | typer.Option( 52 | help="Print a trace of the execution of layers and functions.", 53 | ), 54 | ] = False, 55 | print_log_summary: Annotated[ 56 | bool, 57 | typer.Option( 58 | help="Print a detailed summary of each layer and function executed. For summarization, " 59 | + "generation, and both.", 60 | ), 61 | ] = False, 62 | simulation: Annotated[ 63 | bool, 64 | typer.Option( 65 | help="Enable simulation according to the layer mapping defined", 66 | ), 67 | ] = False, 68 | sim_compute: Annotated[ 69 | bool, 70 | typer.Option( 71 | help="Simulate compute intensive operations. Note that some operations are still " 72 | + "performed due to constraints in inputs/outputs of other layer/functions. " 73 | + "CAUTION: Output tokens will be affected", 74 | ), 75 | ] = False, 76 | sim_data_type: Annotated[ 77 | DataType, 78 | typer.Option( 79 | help="Set the datatype for weights and inputs.", 80 | ), 81 | ] = DataType.bfloat16, 82 | sim_num_key_value_heads: Annotated[ 83 | int, 84 | typer.Option( 85 | help="When using GQA, this value is used to simulate fetching the correct KV caches.", 86 | ), 87 | ] = -1, 88 | sim_sliding_window: Annotated[ 89 | int, 90 | typer.Option( 91 | help="When set, a sliding window is simulated according to this value. Note that the " 92 | + "real underlying execution will run according to the model parameter.", 93 | ), 94 | ] = -1, 95 | sim_verbose: Annotated[ 96 | bool, 97 | typer.Option( 98 | help="Set a verbose mode for simulation", 99 | ), 100 | ] = False, 101 | extra_archs: Annotated[ 102 | Optional[Path], 103 | typer.Option( 104 | help="Path to a yaml file containing extra architectures to be used in simulation", 105 | ), 106 | ] = None, 107 | ): 108 | options.report_layers = report_layers 109 | options.report_functions = report_functions 110 | options.print_log = print_log 111 | options.print_log_summary = print_log_summary 112 | options.simulation = simulation 113 | options.sim_compute = sim_compute 114 | options.sim_data_type = sim_data_type 115 | options.sim_num_key_value_heads = sim_num_key_value_heads 116 | options.sim_sliding_window = sim_sliding_window 117 | options.sim_verbose = sim_verbose 118 | options.extra_archs = extra_archs 119 | -------------------------------------------------------------------------------- /scripts_simulation/simulations_llama2_7B.py: -------------------------------------------------------------------------------- 1 | # import time 2 | from typing import Optional 3 | 4 | import torch 5 | import transformers 6 | import typer 7 | from typing_extensions import Annotated 8 | 9 | import upmem_llm_framework as upmem_layers 10 | 11 | app = typer.Typer(callback=upmem_layers.initialize_profiling_options) 12 | 13 | 14 | @app.command() 15 | def profile( 16 | hf_token: Annotated[ 17 | str, typer.Argument(envvar="hf_token", help="Hugging Face API token") 18 | ], 19 | device: Annotated[str, typer.Option(help="Device to simulate for")] = "mixed", 20 | in_tokens: Annotated[int, typer.Option(help="Number of input tokens")] = 64, 21 | out_tokens: Annotated[int, typer.Option(help="Number of output tokens")] = 128, 22 | bs: Annotated[int, typer.Option(help="Batch size")] = 1, 23 | ): 24 | upmem_layers.profiler_init() 25 | 26 | print("Simulating with device...", device) 27 | print("in:", in_tokens, "out:", out_tokens) 28 | 29 | model = transformers.AutoModelForCausalLM.from_pretrained( 30 | "meta-llama/Llama-2-7b-chat-hf", token=hf_token 31 | ) 32 | tokenizer = transformers.AutoTokenizer.from_pretrained( 33 | "meta-llama/Llama-2-7b-chat-hf", token=hf_token 34 | ) 35 | layer_mapping = ( 36 | { 37 | "LlamaRMSNorm": "PIM-AI-1chip,t", 38 | "q_proj": "PIM-AI-4chip,t", 39 | "k_proj": "PIM-AI-4chip", 40 | "rotatory_emb": "PIM-AI-4chip", 41 | "v_proj": "PIM-AI-4chip", 42 | "o_proj": "PIM-AI-4chip,t", 43 | "output_layernorm": "PIM-AI-1chip,t", 44 | "gate_proj": "PIM-AI-4chip,t", 45 | "up_proj": "PIM-AI-4chip,t", 46 | "down_proj": "PIM-AI-4chip,t", 47 | "norm": "PIM-AI-1chip,t", 48 | "lm_head": "PIM-AI-4chip,t", 49 | } 50 | if device == "mixed" 51 | else { 52 | "LlamaRMSNorm": device, 53 | "q_proj": device, 54 | "k_proj": device, 55 | "rotatory_emb": device, 56 | "v_proj": device, 57 | "o_proj": device, 58 | "output_layernorm": device, 59 | "gate_proj": device, 60 | "up_proj": device, 61 | "down_proj": device, 62 | "norm": device, 63 | "lm_head": device, 64 | } 65 | ) 66 | layer_attn_ctxt = "q_proj" 67 | 68 | print("Batch 1") 69 | prompt = "placeholder" 70 | prompt_batch = [prompt] * bs 71 | input_ids = tokenizer( 72 | prompt_batch, return_tensors="pt", return_token_type_ids=False 73 | ) 74 | input_ids["input_ids"] = torch.randint(100, [bs, in_tokens]) 75 | input_ids["attention_mask"] = torch.ones([bs, in_tokens], dtype=torch.int) 76 | print(input_ids.data["input_ids"][0].shape) 77 | 78 | model.eval() 79 | print(model) 80 | upmem_layers.profiler_start(layer_mapping, layer_attn_ctxt=layer_attn_ctxt) 81 | # start = time.time_ns() 82 | gen_tokens = model.generate( 83 | **input_ids, 84 | do_sample=True, 85 | temperature=0.9, 86 | min_length=out_tokens, 87 | max_length=out_tokens, 88 | ) 89 | # print ( (time.time_ns() - start)/1e6) 90 | upmem_layers.profiler_end() 91 | 92 | gen_text = tokenizer.batch_decode(gen_tokens) 93 | print(gen_text) 94 | 95 | raise typer.Exit() 96 | 97 | print("Batch 10") 98 | prompt_batch = [prompt] * 10 99 | input_ids = tokenizer(prompt_batch, return_tensors="pt") 100 | print(input_ids.data["input_ids"][0].shape) 101 | 102 | upmem_layers.profiler_start(layer_mapping, layer_attn_ctxt=layer_attn_ctxt) 103 | gen_tokens = model.generate( 104 | **input_ids, do_sample=True, temperature=0.9, min_length=128, max_length=128 105 | ) 106 | upmem_layers.profiler_end() 107 | 108 | print("Batch 30") 109 | prompt_batch = [prompt] * 30 110 | input_ids = tokenizer(prompt_batch, return_tensors="pt") 111 | print(input_ids.data["input_ids"][0].shape) 112 | 113 | upmem_layers.profiler_start(layer_mapping, layer_attn_ctxt=layer_attn_ctxt) 114 | gen_tokens = model.generate( 115 | **input_ids, do_sample=True, temperature=0.9, min_length=128, max_length=128 116 | ) 117 | upmem_layers.profiler_end() 118 | 119 | print("Batch 40") 120 | prompt_batch = [prompt] * 40 121 | input_ids = tokenizer(prompt_batch, return_tensors="pt") 122 | print(input_ids.data["input_ids"][0].shape) 123 | 124 | upmem_layers.profiler_start(layer_mapping, layer_attn_ctxt=layer_attn_ctxt) 125 | gen_tokens = model.generate( 126 | **input_ids, do_sample=True, temperature=0.9, min_length=128, max_length=128 127 | ) 128 | upmem_layers.profiler_end() 129 | 130 | print("Batch 200") 131 | prompt_batch = [prompt] * 200 132 | input_ids = tokenizer(prompt_batch, return_tensors="pt") 133 | print(input_ids.data["input_ids"][0].shape) 134 | 135 | upmem_layers.profiler_start(layer_mapping, layer_attn_ctxt=layer_attn_ctxt) 136 | gen_tokens = model.generate( 137 | **input_ids, do_sample=True, temperature=0.9, min_length=128, max_length=128 138 | ) 139 | upmem_layers.profiler_end() 140 | 141 | # Batching (from https://lukesalamone.github.io/posts/what-are-attention-masks/) 142 | # tokenizer.padding_side = "left" 143 | # tokenizer.pad_token = tokenizer.eos_token 144 | # 145 | # sentences = ["It will rain in the", 146 | # "I want to eat a big bowl of", 147 | # "My dog is"] 148 | # inputs = tokenizer(sentences, return_tensors="pt", padding=True) 149 | 150 | # gen_text = tokenizer.batch_decode(gen_tokens)[0] 151 | gen_text = tokenizer.batch_decode(gen_tokens) 152 | print(gen_text) 153 | 154 | ## torch profiler snippet 155 | # with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: 156 | # with record_function("forward"): 157 | # gen_text = tokenizer.batch_decode(gen_tokens)[0] 158 | # #model(inputs) 159 | # 160 | # 161 | # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=100)) 162 | # 163 | # print ("----- Group by input shape") 164 | # print(prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10)) 165 | # 166 | # prof.export_chrome_trace("trace.json") 167 | 168 | 169 | if __name__ == "__main__": 170 | app() 171 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | UPMEM LLM framework for profiling / simulation 2 | ============================================== 3 | 4 | [![pytest](https://github.com/upmem/upmem_llm_framework/actions/workflows/pytest.yml/badge.svg)](https://github.com/upmem/upmem_llm_framework/actions/workflows/pytest.yml) 5 | 6 | This library allows 7 | 8 | 1. Profiling PyTorch neural networks on a CPU, 9 | 2. Simulating the execution of the neural network in a target hardware 10 | accelerator. 11 | 12 | Usage 13 | ----- 14 | 15 | 1. Import the `upmem_llm_framework` library and 16 | [`typer`](https://typer.tiangolo.com/). Create a `typer` app to handle the user 17 | input for the profiler. 18 | 19 | ```python 20 | # file: my_profiler.py 21 | import typer 22 | 23 | import upmem_llm_framework as upmem_layers 24 | 25 | app = typer.Typer(callback=upmem_layers.initialize_profiling_options) 26 | ``` 27 | 28 | 2. Define your main function and add the desired user input. 29 | Initialize the library before creating or importing the neural network: 30 | 31 | ```python 32 | @app.command() 33 | def profile(my_input: str): 34 | upmem_layers.profiler_init() 35 | # Create or import the neural network 36 | model = ... 37 | # Define the input tensor 38 | myTensor = ... 39 | ``` 40 | 41 | 3. Call the profiler when doing a forward pass / inference: 42 | 43 | ```python 44 | upmem_layers.profiler_start() 45 | prediction = model.forward(myTensor) 46 | upmem_layers.profiler_end() 47 | ``` 48 | 49 | 4. Call the app: 50 | 51 | ```python 52 | if __name__ == "__main__": 53 | app() 54 | ``` 55 | 56 | 5. See the available options: 57 | 58 | ```bash 59 | python my_profiler.py --help 60 | ``` 61 | 62 | 6. Run the app: 63 | 64 | ```bash 65 | python my_profiler.py --some-option profile my_input 66 | ``` 67 | 68 | ### Examples 69 | 70 | You can find usage examples with a custom PyTorch model in `nn_example.py` and 71 | with a model from HuggingFace in `hf_example.py`. 72 | 73 |
74 | PyTorch model 75 | 76 | ```bash 77 | python3 nn_example.py profile 78 | ``` 79 | 80 | Expected output: 81 | 82 | ```text 83 | Options: Options(report_layers=False, report_functions=False, print_log=False, print_log_summary=False, simulation=False, sim_compute=False, sim_data_type=, sim_num_key_value_heads=-1, sim_sliding_window=-1, sim_verbose=False, extra_archs=None) 84 | The model: 85 | TinyModel( 86 | (linear1): UPM_Linear(in_features=100, out_features=200, bias=True) 87 | (activation): ReLU() 88 | (linear2): UPM_Linear(in_features=200, out_features=10, bias=True) 89 | ) 90 | ##### UPMEM PROFILER OUTPUT ##### 91 | Total time (SUM + GEN): 0.002975238 s, with data type: bfloat16, batch size: 1 92 | Generated tokens: 0 in 0.002175651 s, with tokens/s: 0.0 93 | Summarization step took: 0.000799587 s, weight in the execution: SUM: 0.268747239716621%, GEN: 0.7312527602833789% 94 | ##### END UPMEM PROFILER OUTPUT ##### 95 | tensor([0.0983, 0.0919, 0.1012, 0.0836, 0.0796, 0.1157, 0.1202, 0.0996, 0.0930, 96 | 0.1168], grad_fn=) 97 | ``` 98 | 99 |
100 | 101 |
102 | HuggingFace model 103 | 104 | ```bash 105 | python3 hf_example.py profile 106 | ``` 107 | 108 | Expected output: 109 | 110 | ```text 111 | Options: Options(report_layers=False, report_functions=False, print_log=False, print_log_summary=False, simulation=False, sim_compute=False, sim_data_type=, sim_num_key_value_heads=-1, sim_sliding_window=-1, sim_verbose=False, extra_archs=None) 112 | Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.03it/s] 113 | torch.Size([6]) 114 | ##### UPMEM PROFILER OUTPUT ##### 115 | Total time (SUM + GEN): 42.124470486 s, with data type: bfloat16, batch size: 1 116 | Generated tokens: 57 in 41.404485553 s, with tokens/s: 1.3766624373834302 117 | Summarization step took: 0.719984933 s, weight in the execution: SUM: 0.017091845302584535%, GEN: 0.9829081546974154% 118 | ##### END UPMEM PROFILER OUTPUT ##### 119 | How to prepare coffee? 120 | 121 | There are many ways to prepare coffee, and the method you choose will depend on your personal preferences and the equipment you have available. Here are some common methods for preparing coffee: 122 | 123 | 1. Drip brewing: This is one of the most common methods of prepar 124 | ``` 125 | 126 |
127 | 128 | Profiler 129 | -------- 130 | 131 | The profiler records the start time and end time of a computation layer or 132 | function. 133 | Currently, the profiler doesn't track the real power consumption of the CPU. 134 | 135 | The profiler identifies a layer or function by 4 parameters: 136 | 137 | 1. Layer type (f.e. `Linear` module) or function (f.e. `softmax`), 138 | 2. Context when the layer or function is called, meaning the variable name 139 | assigned to the layer or function (f.e. `q_proj = torch.nn.Linear(...)` has a 140 | context of `q_proj`), 141 | 3. the input dimensions of the layer or function, 142 | 4. specifically for layer, a unique ID assigned at layer initialization. 143 | 144 | ### Profiler output 145 | 146 | By default, the profiler reports a summary with execution time, energy (when 147 | simulating), and power consumption (when simulating) at the end of its 148 | execution. 149 | 150 | When simulating, this summary breaks down into the summarization (encoding) 151 | phase and the generation (decoding) phase. 152 | 153 | You can enable the following flags to show more information: 154 | 155 | * `--report-layers`: reports the created layers in the neural network with its 156 | associated parameters 157 | * `--report-functions`: reports the called functions during the forward pass of 158 | the neural network with its associated parameters 159 | * `--print-log`: prints a time-ordered detailed log of each layer and function 160 | executed during the forward pass of the neural network 161 | 162 | Simulation 163 | ---------- 164 | 165 | To run a simulation, library users need to provide a dictionary mapping layers 166 | with a device or hardware accelerator. 167 | 168 | This dictionary contains `name_of_layer:device,options` key-value pairs. 169 | The name of the layer corresponds to the context concept introduced before. 170 | The device corresponds to one of the accelerators defined in 171 | `sim_architectures.yaml`. 172 | 173 | Currently supported options: 174 | 175 | * 't' or transfer point: the input of a layer with this option comes from the 176 | CPU, which means that the last device sent its results back to the CPU and the 177 | CPU is sending them back as input to the layer's device. 178 | * 'm' or MoE transfer point: the input of a layer with this option comes from 179 | the CPU but only once since the input is shared across different MoEs. 180 | 181 | For instance, for a neural network composed of 2 Linear layers that execute 182 | sequentially in different chips: 183 | 184 | ```python 185 | layer_mapping = { 186 | "linear1":"PIM-AI-1chip,t", 187 | "linear2":"PIM-AI-1chip,t", 188 | } 189 | 190 | upmem_layers.profiler_start(layer_mapping) 191 | prediction = model.forward(myTensor) 192 | upmem_layers.profiler_end() 193 | ``` 194 | 195 | This mapping corresponds to the following scheme 196 | 197 | ```mermaid 198 | graph LR 199 | **CPU** -->|input of *linear1* is sent to **PIM-AI-1chip1** device| PIM-AI-1chip1["`**PIM-AI-1chip1** 200 | Execute *linear1*`"] 201 | PIM-AI-1chip1 -->|output of *linear1* is sent to **CPU**| **CPU** 202 | **CPU** -->|input of *linear2* is sent to **PIM-AI-1chip2** device| PIM-AI-1chip2["`**PIM-AI-1chip2** 203 | Execute *linear2*`"] 204 | ``` 205 | 206 | ### Running a simulation 207 | 208 | After specifying the layer mapping, to run a simulation: 209 | 210 | ```bash 211 | python3 hf_example.py --simulation profile 212 | ``` 213 | 214 | ### Adding a hardware accelerator 215 | 216 | The file `sim_architectures.yaml` contains hardware accelerator profiles. 217 | 218 | To add a new hardware accelerator profile, create a YAML file with the following 219 | structure: 220 | 221 | ```yaml 222 | # yaml-language-server: $schema=/architectures_schema.json 223 | # (The above line is optional, it will enable autocompletion and validation in editors that support 224 | # the YAML language server) 225 | 226 | My_accelerator: 227 | # * Required parameters: 228 | # - HOST communication 229 | host_to_device_bw_GBs: 22 230 | host_to_device_pj_per_bit: 200 231 | device_to_host_bw_GBs: 88 232 | device_to_host_pj_per_bit: 50 233 | # - Device memory (shared memory like) 234 | mem_bw_GBs: 6553.6 235 | mem_pj_per_bit: 0.95 236 | # - Compute 237 | tflops: 320 238 | pj_per_tflop: 0.4e+12 239 | # * Optional parameters: 240 | softmax_ns_per_element: 6.25e-03 241 | SiLU_ns_per_element: 9.375e-03 242 | RMSNorm_ns_per_element: 1.625e-02 243 | 244 | My_accelerator2: 245 | <...> 246 | ``` 247 | 248 | Use the `extra-archs` option to add the new accelerator to the simulation: 249 | 250 | ```bash 251 | python3 simulations_llama2_7B.py --simulation --extra-archs my_archs.yaml profile --device My_accelerator 252 | ``` 253 | 254 | *Note:* underscores in device names such as `new_device` convert to hyphens, 255 | resulting in `new-device` in the layer mapping. 256 | 257 | ### Notes on simulation 258 | 259 | This library makes two assumptions to simplify execution modelling across 260 | hardware profiles: 261 | 262 | 1. Ignored interconnection communication latency: it assumes that 263 | intercommunication between devices finishes fast enough that it can overlap with 264 | compute and get hidden. 265 | For instance, when simulating more than one GPU, it doesn't model the required 266 | data exchange between them. 267 | For an AI-PIM device (DIMM), it doesn't model communication within a DIMM. 268 | 2. Devices always reach peak performance. All hardware profiles perform 269 | operations at their peak performance. This is unrealistic in some scenarios. 270 | Adding a performance ratio to model this is left to future work. 271 | 272 | Installation 273 | ------------ 274 | 275 | ### Environment setup 276 | 277 | #### Python version 278 | 279 | This library expects Python 3.10. 280 | 281 | If your distribution doesn't provide it, you can use 282 | [`pyenv`](https://github.com/pyenv/pyenv) to install it, or any other Python 283 | version manager: 284 | 285 | ```bash 286 | pyenv install 3.10 287 | pyenv shell 3.10 288 | ``` 289 | 290 | Now your shell runs Python 3.10. 291 | 292 | #### Virtual environment 293 | 294 | Preferably, create a virtual environment to install the library: 295 | 296 | ```bash 297 | python -m venv venv 298 | source venv/bin/activate 299 | ``` 300 | 301 | This avoids conflicts with other Python libraries in your system. 302 | 303 | ### User installation 304 | 305 | To install the library in your current Python environment: 306 | 307 | ```bash 308 | python -m pip install . 309 | ``` 310 | 311 | ### Developer installation 312 | 313 | To install the library for editing in your current Python environment, with the 314 | necessary development dependencies: 315 | 316 | ```bash 317 | python -m pip install -e '.[dev]' 318 | ``` 319 | 320 | #### Running tests 321 | 322 | Run the tests with: 323 | 324 | ```bash 325 | python -m pytest 326 | ``` 327 | 328 | #### Formatting 329 | 330 | This project uses `black` formatting. Please, make sure to run it before 331 | committing: 332 | 333 | ```bash 334 | python -m black src/upmem_llm_framework/*.py 335 | ``` 336 | -------------------------------------------------------------------------------- /src/upmem_llm_framework/pytorch_upmem_layers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2014-2024 - UPMEM 3 | # UPMEM S.A.S France property - UPMEM confidential information covered by NDA 4 | # For UPMEM partner internal use only - no modification allowed without permission of UPMEM 5 | # 6 | # This file wraps PyTorch classes and functions into new UPM classes and functions able to 7 | # track the start, inputs, end and outputs of the corresponding function. 8 | # Currently, forward from multiple modules and other minor functions 9 | # (normalizations, softmax, activations, etc.) are tracked and profiled. 10 | 11 | from inspect import getframeinfo, stack 12 | from typing import Tuple 13 | 14 | import torch 15 | import transformers 16 | 17 | from upmem_llm_framework.options import options 18 | from upmem_llm_framework.profiler import UPM_Profiler 19 | 20 | profiler: UPM_Profiler = None 21 | profiling = 0 22 | 23 | 24 | def get_context(): 25 | # https://stackoverflow.com/questions/24438976/debugging-get-filename-and-line-number-from-which-a-function-is-called 26 | return getframeinfo(stack()[2][0]).code_context[0].split()[0].replace("self.", "") 27 | 28 | 29 | class UPM_Module(torch.nn.Module): 30 | 31 | def forward(self, x): 32 | x = super().forward(x) 33 | return x 34 | 35 | 36 | class UPM_Linear(torch.nn.Linear): 37 | 38 | def __init__(self, *args, **kwargs): 39 | super().__init__(*args, **kwargs) 40 | profiler.add(self, get_context()) 41 | 42 | def forward(self, x): 43 | context = get_context() 44 | profiler.forward_start(x.shape) 45 | if options.sim_compute: 46 | shape = list(x.shape) 47 | shape[-1] = self.out_features 48 | x = torch.zeros(shape) 49 | else: 50 | x = super().forward(x) 51 | profiler.forward_end(x.shape, context, layer_obj=self) 52 | return x 53 | 54 | 55 | class UPM_NonDynamicallyQuantizableLinear( 56 | torch.nn.modules.linear.NonDynamicallyQuantizableLinear 57 | ): 58 | def __init__(self, *args, **kwargs): 59 | super().__init__(*args, **kwargs) 60 | profiler.add(self, get_context()) 61 | print("HERE") 62 | 63 | def forward(self, x): 64 | context = get_context() 65 | profiler.forward_start(x.shape) 66 | if options.sim_compute: 67 | shape = list(x.shape) 68 | shape[-1] = self.out_features 69 | x = torch.zeros(shape) 70 | else: 71 | x = super().forward(x) 72 | profiler.forward_end(x.shape, context, layer_obj=self) 73 | return x 74 | 75 | 76 | class UPM_LayerNorm(torch.nn.LayerNorm): 77 | 78 | def __init__(self, *args, **kwargs): 79 | super().__init__(*args, **kwargs) 80 | profiler.add(self, get_context()) 81 | 82 | def forward(self, x): 83 | context = get_context() 84 | profiler.forward_start(x.shape) 85 | x = super().forward(x) 86 | profiler.forward_end(x.shape, context, layer_obj=self) 87 | return x 88 | 89 | 90 | class UPM_Embedding(torch.nn.Embedding): 91 | 92 | def __init__(self, *args, **kwargs): 93 | super().__init__(*args, **kwargs) 94 | profiler.add(self, get_context()) 95 | 96 | def forward(self, x): 97 | context = get_context() 98 | profiler.forward_start(x.shape) 99 | x = super().forward(x) 100 | profiler.forward_end(x.shape, context, layer_obj=self) 101 | return x 102 | 103 | 104 | class UPM_LlamaRotaryEmbedding( 105 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding 106 | ): 107 | 108 | def __init__(self, *args, **kwargs): 109 | super().__init__(*args, **kwargs) 110 | profiler.add(self, get_context()) 111 | 112 | def forward(self, x, position_ids) -> Tuple[torch.Tensor, torch.Tensor]: 113 | context = get_context() 114 | shape = x.shape 115 | profiler.forward_start(shape) 116 | x = super().forward(x, position_ids) 117 | profiler.forward_end(shape, context, layer_obj=self) 118 | return x 119 | 120 | 121 | class UPM_LlamaRMSNorm(transformers.models.llama.modeling_llama.LlamaRMSNorm): 122 | 123 | def __init__(self, *args, **kwargs): 124 | super().__init__(*args, **kwargs) 125 | profiler.add(self, get_context()) 126 | 127 | def forward(self, hidden_states): 128 | context = get_context() 129 | profiler.forward_start(hidden_states.shape) 130 | hidden_states = super().forward(hidden_states) 131 | profiler.forward_end(hidden_states.shape, context, layer_obj=self) 132 | return hidden_states 133 | 134 | 135 | class UPM_SiLUActivation(torch.nn.SiLU): 136 | 137 | def __init__(self, *args, **kwargs): 138 | super().__init__(*args, **kwargs) 139 | profiler.add(self, get_context()) 140 | 141 | def forward(self, x): 142 | context = get_context() 143 | profiler.forward_start(x.shape) 144 | x = super().forward(x) 145 | profiler.forward_end(x.shape, context, layer_obj=self) 146 | return x 147 | 148 | 149 | class UPM_NewGELUActivation(transformers.activations.NewGELUActivation): 150 | 151 | def __init__(self, *args, **kwargs): 152 | super().__init__(*args, **kwargs) 153 | profiler.add(self, get_context()) 154 | 155 | def forward(self, x): 156 | context = get_context() 157 | profiler.forward_start(x.shape) 158 | x = super().forward(x) 159 | profiler.forward_end(x.shape, context, layer_obj=self) 160 | return x 161 | 162 | 163 | # Not used in inference 164 | class UPM_Dropout(torch.nn.Dropout): 165 | 166 | def __init__(self, *args, **kwargs): 167 | super().__init__(*args, **kwargs) 168 | profiler.add(self, get_context()) 169 | 170 | 171 | class UPM_Conv1d(torch.nn.Conv1d): 172 | 173 | def __init__(self, *args, **kwargs): 174 | super().__init__(*args, **kwargs) 175 | profiler.add(self, get_context()) 176 | 177 | 178 | class UPM_Conv2d(torch.nn.Conv2d): 179 | def __init__(self, *args, **kwargs): 180 | super().__init__(*args, **kwargs) 181 | profiler.add(self, get_context()) 182 | 183 | def forward(self, x): 184 | context = get_context() 185 | profiler.forward_start(x.shape) 186 | x = super().forward(x) 187 | profiler.forward_end(x.shape, context, layer_obj=self) 188 | return x 189 | 190 | 191 | class UPM_Conv1D(transformers.pytorch_utils.Conv1D): 192 | 193 | def __init__(self, *args, **kwargs): 194 | super().__init__(*args, **kwargs) 195 | profiler.add(self, get_context()) 196 | 197 | def forward(self, x): 198 | context = get_context() 199 | profiler.forward_start(x.shape) 200 | x = super().forward(x) 201 | profiler.forward_end(x.shape, context) 202 | return x 203 | 204 | 205 | class UPM_Softmax(torch.nn.Softmax): 206 | 207 | def __init__(self, *args, **kwargs): 208 | super().__init__(*args, **kwargs) 209 | profiler.add(self, get_context()) 210 | 211 | def forward(self, x): 212 | context = get_context() 213 | profiler.forward_start(x.shape) 214 | x = super().forward(x) 215 | profiler.forward_end(x.shape, context) 216 | return x 217 | 218 | 219 | class UPM_Tensor(torch.Tensor): 220 | 221 | def transpose(self, dim0, dim1): 222 | print("MyTranpose with input:", self, "dim0", dim0, "dim1", dim1) 223 | super().transpose(dim0, dim1) 224 | 225 | 226 | __pytorch_nn_functional_softmax = torch.nn.functional.softmax 227 | 228 | 229 | # TODO: change logic here to not use stringly types 230 | def UPM_Softmax_functional(input, dim=None, dtype=None): 231 | context = get_context() 232 | profiler.forward_func_start("softmax", context, input.shape) 233 | x = __pytorch_nn_functional_softmax(input, dim=dim, dtype=dtype) 234 | profiler.forward_func_end(__pytorch_nn_functional_softmax, context, x.shape) 235 | return x 236 | 237 | 238 | __pytorch_matmul = torch.matmul 239 | 240 | 241 | # TODO: here too 242 | def UPM_Matmul(input, other, *, out=None): 243 | context = get_context() 244 | profiler.forward_func_start("matmul", context, input.shape) 245 | x = __pytorch_matmul(input, other, out=out) 246 | profiler.forward_func_end(__pytorch_matmul, context, x.shape) 247 | return x 248 | 249 | 250 | __pytorch_scaled_dot_product_attention = ( 251 | torch.nn.functional.scaled_dot_product_attention 252 | ) 253 | 254 | 255 | # TODO: here too 256 | def UPM_scaled_dot_product_attention(query, key, value, **kwargs): 257 | context = get_context() 258 | # profiler.forward_func_start("scaled_dot_product_attention", context, [query.shape, key.shape, value.shape]) 259 | profiler.forward_func_start("scaled_dot_product_attention", context, key.shape) 260 | if options.sim_compute: 261 | q_shape = list(query.shape) 262 | v_shape = list(value.shape) 263 | q_shape[-1] = v_shape[-1] 264 | x = torch.zeros(q_shape) 265 | else: 266 | x = __pytorch_scaled_dot_product_attention(query, key, value, **kwargs) 267 | profiler.forward_func_end(__pytorch_scaled_dot_product_attention, context, x.shape) 268 | return x 269 | 270 | 271 | __pytorch_transpose = torch.transpose 272 | 273 | 274 | def UPM_Transpose(input, dim0, dim1): 275 | print("UPM_Transpose with input", input.shape, "dim0:", dim0, "dim1", dim1) 276 | x = __pytorch_transpose(input, dim0, dim1) 277 | return x 278 | 279 | 280 | def profiler_init(): 281 | 282 | print(f"Options: {options}") 283 | 284 | global profiling, profiler 285 | profiling = 1 286 | profiler = UPM_Profiler(options) 287 | 288 | # torch library 289 | torch.nn.Module = UPM_Module 290 | torch.nn.Linear = UPM_Linear 291 | torch.nn.modules.linear.NonDynamicallyQuantizableLinear = ( 292 | UPM_NonDynamicallyQuantizableLinear 293 | ) 294 | torch.nn.LayerNorm = UPM_LayerNorm 295 | torch.nn.Embedding = UPM_Embedding 296 | torch.nn.Dropout = UPM_Dropout 297 | torch.nn.Conv1d = UPM_Conv1d 298 | torch.nn.Conv2d = UPM_Conv2d 299 | torch.nn.Softmax = UPM_Softmax 300 | torch.nn.functional.softmax = UPM_Softmax_functional 301 | torch.matmul = UPM_Matmul 302 | torch.transpose = UPM_Transpose 303 | torch.nn.functional.scaled_dot_product_attention = UPM_scaled_dot_product_attention 304 | # torch.Tensor = UPM_Tensor 305 | 306 | # transformers library 307 | transformers.pytorch_utils.Conv1D = UPM_Conv1D 308 | transformers.activations.NewGELUActivation = UPM_NewGELUActivation 309 | transformers.activations.ACT2FN["gelu_new"] = ( 310 | UPM_NewGELUActivation # classes are hardcoded in ACT2FN 311 | ) 312 | transformers.activations.ACT2FN["silu"] = ( 313 | UPM_SiLUActivation # classes are hardcoded in ACT2FN 314 | ) 315 | transformers.models.llama.modeling_llama.LlamaRMSNorm = UPM_LlamaRMSNorm 316 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = ( 317 | UPM_LlamaRotaryEmbedding 318 | ) 319 | 320 | transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm = ( 321 | UPM_LlamaRMSNorm # miXtral models 322 | ) 323 | transformers.models.mistral.modeling_mistral.MistralRMSNorm = ( 324 | UPM_LlamaRMSNorm # miStral models 325 | ) 326 | 327 | 328 | def profiler_start( 329 | layer_mapping=None, 330 | layer_attn_ctxt="", 331 | last_layer="lm_head", 332 | batch_size=1, 333 | moe_end="", 334 | experts_per_token=2, 335 | ): 336 | profiler.start( 337 | layer_mapping=layer_mapping, 338 | layer_attn_ctxt=layer_attn_ctxt, 339 | last_layer=last_layer, 340 | batch_size=batch_size, 341 | moe_end=moe_end, 342 | experts_per_token=experts_per_token, 343 | ) 344 | 345 | 346 | def profiler_end(): 347 | profiler.end() 348 | -------------------------------------------------------------------------------- /src/upmem_llm_framework/sim_architectures.yaml: -------------------------------------------------------------------------------- 1 | # yaml-language-server: $schema=./architectures_schema.json 2 | 3 | DGX100: 4 | host_to_device_bw_GBs: 450 5 | host_to_device_pj_per_bit: 27 6 | device_to_host_bw_GBs: 450 7 | device_to_host_pj_per_bit: 27 8 | mem_bw_GBs: 26800 9 | mem_pj_per_bit: 7 10 | tflops: 7916 11 | pj_per_tflop: 0.5e+12 12 | 13 | V100: 14 | host_to_device_bw_GBs: 64 15 | host_to_device_pj_per_bit: 27 16 | device_to_host_bw_GBs: 64 17 | device_to_host_pj_per_bit: 27 18 | mem_bw_GBs: 900 19 | mem_pj_per_bit: 7 20 | tflops: 112 21 | pj_per_tflop: 0.5e+12 22 | 23 | H100_x4: 24 | host_to_device_bw_GBs: 64 25 | host_to_device_pj_per_bit: 27 26 | device_to_host_bw_GBs: 64 27 | device_to_host_pj_per_bit: 27 28 | mem_bw_GBs: 8000 # 2000 * 4 29 | mem_pj_per_bit: 7 30 | tflops: 3026 # 756.5 * 4 31 | pj_per_tflop: 0.5e+12 32 | softmax_ns_per_element: 5.208e-4 # 0.4 / (16 * 12) / 4 33 | SiLU_ns_per_element: 7.813e-4 # 0.6 / (16 * 12) / 4 34 | RMSNorm_ns_per_element: 1.354e-3 # 1.04 / (16 * 12) / 4 35 | 36 | H100_x5: 37 | host_to_device_bw_GBs: 64 38 | host_to_device_pj_per_bit: 27 39 | device_to_host_bw_GBs: 64 40 | device_to_host_pj_per_bit: 27 41 | mem_bw_GBs: 10000 # 2000 * 5 42 | mem_pj_per_bit: 7 43 | tflops: 3782.5 # 756.5 * 5 44 | pj_per_tflop: 0.5e+12 45 | softmax_ns_per_element: 4.167e-4 # 0.4 / (16 * 12) / 5 46 | SiLU_ns_per_element: 6.25e-4 # 0.6 / (16 * 12) / 5 47 | RMSNorm_ns_per_element: 1.0833e-3 # 1.04 / (16 * 12) / 5 48 | 49 | H100_x8: 50 | host_to_device_bw_GBs: 450 51 | host_to_device_pj_per_bit: 280 # 40 * (8 - 1) 52 | device_to_host_bw_GBs: 450 53 | device_to_host_pj_per_bit: 40 54 | mem_bw_GBs: 16000 # 2000 * 8 55 | mem_pj_per_bit: 7 56 | tflops: 6052 # 756.5 * 8 57 | pj_per_tflop: 0.5e+12 58 | softmax_ns_per_element: 2.604e-4 # 0.4 / (16 * 12) / 8 59 | SiLU_ns_per_element: 3.906e-4 # 0.6 / (16 * 12) / 8 60 | RMSNorm_ns_per_element: 6.771e-4 # 1.04 / (16 * 12) / 8 61 | 62 | H100_x3: 63 | host_to_device_bw_GBs: 64 64 | host_to_device_pj_per_bit: 27 65 | device_to_host_bw_GBs: 64 66 | device_to_host_pj_per_bit: 27 67 | mem_bw_GBs: 6000 # 2000 * 3 68 | mem_pj_per_bit: 7 69 | tflops: 2269.5 # 756.5 * 3 70 | pj_per_tflop: 0.5e+12 71 | # Assuming a H100 is equivalent to 128 AI PIM cores (8 DIMMs) due to server size 72 | softmax_ns_per_element: 1.0417e-03 # 0.4 / (16 * 2 * 4) / 3 73 | SiLU_ns_per_element: 1.5625e-03 # 0.6 / (16 * 2 * 4) / 3 74 | RMSNorm_ns_per_element: 2.7083e-3 # 1.04 / (16 * 2 * 4) / 3 75 | 76 | H100_x2: 77 | host_to_device_bw_GBs: 64 78 | host_to_device_pj_per_bit: 27 79 | device_to_host_bw_GBs: 64 80 | device_to_host_pj_per_bit: 27 81 | mem_bw_GBs: 4000 # 2000 * 2 82 | mem_pj_per_bit: 7 83 | tflops: 1513 # 756.5 * 2 84 | pj_per_tflop: 0.5e+12 85 | softmax_ns_per_element: 1.5625e-3 # 0.4 / (16 * 2 * 4) / 2 86 | SiLU_ns_per_element: 2.34375e-3 # 0.6 / (16 * 2 * 4) / 2 87 | RMSNorm_ns_per_element: 4.0625e-3 # 1.04 / (16 * 2 * 4) / 2 88 | 89 | A800: 90 | host_to_device_bw_GBs: 64 91 | host_to_device_pj_per_bit: 27 92 | device_to_host_bw_GBs: 64 93 | device_to_host_pj_per_bit: 27 94 | mem_bw_GBs: 1500 95 | mem_pj_per_bit: 7 96 | tflops: 312 97 | pj_per_tflop: 0.5e+12 98 | # Assuming a A800 is equivalent to 128 AI PIM cores (8 DIMMs) due to server size 99 | softmax_ns_per_element: 3.125e-3 # 0.4 / (16 * 2 * 4) 100 | SiLU_ns_per_element: 4.6875e-3 # 0.6 / (16 * 2 * 4) 101 | RMSNorm_ns_per_element: 8.125e-3 # 1.04 / (16 * 2 * 4) 102 | 103 | H20: 104 | host_to_device_bw_GBs: 64 105 | host_to_device_pj_per_bit: 27 106 | device_to_host_bw_GBs: 64 107 | device_to_host_pj_per_bit: 27 108 | mem_bw_GBs: 4000 109 | mem_pj_per_bit: 7 110 | tflops: 148 111 | pj_per_tflop: 0.5e+12 112 | # Assuming a H20 is equivalent to 128 AI PIM cores (8 DIMMs) due to server size 113 | softmax_ns_per_element: 3.125e-3 # 0.4 / (16 * 2 * 4) 114 | SiLU_ns_per_element: 4.6875e-3 # 0.6 / (16 * 2 * 4) 115 | RMSNorm_ns_per_element: 8.125e-3 # 1.04 / (16 * 2 * 4) 116 | 117 | H200: 118 | host_to_device_bw_GBs: 64 119 | host_to_device_pj_per_bit: 27 120 | device_to_host_bw_GBs: 64 121 | device_to_host_pj_per_bit: 27 122 | mem_bw_GBs: 2860 123 | mem_pj_per_bit: 7 124 | tflops: 989 125 | pj_per_tflop: 0.5e+12 126 | # Assuming a H100 is equivalent to 128 AI PIM cores (8 DIMMs) due to server size 127 | softmax_ns_per_element: 3.125e-3 # 0.4 / (16 * 2 * 4) 128 | SiLU_ns_per_element: 4.6875e-3 # 0.6 / (16 * 2 * 4) 129 | RMSNorm_ns_per_element: 8.125e-3 # 1.04 / (16 * 2 * 4) 130 | 131 | H100: 132 | host_to_device_bw_GBs: 64 133 | host_to_device_pj_per_bit: 27 134 | device_to_host_bw_GBs: 64 135 | device_to_host_pj_per_bit: 27 136 | mem_bw_GBs: 2000 137 | mem_pj_per_bit: 7 138 | tflops: 756.5 139 | pj_per_tflop: 0.5e+12 140 | # Assuming a H100 is equivalent to 128 AI PIM cores (8 DIMMs) due to server size 141 | softmax_ns_per_element: 3.125e-3 # 0.4 / (16 * 2 * 4) 142 | SiLU_ns_per_element: 4.6875e-3 # 0.6 / (16 * 2 * 4) 143 | RMSNorm_ns_per_element: 8.125e-3 # 1.04 / (16 * 2 * 4) 144 | 145 | A6000: 146 | host_to_device_bw_GBs: 32 147 | host_to_device_pj_per_bit: 35 148 | device_to_host_bw_GBs: 32 149 | device_to_host_pj_per_bit: 35 150 | mem_bw_GBs: 768 151 | mem_pj_per_bit: 15 152 | tflops: 155 153 | pj_per_tflop: 0.5e+12 154 | 155 | A17Pro: 156 | host_to_device_bw_GBs: 51.2 157 | host_to_device_pj_per_bit: 20 158 | device_to_host_bw_GBs: 51.2 159 | device_to_host_pj_per_bit: 20 160 | mem_bw_GBs: 51.2 161 | mem_pj_per_bit: 20 162 | tflops: 17 # GPU (FP16): 4.3, ANE (INT4): 35, ANE (INT8): 17 163 | pj_per_tflop: 0.4e+12 164 | 165 | Dimensity9300: 166 | host_to_device_bw_GBs: 76.8 167 | host_to_device_pj_per_bit: 10 168 | device_to_host_bw_GBs: 76.8 169 | device_to_host_pj_per_bit: 10 170 | mem_bw_GBs: 76.8 171 | mem_pj_per_bit: 10 172 | tflops: 16 # GPU (FP16): 6, APU (INT4): 33, APU (INT8): 16 173 | pj_per_tflop: 0.4e+12 174 | 175 | Snapdragon8gen3: 176 | host_to_device_bw_GBs: 77 177 | host_to_device_pj_per_bit: 10 178 | device_to_host_bw_GBs: 77 179 | device_to_host_pj_per_bit: 10 180 | mem_bw_GBs: 77 181 | mem_pj_per_bit: 10 182 | tflops: 17 # GPU (FP16): 4.73, Hexagon (INT4): 34, Hexagon (INT8): 17 183 | pj_per_tflop: 0.4e+12 184 | 185 | SAM_LPDDR5PIM: 186 | host_to_device_bw_GBs: 12.8 187 | host_to_device_pj_per_bit: 22 188 | device_to_host_bw_GBs: 12.8 189 | device_to_host_pj_per_bit: 22 190 | mem_bw_GBs: 102.4 191 | mem_pj_per_bit: 0.95 192 | tflops: 0.1024 193 | tflops_int4: 0.4096 # 0.1024 * 4 194 | pj_per_tflop: 0.8e+12 195 | 196 | PIM_AI_1chip: 197 | host_to_device_bw_GBs: 12.8 198 | host_to_device_pj_per_bit: 20 199 | device_to_host_bw_GBs: 12.8 200 | device_to_host_pj_per_bit: 20 201 | mem_bw_GBs: 102.4 202 | mem_pj_per_bit: 0.95 203 | tflops: 5 204 | pj_per_tflop: 0.4e+12 205 | 206 | PIM_AI_4chip: 207 | host_to_device_bw_GBs: 12.8 208 | host_to_device_pj_per_bit: 80 # 20 * 4 209 | device_to_host_bw_GBs: 51.2 # 12.8 * 4 210 | device_to_host_pj_per_bit: 20 211 | mem_bw_GBs: 409.6 # 102.4 * 4 212 | mem_pj_per_bit: 0.95 213 | tflops: 20 # 5 * 4 214 | tflops_int4: 128 # 32 * 4 215 | pj_per_tflop: 0.4e+12 216 | softmax_ns_per_element: 0.1 # 0.4 / 4 217 | SiLU_ns_per_element: 0.15 # 0.6 / 4 218 | RMSNorm_ns_per_element: 0.275 # 1.04 / 4 219 | 220 | PIM_AI_1dimm: 221 | host_to_device_bw_GBs: 44 222 | host_to_device_pj_per_bit: 50 223 | device_to_host_bw_GBs: 44 224 | device_to_host_pj_per_bit: 50 225 | mem_bw_GBs: 1638.4 # 102.4 * 16 226 | mem_pj_per_bit: 0.95 227 | tflops: 80 # 5 * 16 228 | pj_per_tflop: 0.4e+12 229 | 230 | PIM_AI_2dimm: 231 | host_to_device_bw_GBs: 22 232 | host_to_device_pj_per_bit: 100 # 50 * 2 233 | device_to_host_bw_GBs: 44 234 | device_to_host_pj_per_bit: 50 235 | mem_bw_GBs: 3276.8 # 102.4 * 16 * 2 236 | mem_pj_per_bit: 0.95 237 | tflops: 160 # 5 * 16 * 2 238 | pj_per_tflop: 0.4e+12 239 | softmax_ns_per_element: 1.25e-02 # 0.4 / (16 * 2) 240 | SiLU_ns_per_element: 1.875e-02 # 0.6 / (16 * 2) 241 | RMSNorm_ns_per_element: 3.25e-02 # 1.04 / (16 * 2) 242 | 243 | PIM_AI_4dimm: 244 | host_to_device_bw_GBs: 22 245 | host_to_device_pj_per_bit: 200 # 50 * 4 246 | device_to_host_bw_GBs: 88 # 44 * 2 247 | device_to_host_pj_per_bit: 50 248 | mem_bw_GBs: 6553.6 # 102.4 * 16 * 4 249 | mem_pj_per_bit: 0.95 250 | tflops: 320 # 5 * 16 * 4 251 | pj_per_tflop: 0.4e+12 252 | softmax_ns_per_element: 6.25e-03 # 0.4 / (16 * 4) 253 | SiLU_ns_per_element: 9.375e-03 # 0.6 / (16 * 4) 254 | RMSNorm_ns_per_element: 1.625e-02 # 1.04 / (16 * 4) 255 | 256 | PIM_AI_16dimm: 257 | host_to_device_bw_GBs: 22 258 | host_to_device_pj_per_bit: 800 # 50 * 16 259 | device_to_host_bw_GBs: 352 # 44 * 8 260 | device_to_host_pj_per_bit: 50 261 | mem_bw_GBs: 26214.4 # 102.4 * 16 * 16 262 | mem_pj_per_bit: 0.95 263 | tflops: 2048 # 8 * 16 * 16 264 | pj_per_tflop: 0.4e+12 265 | softmax_ns_per_element: 1.5625e-03 # 0.4 / (16 * 16) 266 | SiLU_ns_per_element: 2.34375e-03 # 0.6 / (16 * 16) 267 | RMSNorm_ns_per_element: 4.0625e-03 # 1.04 / (16 * 16) 268 | 269 | PIM_AI_8dimm: 270 | host_to_device_bw_GBs: 22 271 | host_to_device_pj_per_bit: 640 # (50 + 2 * 15) * 8 272 | device_to_host_bw_GBs: 176 # 44 * 4 273 | device_to_host_pj_per_bit: 50 274 | mem_bw_GBs: 13107.2 # 102.4 * 16 * 8 275 | mem_pj_per_bit: 0.95 276 | tflops: 1024 # 8 * 16 * 8 277 | pj_per_tflop: 0.5e+12 278 | softmax_ns_per_element: 3.125e-03 # 0.4 / (16 * 8) 279 | SiLU_ns_per_element: 4.6875e-03 # 0.6 / (16 * 8) 280 | RMSNorm_ns_per_element: 8.125e-03 # 1.04 / (16 * 8) 281 | 282 | PIM_AI_10dimm: 283 | host_to_device_bw_GBs: 22 284 | host_to_device_pj_per_bit: 500 # 50 * 10 285 | device_to_host_bw_GBs: 220 # 44 * 5 286 | device_to_host_pj_per_bit: 50 287 | mem_bw_GBs: 16384 # 102.4 * 16 * 10 288 | mem_pj_per_bit: 0.95 289 | tflops: 1280 # 8 * 16 * 10 290 | pj_per_tflop: 0.4e+12 291 | softmax_ns_per_element: 2.5e-03 # 0.4 / (16 * 10) 292 | SiLU_ns_per_element: 3.75e-03 # 0.6 / (16 * 10) 293 | RMSNorm_ns_per_element: 6.5e-03 # 1.04 / (16 * 10) 294 | 295 | PIM_AI_6dimm: 296 | host_to_device_bw_GBs: 22 297 | host_to_device_pj_per_bit: 480 # (50 + 2 * 15) * 6 298 | device_to_host_bw_GBs: 132 # 44 * 3 299 | device_to_host_pj_per_bit: 50 300 | mem_bw_GBs: 9830.4 # 102.4 * 16 * 6 301 | mem_pj_per_bit: 0.95 302 | tflops: 768 # 8 * 16 * 6 303 | pj_per_tflop: 0.5e+12 304 | softmax_ns_per_element: 4.1667e-03 # 0.4 / (16 * 6) 305 | SiLU_ns_per_element: 6.25e-03 # 0.6 / (16 * 6) 306 | RMSNorm_ns_per_element: 1.0833e-02 # 1.04 / (16 * 6) 307 | 308 | CXL_PIM_BC: 309 | # CXL board with: 310 | # 8-lane full duplex PCIe GEN5 311 | # 16 LPDDR controllers, 16 bits, 9.6 GT/s, dual rank (2 devices per IFC) 312 | # A device is a stack of 4 LPDDR-PIM 313 | # 256 GB overall memory (stacking 4 LPDDR-PIM = 8 dies) 314 | # This might be seen as 8x AI PIM DIMM with C2C connection between groups of 4 chips 315 | # Broadcast between LPDDR-PIM is possible 316 | host_to_device_bw_GBs: 19.2 # 8-lane PCIe GEN5, but only one LPDDR5 at a time 317 | host_to_device_pj_per_bit: 50 # crossing PCIe and LPDDR interfaces on both host and device 318 | device_to_host_bw_GBs: 19.2 # 8-lane PCIe GEN5, but only one LPDDR5 at a time 319 | device_to_host_pj_per_bit: 50 # crossing PCIe and LPDDR interfaces on both host and device 320 | mem_bw_GBs: 13107.2 # 102.4 * 16 * 2 * 4 321 | mem_pj_per_bit: 0.95 322 | tflops: 640 # 5 * 16 * 2 *4 323 | pj_per_tflop: 0.4e+12 324 | softmax_ns_per_element: 3.125e-03 # 0.4 / (16 * 2 * 4) 325 | SiLU_ns_per_element: 4.6875e-03 # 0.6 / (16 * 2 * 4) 326 | RMSNorm_ns_per_element: 8.125e-03 # 1.04 / (16 * 2 * 4) 327 | 328 | CXL_PIM_nBC: 329 | # CXL board with: 330 | # 8-lane full duplex PCIe GEN5 331 | # 16 LPDDR controllers, 16 bits, 9.6 GT/s, dual rank (2 devices per IFC) 332 | # A device is a stack of 4 LPDDR-PIM 333 | # 256 GB overall memory (stacking 4 LPDDR-PIM = 8 dies) 334 | # This might be seen as 8x AI PIM DIMM with C2C connection between groups of 4 chips 335 | # Broadcast between LPDDR-PIM is not possible 336 | host_to_device_bw_GBs: 0.6 # 19.2 /32: 8-lane PCIe GEN5, but only one LPDDR5 at a time 337 | host_to_device_pj_per_bit: 1600 # 50 * 32: crossing PCIe and LPDDR interfaces on both host and device 338 | device_to_host_bw_GBs: 19.2 # 8-lane PCIe GEN5, but only one LPDDR5 at a time 339 | device_to_host_pj_per_bit: 50 # crossing PCIe and LPDDR interfaces on both host and device 340 | mem_bw_GBs: 13107.2 # 102.4 * 16 * 2 * 4 341 | mem_pj_per_bit: 0.95 342 | tflops: 640 # 5 * 16 * 2 * 4 343 | pj_per_tflop: 0.4e+12 344 | softmax_ns_per_element: 3.125e-3 # 0.4 / (16 * 2 * 4) 345 | SiLU_ns_per_element: 4.6875e-3 # 0.6 / (16 * 2 * 4) 346 | RMSNorm_ns_per_element: 8.125e-3 # 1.04 / (16 * 2 * 4) 347 | 348 | PIM_AI_12dimm: 349 | host_to_device_bw_GBs: 22 350 | host_to_device_pj_per_bit: 960 # (50 + 2 * 15) * 12 351 | device_to_host_bw_GBs: 264 # 44 * 6 352 | device_to_host_pj_per_bit: 50 353 | mem_bw_GBs: 19660.8 # 102.4 * 16 * 12 354 | mem_pj_per_bit: 0.95 355 | tflops: 1536 # 8 * 16 * 12 356 | pj_per_tflop: 0.5e+12 357 | softmax_ns_per_element: 2.0833e-3 # 0.4 / (16 * 12) 358 | SiLU_ns_per_element: 3.125e-3 # 0.6 / (16 * 12) 359 | RMSNorm_ns_per_element: 5.4167e-03 # 1.04 / (16 * 12) 360 | 361 | PIM_AI_24dimm: 362 | host_to_device_bw_GBs: 22 363 | host_to_device_pj_per_bit: 1920 # (50 + 2 * 15) * 24 364 | device_to_host_bw_GBs: 528 # 44 * 12 365 | device_to_host_pj_per_bit: 50 366 | mem_bw_GBs: 39321.6 # 102.4 * 16 * 24 367 | mem_pj_per_bit: 0.95 368 | tflops: 3072 # 8 * 16 * 24 369 | pj_per_tflop: 0.5e+12 370 | softmax_ns_per_element: 1.0417e-3 # 0.4 / (16 * 24) 371 | SiLU_ns_per_element: 1.5625e-3 # 0.6 / (16 * 24) 372 | RMSNorm_ns_per_element: 2.7083e-3 # 1.04 / (16 * 24) 373 | -------------------------------------------------------------------------------- /src/upmem_llm_framework/simulator.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2014-2024 - UPMEM 3 | # UPMEM S.A.S France property - UPMEM confidential information covered by NDA 4 | # For UPMEM partner internal use only - no modification allowed without permission of UPMEM 5 | # 6 | # This file implements multiple entry points called by the profiler to simulate the underlying 7 | # hardware. 8 | 9 | 10 | import torch 11 | 12 | from upmem_llm_framework.base_architecture import BaseArchitecture 13 | from upmem_llm_framework.sim_architectures import get_spec 14 | from upmem_llm_framework.utils import add_dictionaries 15 | 16 | 17 | class Simulator: 18 | 19 | def __init__( 20 | self, 21 | data_type_bytes=2.0, 22 | sliding_window=-1, 23 | num_key_value_heads=-1, 24 | verbose=False, 25 | ): 26 | self.data_type_bytes = data_type_bytes 27 | self.layer_mapping = {} 28 | self.layer_attn_ctxt = "" 29 | self.use_kv_cache = True 30 | self.sum = True 31 | self.sum_size = 0 32 | self.batch_size = 0 33 | self.sliding_window = sliding_window 34 | self.num_key_value_heads = num_key_value_heads 35 | self.verbose = verbose 36 | 37 | self.moe_already_sent = {} 38 | self.moe_end = "" 39 | self.experts_per_token = 2 # from Mixtral 8x7B 40 | 41 | self.current_device, _, _ = self.name_to_device("HOST") 42 | 43 | def start_gen(self): 44 | self.sum = False 45 | 46 | def map_layers(self, mapping, layer_attn_ctxt="", moe_end="", experts_per_token=2): 47 | self.layer_mapping = mapping 48 | self.layer_attn_ctxt = layer_attn_ctxt 49 | if self.verbose: 50 | print(self.layer_mapping) 51 | print(self.layer_attn_ctxt) 52 | self.moe_end = moe_end 53 | self.experts_per_token = experts_per_token 54 | 55 | def name_to_device(self, full_name: str): 56 | 57 | device_name = full_name.split(",")[0].replace("-", "_") 58 | flags = full_name.split(",")[1] if len(full_name.split(",")) > 1 else "" 59 | 60 | do_transfer = "t" in flags 61 | moe = "m" in flags 62 | 63 | new_device = BaseArchitecture( 64 | data_type_bytes=self.data_type_bytes, 65 | sliding_window=self.sliding_window, 66 | num_key_value_heads=self.num_key_value_heads, 67 | verbose=self.verbose, 68 | ) 69 | 70 | if device_name == "HOST": 71 | new_device.name = "HOST" 72 | return new_device, do_transfer, moe 73 | 74 | spec = get_spec(device_name) 75 | new_device.load_spec(device_name, spec) 76 | 77 | # new_device.adjust_for_quantization() 78 | 79 | return new_device, do_transfer, moe 80 | 81 | def simulate_attn(self, input_shape): 82 | batch_size = input_shape[0] if (len(input_shape) > 2) else 1 83 | n_rows = input_shape[1] if (len(input_shape) > 1) else 1 84 | n_columns = input_shape[-1] 85 | 86 | # input (n_rows, n_columns) x Wqkv (n_columns, Wqkv), 87 | # where Wqkv is n_columns*3 (all Wq, Wk, Wv together) 88 | # Qall_heads = (n_rows, n_columns), each head computes n_columns/n_heads 89 | # Kall_heads = (n_rows, n_columns) 90 | # Vall_heads = (n_rows, n_columns) 91 | 92 | # TODO: add transpose time, concat time 93 | compute_time_ns = 0 94 | performance = {} 95 | energy_compute = {} 96 | 97 | kt = torch.Size([n_columns, n_rows]) 98 | qkt = torch.Size([batch_size, n_rows, n_rows]) 99 | v = torch.Size([n_rows, n_columns]) 100 | 101 | if self.use_kv_cache: 102 | if self.sum: 103 | self.sum_size = n_rows # Just to keep it updated 104 | else: 105 | # load KV cache 106 | kv_cache = torch.Size([batch_size, self.sum_size, n_columns * 2]) 107 | step_time, step_perf, step_energy = self.current_device.load_data( 108 | kv_cache 109 | ) 110 | compute_time_ns += step_time 111 | performance = add_dictionaries(performance, step_perf) 112 | energy_compute = add_dictionaries(energy_compute, step_energy) 113 | 114 | # concat K + 1 115 | kt = torch.Size([n_columns, self.sum_size + 1]) 116 | # concat V + 1 117 | v = torch.Size([self.sum_size + 1, n_columns]) 118 | qkt = torch.Size([batch_size, n_rows, self.sum_size + 1]) 119 | 120 | # Q x Kt 121 | if self.verbose: 122 | print("Computing Q x Kt") 123 | step_time, step_perf, step_energy = self.current_device.compute_ns( 124 | input_shape, 125 | torch.nn.Linear(in_features=n_columns, out_features=n_rows), 126 | kt, 127 | load_input=False, 128 | ) 129 | compute_time_ns += step_time 130 | performance = add_dictionaries(performance, step_perf) 131 | energy_compute = add_dictionaries(energy_compute, step_energy) 132 | 133 | # output = V * QKt 134 | if self.verbose: 135 | print("Computing V x QKt") 136 | step_time, step_perf, step_energy = self.current_device.compute_ns( 137 | qkt, 138 | torch.nn.Linear(in_features=n_rows, out_features=n_rows), 139 | v, 140 | load_input=False, 141 | ) 142 | compute_time_ns += step_time 143 | performance = add_dictionaries(performance, step_perf) 144 | energy_compute = add_dictionaries(energy_compute, step_energy) 145 | 146 | if self.verbose: 147 | print( 148 | f"Attn ({'SUM.' if self.sum else 'GEN.'}): " 149 | f"Q: {input_shape} Kt: {kt} QKt: {qkt} V: {v}" 150 | ) 151 | print(performance) 152 | 153 | return compute_time_ns, performance, energy_compute 154 | 155 | def simulate_end(self, input_shape, generated_tokens=1): 156 | 157 | time_send_ans_to_host = 0 158 | perf_send_ans_to_host = {} 159 | energy_send_ans_to_host = {} 160 | data_send_ans_to_host = {} 161 | 162 | if self.current_device.name != "HOST": 163 | # we asume the new input is what needs to be written back to host from previous layer 164 | ( 165 | time_send_ans_to_host, 166 | perf_send_ans_to_host, 167 | energy_send_ans_to_host, 168 | data_send_ans_to_host, 169 | ) = self.current_device.host_transfer( 170 | input_shape, direction="to_host", generated_tokens=generated_tokens 171 | ) 172 | 173 | return ( 174 | time_send_ans_to_host, 175 | perf_send_ans_to_host, 176 | energy_send_ans_to_host, 177 | data_send_ans_to_host, 178 | ) 179 | 180 | def check_moe(self, context): 181 | num_seen = self.moe_already_sent.get(context, 0) 182 | 183 | self.moe_already_sent[context] = num_seen + 1 184 | 185 | return num_seen == 0 186 | 187 | def reset_moe(self): 188 | all_moe_seen = True 189 | for v in self.moe_already_sent.values(): 190 | all_moe_seen = all_moe_seen and (v == self.experts_per_token) 191 | 192 | if all_moe_seen: 193 | # Reset dict for next iteration 194 | self.moe_already_sent = {} 195 | 196 | return all_moe_seen 197 | 198 | def check_sync_point(self, context, input_shape): 199 | time_send_ans_to_host = 0 200 | time_send_ans_from_host = 0 201 | perf = {} 202 | energy = {} 203 | moved_data = {} 204 | 205 | new_device = None 206 | 207 | # print(f"Try mapping {context}") 208 | # assume that if the layer is not mapped, it stays in the current device 209 | if context in self.layer_mapping: 210 | # print(f"Mapping {context} to {self.layer_mapping[context]}") 211 | new_device, gather_at_host, moe = self.name_to_device( 212 | self.layer_mapping[context] 213 | ) 214 | 215 | # if new_device != current_device --> pay transfer 216 | if ( 217 | new_device.name != self.current_device.name 218 | or gather_at_host 219 | or (moe and self.check_moe(context)) 220 | ): 221 | 222 | # if HOST is not current device, transfer to HOST the output from the last layer 223 | # we asume the new input is what needs to be written back to host from previous 224 | # layer 225 | if self.current_device.name != "HOST": 226 | time_send_ans_to_host, step_perf, step_energy, step_data = ( 227 | self.current_device.host_transfer( 228 | input_shape, direction="to_host" 229 | ) 230 | ) 231 | perf = add_dictionaries(perf, step_perf) 232 | energy = add_dictionaries(energy, step_energy) 233 | moved_data = add_dictionaries(moved_data, step_data) 234 | 235 | # change current_device = new_device 236 | old_device = self.current_device 237 | self.current_device = new_device 238 | 239 | # then, host writes into the new device 240 | if self.current_device.name != "HOST": 241 | time_send_ans_from_host, step_perf, step_energy, step_data = ( 242 | self.current_device.host_transfer(input_shape) 243 | ) 244 | perf = add_dictionaries(perf, step_perf) 245 | energy = add_dictionaries(energy, step_energy) 246 | moved_data = add_dictionaries(moved_data, step_data) 247 | 248 | if self.verbose: 249 | print( 250 | f"Changing device from {old_device.name} to {new_device.name} " 251 | f"took {perf} with energy: {energy}" 252 | ) 253 | 254 | return time_send_ans_to_host, time_send_ans_from_host, perf, energy, moved_data 255 | 256 | def simulate_layer(self, layer, input_shape, layer_obj, weight_shape, output_shape): 257 | 258 | time_send_ans_to_host = 0 259 | time_send_ans_from_host = 0 260 | compute_time_ns = 0 261 | 262 | perf_transfer = {} 263 | perf_compute = {} 264 | 265 | energy_transfer = {} 266 | energy_compute = {} 267 | 268 | data_transfer = {} 269 | 270 | if self.verbose: 271 | print("Simulating layer:", layer.context, "n_layer:", layer.n_layer) 272 | 273 | ( 274 | time_send_ans_to_host, 275 | time_send_ans_from_host, 276 | perf_transfer, 277 | energy_transfer, 278 | data_transfer, 279 | ) = self.check_sync_point(layer.context, input_shape) 280 | 281 | if layer.context == self.moe_end and self.moe_end != "": 282 | if self.reset_moe(): 283 | # output_shape expected to be [tokens * batch_size, features] 284 | # Send all experts' output except one, which shall be accounted by the layer mapping 285 | transfer_shape = torch.Size( 286 | [self.experts_per_token - 1, output_shape[0], output_shape[1]] 287 | ) 288 | ( 289 | time_send_ans_to_host_moe, 290 | perf_transfer_moe, 291 | energy_transfer_moe, 292 | data_transfer_moe, 293 | ) = self.current_device.host_transfer( 294 | transfer_shape, direction="to_host" 295 | ) 296 | if self.verbose: 297 | print("Last layer of MoE sends back to HOST: ", transfer_shape) 298 | time_send_ans_to_host += time_send_ans_to_host_moe 299 | perf_transfer = add_dictionaries(perf_transfer, perf_transfer_moe) 300 | energy_transfer = add_dictionaries(energy_transfer, energy_transfer_moe) 301 | data_transfer = add_dictionaries(data_transfer, data_transfer_moe) 302 | 303 | # pay compute 304 | step_time, step_perf, step_energy = self.current_device.compute_ns( 305 | input_shape, layer_obj, weight_shape 306 | ) 307 | compute_time_ns += step_time 308 | perf_compute = add_dictionaries(perf_compute, step_perf) 309 | energy_compute = add_dictionaries(energy_compute, step_energy) 310 | 311 | # If self-attention required, simulate 312 | # if (layer.context == self.layer_attn_ctxt): 313 | # step_time, step_perf, step_energy = self.simulate_attn(input_shape, weight_shape) 314 | # compute_time_ns += step_time 315 | # perf_compute = add_dictionaries(perf_compute , step_perf ) 316 | # energy_compute = add_dictionaries(energy_compute, step_energy) 317 | 318 | if self.verbose: 319 | print("Time send ans to host (ns):", time_send_ans_to_host) 320 | print("Time send ans from host (ns):", time_send_ans_from_host) 321 | print("Compute time (ns):", compute_time_ns) 322 | print("Energy send ans to/from host (pJ):", energy_transfer) 323 | print("Energy Compute (pJ):", energy_compute) 324 | 325 | # we can pipeline TODO: calculate this 326 | # max (time_send_ans_from_host, compute_time) 327 | 328 | # return simulated stats 329 | total_time = time_send_ans_to_host + time_send_ans_from_host + compute_time_ns 330 | 331 | total_perf = add_dictionaries(perf_transfer, perf_compute) 332 | 333 | total_energy = add_dictionaries(energy_transfer, energy_compute) 334 | 335 | return total_time, total_perf, total_energy, data_transfer 336 | 337 | def simulate_function(self, function, context, input_shape, output_shape): 338 | 339 | function_name = ( 340 | function.__name__ if hasattr(function, "__name__") else function.name 341 | ) 342 | 343 | if self.verbose: 344 | print( 345 | f"Simulating function: {function_name}, " 346 | f"context: {context}, " 347 | f"input shape: {input_shape}, " 348 | f"output shape: {output_shape}" 349 | ) 350 | 351 | ( 352 | time_send_ans_to_host, 353 | time_send_ans_from_host, 354 | perf_transfer, 355 | energy_transfer, 356 | data_transfer, 357 | ) = self.check_sync_point(function_name, input_shape) 358 | 359 | compute_time_ns, perf_compute, energy_compute = self._compute_function_metrics( 360 | function, context, input_shape, output_shape 361 | ) 362 | 363 | if self.verbose: 364 | print("Time send ans to host (ns):", time_send_ans_to_host) 365 | print("Time send ans from host (ns):", time_send_ans_from_host) 366 | print("Compute time (ns):", compute_time_ns) 367 | print("Energy send ans to/from host (pJ):", energy_transfer) 368 | print("Energy Compute (pJ):", energy_compute) 369 | 370 | # return simulated stats 371 | total_time = time_send_ans_to_host + time_send_ans_from_host + compute_time_ns 372 | 373 | total_perf = add_dictionaries(perf_transfer, perf_compute) 374 | total_energy = add_dictionaries(energy_transfer, energy_compute) 375 | 376 | return total_time, total_perf, total_energy, data_transfer 377 | 378 | def _compute_function_metrics(self, function, context, input_shape, output_shape): 379 | if hasattr(function, "__name__") and function.__name__.endswith("softmax"): 380 | return self.current_device.compute_softmax_ns(input_shape) 381 | if hasattr(function, "name") and function.name.endswith("LlamaRMSNorm"): 382 | return self.current_device.compute_RMSNorm_ns(input_shape, output_shape) 383 | if hasattr(function, "name") and ( 384 | function.name.endswith("SiLU") or function.name.endswith("SiLUActivation") 385 | ): 386 | return self.current_device.compute_activation_ns( 387 | input_shape, activation="SiLU" 388 | ) 389 | if hasattr(function, "__name__") and function.__name__.endswith("matmul"): 390 | return self.current_device.compute_matmul_ns( 391 | context, 392 | input_shape, 393 | output_shape, 394 | summarization=self.sum, 395 | sum_size=self.sum_size, 396 | ) 397 | if hasattr(function, "__name__") and function.__name__.endswith( 398 | "scaled_dot_product_attention" 399 | ): 400 | return self.current_device.compute_scaled_dot_product_ns( 401 | context, 402 | input_shape, 403 | output_shape, 404 | summarization=self.sum, 405 | sum_size=self.sum_size, 406 | ) 407 | raise ValueError( 408 | "Unsupported function: " 409 | f"{function.__name__ if hasattr(function, '__name__') else function.name}, " 410 | f"type: {type(function)}, " 411 | f"string: {function}" 412 | ) 413 | -------------------------------------------------------------------------------- /src/upmem_llm_framework/base_architecture.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2014-2024 - UPMEM 3 | # UPMEM S.A.S France property - UPMEM confidential information covered by NDA 4 | # For UPMEM partner internal use only - no modification allowed without permission of UPMEM 5 | # 6 | # This file implements Base_architecture class 7 | # This class contains a default implementation of the following functions: 8 | # - adjust_for_quantization: scales up/down the TFLOPs depending on the quantization choosen 9 | # - get_tflops: returns the TFLOPs required in a MxM 10 | # - get_moved_data_bytes: returns the required bytes to move in order to do an operation 11 | # - load_data: models loading the KV cache 12 | # - host_transfer: simulates a data transfer with host in any direction 13 | # - compute_ns: simulates the computation of a MxM 14 | # - compute_scaled_dot_product_ns: simulates the computation of function scaled_dot_product where, 15 | # usually, attention computation occurs 16 | # - compute_matmul_ns: simulates the computation of a matmul for self-attention 17 | # - compute_activation_ns: simulates an activation layer 18 | # - compute_RMSNorm_ns: simulates a RMSNorm layer 19 | # - compute_softmax_ns: simulates a softmax operation 20 | # 21 | # Note that all simulations returns compute_time_ns, performance_dict, energy_dict: 22 | # - compute_time_ns: the simulated time in ns, 23 | # - performance_dict: dictionary containing the simulated time in ns for each operation simulated, 24 | # - energy_dict: dictionary containing the simulated energy in pJ for each operation simulated. 25 | 26 | import math 27 | from typing import Dict 28 | 29 | import torch 30 | 31 | from upmem_llm_framework.utils import add_dictionaries 32 | 33 | 34 | class BaseArchitecture: 35 | 36 | def __init__( 37 | self, 38 | active_chips=1, 39 | tflops=1, 40 | pj_per_tflop=1, 41 | host_to_device_bw_GBs=1, 42 | device_to_host_bw_GBs=1, 43 | # inter_bw = 1, 44 | memory=1, 45 | mem_bw_GBs=1, 46 | mem_pj_per_bit=1, 47 | data_type_bytes=2.0, # float16 48 | # 3000 cycles per row of 2048 elements --> 1.4 cycles / element 49 | # assuming 1 GHz, 1.5 ns / element, parallelized accross 4 chips -> 0.37 50 | softmax_ns_per_element=0.4, # ns, considering it cycles in 1GHz config 51 | SiLU_ns_per_element=0.6, # ns, softmax * 1.5 52 | # (empiric number based on execution of Llama2-7b) 53 | RMSNorm_ns_per_element=1.1, # ns, softmax * 2.6 54 | # (empiric number based on execution of Llama2-7b) 55 | # 3000 cycles per row of 2048 elements with 5 TFLOPs of computing power 56 | # assuming 1 GHz, 0,000003 s --> 3MOPS per row of 2048 --> 1.5kOPS per element 57 | misc_tflops_per_element=1500 / 1e12, 58 | sliding_window=-1, 59 | num_key_value_heads=-1, 60 | verbose=False, 61 | ): 62 | 63 | self.name = "" 64 | self.active_chips = active_chips 65 | # Compute capabilities 66 | self.tflops = tflops 67 | self.pj_per_tflop = 0.4 68 | # Interface with HOST 69 | self.host_to_device_bw_GBs = host_to_device_bw_GBs 70 | self.device_to_host_bw_GBs = device_to_host_bw_GBs 71 | self.host_to_device_pj_per_bit = 25 72 | self.device_to_host_pj_per_bit = 25 73 | # self.inter_bw = inter_bw 74 | 75 | # Device memory (shared memory like) 76 | self.memory = memory # unused at the moment 77 | self.mem_bw_GBs = mem_bw_GBs 78 | self.mem_pj_per_bit = mem_pj_per_bit 79 | self.pj_per_tflop = pj_per_tflop 80 | 81 | self.data_type_bytes = data_type_bytes 82 | 83 | self.softmax_ns_per_element = softmax_ns_per_element 84 | self.RMSNorm_ns_per_element = RMSNorm_ns_per_element 85 | self.SiLU_ns_per_element = SiLU_ns_per_element 86 | 87 | self.misc_tflops_per_element = misc_tflops_per_element 88 | 89 | self.sliding_window = sliding_window 90 | self.num_key_value_heads = num_key_value_heads 91 | 92 | self.verbose = verbose 93 | 94 | def load_spec(self, name: str, spec: Dict): 95 | """Load accelerator specification from a dictionary""" 96 | self.name = name 97 | for key, value in spec.items(): 98 | if key == "tflops_int4": 99 | continue 100 | if not hasattr(self, key): 101 | raise ValueError( 102 | f"Warning: {key} is not a valid attribute for {self.__class__}" 103 | ) 104 | setattr(self, key, value) 105 | if "tflops_int4" in spec and self.data_type_bytes == 0.5: 106 | self.tflops = spec["tflops_int4"] 107 | 108 | # Defined TFLOPS are defined for float16, 109 | # assume that performance is doubled if data type is demoted 110 | def adjust_for_quantization(self): 111 | ratio = 2 / self.data_type_bytes # Assume pj_per_tflop corresponds to float16 112 | self.pj_per_tflop = self.pj_per_tflop / ratio 113 | # self.tflops = self.tflops * (2 / self.data_type_bytes) 114 | 115 | # self.softmax_ns_per_element = self.softmax_ns_per_element / ratio 116 | # self.RMSNorm_ns_per_element = self.RMSNorm_ns_per_element / ratio 117 | # self.SiLU_ns_per_element = self.SiLU_ns_per_element / ratio 118 | 119 | def get_tflops_Linear(self, input_shape, weight_shape): 120 | out_features = weight_shape[0] 121 | tflops = input_shape.numel() * out_features * 2 / 1e12 122 | 123 | return tflops 124 | 125 | # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html 126 | def get_tflops_Conv2d(self, input_shape, layer, weight_shape): 127 | batch_size = input_shape[-4] if (len(input_shape) > 3) else 1 128 | n_channels = input_shape[-3] if (len(input_shape) > 2) else 1 129 | n_height = input_shape[-2] if (len(input_shape) > 1) else 1 130 | n_width = input_shape[-1] 131 | 132 | stride = layer.stride 133 | padding = layer.padding 134 | 135 | n_height = n_height + 2 * ( 136 | padding[0] if isinstance(padding, tuple) else padding 137 | ) 138 | n_width = n_width + 2 * (padding[1] if isinstance(padding, tuple) else padding) 139 | 140 | # Example of how many times a kernel is applied depending on stride: 141 | # | 0 | 1 | 2 | 3 | 4 | 5 | 142 | # Stride 1: 143 | # 00001111 144 | # 11112222 145 | # 22223333 146 | # 33334444 147 | # 44445555 148 | # Stride 2: 149 | # 00001111 150 | # 22223333 151 | # 44445555 152 | 153 | width_times = math.ceil( 154 | (n_width - 1) / (stride[1] if isinstance(stride, tuple) else stride) 155 | ) 156 | height_times = math.ceil( 157 | (n_width - 1) / (stride[0] if isinstance(stride, tuple) else stride) 158 | ) 159 | 160 | # TFLOPS when applying once the kernel 161 | tflops_kernel = ( 162 | 2 * batch_size * n_channels * weight_shape[1] * weight_shape[0] 163 | ) / 1e12 164 | 165 | tflops = tflops_kernel * width_times * height_times 166 | 167 | return tflops 168 | 169 | def get_tflops_LayerNorm(self, input_shape): 170 | batch_size = input_shape[-4] if (len(input_shape) > 3) else 1 171 | n_heads = input_shape[-3] if (len(input_shape) > 2) else 1 172 | n_rows = input_shape[-2] if (len(input_shape) > 1) else 1 173 | n_columns = input_shape[-1] 174 | 175 | tflops = ( 176 | batch_size * n_heads * n_rows * n_columns * self.misc_tflops_per_element 177 | ) 178 | 179 | return tflops 180 | 181 | # TODO: see DeepSpeed implementation 182 | # def _attn_flops_compute 183 | 184 | def get_tflops(self, input_shape, layer, weight_shape): 185 | if issubclass(torch.nn.Conv2d, type(layer)): 186 | tflops = self.get_tflops_Conv2d(input_shape, layer, weight_shape) 187 | elif issubclass(torch.nn.LayerNorm, type(layer)): 188 | tflops = self.get_tflops_LayerNorm(input_shape) 189 | else: 190 | # Treat everything else as Linear 191 | tflops = self.get_tflops_Linear(input_shape, weight_shape) 192 | # print ("get_tflops not defined for layer: ", type(layer)) 193 | # sys.exit(-1) 194 | 195 | return tflops 196 | 197 | def get_moved_data_bytes( 198 | self, input_shape, weight_shape, load_input=False, load_weight=True 199 | ): 200 | batch_size = input_shape[-4] if (len(input_shape) > 3) else 1 201 | n_heads = input_shape[-3] if (len(input_shape) > 2) else 1 202 | n_rows = input_shape[-2] if (len(input_shape) > 1) else 1 203 | n_columns = input_shape[-1] 204 | 205 | weight_size = weight_shape[1] * weight_shape[0] if load_weight else 0 206 | input_size = batch_size * n_heads * n_rows * n_columns if load_input else 0 207 | # output_size = batch_size * n_rows * weight_shape[1] 208 | return self.data_type_bytes * (weight_size + input_size) # + output_size) 209 | 210 | # KV cache load 211 | def load_data(self, input_shape): 212 | batch_size = input_shape[-4] if (len(input_shape) > 3) else 1 213 | n_heads = input_shape[-3] if (len(input_shape) > 2) else 1 214 | n_rows = input_shape[-2] if (len(input_shape) > 1) else 1 215 | n_columns = input_shape[-1] 216 | 217 | # data_size_bytes = self.data_type_bytes * ( 218 | # batch_size * n_rows * n_heads * n_columns 219 | # ) 220 | 221 | # Hardcoded to FP16 222 | data_size_bytes = 2 * (batch_size * n_rows * n_heads * n_columns) 223 | # B / (GB/s) --> s/G --> ns 224 | transfer_time_ns = data_size_bytes / self.mem_bw_GBs 225 | 226 | performance = {"kv_load": transfer_time_ns} 227 | # GB/s * time --> GB * pJ/bit --> J 228 | energy = {"main_mem": data_size_bytes * 8 * self.mem_pj_per_bit} 229 | 230 | if self.verbose: 231 | print( 232 | "Load time for input_shape:", 233 | input_shape, 234 | "=", 235 | transfer_time_ns, 236 | "(ns) with", 237 | energy, 238 | "pj", 239 | performance, 240 | ) 241 | 242 | return transfer_time_ns, performance, energy 243 | 244 | def host_transfer(self, input_shape, direction="to_device", generated_tokens=1): 245 | batch_size = input_shape[-4] if (len(input_shape) > 3) else 1 246 | n_heads = input_shape[-3] if (len(input_shape) > 2) else 1 247 | n_rows = input_shape[-2] if (len(input_shape) > 1) else 1 248 | n_columns = input_shape[-1] 249 | 250 | bandwidth = ( 251 | self.host_to_device_bw_GBs 252 | if direction == "to_device" 253 | else self.device_to_host_bw_GBs 254 | ) 255 | 256 | data_size_bytes = self.data_type_bytes * ( 257 | batch_size * n_heads * n_rows * n_columns * generated_tokens 258 | ) 259 | # B / (GB/s) --> s / G --> ns 260 | transfer_time_ns = data_size_bytes / bandwidth 261 | 262 | name_op = "host_to_device" if direction == "to_device" else "device_to_host" 263 | 264 | performance = {name_op: transfer_time_ns} 265 | moved_data = {name_op: data_size_bytes} 266 | 267 | ddr_pj_per_bit = ( 268 | self.host_to_device_pj_per_bit 269 | if direction == "to_device" 270 | else self.device_to_host_pj_per_bit 271 | ) 272 | energy = {name_op: data_size_bytes * 8 * ddr_pj_per_bit} 273 | 274 | if self.verbose: 275 | print( 276 | f"Transfer time for input_shape: {input_shape} = {transfer_time_ns} (ns) " 277 | f"with {energy} pj perf: {performance} data in bytes: {moved_data}" 278 | ) 279 | 280 | return transfer_time_ns, performance, energy, moved_data 281 | 282 | def compute_ns( 283 | self, input_shape, layer_obj, weight_shape, load_input=False, load_weight=True 284 | ): 285 | tflops = self.get_tflops(input_shape, layer_obj, weight_shape) 286 | 287 | data_size_bytes = self.get_moved_data_bytes( 288 | input_shape, weight_shape, load_input=load_input, load_weight=load_weight 289 | ) 290 | 291 | compute_time_ns = (tflops / self.tflops) * 1e9 292 | transfer_time_ns = data_size_bytes / self.mem_bw_GBs 293 | real_time_ns = max(compute_time_ns, transfer_time_ns) 294 | 295 | performance = { 296 | "compute": compute_time_ns, 297 | "mem_transfer": transfer_time_ns, 298 | } 299 | 300 | energy = { 301 | "compute": tflops * self.pj_per_tflop, 302 | "main_mem": data_size_bytes * 8 * self.mem_pj_per_bit, 303 | } 304 | 305 | if self.verbose: 306 | print( 307 | f"Computing {input_shape} x {weight_shape} with TFLOPS: {tflops} " 308 | f"with {data_size_bytes} bytes" 309 | ) 310 | print( 311 | f"takes {real_time_ns} ns with {compute_time_ns} in compute and {transfer_time_ns} " 312 | f"in loading data" 313 | ) 314 | print( 315 | f"and consumes {energy['compute']} pJ for compute and {energy['main_mem']} pJ " 316 | "for loading data" 317 | ) 318 | print(f"performance: {performance}") 319 | 320 | return real_time_ns, performance, energy 321 | 322 | def compute_scaled_dot_product_ns( 323 | self, 324 | context, 325 | key_shape, # same dimensions as value_shape 326 | output_shape, 327 | use_kv_cache=True, 328 | summarization=False, 329 | sum_size=0, 330 | ): 331 | batch_size = key_shape[-4] if (len(key_shape) > 3) else 1 332 | n_heads = key_shape[-3] if (len(key_shape) > 2) else 1 333 | n_rows = key_shape[-2] if (len(key_shape) > 1) else 1 # already concatenated! 334 | n_columns = key_shape[-1] 335 | 336 | q_rows = ( 337 | output_shape[-2] if (len(output_shape) > 2) else 1 338 | ) # 1 when using kv cache in GEN. 339 | 340 | compute_time_ns = 0 341 | load_k_time = 0 342 | load_v_time = 0 343 | performance = {} 344 | energy = {} 345 | 346 | # Load KV cache if GENeration and kv cache is enabled 347 | # Only K is required for next step 348 | if not summarization and use_kv_cache: 349 | if self.sliding_window != -1: 350 | n_rows = self.sliding_window 351 | if self.num_key_value_heads != -1: 352 | n_heads = self.num_key_value_heads 353 | 354 | k_cache = torch.Size([batch_size, n_heads, n_rows, n_columns]) 355 | load_k_time, load_k_perf, load_k_energy = self.load_data(k_cache) 356 | performance = add_dictionaries(performance, load_k_perf) 357 | energy = add_dictionaries(energy, load_k_energy) 358 | 359 | # Q x Kt 360 | # (q_rows, embedding) x (embedding, kv_cache_length) = (q_rows, kv_cache_length) 361 | query_shape = torch.Size([batch_size, n_heads, q_rows, n_columns]) 362 | kt_shape = torch.Size([batch_size, n_heads, n_columns, n_rows]) 363 | # TODO: fix this calculation 364 | step_time, step_perf, step_energy = self.compute_ns( 365 | query_shape, 366 | torch.nn.Linear(in_features=n_columns, out_features=n_rows), 367 | kt_shape, 368 | load_input=False, 369 | load_weight=False, 370 | ) 371 | compute_time_ns += max( 372 | load_k_time, step_time 373 | ) # overlap loading K with Q x Kt computation 374 | performance = add_dictionaries(performance, step_perf) 375 | energy = add_dictionaries(energy, step_energy) 376 | 377 | # Load KV cache if GENeration and kv cache is enabled 378 | # Only V is required for next step 379 | if not summarization and use_kv_cache: 380 | 381 | v_cache = torch.Size([batch_size, n_heads, n_rows, n_columns]) 382 | load_v_time, load_v_perf, load_v_energy = self.load_data(v_cache) 383 | performance = add_dictionaries(performance, load_v_perf) 384 | energy = add_dictionaries(energy, load_v_energy) 385 | 386 | # QxKT x V 387 | # (q_rows, kv_cache_length) x (kv_cache_length, embedding) = (q_rows, embedding) 388 | qxkt_shape = torch.Size([batch_size, n_heads, q_rows, n_rows]) 389 | # TODO: fix this calculation 390 | step_time, step_perf, step_energy = self.compute_ns( 391 | query_shape, 392 | torch.nn.Linear(in_features=n_rows, out_features=n_columns), 393 | key_shape, 394 | load_input=False, 395 | load_weight=False, 396 | ) 397 | compute_time_ns += max( 398 | load_v_time, step_time 399 | ) # overlap loading V with QxKt x V computation 400 | performance = add_dictionaries(performance, step_perf) 401 | energy = add_dictionaries(energy, step_energy) 402 | 403 | if self.verbose: 404 | print( 405 | f"Computing scaled_dot_product: {query_shape} x {kt_shape} x {key_shape} " 406 | f"in {compute_time_ns} with {energy}" 407 | ) 408 | 409 | return compute_time_ns, performance, energy 410 | 411 | def compute_matmul_ns( 412 | self, 413 | context, 414 | shape_a, 415 | shape_b, 416 | summarization=False, 417 | sum_size=0, 418 | ): 419 | batch_size = shape_a[-4] if (len(shape_a) > 3) else 1 420 | n_heads = shape_a[-3] if (len(shape_a) > 2) else 1 421 | n_rows = shape_a[-2] if (len(shape_a) > 1) else 1 422 | n_columns = shape_a[-1] 423 | 424 | compute_time_ns = 0 425 | performance = {} 426 | energy = {} 427 | 428 | if context == "attn_weights": 429 | if not summarization: 430 | kv_cache = torch.Size([batch_size, n_heads, sum_size, n_columns * 2]) 431 | step_time, step_perf, step_energy = self.load_data(kv_cache) 432 | compute_time_ns += step_time 433 | performance = add_dictionaries(performance, step_perf) 434 | energy = add_dictionaries(energy, step_energy) 435 | 436 | step_time, step_perf, step_energy = self.compute_ns( 437 | shape_a, 438 | torch.nn.Linear(in_features=n_rows, out_features=n_columns), 439 | shape_b, 440 | load_input=False, 441 | load_weight=False, 442 | ) 443 | 444 | compute_time_ns += step_time 445 | performance = add_dictionaries(performance, step_perf) 446 | energy = add_dictionaries(energy, step_energy) 447 | 448 | if self.verbose: 449 | print( 450 | f"Computing matmul: {shape_a} x {shape_b} in {compute_time_ns} with {energy}" 451 | ) 452 | 453 | return compute_time_ns, performance, energy 454 | 455 | def compute_activation_ns(self, data_shape, activation="SiLU"): 456 | batch_size = data_shape[-4] if (len(data_shape) > 3) else 1 457 | n_heads = data_shape[-3] if (len(data_shape) > 2) else 1 458 | n_rows = data_shape[-2] if (len(data_shape) > 1) else 1 459 | n_columns = data_shape[-1] 460 | 461 | activation_ns_per_element = 0 462 | if activation == "SiLU": 463 | activation_ns_per_element = self.SiLU_ns_per_element 464 | 465 | tflops = ( 466 | batch_size * n_heads * n_rows * n_columns * self.misc_tflops_per_element 467 | ) 468 | compute_time_ns = ( 469 | batch_size * n_heads * (n_rows * (activation_ns_per_element * n_columns)) 470 | ) 471 | 472 | performance = {"compute": compute_time_ns} 473 | energy = {"compute": tflops * self.pj_per_tflop} 474 | 475 | if self.verbose: 476 | print( 477 | f"Computing Activation: {activation} : {data_shape} in {compute_time_ns} " 478 | f"with {energy}" 479 | ) 480 | 481 | return compute_time_ns, performance, energy 482 | 483 | def compute_RMSNorm_ns(self, data_shape, dimension): 484 | batch_size = data_shape[-4] if (len(data_shape) > 3) else 1 485 | n_heads = data_shape[-3] if (len(data_shape) > 2) else 1 486 | n_rows = data_shape[-2] if (len(data_shape) > 1) else 1 487 | n_columns = dimension 488 | 489 | tflops = ( 490 | batch_size * n_heads * n_rows * n_columns * self.misc_tflops_per_element 491 | ) 492 | compute_time_ns = ( 493 | batch_size * n_heads * n_rows * (self.RMSNorm_ns_per_element * n_columns) 494 | ) 495 | 496 | performance = {"compute": compute_time_ns} 497 | energy = {"compute": tflops * self.pj_per_tflop} 498 | 499 | if self.verbose: 500 | print( 501 | "Computing RMSNorm:", data_shape, "in", compute_time_ns, "with", energy 502 | ) 503 | 504 | return compute_time_ns, performance, energy 505 | 506 | def compute_softmax_ns(self, data_shape): 507 | batch_size = data_shape[-4] if (len(data_shape) > 3) else 1 508 | n_heads = data_shape[-3] if (len(data_shape) > 2) else 1 509 | n_rows = data_shape[-2] if (len(data_shape) > 1) else 1 510 | n_columns = data_shape[-1] 511 | 512 | tflops = ( 513 | batch_size * n_heads * n_rows * n_columns * self.misc_tflops_per_element 514 | ) 515 | compute_time_ns = ( 516 | batch_size * n_heads * n_rows * (self.softmax_ns_per_element * n_columns) 517 | ) 518 | 519 | performance = {"compute": compute_time_ns} 520 | energy = {"compute": tflops * self.pj_per_tflop} 521 | 522 | if self.verbose: 523 | print( 524 | "Computing softmax:", data_shape, "in", compute_time_ns, "with", energy 525 | ) 526 | 527 | return compute_time_ns, performance, energy 528 | -------------------------------------------------------------------------------- /src/upmem_llm_framework/profiler.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2014-2024 - UPMEM 3 | # UPMEM S.A.S France property - UPMEM confidential information covered by NDA 4 | # For UPMEM partner internal use only - no modification allowed without permission of UPMEM 5 | # 6 | # This file implements all profiling related classes and functions 7 | 8 | import sys 9 | import time 10 | from collections import OrderedDict 11 | from typing import Mapping, Callable, Tuple 12 | 13 | import torch 14 | from torch.nn import Linear, SiLU, LayerNorm, Embedding, Dropout, Softmax, Conv2d 15 | from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding 16 | from transformers.activations import NewGELUActivation 17 | from transformers.pytorch_utils import Conv1D 18 | 19 | from upmem_llm_framework.simulator import Simulator 20 | from upmem_llm_framework.utils import add_dictionaries 21 | 22 | 23 | class layer_profile: 24 | def __init__(self, uniq_id, name, n_layer, context, dim_in, dim_out, obj=None): 25 | self.id = uniq_id 26 | self.name = name 27 | self.n_layer = n_layer 28 | self.context = context 29 | self.dim_in = dim_in 30 | self.dim_out = dim_out 31 | self.exec_time = 0 32 | self.exec_nums = 0 33 | self.energy = {} 34 | self.obj = obj 35 | 36 | 37 | class layer_log: 38 | def __init__( 39 | self, 40 | uniq_id, 41 | name, 42 | context, 43 | summarization, 44 | start_time, 45 | input, 46 | weights, 47 | output, 48 | exec_time_ms, 49 | performance, 50 | energy_pj, 51 | transfer_bytes, 52 | ): 53 | self.id = uniq_id 54 | self.name = name 55 | self.context = context 56 | self.summarization = summarization 57 | self.start_time = start_time 58 | self.input = input 59 | self.weights = weights 60 | self.output = output 61 | self.exec_time_ms = exec_time_ms 62 | self.performance = performance 63 | self.energy = energy_pj 64 | self.transfer_bytes = transfer_bytes 65 | 66 | 67 | class UPM_Profiler: 68 | layer_dimensions: Mapping[type, Callable[[torch.nn.Module], Tuple[int, int]]] = { 69 | Linear: (lambda l: (l.in_features, l.out_features)), 70 | NewGELUActivation: (lambda l: (1, 1)), 71 | SiLU: (lambda l: (1, 1)), 72 | LlamaRMSNorm: (lambda l: (l.weight.size()[0], l.weight.size()[0])), 73 | LlamaRotaryEmbedding: (lambda l: (1, 1)), 74 | LayerNorm: (lambda l: (1, 1)), 75 | Embedding: (lambda l: (1, 1)), 76 | Dropout: (lambda l: (1, 1)), 77 | Softmax: (lambda l: (l.dim, l.dim)), 78 | Conv1D: (lambda l: (l.weight.shape[0], l.weight.shape[1])), 79 | Conv2d: (lambda l: (l.kernel_size[0], l.kernel_size[1])), 80 | } 81 | 82 | functional_layers = {LlamaRMSNorm, SiLU} 83 | 84 | def __init__(self, options): 85 | self.n_layers = 0 86 | self.n_executions = 0 87 | self.layers = {} 88 | self.functions = {} 89 | 90 | self.options = options 91 | self.sim_compute = False 92 | self.sim_sliding_window = -1 93 | self.sim_num_key_value_heads = 8 94 | self.sim_data_type = "float16" 95 | self.sim_data_type_bytes = 2.0 96 | self.simulator = None 97 | self.set_options(options) 98 | 99 | self.start_inference = time.time_ns() 100 | self.inference_time = 0 101 | self.summarization_time = 0 102 | self.sum_perf = {} 103 | self.gen_perf = {} 104 | self.sum_energy = {} 105 | self.gen_energy = {} 106 | self.sum_transfer_bytes = {} 107 | self.gen_transfer_bytes = {} 108 | 109 | self.last_layer = "lm_head" 110 | self.batch_size = 1 111 | 112 | self.layers_start = {} 113 | self.layers_end = {} 114 | self.log = [] 115 | 116 | self.forward_input_shape = None 117 | self.forward_time_start = 0 118 | self.forward_time_end = 0 119 | 120 | self.func_input_shape = None 121 | self.start_func = 0 122 | self.end_func = 0 123 | 124 | def set_options(self, options): 125 | self.options = options 126 | # simulation related 127 | self.sim_compute = options.sim_compute 128 | self.sim_sliding_window = options.sim_sliding_window 129 | self.sim_num_key_value_heads = options.sim_num_key_value_heads 130 | 131 | self.sim_data_type = options.sim_data_type 132 | 133 | self.sim_data_type_bytes = { 134 | "int4": 0.5, 135 | "int8": 1.0, 136 | "float16": 2.0, 137 | "bfloat16": 2.0, 138 | "float32": 4.0, 139 | }.get(self.sim_data_type) 140 | 141 | self.simulator = self.create_arch_simulator() if options.simulation else None 142 | 143 | def create_arch_simulator(self): 144 | return Simulator( 145 | data_type_bytes=self.sim_data_type_bytes, 146 | sliding_window=self.sim_sliding_window, 147 | num_key_value_heads=self.sim_num_key_value_heads, 148 | verbose=self.options.sim_verbose, 149 | ) 150 | 151 | def print_layers_model(self): 152 | print("##### Layers of Model in order of creation #####") 153 | print( 154 | "Layer ID (creation order), Context, Function, Dimensions (rows x columns), " 155 | "times executed, avg. execution time (ms)" 156 | ) 157 | for layer in self.layers.values(): 158 | print( 159 | f"{str(layer.id)}," 160 | f"{layer.context}," 161 | f"{layer.name}," 162 | f"({layer.dim_in}x{layer.dim_out})," 163 | f"{layer.exec_nums}," 164 | f"{layer.exec_time / self.n_executions / 1e6}" 165 | ) 166 | 167 | def print_functions_model(self): 168 | print("##### Functions called by the Model in order of calling #####") 169 | print( 170 | "Function name, Context, Dimensions in (columns), Dimensions out (columns), " 171 | "times executed, avg. execution time (ms)" 172 | ) 173 | for name, func in self.functions.items(): 174 | print( 175 | f"{str(name)}," 176 | f"{func.context}," 177 | f"({func.dim_in})," 178 | f"({func.dim_out})," 179 | f"{func.exec_nums}," 180 | f"{func.exec_time / 1e6}" 181 | ) 182 | 183 | def print_log_summary(self, show_summarization=False, show_all=False): 184 | phase = "Generation" 185 | if show_summarization: 186 | phase = "Summarization" 187 | if show_all: 188 | phase = "All (SUM and GEN)" 189 | print("#####", phase, "Execution summary #####") 190 | name_ctxt = [] 191 | summary_time = OrderedDict() 192 | summary_perf = OrderedDict() 193 | summary_energy = OrderedDict() 194 | summary_transfer_bytes = OrderedDict() 195 | summary_nexec = OrderedDict() 196 | input_shapes = OrderedDict() 197 | weights_shapes = OrderedDict() 198 | output_shapes = OrderedDict() 199 | for log in self.log: 200 | if not show_all and not show_summarization and log.summarization: 201 | continue 202 | if not show_all and show_summarization and not log.summarization: 203 | continue 204 | ctxt = log.name + ":" + log.context 205 | if not ctxt in summary_time.keys(): 206 | name_ctxt.append(ctxt) 207 | summary_nexec[ctxt] = 1 + summary_nexec.get(ctxt, 0) 208 | summary_time[ctxt] = log.exec_time_ms + summary_time.get(ctxt, 0) 209 | summary_energy[ctxt] = add_dictionaries( 210 | summary_energy.get(ctxt, {}), log.energy 211 | ) 212 | summary_perf[ctxt] = add_dictionaries( 213 | summary_perf.get(ctxt, {}), log.performance 214 | ) 215 | summary_transfer_bytes[ctxt] = add_dictionaries( 216 | summary_transfer_bytes.get(ctxt, {}), log.transfer_bytes 217 | ) 218 | 219 | input_shapes[ctxt] = "(" + ":".join([str(x) for x in log.input]) + ")" 220 | weights_shapes[ctxt] = "(" + ":".join([str(x) for x in log.weights]) + ")" 221 | output_shapes[ctxt] = "(" + ":".join([str(x) for x in log.output]) + ")" 222 | 223 | executed_times = 1 if show_summarization else (self.n_executions - 1) 224 | print( 225 | "Function: Context: input shape: weights shape: output shape:" 226 | "time(s):H2C(ms):C2H(ms):compute(ms):mem_transfer(ms):kv_load(ms)" 227 | "host_to_device(mJ):device_to_host(mJ):main_mem(mJ):compute(mJ)" 228 | ) 229 | for key in name_ctxt: 230 | 231 | perf_values = [] 232 | for perf_key in [ 233 | "host_to_device", 234 | "device_to_host", 235 | "compute", 236 | "mem_transfer", 237 | "kv_load", 238 | ]: 239 | perf_values.append( 240 | str(summary_perf[key].get(perf_key, 0) / 1e6 / executed_times) 241 | ) 242 | perf_string = ":".join(perf_values) 243 | 244 | energy_values = [] 245 | for ene_key in ["host_to_device", "device_to_host", "main_mem", "compute"]: 246 | energy_values.append( 247 | str(summary_energy[key].get(ene_key, 0) / 1e6 / executed_times) 248 | ) 249 | energy_string = ":".join(energy_values) 250 | print( 251 | key, 252 | input_shapes[key], 253 | weights_shapes[key], 254 | output_shapes[key], 255 | (summary_time[key] / executed_times), 256 | perf_string, 257 | energy_string, 258 | ) 259 | 260 | total_time_explained = sum(summary_time.values()) 261 | total_percentage_explained = 0 262 | for key in name_ctxt: 263 | print( 264 | key, 265 | "explains", 266 | (summary_time[key] / total_time_explained) * 100, 267 | "% of the total inference time (num. executions:", 268 | summary_nexec[key], 269 | ") average time:", 270 | summary_time[key] / summary_nexec[key], 271 | "(ms)", 272 | ) 273 | total_percentage_explained += ( 274 | summary_time[key] / total_time_explained 275 | ) * 100 276 | print( 277 | "Profiler captures", total_percentage_explained, "% of the total execution" 278 | ) 279 | print("Profiler captures", total_time_explained, "ms of the total execution") 280 | print(summary_time) 281 | 282 | def print_log(self): 283 | print("##### Execution log #####") 284 | print( 285 | "Start time, exec time, Function, Context, input shape, weights shape, output shape" 286 | ) 287 | for log in self.log: 288 | input_shape = "(" + ",".join([str(x) for x in log.input]) + ")" 289 | weights_shape = "(" + ",".join([str(x) for x in log.weights]) + ")" 290 | output_shape = "(" + ",".join([str(x) for x in log.output]) + ")" 291 | print( 292 | log.start_time / 1e6, 293 | log.exec_time_ms, 294 | log.id, 295 | log.name, 296 | log.context, 297 | input_shape, 298 | weights_shape, 299 | output_shape, 300 | ) 301 | 302 | def update_inference_perf(self, step_perf): 303 | if self.simulator.sum: 304 | for key in step_perf.keys(): 305 | self.sum_perf[key] = self.sum_perf.get(key, 0) + step_perf[key] 306 | else: 307 | for key in step_perf.keys(): 308 | self.gen_perf[key] = self.gen_perf.get(key, 0) + step_perf[key] 309 | 310 | def update_inference_energy(self, step_energy): 311 | if self.simulator.sum: 312 | for key in step_energy.keys(): 313 | self.sum_energy[key] = self.sum_energy.get(key, 0) + step_energy[key] 314 | else: 315 | for key in step_energy.keys(): 316 | self.gen_energy[key] = self.gen_energy.get(key, 0) + step_energy[key] 317 | 318 | def update_inference_transfer_bytes(self, step_transfer_bytes): 319 | if self.simulator.sum: 320 | for key in step_transfer_bytes.keys(): 321 | self.sum_transfer_bytes[key] = ( 322 | self.sum_transfer_bytes.get(key, 0) + step_transfer_bytes[key] 323 | ) 324 | else: 325 | for key in step_transfer_bytes.keys(): 326 | self.gen_transfer_bytes[key] = ( 327 | self.gen_transfer_bytes.get(key, 0) + step_transfer_bytes[key] 328 | ) 329 | 330 | def start( 331 | self, 332 | layer_mapping=None, 333 | layer_attn_ctxt="", 334 | last_layer="lm_head", 335 | batch_size=1, 336 | moe_end="", 337 | experts_per_token=2, 338 | ): 339 | self.start_inference = time.time_ns() 340 | self.n_executions = 0 341 | self.inference_time = 0 342 | self.summarization_time = 0 343 | self.sum_perf = {} 344 | self.gen_perf = {} 345 | self.sum_energy = {} 346 | self.gen_energy = {} 347 | self.sum_transfer_bytes = {} 348 | self.gen_transfer_bytes = {} 349 | 350 | self.last_layer = last_layer 351 | self.batch_size = batch_size 352 | 353 | self.layers_start = {} 354 | self.layers_end = {} 355 | self.log = [] 356 | 357 | if self.simulator: 358 | self.simulator.map_layers( 359 | layer_mapping, 360 | layer_attn_ctxt=layer_attn_ctxt, 361 | moe_end=moe_end, 362 | experts_per_token=experts_per_token, 363 | ) 364 | 365 | def end(self): 366 | if self.simulator: 367 | step_time, step_perf, step_energy, step_transfer_bytes = ( 368 | self.simulator.simulate_end( 369 | self.forward_input_shape, generated_tokens=(self.n_executions) 370 | ) 371 | ) 372 | self.inference_time += step_time 373 | self.update_inference_perf(step_perf) 374 | self.update_inference_energy(step_energy) 375 | self.update_inference_transfer_bytes(step_transfer_bytes) 376 | else: 377 | self.inference_time = time.time_ns() - self.start_inference 378 | 379 | inference_time_sec = self.inference_time / 1e9 380 | sum_energy_mJ = 0 381 | gen_energy_mJ = 0 382 | sum_time_s = self.summarization_time / 1e9 383 | gen_time_s = inference_time_sec - sum_time_s 384 | gen_n_executions = self.n_executions - 1 385 | 386 | print("##### UPMEM PROFILER OUTPUT #####") 387 | print( 388 | f"Total time (SUM + GEN): {inference_time_sec} s, " 389 | f"with data type: {self.sim_data_type}, " 390 | f"batch size: {self.batch_size}" 391 | ) 392 | print( 393 | f"Generated tokens: {gen_n_executions * self.batch_size} " 394 | f"in {gen_time_s} s, " 395 | f"with tokens/s: {(gen_n_executions * self.batch_size) / gen_time_s}" 396 | ) 397 | print( 398 | f"Summarization step took: {sum_time_s} s, " 399 | f"weight in the execution: SUM: {sum_time_s / inference_time_sec}%, " 400 | f"GEN: {gen_time_s / inference_time_sec}%" 401 | ) 402 | 403 | if self.simulator: 404 | print("SUMMARIZATION summary") 405 | for key, transfer in self.sum_transfer_bytes.items(): 406 | print(f"Transferred data in {key}: {transfer / 1e6} MB") 407 | for key, energy in self.sum_energy.items(): 408 | energy_mj = energy / 1e9 409 | print(f"Energy in {key}: {energy_mj} mJ") 410 | sum_energy_mJ += energy / 1e9 411 | print(f"Energy: {sum_energy_mJ} mJ") 412 | print(f"Power: {sum_energy_mJ / 1e3 / sum_time_s} W") 413 | 414 | if gen_n_executions > 0: 415 | print("GENERATION summary") 416 | for key, transfer in self.gen_transfer_bytes.items(): 417 | print( 418 | f"Transferred data in {key}: {transfer / 1e6} MB, " 419 | f"MB/token: {transfer / 1e6 / self.n_executions / self.batch_size}" 420 | ) 421 | for key, energy in self.gen_energy.items(): 422 | energy_mj = energy / 1e9 423 | print( 424 | f"Energy in {key}: {energy_mj} mJ, " 425 | f"mJ/token: {energy_mj / gen_n_executions / self.batch_size}" 426 | ) 427 | gen_energy_mJ += energy_mj 428 | print( 429 | f"Energy: {gen_energy_mJ} mJ, " 430 | f"mJ/token: {gen_energy_mJ / gen_n_executions / self.batch_size}" 431 | ) 432 | print(f"Power: {gen_energy_mJ / 1e3 / gen_time_s} W") 433 | 434 | print("Execution time breakdown (ms / %)") 435 | print("SUMMARIZATION phase") 436 | for perf_key in [ 437 | "host_to_device", 438 | "device_to_host", 439 | "compute", 440 | "mem_transfer", 441 | "kv_load", 442 | ]: 443 | perf_value = self.sum_perf.get(perf_key, 0) 444 | print( 445 | perf_key, (perf_value / 1e6), "(ms)", perf_value / 1e9 / sum_time_s 446 | ) 447 | 448 | if gen_n_executions > 0: 449 | print("GENERATION phase") 450 | for perf_key in [ 451 | "host_to_device", 452 | "device_to_host", 453 | "compute", 454 | "mem_transfer", 455 | "kv_load", 456 | ]: 457 | perf_value = self.gen_perf.get(perf_key, 0) 458 | print( 459 | f"{perf_key}: {perf_value / 1e6} ms, {perf_value / 1e9 / gen_time_s}" 460 | ) 461 | 462 | if self.options.report_layers: 463 | self.print_layers_model() 464 | 465 | if self.options.report_functions: 466 | self.print_functions_model() 467 | 468 | if self.options.print_log: 469 | self.print_log() 470 | 471 | if self.options.print_log_summary: 472 | self.print_log_summary() 473 | self.print_log_summary(show_summarization=True) 474 | self.print_log_summary(show_all=True) 475 | 476 | print("##### END UPMEM PROFILER OUTPUT #####") 477 | 478 | def add(self, layer, context): 479 | layer_type = type(layer) 480 | dim_in, dim_out = next( 481 | ( 482 | dim_func(layer) 483 | for key, dim_func in self.layer_dimensions.items() 484 | if issubclass(layer_type, key) 485 | ), 486 | (None, None), 487 | ) 488 | 489 | if dim_in is None or dim_out is None: 490 | print(f"Layer: {layer_type} not supported") 491 | sys.exit() 492 | 493 | name_layer = layer_type.__name__ 494 | 495 | self.layers[layer] = layer_profile( 496 | self.n_layers, 497 | name_layer, 498 | self.n_layers, 499 | context, 500 | dim_in, 501 | dim_out, 502 | obj=layer, 503 | ) 504 | self.n_layers += 1 505 | 506 | def forward_start(self, input_shape): 507 | self.forward_input_shape = input_shape 508 | 509 | if self.simulator: 510 | self.forward_time_start = self.inference_time 511 | else: 512 | self.forward_time_start = time.time_ns() 513 | 514 | def forward_end(self, output_shape, context, layer_obj=None): 515 | self.forward_time_end = time.time_ns() 516 | 517 | weights_shape = torch.Size( 518 | [self.layers[layer_obj].dim_in, self.layers[layer_obj].dim_out] 519 | ) 520 | 521 | cur_exec_time = self.forward_time_end - self.forward_time_start 522 | performance = {} 523 | energy = {} 524 | relative_start_time = self.forward_time_start - self.start_inference 525 | 526 | if self.simulator: 527 | name = self.layers[layer_obj].name 528 | if any( 529 | isinstance(layer_obj, layer_t) for layer_t in self.functional_layers 530 | ): 531 | cur_exec_time, performance, energy, transfer_bytes = ( 532 | self.simulator.simulate_function( 533 | self.layers[layer_obj], 534 | context, 535 | output_shape, 536 | self.layers[layer_obj].dim_out, 537 | ) 538 | ) 539 | else: 540 | cur_exec_time, performance, energy, transfer_bytes = ( 541 | self.simulator.simulate_layer( 542 | self.layers[layer_obj], 543 | self.forward_input_shape, 544 | layer_obj, 545 | weights_shape, 546 | output_shape, 547 | ) 548 | ) # or context? 549 | self.inference_time += cur_exec_time 550 | self.update_inference_perf(performance) 551 | self.update_inference_energy(energy) 552 | self.update_inference_transfer_bytes(transfer_bytes) 553 | 554 | self.layers[layer_obj].exec_time += cur_exec_time 555 | self.layers[layer_obj].energy = add_dictionaries( 556 | self.layers[layer_obj].energy, energy 557 | ) 558 | self.layers_start[layer_obj] = self.forward_time_start 559 | self.layers_end[layer_obj] = self.inference_time 560 | relative_start_time = self.forward_time_start 561 | else: 562 | self.inference_time += cur_exec_time 563 | self.layers[layer_obj].exec_time += cur_exec_time 564 | self.layers_start[layer_obj] = self.forward_time_start 565 | self.layers_end[layer_obj] = self.forward_time_end 566 | transfer_bytes = 0 567 | 568 | self.layers[layer_obj].exec_nums += 1 569 | 570 | summarization_phase = self.simulator.sum if self.simulator else False 571 | 572 | cur_logging = layer_log( 573 | self.layers[layer_obj].id, 574 | self.layers[layer_obj].name, 575 | self.layers[layer_obj].context, 576 | summarization_phase, 577 | relative_start_time / 1e6, 578 | self.forward_input_shape, 579 | weights_shape, 580 | output_shape, 581 | cur_exec_time / 1e6, 582 | performance, 583 | energy, 584 | transfer_bytes, 585 | ) 586 | 587 | self.log.append(cur_logging) 588 | if self.layers[layer_obj].context == self.last_layer: 589 | if self.simulator and not self.simulator.sum: 590 | self.simulator.sum_size += 1 591 | if self.n_executions == 0: 592 | if self.simulator: 593 | self.simulator.start_gen() 594 | self.simulator.sum_size = ( 595 | output_shape[-2] if (len(output_shape) > 1) else 1 596 | ) 597 | self.summarization_time = self.inference_time 598 | self.n_executions += 1 599 | print(f"New token generated ({self.n_executions})", end="\r") 600 | 601 | def forward_func_start(self, name, context, input_shape): 602 | self.func_input_shape = input_shape 603 | 604 | self.start_func = time.time_ns() 605 | if self.simulator: 606 | self.start_func = self.inference_time 607 | 608 | def forward_func_end(self, function, context, output_shape): 609 | self.end_func = time.time_ns() 610 | 611 | func_profile = self.functions.get( 612 | function.__name__, 613 | layer_profile( 614 | 0, function.__name__, 0, context, self.func_input_shape[-1], output_shape[-1] 615 | ), 616 | ) 617 | 618 | cur_exec_time = self.end_func - self.start_func 619 | performance = {} 620 | energy = {} 621 | relative_time = self.start_func - self.start_inference 622 | 623 | if self.simulator: 624 | cur_exec_time, performance, energy, transfer_bytes = ( 625 | self.simulator.simulate_function( 626 | function, context, self.func_input_shape, output_shape 627 | ) 628 | ) 629 | self.inference_time += cur_exec_time 630 | self.update_inference_perf(performance) 631 | self.update_inference_energy(energy) 632 | self.update_inference_transfer_bytes(transfer_bytes) 633 | else: 634 | transfer_bytes = 0 635 | 636 | relative_exec_time = self.start_func 637 | 638 | summarization_phase = self.simulator.sum if self.simulator else False 639 | cur_logging = layer_log( 640 | 0, # functions have ID set to 0 641 | function.__name__, 642 | context, 643 | summarization_phase, 644 | relative_exec_time / 1e6, 645 | self.func_input_shape, 646 | output_shape, 647 | output_shape, 648 | cur_exec_time / 1e6, 649 | performance, 650 | energy, 651 | transfer_bytes, 652 | ) 653 | 654 | self.log.append(cur_logging) 655 | 656 | func_profile.exec_nums += 1 657 | func_profile.exec_time += cur_exec_time 658 | self.functions[function.__name__] = func_profile 659 | --------------------------------------------------------------------------------