├── src ├── __init__.py ├── fast_init.py ├── profile.py ├── utils.py ├── metrics.py ├── parse_results.py ├── main.py └── pipeline.py ├── .gitignore ├── README.md ├── pyproject.toml ├── .dockerignore ├── .gitmodules ├── setup.cfg ├── requirements.txt ├── scripts ├── run_grid.sh └── run_benchmark.sh ├── .pre-commit-config.yaml ├── Dockerfile ├── Makefile └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bigcode-inference-benchmark 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | target-version = ['py35'] 4 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | * 2 | !src 3 | !scripts 4 | !transformers 5 | !requirements.txt 6 | !Makefile 7 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "transformers"] 2 | path = transformers 3 | url = https://github.com/bigcode-project/transformers.git 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | ensure_newline_before_comments = True 3 | force_grid_wrap = 0 4 | include_trailing_comma = True 5 | line_length = 119 6 | lines_after_imports = 2 7 | multi_line_output = 3 8 | use_parentheses = True 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.15.0 2 | bitsandbytes 3 | safetensors 4 | deepspeed==0.7.7 5 | -e ./transformers 6 | 7 | # TODO: Analysis only 8 | py-markdown-table 9 | 10 | # TODO: Dev only 11 | isort>=5.5.4 12 | black~=22.0 13 | -------------------------------------------------------------------------------- /scripts/run_grid.sh: -------------------------------------------------------------------------------- 1 | for bs in $1 2 | do 3 | for seq in $2 4 | do 5 | for tok in $3 6 | do 7 | "${@:5}" --save="$4"_bs_"$bs"_seq_"$seq"_tok_"$tok".json --batch_size=$bs --max_input_length=$seq --max_new_tokens=$tok 8 | done 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pycqa/isort 3 | rev: 5.10.1 4 | hooks: 5 | - id: isort 6 | name: isort (python) 7 | - repo: https://github.com/psf/black 8 | rev: 22.8.0 9 | hooks: 10 | - id: black 11 | args: [--line-length=119,--target-version=py35] 12 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.01-py3 2 | 3 | ARG USER=1000 4 | ARG USERNAME=user 5 | 6 | WORKDIR /app 7 | ENV PYTHONPATH=/app 8 | 9 | RUN useradd -m -u $USER -s /bin/bash $USERNAME \ 10 | && chown $USERNAME /app 11 | 12 | # git-lfs is needed to interact with the huggingface hub 13 | RUN apt-get update \ 14 | && apt-get install git-lfs \ 15 | && rm -rf /var/lib/apt/lists/* \ 16 | && git lfs install 17 | 18 | COPY --chown=$USERNAME ./requirements.txt ./ 19 | COPY --chown=$USERNAME transformers/ ./transformers 20 | 21 | # Stock version of pip doesn't work with editable transformers. 22 | RUN pip install --upgrade pip --no-cache-dir && pip install -r requirements.txt --no-cache-dir 23 | 24 | COPY --chown=$USERNAME Makefile . 25 | COPY --chown=$USERNAME src/ ./src 26 | COPY --chown=$USERNAME scripts/ ./scripts 27 | -------------------------------------------------------------------------------- /src/fast_init.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import Callable, Dict, Type 3 | 4 | import torch 5 | 6 | from transformers import Conv1D 7 | from transformers.modeling_utils import no_init_weights 8 | 9 | 10 | def _conv1d_init(self, nf, nx, device=None): 11 | super(Conv1D, self).__init__() 12 | self.nf = nf 13 | w = torch.empty(nx, nf, device=device) 14 | torch.nn.init.normal_(w, std=0.02) 15 | self.weight = torch.nn.Parameter(w) 16 | b = torch.empty(nf, device=device) 17 | torch.nn.init.zeros_(b) 18 | self.bias = torch.nn.Parameter(b) 19 | 20 | 21 | _ORIGINAL_INITS: Dict[Type[torch.nn.Module], Callable] = { 22 | Conv1D: _conv1d_init, 23 | torch.nn.Linear: torch.nn.Linear.__init__, 24 | torch.nn.Embedding: torch.nn.Embedding.__init__, 25 | torch.nn.LayerNorm: torch.nn.LayerNorm.__init__, 26 | } 27 | 28 | 29 | def _get_fast_init(cls: Type[torch.nn.Module], device: torch.device): 30 | assert cls in _ORIGINAL_INITS 31 | 32 | def _fast_init(self, *args, **kwargs): 33 | # Same as torch.nn.utils.skip_init, excluding checks 34 | _ORIGINAL_INITS[cls](self, *args, **kwargs, device="meta") 35 | self.to_empty(device=device) 36 | 37 | return _fast_init 38 | 39 | 40 | @contextlib.contextmanager 41 | def fast_init(device: torch.device, init_weights: bool = False): 42 | """ 43 | Avoid multiple slow initializations on cpu. 44 | """ 45 | for cls in _ORIGINAL_INITS: 46 | cls.__init__ = _get_fast_init(cls, device) 47 | 48 | with contextlib.nullcontext() if init_weights else no_init_weights(): 49 | yield 50 | 51 | for cls in _ORIGINAL_INITS: 52 | cls.__init__ = _ORIGINAL_INITS[cls] 53 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | check_dirs := src scripts 2 | 3 | .PHONY: style 4 | style: 5 | black --preview $(check_dirs) 6 | isort $(check_dirs) 7 | 8 | BATCH_SIZE ?= 1 9 | DTYPE ?= float16 10 | HIDDEN_SIZE ?= 2048 11 | N_HEAD ?= 16 12 | N_LAYER ?= 24 13 | N_POSITION ?= 2048 14 | MAX_INPUT_LENGTH ?= -1 15 | 16 | RUN_HF := python3 src/main.py --pipeline_class=HF_Pipeline 17 | RUN_DS := deepspeed --num_gpus 1 src/main.py --pipeline_class=DS_Pipeline 18 | EXP_ARGS := --dtype=${DTYPE} --batch_size=${BATCH_SIZE} --max_input_length=${MAX_INPUT_LENGTH} ${EXTRA_ARGS} 19 | COMMON_ARGS := ${EXP_ARGS} n_head=${N_HEAD} n_layer=${N_LAYER} 20 | BLOOM_ARGS := --model_type=bloom ${COMMON_ARGS} hidden_size=${HIDDEN_SIZE} 21 | GPT2_ARGS := --model_type=gpt2 ${COMMON_ARGS} n_embd=${HIDDEN_SIZE} 22 | BIGCODE_ARGS := --model_type=gpt_bigcode ${COMMON_ARGS} n_embd=${HIDDEN_SIZE} 23 | 24 | 25 | .PHONY: install 26 | install: 27 | git submodule update --init 28 | pip install -r requirements.txt 29 | 30 | .PHONY: bloom 31 | bloom: 32 | ${RUN_HF} ${BLOOM_ARGS} 33 | 34 | .PHONY: bloom-ds 35 | bloom-ds: 36 | ${RUN_DS} ${BLOOM_ARGS} 37 | 38 | .PHONY: gpt2 39 | gpt2: 40 | ${RUN_HF} ${GPT2_ARGS} 41 | 42 | .PHONY: gpt2-ds 43 | gpt2-ds: 44 | ${RUN_DS} ${GPT2_ARGS} 45 | 46 | .PHONY: gpt-bigcode-mha 47 | gpt-bigcode-mha: 48 | ${RUN_HF} ${BIGCODE_ARGS} attention_type=1 49 | 50 | .PHONY: gpt-bigcode-mqa1 51 | gpt-bigcode-mqa1: 52 | ${RUN_HF} ${BIGCODE_ARGS} attention_type=2 53 | 54 | .PHONY: gpt-bigcode-mqa2 55 | gpt-bigcode-mqa2: 56 | ${RUN_HF} ${BIGCODE_ARGS} attention_type=3 57 | 58 | .PHONY: santacoder-original 59 | santacoder-original: 60 | ${RUN_HF} --pretrained_model=bigcode/santacoder --tokenizer=bigcode/santacoder --trust_remote_code ${EXP_ARGS} 61 | 62 | .PHONY: santacoder 63 | santacoder: 64 | ${RUN_HF} --pretrained_model=bigcode/santacoder-fast-inference --tokenizer=bigcode/santacoder ${EXP_ARGS} 65 | 66 | .PHONY: optimized-santacoder 67 | optimized-santacoder: 68 | ${RUN_HF} --pretrained_model=olivierdehaene/optimized-santacoder --tokenizer=bigcode/santacoder --trust_remote_code ${EXP_ARGS} 69 | -------------------------------------------------------------------------------- /src/profile.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import logging 3 | from typing import Union 4 | 5 | import torch 6 | 7 | from src.utils import log_rank_n 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def get_trace_fn(full_trace: bool = False, show_op_names: bool = False, rank: int = -1): 14 | def trace_fn( 15 | p: torch.profiler.profile, 16 | ): 17 | averages = p.key_averages() 18 | if full_trace: 19 | # Show every GPU op. 20 | # Exclude CPU cuda ops to shorten the table. 21 | events = torch.autograd.profiler.EventList( 22 | [evt for evt in p.profiler.function_events if evt.self_cuda_time_total > 0] 23 | ) 24 | log_rank_n(events.table(row_limit=-1, max_src_column_width=1000), logger.info, rank) 25 | 26 | if show_op_names: 27 | # Show non-cropped names, in the same order as in the table. 28 | averages_sorted = torch.autograd.profiler.EventList( 29 | sorted(averages, key=lambda evt: evt.self_cuda_time_total, reverse=True) 30 | ) 31 | for entry in averages_sorted: 32 | log_rank_n(entry.key, logger.info, rank) 33 | 34 | # Try to avoid name cropping, still hard-coded to max 55 characters 35 | log_rank_n( 36 | averages.table(sort_by="self_cuda_time_total", row_limit=-1, max_src_column_width=1000), logger.info, rank 37 | ) 38 | 39 | return trace_fn 40 | 41 | 42 | def get_profiler( 43 | skip: int, 44 | warmup: int, 45 | cycles: int, 46 | full_trace: bool = False, 47 | show_op_names: bool = False, 48 | ) -> Union[torch.profiler.profile, contextlib.nullcontext]: 49 | schedule = torch.profiler.schedule( 50 | # Warmup is a must if measuring speed as it's when all the optimizations are performed 51 | # e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs 52 | skip_first=skip, 53 | # Warmup for the profiler 54 | warmup=warmup, 55 | wait=0, 56 | active=cycles, 57 | ) 58 | return torch.profiler.profile( 59 | schedule=schedule, 60 | activities=[torch.profiler.ProfilerActivity.CUDA], 61 | on_trace_ready=get_trace_fn(full_trace, show_op_names), 62 | ) 63 | -------------------------------------------------------------------------------- /scripts/run_benchmark.sh: -------------------------------------------------------------------------------- 1 | SAVE_DIR=data/benchmarks/v1 2 | BATCH_SIZES="1 2 4 8 16 24 32 48 64 96 128 160 224 256" 3 | RUN="python3 src/main.py --tokenizer=bigcode/santacoder --max_log_outputs=1 --dtype=float16 --device=cuda" 4 | RUN_DEEPSPEED="deepspeed --num_gpus 1 src/main.py --pipeline_class=DS_Pipeline --tokenizer=bigcode/santacoder --max_log_outputs=1 --dtype=float16 --device=cuda" 5 | 6 | SANTACODER="--pretrained_model=bigcode/santacoder --trust_remote_code" 7 | GPT_BIGCODE="--pretrained_model=bigcode/santacoder-fast-inference:linear" 8 | PRE_ALLOCATE="--pretrained_model=bigcode/santacoder-fast-inference:linear pre_allocate_kv_cache=True" 9 | INFERENCE_RUNNER="--pretrained_model=bigcode/santacoder-fast-inference:linear pre_allocate_kv_cache=True inference_runner=1" 10 | CUDA_GRAPH="--pretrained_model=bigcode/santacoder-fast-inference:linear pre_allocate_kv_cache=True inference_runner=3" 11 | MHA_GPT2="--model_type=gpt2 n_positions=2048 n_embd=2048 n_head=16 n_layer=24" 12 | MHA_GPT_BIGCODE="--model_type=gpt_bigcode attention_type=1 n_positions=2048 n_embd=2048 n_head=16 n_layer=24" 13 | MHA_PRE_ALLOCATE="--model_type=gpt_bigcode n_positions=2048 n_embd=2048 n_head=16 n_layer=24 pre_allocate_kv_cache=True max_sequence_length=1024" 14 | MQA2_GPT_BIGCODE="--model_type=gpt_bigcode attention_type=3 n_positions=2048 n_embd=2048 n_head=16 n_layer=24" 15 | MQA2_PRE_ALLOCATE="--model_type=gpt_bigcode attention_type=3 n_positions=2048 n_embd=2048 n_head=16 n_layer=24 pre_allocate_kv_cache=True" 16 | 17 | SEQ_TOK=("-1 1" "-1 100" "501 1" "504 1" "-1 1000") 18 | 19 | for seq_tok in "${SEQ_TOK[@]}" 20 | do 21 | ./scripts/run_grid.sh "$BATCH_SIZES" $seq_tok $SAVE_DIR/santacoder $RUN --cycles=10 $SANTACODER 22 | ./scripts/run_grid.sh "$BATCH_SIZES" $seq_tok $SAVE_DIR/gpt_bigcode $RUN --cycles=10 $GPT_BIGCODE 23 | ./scripts/run_grid.sh "$BATCH_SIZES" $seq_tok $SAVE_DIR/pre_allocate $RUN --cycles=10 $PRE_ALLOCATE 24 | ./scripts/run_grid.sh "$BATCH_SIZES" $seq_tok $SAVE_DIR/inference_runner $RUN --cycles=10 $INFERENCE_RUNNER 25 | ./scripts/run_grid.sh "$BATCH_SIZES" $seq_tok $SAVE_DIR/cuda_graph $RUN --cycles=10 $CUDA_GRAPH 26 | ./scripts/run_grid.sh "$BATCH_SIZES" $seq_tok $SAVE_DIR/mha_deepspeed $RUN_DEEPSPEED --cycles=10 $MHA_GPT2 27 | ./scripts/run_grid.sh "$BATCH_SIZES" $seq_tok $SAVE_DIR/mha_gpt2 $RUN --cycles=10 $MHA_GPT2 28 | ./scripts/run_grid.sh "$BATCH_SIZES" $seq_tok $SAVE_DIR/mha_gpt_bigcode $RUN --cycles=10 $MHA_GPT_BIGCODE 29 | ./scripts/run_grid.sh "$BATCH_SIZES" $seq_tok $SAVE_DIR/mha_pre_allocate $RUN --cycles=10 $MHA_PRE_ALLOCATE 30 | ./scripts/run_grid.sh "$BATCH_SIZES" $seq_tok $SAVE_DIR/mqa2_gpt_bigcode $RUN --cycles=10 $MQA2_GPT_BIGCODE 31 | ./scripts/run_grid.sh "$BATCH_SIZES" $seq_tok $SAVE_DIR/mqa2_pre_allocate $RUN --cycles=10 $MQA2_PRE_ALLOCATE 32 | done 33 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import logging.config 4 | import math 5 | import typing 6 | from typing import Any, Callable, List, Optional, Tuple 7 | 8 | from torch import distributed as dist 9 | 10 | 11 | def parse_revision(pretrained_model: Optional[str]) -> Tuple[Optional[str], Optional[str]]: 12 | revision = None 13 | if pretrained_model is not None: 14 | pretrained_split = pretrained_model.split(":", 1) 15 | if len(pretrained_split) == 2: 16 | pretrained_model, revision = pretrained_split 17 | return pretrained_model, revision 18 | 19 | 20 | def parse_config_arg(config_arg: str) -> Tuple[str, Any]: 21 | split_arg = [x.strip() for x in config_arg.split("=", 1)] 22 | if len(split_arg) != 2: 23 | raise ValueError(f"Cannot parse argument (not in 'key=value' format): {config_arg}") 24 | key, value = split_arg 25 | if not key.isidentifier(): 26 | raise ValueError(f"Invalid argument (not a python identifier): {key}") 27 | if value.lower() == "true": 28 | value = True 29 | elif value.lower() == "false": 30 | value = False 31 | elif value.lower() == "none": 32 | value = None 33 | else: 34 | try: 35 | value = int(value) 36 | except ValueError: 37 | try: 38 | value = float(value) 39 | except ValueError: 40 | pass 41 | return key, value 42 | 43 | 44 | def parse_config_args(config_args: List[str]) -> typing.Dict[str, Any]: 45 | parsed_config_args = {} 46 | for config_arg in config_args: 47 | key, value = parse_config_arg(config_arg) 48 | if key in parsed_config_args: 49 | raise ValueError(f"Duplicate argument: {key}") 50 | parsed_config_args[key] = value 51 | return parsed_config_args 52 | 53 | 54 | def configure_logging(name=None): 55 | logging_config = { 56 | "version": 1, 57 | "disable_existing_loggers": False, 58 | "formatters": { 59 | "default": { 60 | "format": f"%(asctime)s{'' if name is None else ' ['+name+']'}: %(message)s", 61 | "use_colors": True, 62 | } 63 | }, 64 | "handlers": { 65 | "default": { 66 | "level": "INFO", 67 | "formatter": "default", 68 | "class": "logging.StreamHandler", 69 | "stream": "ext://sys.stdout", 70 | } 71 | }, 72 | "loggers": {"default": {"level": "DEBUG", "handlers": ["default"]}}, 73 | "root": {"handlers": ["default"], "level": "INFO"}, 74 | } 75 | logging.config.dictConfig(logging_config) 76 | 77 | 78 | def log_rank_n(msg: str, logger: Callable = logging.info, rank: int = 0): 79 | if rank < 0 or not dist.is_initialized() or dist.get_rank() == rank: 80 | # Multi-line logs break formatting 81 | for line in msg.splitlines(): 82 | logger(line) 83 | 84 | 85 | def log_dict(data: dict, logger: Callable = logging.info, rank: int = 0): 86 | for key, value in data.items(): 87 | log_rank_n(f"{key}: {value}", logger, rank) 88 | 89 | 90 | dummy_input_sentences = [ 91 | "DeepSpeed is a machine learning framework", 92 | "He is working on", 93 | "He has a", 94 | "He got all", 95 | "Everyone is happy and I can", 96 | "The new movie that got Oscar this year", 97 | "In the far far distance from our galaxy,", 98 | "Peace is the only way", 99 | ] 100 | 101 | 102 | def get_dummy_batch(batch_size: int, max_input_length: int = -1) -> List[str]: 103 | if max_input_length == -1: 104 | input_sentences = copy.deepcopy(dummy_input_sentences) 105 | else: 106 | input_sentences = batch_size * [" Hello" * max_input_length] 107 | 108 | if batch_size > len(input_sentences): 109 | input_sentences *= math.ceil(batch_size / len(input_sentences)) 110 | input_sentences = input_sentences[:batch_size] 111 | 112 | return input_sentences 113 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict 2 | 3 | 4 | def format_round(x: float) -> str: 5 | return str(round(x)) 6 | 7 | 8 | def format_throughput(x: float) -> str: 9 | return f"{x:.2f} tokens/s" 10 | 11 | 12 | def format_inverse_throughput(x: float) -> str: 13 | return f"{format_ms(x)}/token" 14 | 15 | 16 | def format_ms(t: float) -> str: 17 | return f"{1000 * t:.2f} ms" 18 | 19 | 20 | def format_mib(m: float) -> str: 21 | return f"{m/2**20:.0f} MiB" 22 | 23 | 24 | class Metrics: 25 | LATENCY_E2E = "Latency (end to end)" 26 | LATENCY_TOKEN = "Latency (tokenization)" 27 | LATENCY_MODEL = "Latency (model)" 28 | LATENCY_DECODE = "Latency (decode)" 29 | LATENCY_MAX = "Latency (max)" 30 | LATENCY_MIN = "Latency (min)" 31 | LATENCY_STD = "Latency (std)" 32 | BATCH_SIZE = "Batch size" 33 | INPUT_LENGTH = "Input sequence length" 34 | OUTPUT_LENGTH = "Output sequence length" 35 | TOKENS_SAMPLE = "Tokens generated (sample)" 36 | TOKENS_BATCH = "Tokens generated (batch)" 37 | THROUGHPUT_MODEL = "Throughput (model)" 38 | THROUGHPUT_E2E = "Throughput (end to end)" 39 | TOKEN_TIME = "Token time (end to end)" 40 | INIT_TOKEN = "Initialization time (tokenizer)" 41 | INIT_CONFIG = "Initialization time (configuration)" 42 | INIT_DEVICE = "Initialization time (move to device)" 43 | INIT_TOTAL = "Initialization time (total)" 44 | INIT_CREATE = "Initialization time (create model)" 45 | INIT_WEIGHTS = "Initialization time (init weights)" 46 | INIT_SAVE = "Initialization time (save model)" 47 | INIT_LOAD = "Initialization time (load model)" 48 | RUNTIME_WARMUP = "Runtime time (warmup)" 49 | RUNTIME_BENCHMARK = "Runtime time (benchmark)" 50 | RUNTIME_TOTAL = "Runtime time (total)" 51 | MEMORY_USED_INIT = "Memory used (init)" 52 | MEMORY_USED_END = "Memory used (end)" 53 | MEMORY_USED_MAX = "Memory used (max)" 54 | MEMORY_RESERVED_INIT = "Memory reserved (init)" 55 | MEMORY_RESERVED_END = "Memory reserved (end)" 56 | MEMORY_RESERVED_MAX = "Memory reserved (max)" 57 | 58 | _METRIC_ORDER_AND_FORMAT: Dict[str, Callable[[Any], str]] = { 59 | LATENCY_E2E: format_ms, 60 | LATENCY_TOKEN: format_ms, 61 | LATENCY_MODEL: format_ms, 62 | LATENCY_DECODE: format_ms, 63 | LATENCY_MAX: format_ms, 64 | LATENCY_MIN: format_ms, 65 | LATENCY_STD: format_ms, 66 | BATCH_SIZE: format_round, 67 | INPUT_LENGTH: format_round, 68 | OUTPUT_LENGTH: format_round, 69 | TOKENS_SAMPLE: format_round, 70 | TOKENS_BATCH: format_round, 71 | THROUGHPUT_MODEL: format_throughput, 72 | THROUGHPUT_E2E: format_throughput, 73 | TOKEN_TIME: format_inverse_throughput, 74 | INIT_TOKEN: format_ms, 75 | INIT_CONFIG: format_ms, 76 | INIT_DEVICE: format_ms, 77 | INIT_TOTAL: format_ms, 78 | INIT_CREATE: format_ms, 79 | INIT_WEIGHTS: format_ms, 80 | INIT_SAVE: format_ms, 81 | INIT_LOAD: format_ms, 82 | RUNTIME_WARMUP: format_ms, 83 | RUNTIME_BENCHMARK: format_ms, 84 | RUNTIME_TOTAL: format_ms, 85 | MEMORY_USED_INIT: format_mib, 86 | MEMORY_USED_END: format_mib, 87 | MEMORY_USED_MAX: format_mib, 88 | MEMORY_RESERVED_INIT: format_mib, 89 | MEMORY_RESERVED_END: format_mib, 90 | MEMORY_RESERVED_MAX: format_mib, 91 | } 92 | 93 | @classmethod 94 | def reorder_metrics(cls, metrics: Dict[str, Any]) -> Dict[str, Any]: 95 | metrics = metrics.copy() 96 | reordered_metrics = {} 97 | for name, format_fn in cls._METRIC_ORDER_AND_FORMAT.items(): 98 | if name in metrics: 99 | reordered_metrics[name] = metrics.pop(name) 100 | reordered_metrics.update(metrics) 101 | return reordered_metrics 102 | 103 | @classmethod 104 | def format_metric(cls, key: str, value: Any) -> str: 105 | return cls._METRIC_ORDER_AND_FORMAT.get(key, str)(value) 106 | 107 | @classmethod 108 | def format_metrics(cls, metrics: Dict[str, Any]) -> Dict[str, str]: 109 | return {key: cls.format_metric(key, value) for key, value in metrics.items()} 110 | -------------------------------------------------------------------------------- /src/parse_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | from typing import List, Optional 5 | 6 | from src.metrics import Metrics 7 | from src.utils import parse_config_args, parse_config_arg 8 | 9 | 10 | def get_arg_parser() -> ArgumentParser: 11 | parser = ArgumentParser() 12 | parser.add_argument("input_dir", type=Path) 13 | parser.add_argument("--filter", action="append") 14 | parser.add_argument("--column", "--col", action="append") 15 | parser.add_argument("--compare_value") 16 | parser.add_argument("--compare_col", default="Setting") 17 | parser.add_argument("--table", action="store_true") 18 | parser.add_argument("--plot", action="store_true") 19 | parser.add_argument("-x", "--x_axis", default=Metrics.BATCH_SIZE) 20 | parser.add_argument("-y", "--y_axis", default=Metrics.THROUGHPUT_E2E) 21 | parser.add_argument("-z", "--z_axis") 22 | parser.add_argument("--title") 23 | return parser 24 | 25 | 26 | DEFAULT_COLUMNS = ( 27 | "Setting", 28 | Metrics.INPUT_LENGTH, 29 | Metrics.TOKENS_SAMPLE, 30 | Metrics.BATCH_SIZE, 31 | Metrics.THROUGHPUT_E2E, 32 | Metrics.LATENCY_E2E, 33 | ) 34 | 35 | 36 | def read_data(input_file: Path): 37 | try: 38 | with input_file.open("r") as f: 39 | data = json.load(f) 40 | data = {**data["config"], **data["results"]} 41 | except (ValueError, OSError) as e: 42 | raise ValueError(f"Cannot parse file {input_file} ({e})") 43 | try: 44 | setting, bs_, bs, seq_, seq, tok_, tok = input_file.stem.rsplit("_", 6) 45 | assert bs_ == "bs" 46 | assert data[Metrics.BATCH_SIZE] == int(bs) 47 | assert seq_ == "seq" 48 | assert data[Metrics.INPUT_LENGTH] == int(seq) or int(seq) < 0 49 | assert tok_ == "tok" 50 | assert data[Metrics.TOKENS_SAMPLE] == int(tok) 51 | except (ValueError, AssertionError) as e: 52 | raise ValueError(f"Cannot parse filename {input_file} ({e})") 53 | data["Setting"] = setting 54 | return data 55 | 56 | 57 | def parse_key(key: Optional[str]) -> Optional[str]: 58 | if key is None: 59 | return key 60 | return getattr(Metrics, key.upper(), key) 61 | 62 | 63 | def make_table(data, cols): 64 | from markdownTable import markdownTable 65 | 66 | data = [Metrics.format_metrics({col: x[col] for col in cols}) for x in data] 67 | return markdownTable(data).getMarkdown() 68 | 69 | 70 | def make_compare_table(data, cols, compare_value, compare_col): 71 | from markdownTable import markdownTable 72 | 73 | compare_value = parse_key(compare_value) 74 | compare_col = parse_key(compare_col) 75 | compare_data = {} 76 | all_compare_index = set() 77 | # Aggregate by the cols entries, then map compare_key to compare 78 | for x in data: 79 | index = tuple(x[col] for col in cols) 80 | if index not in compare_data: 81 | compare_data[index] = {} 82 | compare_index = x[compare_col] 83 | all_compare_index.add(compare_index) 84 | if compare_index in compare_data[index]: 85 | print(f"Duplicate entry {compare_index} for index {index}") 86 | compare_data[index][compare_index] = Metrics.format_metric(compare_value, x[compare_value]) 87 | 88 | table_data = [] 89 | for index in sorted(compare_data): 90 | # Merge the index and values 91 | table_data.append( 92 | { 93 | **Metrics.format_metrics({col: v for col, v in zip(cols, index)}), 94 | **{ 95 | compare_index: compare_data[index].get(compare_index, "N.A.") 96 | for compare_index in sorted(all_compare_index) 97 | }, 98 | } 99 | ) 100 | 101 | return markdownTable(table_data).getMarkdown() 102 | 103 | 104 | def filter_data(data, filters): 105 | if filters is None: 106 | return data 107 | 108 | parsed_filters = {} 109 | for filter in filters: 110 | key, value = parse_config_arg(filter) 111 | key = parse_key(key) 112 | if key not in parsed_filters: 113 | parsed_filters[key] = [] 114 | parsed_filters[key].append(value) 115 | 116 | filtered_data = [] 117 | for x in data: 118 | filter = True 119 | for key, value in parsed_filters.items(): 120 | filter = filter and x[key] in value 121 | if filter: 122 | filtered_data.append(x) 123 | return filtered_data 124 | 125 | 126 | def plot(data, x_axis, y_axis, z_axis, title=None): 127 | import matplotlib.pyplot as plt 128 | 129 | x_axis = parse_key(x_axis) 130 | y_axis = parse_key(y_axis) 131 | z_axis = parse_key(z_axis) 132 | x = [d[x_axis] for d in data] 133 | y = [d[y_axis] for d in data] 134 | 135 | fig = plt.figure() 136 | ax = fig.add_subplot() 137 | 138 | # z = None if z_axis is None else [d[z_axis] for d in data] 139 | if z_axis is None: 140 | ax.scatter(x, y) 141 | else: 142 | z = [d[z_axis] for d in data] 143 | for z_value in set(z): 144 | xx, yy = tuple(zip(*sorted((x_, y_) for x_, y_, z_ in zip(x, y, z) if z_ == z_value))) 145 | ax.plot(xx, yy, label=z_value, linewidth=1, linestyle=":", markersize=4, marker="o") 146 | # ax.scatter(x,y, label=z_value) 147 | # handles, labels = scatter.legend_elements() 148 | ax.legend(loc="upper left") # handles=handles, labels=labels, title=z_axis) 149 | 150 | ax.set_title(y_axis if title is None else title) 151 | ax.set_xlabel(x_axis) 152 | ax.set_ylabel(y_axis) 153 | fig.show() 154 | input("Press enter to continue") 155 | 156 | 157 | def main(argv: Optional[List[str]] = None) -> None: 158 | parser = get_arg_parser() 159 | args = parser.parse_args(argv) 160 | data = [read_data(input_file) for input_file in args.input_dir.iterdir()] 161 | 162 | data = filter_data(data, args.filter) 163 | 164 | if len(data) == 0: 165 | raise RuntimeError(f"No data to show.") 166 | 167 | cols = DEFAULT_COLUMNS if args.column is None else [parse_key(col) for col in args.column] 168 | 169 | if args.table: 170 | if args.compare_value: 171 | print(make_compare_table(data, cols, args.compare_value, args.compare_col)) 172 | else: 173 | print(make_table(data, cols)) 174 | 175 | if args.plot: 176 | plot(data, args.x_axis, args.y_axis, args.z_axis, args.title) 177 | 178 | 179 | if __name__ == "__main__": 180 | main() 181 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import gc 3 | import json 4 | import time 5 | from argparse import ArgumentParser 6 | from pathlib import Path 7 | from typing import List, Optional 8 | 9 | import torch 10 | 11 | from src.metrics import Metrics 12 | from src.pipeline import Pipeline, get_pipeline_class 13 | from src.profile import get_profiler, logger 14 | from src.utils import configure_logging, get_dummy_batch, log_dict, log_rank_n, parse_config_args 15 | 16 | 17 | def get_arg_parser() -> ArgumentParser: 18 | parser = ArgumentParser() 19 | 20 | # Model 21 | parser.add_argument("--model_type") 22 | parser.add_argument("--pretrained_config") 23 | parser.add_argument("--pretrained_model") 24 | parser.add_argument("--tokenizer", default="gpt2") 25 | parser.add_argument("--trust_remote_code", action="store_true") 26 | parser.add_argument("config_args", nargs="*") 27 | 28 | # Runtime 29 | parser.add_argument("--pipeline_class", default="HF_Pipeline") 30 | parser.add_argument("--device", default="cuda", type=torch.device) 31 | parser.add_argument("--dtype", default="float16", type=lambda x: getattr(torch, x)) 32 | parser.add_argument("--local_rank", type=int) 33 | parser.add_argument("--no_fast_init", dest="fast_init", action="store_false") 34 | 35 | # Input and output 36 | parser.add_argument("--batch_size", default=1, type=int) 37 | parser.add_argument("--max_input_length", default=-1, type=int) 38 | parser.add_argument("--max_new_tokens", default=100, type=int) 39 | 40 | # Cleanup 41 | parser.add_argument("--clear_every_run", action="store_true") 42 | 43 | # Benchmark cycles 44 | parser.add_argument("--skip", type=int, default=1) 45 | parser.add_argument("--warmup", type=int, default=None) 46 | parser.add_argument("--cycles", type=int, default=5) 47 | 48 | # Profiling and logging 49 | parser.add_argument("--max_log_outputs", type=int) 50 | parser.add_argument("--profile", action="store_true") 51 | parser.add_argument("--profile_cycles", type=int) 52 | parser.add_argument("--full_trace", action="store_true") 53 | parser.add_argument("--show_op_names", action="store_true") 54 | parser.add_argument("--save", type=Path) 55 | 56 | return parser 57 | 58 | 59 | def main(argv: Optional[List[str]] = None) -> None: 60 | t0 = time.perf_counter() 61 | parser = get_arg_parser() 62 | args = parser.parse_args(argv) 63 | config_args = parse_config_args(args.config_args) 64 | generate_kwargs = {"max_new_tokens": args.max_new_tokens, "do_sample": False} 65 | inputs = get_dummy_batch(args.batch_size, args.max_input_length) 66 | separate_profile = args.profile and args.profile_cycles is not None 67 | warmup = args.profile if args.warmup is None else args.warmup 68 | if separate_profile: 69 | pre_warmup_cycles = args.cycles 70 | post_warmup_cycles = args.profile_cycles 71 | benchmark_begin = args.skip 72 | else: 73 | pre_warmup_cycles = 0 74 | post_warmup_cycles = args.cycles 75 | benchmark_begin = args.skip + warmup 76 | benchmark_end = benchmark_begin + args.cycles 77 | 78 | max_log_outputs = args.batch_size if args.max_log_outputs is None else args.max_log_outputs 79 | 80 | pipeline_class = get_pipeline_class(args.pipeline_class) 81 | pipeline: Pipeline = pipeline_class( 82 | model_type=args.model_type, 83 | pretrained_model=args.pretrained_model, 84 | pretrained_config=args.pretrained_config, 85 | config_args=config_args, 86 | tokenizer=args.tokenizer, 87 | device=args.device, 88 | dtype=args.dtype, 89 | fast_init=args.fast_init, 90 | trust_remote_code=args.trust_remote_code, 91 | ) 92 | 93 | all_metrics = [] 94 | 95 | if args.profile: 96 | profiler = get_profiler( 97 | skip=args.skip + pre_warmup_cycles, 98 | warmup=warmup, 99 | cycles=post_warmup_cycles, 100 | full_trace=args.full_trace, 101 | show_op_names=args.show_op_names, 102 | ) 103 | else: 104 | profiler = contextlib.nullcontext() 105 | 106 | benchmark_metrics = { 107 | **generate_kwargs, 108 | "Model parameters": pipeline.get_num_parameters(), 109 | "Cycles (warmup)": args.skip + warmup, 110 | "Cycles (benchmark)": args.cycles, 111 | } 112 | if args.profile: 113 | benchmark_metrics["Cycles (profile)"] = post_warmup_cycles 114 | benchmark_metrics["Cycles (total)"] = args.skip + warmup + pre_warmup_cycles + post_warmup_cycles 115 | 116 | if pipeline.device.type == "cuda": 117 | benchmark_metrics[Metrics.MEMORY_USED_INIT] = torch.cuda.memory_allocated() 118 | benchmark_metrics[Metrics.MEMORY_RESERVED_INIT] = torch.cuda.memory_reserved() 119 | torch.cuda.reset_peak_memory_stats() 120 | 121 | t1 = time.perf_counter() 122 | with profiler as p: 123 | for step in range(args.skip + warmup + args.cycles): 124 | if step == args.skip + warmup: 125 | t2 = time.perf_counter() 126 | benchmark_metrics[Metrics.RUNTIME_WARMUP] = t2 - t1 127 | generated_text, metrics = pipeline(inputs, **generate_kwargs) 128 | if args.profile: 129 | p.step() 130 | 131 | if step == 0: 132 | for i, o, _ in zip(inputs, generated_text, range(max_log_outputs)): 133 | log_rank_n(f"{'-' * 60}\nINPUT = {i}\nOUTPUT = {o}", logger.info) 134 | 135 | if benchmark_begin <= step < benchmark_end: 136 | all_metrics.append(metrics) 137 | 138 | if args.clear_every_run: 139 | torch.cuda.synchronize() 140 | gc.collect() 141 | torch.cuda.empty_cache() 142 | if pipeline.device.type == "cuda": 143 | benchmark_metrics[Metrics.MEMORY_USED_END] = torch.cuda.memory_allocated() 144 | benchmark_metrics[Metrics.MEMORY_RESERVED_END] = torch.cuda.memory_reserved() 145 | benchmark_metrics[Metrics.MEMORY_USED_MAX] = torch.cuda.max_memory_allocated() 146 | benchmark_metrics[Metrics.MEMORY_RESERVED_MAX] = torch.cuda.max_memory_reserved() 147 | 148 | t3 = time.perf_counter() 149 | benchmark_metrics[Metrics.RUNTIME_BENCHMARK] = t3 - t2 150 | benchmark_metrics[Metrics.RUNTIME_TOTAL] = t3 - t0 151 | 152 | if len(all_metrics) > 0: 153 | benchmark_metrics.update(pipeline.aggregate_metrics(all_metrics)) 154 | 155 | benchmark_metrics = Metrics.reorder_metrics(benchmark_metrics) 156 | 157 | log_rank_n("*** Benchmark results:", logger.info) 158 | log_dict(Metrics.format_metrics(benchmark_metrics), logger.info) 159 | 160 | if args.save: 161 | save_path = Path(args.save).resolve() 162 | print(f"*** Saving results to {save_path}") 163 | save_path.parent.mkdir(parents=True, exist_ok=True) 164 | with save_path.open("w") as f: 165 | json.dump( 166 | { 167 | "config": pipeline.config.to_dict(), 168 | "results": benchmark_metrics, 169 | }, 170 | f, 171 | ) 172 | 173 | 174 | if __name__ == "__main__": 175 | configure_logging() 176 | main() 177 | -------------------------------------------------------------------------------- /src/pipeline.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import gc 3 | import logging 4 | import os 5 | import time 6 | from typing import Any, Dict, List, Optional, Tuple 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from src.fast_init import fast_init 12 | from src.metrics import Metrics 13 | from src.utils import log_rank_n, parse_revision 14 | from transformers import ( 15 | CONFIG_MAPPING, 16 | AutoConfig, 17 | AutoModelForCausalLM, 18 | AutoTokenizer, 19 | PretrainedConfig, 20 | PreTrainedModel, 21 | ) 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class Pipeline: 28 | def __init__( 29 | self, 30 | *, 31 | model_type: Optional[str] = None, 32 | pretrained_config: Optional[str] = None, 33 | pretrained_model: Optional[str] = None, 34 | config_args: Dict[str, Any], 35 | tokenizer: str, 36 | device: torch.device, 37 | dtype: torch.dtype, 38 | fast_init: bool = True, 39 | trust_remote_code: bool = False, 40 | ): 41 | self.global_metrics = {} 42 | log_rank_n("*** Setting up tokenizer", logger.info) 43 | t0 = time.perf_counter() 44 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) 45 | 46 | self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 47 | t1 = time.perf_counter() 48 | 49 | self.device = device 50 | self.dtype = dtype 51 | self.is_int8 = self.dtype == torch.int8 52 | self.fast_init = fast_init 53 | self.trust_remote_code = trust_remote_code 54 | if self.is_int8 and self.device != torch.device("cuda"): 55 | raise ValueError(f"Model quantization not supported on device {self.device}") 56 | 57 | self.config = self._get_config(model_type, pretrained_config or pretrained_model, config_args) 58 | t2 = time.perf_counter() 59 | 60 | logger.info(f"Model configuration: {self.config}") 61 | 62 | if pretrained_model is None: 63 | self.model = self._create_model() 64 | if self.is_int8: 65 | self._reload_model() 66 | else: 67 | self.model = self._load_pretrained(pretrained_model) 68 | 69 | self.model.eval() 70 | t3 = time.perf_counter() 71 | self.global_metrics[Metrics.INIT_TOKEN] = t1 - t0 72 | self.global_metrics[Metrics.INIT_CONFIG] = t2 - t1 73 | self.global_metrics[Metrics.INIT_TOTAL] = t3 - t0 74 | 75 | def _create_model(self) -> PreTrainedModel: 76 | t0 = time.perf_counter() 77 | log_rank_n("*** Creating model", logger.info) 78 | with fast_init(self.device) if self.fast_init else contextlib.nullcontext(): 79 | torch_dtype = torch.float16 if self.is_int8 else self.dtype 80 | model = AutoModelForCausalLM.from_config( 81 | config=self.config, torch_dtype=torch_dtype, trust_remote_code=self.trust_remote_code 82 | ) 83 | t1 = time.perf_counter() 84 | log_rank_n("*** Moving to device", logger.info) 85 | model.to(self.device) 86 | t2 = time.perf_counter() 87 | log_rank_n("*** Initializing weights", logger.info) 88 | # Initialization is ~1000x faster on GPU. 89 | model.init_weights() 90 | t3 = time.perf_counter() 91 | self.global_metrics[Metrics.INIT_CREATE] = t1 - t0 92 | self.global_metrics[Metrics.INIT_DEVICE] = t2 - t1 93 | self.global_metrics[Metrics.INIT_WEIGHTS] = t3 - t2 94 | 95 | return model 96 | 97 | def _reload_model(self): 98 | self._save_pretrained("tmp") 99 | del self.model 100 | gc.collect() 101 | self.model = self._load_pretrained("tmp") 102 | 103 | def _save_pretrained(self, pretrained_model: str): 104 | t0 = time.perf_counter() 105 | log_rank_n(f"*** Saving model to {pretrained_model}", logger.info) 106 | t1 = time.perf_counter() 107 | self.global_metrics[Metrics.INIT_SAVE] = t1 - t0 108 | self.model.save_pretrained(pretrained_model) 109 | 110 | def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel: 111 | t0 = time.perf_counter() 112 | log_rank_n(f"*** Loading model from {pretrained_model}", logger.info) 113 | kwargs = {"load_in_8bit": True, "device_map": "auto"} if self.is_int8 else {"torch_dtype": self.dtype} 114 | with fast_init(self.device) if self.fast_init else contextlib.nullcontext(): 115 | pretrained_model, revision = parse_revision(pretrained_model) 116 | model = AutoModelForCausalLM.from_pretrained( 117 | pretrained_model, 118 | revision=revision, 119 | config=self.config, 120 | trust_remote_code=self.trust_remote_code, 121 | **kwargs, 122 | ) 123 | t1 = time.perf_counter() 124 | self.global_metrics["load pretrained model"] = t1 - t0 125 | if not self.is_int8: 126 | log_rank_n("*** Moving to device", logger.info) 127 | model = model.to(self.device) 128 | t2 = time.perf_counter() 129 | self.global_metrics[Metrics.INIT_DEVICE] = t2 - t1 130 | return model 131 | 132 | def _get_config( 133 | self, 134 | model_type: Optional[str], 135 | pretrained_config: Optional[str], 136 | config_args: Dict[str, Any], 137 | ) -> PretrainedConfig: 138 | config_args = { 139 | "use_cache": True, 140 | "return_unused_kwargs": True, 141 | **config_args, 142 | } 143 | 144 | if model_type is None: 145 | if pretrained_config is None: 146 | raise ValueError("You need to provide either --model_type or --pretrained_model") 147 | config_class = AutoConfig 148 | elif model_type not in CONFIG_MAPPING: 149 | raise ValueError(f"Unknown model type: {model_type}") 150 | else: 151 | config_class = CONFIG_MAPPING[model_type] 152 | config_args["model_type"] = model_type 153 | 154 | if pretrained_config is None: 155 | config_args.update( 156 | { 157 | "bos_token_id": self.tokenizer.bos_token_id, 158 | "eos_token_id": self.tokenizer.eos_token_id, 159 | "vocab_size": len(self.tokenizer), 160 | } 161 | ) 162 | config, unused = config_class.from_dict({}, **config_args) 163 | else: 164 | pretrained_config, revision = parse_revision(pretrained_config) 165 | config, unused = config_class.from_pretrained( 166 | pretrained_config, revision=revision, trust_remote_code=self.trust_remote_code, **config_args 167 | ) 168 | 169 | if unused: 170 | raise ValueError(f"There were unused configuration parameters: {tuple(unused)}") 171 | 172 | return config 173 | 174 | def __call__(self, text: List[str], **generate_kwargs) -> Tuple[List[str], Dict[str, Any]]: 175 | t0 = time.perf_counter() 176 | inputs = self.tokenizer(text, return_tensors="pt", padding=True) 177 | 178 | inputs = {key: value.to(self.device) if torch.is_tensor(value) else value for key, value in inputs.items()} 179 | 180 | t1 = time.perf_counter() 181 | with torch.inference_mode(): 182 | output = self.model.generate(**inputs, return_dict_in_generate=True, **generate_kwargs) 183 | t2 = time.perf_counter() 184 | 185 | output_tokens = output.sequences 186 | 187 | batch_size, input_length = inputs["input_ids"].shape 188 | output_length = output_tokens.size(1) 189 | 190 | output_text = self.tokenizer.batch_decode(output_tokens.cpu(), skip_special_tokens=True) 191 | t3 = time.perf_counter() 192 | 193 | metrics = { 194 | Metrics.BATCH_SIZE: batch_size, 195 | Metrics.INPUT_LENGTH: input_length, 196 | Metrics.OUTPUT_LENGTH: output_length, 197 | Metrics.TOKENS_SAMPLE: output_length - input_length, 198 | Metrics.TOKENS_BATCH: batch_size * (output_length - input_length), 199 | Metrics.LATENCY_TOKEN: t1 - t0, 200 | Metrics.LATENCY_MODEL: t2 - t1, 201 | Metrics.LATENCY_DECODE: t3 - t2, 202 | Metrics.LATENCY_E2E: t3 - t0, 203 | } 204 | 205 | return output_text, metrics 206 | 207 | def get_num_parameters(self) -> int: 208 | return sum(p.numel() for p in self.model.parameters()) 209 | 210 | def aggregate_metrics(self, metrics: List[Dict[str, Any]]): 211 | all_metrics = { 212 | key: [metrics_[key] for metrics_ in metrics if key in metrics_] 213 | for key in ( 214 | Metrics.BATCH_SIZE, 215 | Metrics.INPUT_LENGTH, 216 | Metrics.OUTPUT_LENGTH, 217 | Metrics.TOKENS_SAMPLE, 218 | Metrics.TOKENS_BATCH, 219 | Metrics.LATENCY_TOKEN, 220 | Metrics.LATENCY_MODEL, 221 | Metrics.LATENCY_DECODE, 222 | Metrics.LATENCY_E2E, 223 | ) 224 | } 225 | mean_metrics = {key: np.mean(value).item() for key, value in all_metrics.items() if len(value) > 0} 226 | throughput = mean_metrics[Metrics.TOKENS_BATCH] / mean_metrics[Metrics.LATENCY_E2E] 227 | model_throughput = mean_metrics[Metrics.TOKENS_BATCH] / mean_metrics[Metrics.LATENCY_MODEL] 228 | 229 | return { 230 | **self.global_metrics, 231 | **mean_metrics, 232 | Metrics.LATENCY_MAX: max(all_metrics[Metrics.LATENCY_E2E]), 233 | Metrics.LATENCY_MIN: min(all_metrics[Metrics.LATENCY_E2E]), 234 | Metrics.LATENCY_STD: np.std(all_metrics[Metrics.LATENCY_E2E]).item(), 235 | Metrics.THROUGHPUT_MODEL: model_throughput, 236 | Metrics.THROUGHPUT_E2E: throughput, 237 | Metrics.TOKEN_TIME: throughput**-1, 238 | } 239 | 240 | 241 | class HF_Pipeline(Pipeline): 242 | pass 243 | 244 | 245 | class DS_Pipeline(Pipeline): 246 | def __init__(self, **kwargs): 247 | import deepspeed 248 | 249 | super().__init__(**kwargs) 250 | 251 | if self.device != torch.device("cuda"): 252 | raise ValueError(f"Deepspeed does not support device {self.device}") 253 | 254 | if self.dtype not in (torch.float32, torch.float16, torch.bfloat16): 255 | raise ValueError(f"Deepspeed does not support dtype {self.dtype}") 256 | 257 | if self.config.model_type not in ("bloom", "gpt2"): 258 | raise ValueError(f"Deepspeed does not support model type {self.config.model_type}") 259 | 260 | self.model = deepspeed.init_inference( 261 | self.model, 262 | mp_size=int(os.getenv("WORLD_SIZE", "1")), 263 | # base_dir="./", 264 | dtype=self.dtype, 265 | replace_with_kernel_inject=True, 266 | ) 267 | 268 | 269 | _PIPELINE_CLASS_MAP = { 270 | "HF_Pipeline": HF_Pipeline, 271 | "DS_Pipeline": DS_Pipeline, 272 | } 273 | 274 | 275 | def get_pipeline_class(name): 276 | if name not in _PIPELINE_CLASS_MAP: 277 | raise NotImplementedError(f"Unsupported pipeline class: {name}") 278 | return _PIPELINE_CLASS_MAP[name] 279 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------