├── .gitignore ├── oneflow ├── README.md ├── Dockerfile └── test_oneflow.py ├── llm ├── Dockerfile └── test_llm.py ├── AITemplate ├── Dockerfile └── test_ait.py ├── deepspeed ├── Dockerfile └── test_deepspeed.py ├── lcm ├── Dockerfile └── test_lcm.py ├── jax-gpu ├── Dockerfile └── test_jax.py ├── tensorrt_2 ├── Dockerfile └── test_trt.py ├── hidet ├── Dockerfile └── test_hidet.py ├── prompts.txt ├── PyTorch ├── Dockerfile ├── run_sd3_compile.py └── test_pt.py ├── tensorrt └── Dockerfile ├── onnxruntime └── Dockerfile └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | workspace/ -------------------------------------------------------------------------------- /oneflow/README.md: -------------------------------------------------------------------------------- 1 | ## OneFlow deep learning framework 2 | 1. https://github.com/Oneflow-Inc/oneflow 3 | 2. https://github.com/Oneflow-Inc/diffusers -------------------------------------------------------------------------------- /llm/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:latest 2 | 3 | RUN apt-get update && apt-get install -y \ 4 | && apt-get install --no-install-recommends -y git curl python3-pip \ 5 | && pip install --upgrade pip \ 6 | && apt-get clean \ 7 | && rm -rf /var/lib/apt/lists/* 8 | 9 | WORKDIR /code 10 | 11 | RUN pip install transformers torch 12 | 13 | COPY test_llm.py /code/test_llm.py 14 | 15 | CMD python3 /code/test_llm.py -------------------------------------------------------------------------------- /AITemplate/Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | FROM nvidia/cuda:11.6.2-devel-ubuntu20.04 3 | 4 | RUN apt-get update && apt-get install --no-install-recommends -y curl && apt-get -y install git 5 | 6 | RUN apt-get update -y && \ 7 | apt-get install -y python3 python3-dev python3-pip 8 | 9 | RUN git clone --recursive https://github.com/facebookincubator/AITemplate 10 | 11 | RUN bash /AITemplate/docker/install/install_ait.sh 12 | 13 | COPY test_ait.py /test_ait.py 14 | 15 | CMD [ "python3", "/test_ait.py"] 16 | -------------------------------------------------------------------------------- /deepspeed/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.10-py3 2 | 3 | RUN apt-get update && apt-get install -y \ 4 | && apt-get install --no-install-recommends -y git curl python3-pip \ 5 | && pip install --upgrade pip \ 6 | && apt-get clean \ 7 | && rm -rf /var/lib/apt/lists/* 8 | 9 | WORKDIR /code 10 | 11 | RUN pip install deepspeed diffusers transformers accelerate 12 | 13 | COPY test_deepspeed.py /code/test_deepspeed.py 14 | 15 | CMD [ "python3", "/code/test_deepspeed.py"] 16 | -------------------------------------------------------------------------------- /lcm/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.10-py3 2 | # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html 3 | 4 | RUN apt-get update && apt-get install -y \ 5 | && apt-get install --no-install-recommends -y git curl python3-pip \ 6 | && pip install --upgrade pip \ 7 | && apt-get clean \ 8 | && rm -rf /var/lib/apt/lists/* 9 | 10 | WORKDIR /code 11 | 12 | RUN pip install diffusers transformers accelerate 13 | 14 | COPY test_lcm.py /code/test_lcm.py 15 | 16 | CMD python3 /code/test_lcm.py 17 | -------------------------------------------------------------------------------- /jax-gpu/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:latest 2 | 3 | RUN apt-get update && apt-get install -y \ 4 | && apt-get install --no-install-recommends -y git curl python3-pip \ 5 | && pip install --upgrade pip \ 6 | && apt-get clean \ 7 | && rm -rf /var/lib/apt/lists/* 8 | 9 | RUN pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html flax 10 | 11 | WORKDIR /code 12 | 13 | RUN pip install diffusers transformers 14 | 15 | COPY test_jax.py /code/test_jax.py 16 | 17 | CMD [ "python3", "/code/test_jax.py"] 18 | -------------------------------------------------------------------------------- /llm/test_llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['HF_HOME']='/workspace/.cache/huggingface' 3 | 4 | import torch 5 | 6 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM 7 | 8 | def main(): 9 | tokenizer = AutoTokenizer.from_pretrained("daryl149/llama-2-7b-chat-hf") 10 | model = AutoModelForCausalLM.from_pretrained("daryl149/llama-2-7b-chat-hf", torch_dtype=torch.float16) 11 | 12 | print(model) 13 | 14 | num_parameters = sum(p.numel() for p in model.parameters()) 15 | print(f"Number of parameters in Llama-2-13B: {num_parameters}") 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /tensorrt_2/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.06-py3 2 | # custom torch 1.14 build that supports H100 3 | 4 | RUN apt-get update -y && \ 5 | apt-get install --no-install-recommends -y git curl python3-pip && \ 6 | pip install --upgrade pip 7 | 8 | WORKDIR /code 9 | 10 | RUN pip install diffusers transformers accelerate colored 11 | RUN pip install --upgrade tensorrt>=8.6.1 12 | RUN pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com 13 | RUN pip install --upgrade onnxruntime 14 | 15 | COPY test_trt.py /code/test_trt.py 16 | 17 | CMD [ "python3", "/code/test_trt.py"] 18 | -------------------------------------------------------------------------------- /hidet/Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 3 | 4 | RUN apt-get update && apt-get install --no-install-recommends -y curl && apt-get -y install git 5 | 6 | RUN apt-get update -y && \ 7 | apt-get install -y python3-pip 8 | 9 | RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 10 | RUN pip install --pre --extra-index-url https://download.hidet.org/whl hidet 11 | RUN pip install hidet numpy==1.23.0 12 | RUN pip install diffusers transformers accelerate xformers 13 | 14 | WORKDIR /code 15 | 16 | COPY test_hidet.py /code/test_hidet.py 17 | 18 | CMD [ "python3", "/code/test_hidet.py"] 19 | -------------------------------------------------------------------------------- /oneflow/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.08-py3 2 | 3 | WORKDIR /code 4 | 5 | RUN pip install --pre oneflow -f https://oneflow-staging.oss-cn-beijing.aliyuncs.com/branch/master/cu117 6 | # RUN pip install -f https://oneflow-staging.oss-cn-beijing.aliyuncs.com/branch/master/cu117 oneflow 7 | 8 | RUN python3 -m pip install "torch" "transformers==4.27.1" "diffusers[torch]==0.19.3" && \ 9 | python3 -m pip uninstall accelerate -y && \ 10 | git clone https://github.com/Oneflow-Inc/diffusers.git onediff && \ 11 | python3 -m pip install -e /code/onediff 12 | 13 | COPY test_oneflow.py /code/test_oneflow.py 14 | 15 | CMD [ "python3", "/code/test_oneflow.py"] 16 | -------------------------------------------------------------------------------- /prompts.txt: -------------------------------------------------------------------------------- 1 | an aerial drone shot of the breathtaking landscape of the Bora Bora islands, with sparkling waters under the sun 2 | a bottle of perfume on a clean backdrop, surrounded by fragrant white flowers, product photography, minimalistic, natural light 3 | Custom sticker design on an isolated white background with the words "Rachel" written in an elegant font decorated by watercolor butterflies, daisies and soft pastel hues 4 | Simple flat vector illustration of a woman sitting at the desk with her laptop with a puppy, isolated on white background 5 | a bedroom with large windows and modern furniture, gray and gold, luxurious, mid century modern style 6 | A cat holding a sign that says hello world -------------------------------------------------------------------------------- /PyTorch/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | # nvcr.io/nvidia/pytorch:24.05-py3 3 | # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html 4 | 5 | RUN apt-get update && apt-get install -y \ 6 | && apt-get install --no-install-recommends -y git curl python3-pip \ 7 | && pip install --upgrade pip \ 8 | && apt-get clean \ 9 | && rm -rf /var/lib/apt/lists/* 10 | 11 | ARG HF_TOKEN 12 | ENV HF_TOKEN=$HF_TOKEN 13 | ENV HF_HOME="/workspace/.cache/huggingface" 14 | 15 | WORKDIR /code 16 | 17 | RUN pip install torch diffusers transformers accelerate sentencepiece protobuf 18 | 19 | COPY test_pt.py /code/test_pt.py 20 | COPY run_sd3_compile.py /code/run_sd3_compile.py 21 | 22 | CMD python3 /code/test_pt.py 23 | # CMD python3 /code/run_sd3_compile.py 24 | -------------------------------------------------------------------------------- /tensorrt/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | # nvcr.io/nvidia/cuda:12.4.1-devel-ubuntu22.04 3 | 4 | RUN apt-get update && apt-get install -y \ 5 | && apt-get install --no-install-recommends -y git curl python3-pip python3-virtualenv cmake \ 6 | # && pip install --upgrade pip \ 7 | && apt-get clean \ 8 | && rm -rf /var/lib/apt/lists/* 9 | 10 | WORKDIR /code 11 | 12 | # Create a virtualenv for dependencies. This isolates these packages from 13 | # system-level packages. 14 | # Use -p python3 or -p python3.7 to select python version. Default is version 2. 15 | RUN virtualenv /env 16 | 17 | # Setting these environment variables are the same as running 18 | # source /env/bin/activate. 19 | ENV VIRTUAL_ENV /env 20 | ENV PATH /env/bin:$PATH 21 | 22 | # !! release/10.0 , please check https://github.com/NVIDIA/TensorRT/releases 23 | RUN git clone https://github.com/NVIDIA/TensorRT.git && cd TensorRT && git checkout release/10.0 24 | 25 | RUN pip install tensorrt 26 | 27 | RUN pip install -r TensorRT/demo/Diffusion/requirements.txt 28 | 29 | CMD mkdir -p /workspace/onnx /workspace/engine /workspace/output && \ 30 | python3 /code/TensorRT/demo/Diffusion/demo_txt2img.py "a beautiful photograph of Mt. Fuji during cherry blossom" \ 31 | --batch-size 1 --build-static-batch --use-cuda-graph --num-warmup-runs 5 \ 32 | --version 2.1 --onnx-dir /workspace/onnx --engine-dir /workspace/engine --output-dir /workspace/output -v \ 33 | --height 768 --width 768 \ 34 | --denoising-steps 50 35 | -------------------------------------------------------------------------------- /PyTorch/run_sd3_compile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.set_float32_matmul_precision("high") 4 | 5 | from diffusers import StableDiffusion3Pipeline 6 | import time 7 | 8 | id = "stabilityai/stable-diffusion-3-medium-diffusers" 9 | pipeline = StableDiffusion3Pipeline.from_pretrained( 10 | id, 11 | torch_dtype=torch.float16 12 | ).to("cuda") 13 | pipeline.set_progress_bar_config(disable=True) 14 | 15 | torch._inductor.config.conv_1x1_as_mm = True 16 | torch._inductor.config.coordinate_descent_tuning = True 17 | torch._inductor.config.epilogue_fusion = False 18 | torch._inductor.config.coordinate_descent_check_all_directions = True 19 | 20 | pipeline.transformer.to(memory_format=torch.channels_last) 21 | pipeline.vae.to(memory_format=torch.channels_last) 22 | pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True) 23 | pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True) 24 | 25 | prompt = "A cat holding a sign that says hello world" 26 | for _ in range(3): 27 | _ = pipeline( 28 | prompt=prompt, 29 | num_inference_steps=50, 30 | guidance_scale=5.0, 31 | generator=torch.manual_seed(1), 32 | ) 33 | 34 | start = time.time() 35 | for _ in range(10): 36 | _ = pipeline( 37 | prompt=prompt, 38 | num_inference_steps=50, 39 | guidance_scale=5.0, 40 | generator=torch.manual_seed(1), 41 | ) 42 | end = time.time() 43 | avg_inference_time = (end - start) / 10 44 | print(f"Average inference time: {avg_inference_time:.3f} seconds.") 45 | 46 | image = pipeline( 47 | prompt=prompt, 48 | num_inference_steps=50, 49 | guidance_scale=5.0, 50 | generator=torch.manual_seed(1), 51 | ).images[0] 52 | filename = "_".join(prompt.split(" ")) 53 | image.save(f"diffusers_{filename}.png") -------------------------------------------------------------------------------- /lcm/test_lcm.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["HF_HOME"] = "/workspace/.cache/huggingface" 4 | 5 | import torch 6 | import torch.utils.benchmark as benchmark 7 | import argparse 8 | from diffusers import DiffusionPipeline, LCMScheduler 9 | 10 | PROMPT = "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm summilux" 11 | MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" 12 | LORA_ID = "latent-consistency/lcm-lora-sdxl" 13 | 14 | 15 | def benchmark_fn(f, *args, **kwargs): 16 | t0 = benchmark.Timer( 17 | stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} 18 | ) 19 | return t0.blocked_autorange().mean * 1e6 20 | 21 | 22 | def load_pipeline(standard_sdxl=False): 23 | pipe = DiffusionPipeline.from_pretrained(MODEL_ID, variant="fp16") 24 | if not standard_sdxl: 25 | pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) 26 | pipe.load_lora_weights(LORA_ID) 27 | 28 | pipe.to(device="cuda", dtype=torch.float16) 29 | return pipe 30 | 31 | 32 | def call_pipeline(pipe, batch_size, num_inference_steps, guidance_scale): 33 | images = pipe( 34 | prompt=PROMPT, 35 | num_inference_steps=num_inference_steps, 36 | num_images_per_prompt=batch_size, 37 | guidance_scale=guidance_scale, 38 | ).images[0] 39 | 40 | 41 | def main(): 42 | standard_sdxl = False 43 | batch_size = 1 44 | 45 | pipeline = load_pipeline(standard_sdxl) 46 | if standard_sdxl: 47 | num_inference_steps = 25 48 | guidance_scale = 5 49 | else: 50 | num_inference_steps = 4 51 | guidance_scale = 1 52 | 53 | time = benchmark_fn(call_pipeline, pipeline, batch_size, num_inference_steps, guidance_scale) 54 | 55 | print(f"Batch size: {batch_size} in {time/1e6:.3f} seconds") 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /onnxruntime/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.10-py3 2 | # custom torch 1.14 build that supports H100 3 | 4 | RUN apt-get update && apt-get install -y \ 5 | && apt-get install --no-install-recommends -y git curl python3-pip \ 6 | && pip install --upgrade pip \ 7 | && apt-get clean \ 8 | && rm -rf /var/lib/apt/lists/* 9 | 10 | WORKDIR /code 11 | 12 | RUN git clone https://github.com/microsoft/onnxruntime 13 | 14 | WORKDIR /code/onnxruntime 15 | 16 | RUN export CUDACXX=/usr/local/cuda-12.2/bin/nvcc && \ 17 | git config --global --add safe.directory '*' && \ 18 | sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_version 12.2 \ 19 | --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/lib/x86_64-linux-gnu/ --build_wheel --skip_tests \ 20 | --use_tensorrt --tensorrt_home /usr/src/tensorrt \ 21 | --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF \ 22 | --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \ 23 | --allow_running_as_root && \ 24 | python3 -m pip install --upgrade pip && \ 25 | python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl --force-reinstall && \ 26 | rm -rf ./build 27 | 28 | WORKDIR /code/onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion 29 | 30 | RUN python3 -m pip install -r requirements-cuda12.txt && \ 31 | python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com 32 | 33 | CMD cd /code/onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion && \ 34 | python3 demo_txt2img_xl.py --disable-refiner --denoising-steps 50 \ 35 | --engine ORT_CUDA --work-dir /workspace/ "starry night over Golden Gate Bridge by van gogh" 36 | 37 | # CMD cd /code/onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion && \ 38 | # python3 demo_txt2img.py --engine ORT_CUDA --version 2.1 --height 768 --width 768 --batch-size 1 \ 39 | # --disable-refiner --work-dir /workspace/ "starry night over Golden Gate Bridge by van gogh" -------------------------------------------------------------------------------- /oneflow/test_oneflow.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['HF_HOME']='/workspace/.cache/huggingface' 3 | 4 | from onediff.infer_compiler import oneflow_compile 5 | from onediff.optimization import rewrite_self_attention 6 | from diffusers import StableDiffusionPipeline 7 | import oneflow as flow 8 | import torch 9 | 10 | from time import perf_counter 11 | import numpy as np 12 | 13 | sd_args = {"width": 512, "height": 512, "guidance_scale": 7.5, "num_inference_steps": 50} 14 | 15 | 16 | def measure_latency(pipe, prompt): 17 | # warm up 18 | # pipe.set_progress_bar_config(disable=True) 19 | for _ in range(2): 20 | _ = pipe(prompt, **sd_args) 21 | flow._oneflow_internal.eager.Sync() 22 | 23 | # Timed run 24 | latencies = [] 25 | for _ in range(10): 26 | start_time = perf_counter() 27 | with flow.autocast("cuda"): 28 | _ = pipe(prompt, **sd_args) 29 | flow._oneflow_internal.eager.Sync() 30 | latency = perf_counter() - start_time 31 | latencies.append(latency) 32 | # Compute run statistics 33 | time_avg_s = np.mean(latencies) 34 | time_std_s = np.std(latencies) 35 | time_p95_s = np.percentile(latencies,95) 36 | return f"P95 latency (seconds) - {time_p95_s:.2f}; Average latency (seconds) - {time_avg_s:.2f} +\- {time_std_s:.2f};", time_p95_s 37 | 38 | def main(): 39 | # model_id = "stabilityai/stable-diffusion-2-1" 40 | model_id = "runwayml/stable-diffusion-v1-5" 41 | pipe = StableDiffusionPipeline.from_pretrained( 42 | model_id, 43 | revision="fp16", 44 | variant="fp16", 45 | torch_dtype=torch.float16, 46 | safety_checker=None, 47 | ) 48 | pipe = pipe.to("cuda") 49 | 50 | rewrite_self_attention(pipe.unet) 51 | pipe.unet = oneflow_compile(pipe.unet) 52 | 53 | prompt = "a photo of an astronaut riding a horse on mars" 54 | with flow.autocast("cuda"): 55 | images = pipe(prompt, **sd_args).images 56 | for i, image in enumerate(images): 57 | image.save(f"{prompt}-of-{i}.png") 58 | 59 | vanilla_results = measure_latency(pipe, prompt) 60 | 61 | print(f"Vanilla pipeline: {vanilla_results[0]}") 62 | 63 | if __name__ == "__main__": 64 | main() -------------------------------------------------------------------------------- /deepspeed/test_deepspeed.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['HF_HOME']='/workspace/.cache/huggingface' 3 | 4 | import deepspeed 5 | import torch 6 | 7 | from diffusers import DiffusionPipeline 8 | 9 | from time import perf_counter 10 | import numpy as np 11 | 12 | sd_args_v15 = {"_model_id_": "runwayml/stable-diffusion-v1-5", "width": 512, "height": 512, "guidance_scale": 7.5, "num_inference_steps": 50} 13 | sd_args_v21 = {"_model_id_": "stabilityai/stable-diffusion-2-1", "width": 768, "height": 768, "guidance_scale": 7.5, "num_inference_steps": 50} 14 | 15 | sd_args = sd_args_v21 16 | 17 | @torch.inference_mode() 18 | def measure_latency(pipe, prompt): 19 | latencies = [] 20 | # warm up 21 | # pipe.set_progress_bar_config(disable=True) 22 | for _ in range(2): 23 | _ = pipe(prompt, **sd_args) 24 | # Timed run 25 | for _ in range(10): 26 | start_time = perf_counter() 27 | _ = pipe(prompt, **sd_args) 28 | latency = perf_counter() - start_time 29 | latencies.append(latency) 30 | # Compute run statistics 31 | time_avg_s = np.mean(latencies) 32 | time_std_s = np.std(latencies) 33 | time_p95_s = np.percentile(latencies,95) 34 | return f"P95 latency (seconds) - {time_p95_s:.2f}; Average latency (seconds) - {time_avg_s:.2f} +\- {time_std_s:.2f};", time_p95_s 35 | 36 | def main(): 37 | prompt = "A majestic lion jumping from a big stone at night" 38 | 39 | pipe_ds = DiffusionPipeline.from_pretrained( 40 | sd_args["_model_id_"], 41 | torch_dtype=torch.half 42 | ).to("cuda") 43 | 44 | # NOTE: DeepSpeed inference supports local CUDA graphs for replaced SD modules. 45 | # Local CUDA graphs for replaced SD modules will only be enabled when `mp_size==1` 46 | pipe_ds = deepspeed.init_inference( 47 | pipe_ds, 48 | mp_size=1, 49 | dtype=torch.half, 50 | replace_with_kernel_inject=True, 51 | enable_cuda_graph=True, 52 | ) 53 | 54 | deepspeed_image = pipe_ds(prompt, **sd_args).images[0] 55 | deepspeed_image.save(f"deepspeed.png") 56 | 57 | prompt = "a photo of an astronaut riding a horse on mars" 58 | 59 | vanilla_results = measure_latency(pipe_ds, prompt) 60 | 61 | print(f"Deepspeed pipeline: {vanilla_results[0]}") 62 | 63 | if __name__ == "__main__": 64 | main() -------------------------------------------------------------------------------- /tensorrt_2/test_trt.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['HF_HOME']='/workspace/.cache/huggingface' 3 | 4 | import time 5 | import numpy as np 6 | 7 | import torch 8 | from diffusers import DDIMScheduler 9 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 10 | 11 | # @torch.inference_mode() 12 | def benchmark_func(pipe, prompt): 13 | for _ in range(5): 14 | _ = pipe(prompt) 15 | # Start benchmark. 16 | torch.cuda.synchronize() 17 | 18 | # Timed run 19 | n_runs = 10 20 | latencies = [] 21 | for _ in range(n_runs): 22 | start = time.perf_counter_ns() 23 | _ = pipe(prompt) 24 | torch.cuda.synchronize() 25 | end = time.perf_counter_ns() - start 26 | latencies.append(end) 27 | 28 | time_avg_s = np.average(latencies) 29 | return int(time_avg_s / 1000000.0) 30 | 31 | def main(): 32 | model_id = "runwayml/stable-diffusion-v1-5" 33 | # Use the DDIMScheduler scheduler here instead 34 | scheduler = DDIMScheduler.from_pretrained(model_id, 35 | subfolder="scheduler") 36 | 37 | pipe = StableDiffusionPipeline.from_pretrained(model_id, 38 | custom_pipeline="stable_diffusion_tensorrt_txt2img", 39 | revision='fp16', 40 | torch_dtype=torch.float16, 41 | scheduler=scheduler, 42 | image_height=512, 43 | image_width=512, 44 | max_batch_size=1 45 | ) 46 | 47 | # re-use cached folder to save ONNX models and TensorRT Engines 48 | pipe.set_cached_folder(model_id, revision='fp16',) 49 | 50 | pipe = pipe.to("cuda") 51 | 52 | prompt = "a beautiful photograph of Mt. Fuji during cherry blossom" 53 | image = pipe(prompt).images[0] 54 | image.save('tensorrt_mt_fuji.png') 55 | 56 | prompt = ["a beautiful photograph of Mt. Fuji during cherry blossom"] * 1 57 | 58 | latency_ms = benchmark_func(pipe, prompt) 59 | 60 | print("Pipeline latency:", latency_ms, "ms") 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /AITemplate/test_ait.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | from diffusers import DiffusionPipeline 5 | 6 | from time import perf_counter 7 | import numpy as np 8 | 9 | 10 | def measure_latency(pipe, prompt): 11 | latencies = [] 12 | # warm up 13 | # pipe.set_progress_bar_config(disable=True) 14 | for _ in range(2): 15 | _ = pipe(prompt) 16 | # Timed run 17 | for _ in range(10): 18 | start_time = perf_counter() 19 | with torch.inference_mode(): 20 | _ = pipe(prompt) 21 | latency = perf_counter() - start_time 22 | latencies.append(latency) 23 | # Compute run statistics 24 | time_avg_s = np.mean(latencies) 25 | time_std_s = np.std(latencies) 26 | time_p95_s = np.percentile(latencies,95) 27 | return f"P95 latency (seconds) - {time_p95_s:.2f}; Average latency (seconds) - {time_avg_s:.2f} +\- {time_std_s:.2f};", time_p95_s 28 | 29 | def main(): 30 | import torch 31 | import os 32 | import functools 33 | 34 | import torch._dynamo.config 35 | torch._dynamo.config.suppress_errors = True 36 | 37 | import hidet 38 | 39 | # more search 40 | hidet.torch.dynamo_config.search_space(2) 41 | # automatically transform the model to use float16 data type 42 | hidet.torch.dynamo_config.use_fp16(True) 43 | # use float16 data type as the accumulate data type in operators with reduction 44 | hidet.torch.dynamo_config.use_fp16_reduction(True) 45 | # use tensorcore 46 | hidet.torch.dynamo_config.use_tensor_core() 47 | 48 | torch.backends.cudnn.benchmark = True 49 | 50 | from diffusers import DiffusionPipeline 51 | from diffusers import DPMSolverMultistepScheduler 52 | 53 | 54 | prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition." 55 | prompt = "a grey cat sitting on a chair in the kitchen, animated" 56 | 57 | model = "runwayml/stable-diffusion-v1-5" 58 | 59 | generator = torch.Generator(device="cuda").manual_seed(21) 60 | dpm = DPMSolverMultistepScheduler.from_pretrained(model, subfolder="scheduler") 61 | pipe_base = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.half, scheduler=dpm).to("cuda") 62 | # pipe_base = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.half).to("cuda") 63 | 64 | unet = pipe_base.unet 65 | unet.eval() 66 | # unet.to(memory_format=torch.channels_last) # use channels_last memory format 67 | # unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default 68 | 69 | pipe_base.unet = torch.compile(pipe_base.unet, backend='hidet') 70 | 71 | baseline_image = pipe_base(prompt, guidance_scale=7.5, generator=generator, num_inference_steps=120).images[0] 72 | baseline_image.save(f"baseline.png") 73 | 74 | if __name__ == "__main__": 75 | main() -------------------------------------------------------------------------------- /jax-gpu/test_jax.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['HF_HOME']='/workspace/.cache/huggingface' 3 | 4 | import numpy as np 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from pathlib import Path 9 | from jax import pmap 10 | from flax.jax_utils import replicate 11 | from flax.training.common_utils import shard 12 | from PIL import Image 13 | 14 | from diffusers import FlaxStableDiffusionPipeline 15 | 16 | import time 17 | 18 | os.environ['XLA_FLAGS']='--xla_dump_to=/workspace/xla_dump/' 19 | 20 | sd_args_v15 = {"_model_id_": "runwayml/stable-diffusion-v1-5", "width": 512, "height": 512, "guidance_scale": 7.5, "num_inference_steps": 50} 21 | sd_args_v21 = {"_model_id_": "stabilityai/stable-diffusion-2-1", "width": 768, "height": 768, "guidance_scale": 7.5, "num_inference_steps": 50} 22 | 23 | sd_args = sd_args_v21 24 | 25 | def benchmark_func(pipeline, prompts, p_params, rng): 26 | for _ in range(5): 27 | rng = jax.random.split(rng[0], jax.device_count()) 28 | _ = pipeline(prompts, p_params, rng, jit=True, **sd_args) 29 | 30 | # Start benchmark. 31 | 32 | # Timed run 33 | n_runs = 10 34 | latencies = [] 35 | for _ in range(n_runs): 36 | start = time.perf_counter() 37 | rng = jax.random.split(rng[0], jax.device_count()) 38 | _ = pipeline(prompts, p_params, rng, jit=True, **sd_args) 39 | end = time.perf_counter() - start 40 | latencies.append(end) 41 | 42 | # in ms 43 | time_avg_s = np.average(latencies) 44 | return time_avg_s 45 | 46 | 47 | def main(): 48 | num_devices = jax.device_count() 49 | device_type = jax.devices()[0].device_kind 50 | 51 | print(f"Found {num_devices} JAX devices of type {device_type}.") 52 | 53 | pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( 54 | sd_args["_model_id_"], 55 | revision="bf16", 56 | dtype=jax.numpy.bfloat16, 57 | safety_checker=None, 58 | feature_extractor=None 59 | ) 60 | 61 | del sd_args["_model_id_"] 62 | 63 | prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic" 64 | prompt = [prompt] * jax.device_count() 65 | prompt_ids = pipeline.prepare_inputs(prompt) 66 | 67 | p_params = replicate(params) 68 | prompt_ids = shard(prompt_ids) 69 | 70 | def create_key(seed=0): 71 | return jax.random.PRNGKey(seed) 72 | rng = create_key(0) 73 | rng = [rng] 74 | rng = jax.random.split(rng[0], jax.device_count()) 75 | 76 | images = pipeline(prompt_ids, p_params, rng, jit=True, **sd_args).images 77 | images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:]) 78 | images = pipeline.numpy_to_pil(images) 79 | images[0].save('example.png') 80 | 81 | latency_ms = benchmark_func(pipeline, prompt_ids, p_params, rng) 82 | 83 | print(f"Pipeline latency: {latency_ms:.2f}") 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /hidet/test_hidet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | from diffusers import DiffusionPipeline 5 | 6 | from time import perf_counter 7 | import numpy as np 8 | 9 | 10 | def measure_latency(pipe, prompt): 11 | latencies = [] 12 | # warm up 13 | # pipe.set_progress_bar_config(disable=True) 14 | for _ in range(2): 15 | _ = pipe(prompt) 16 | # Timed run 17 | for _ in range(10): 18 | start_time = perf_counter() 19 | with torch.inference_mode(): 20 | _ = pipe(prompt) 21 | latency = perf_counter() - start_time 22 | latencies.append(latency) 23 | # Compute run statistics 24 | time_avg_s = np.mean(latencies) 25 | time_std_s = np.std(latencies) 26 | time_p95_s = np.percentile(latencies,95) 27 | return f"P95 latency (seconds) - {time_p95_s:.2f}; Average latency (seconds) - {time_avg_s:.2f} +\- {time_std_s:.2f};", time_p95_s 28 | 29 | def main(): 30 | import torch 31 | import os 32 | import functools 33 | 34 | import torch._dynamo.config 35 | torch._dynamo.config.suppress_errors = True 36 | 37 | import hidet 38 | 39 | # more search 40 | hidet.torch.dynamo_config.search_space(2) 41 | # automatically transform the model to use float16 data type 42 | hidet.torch.dynamo_config.use_fp16(True) 43 | # use float16 data type as the accumulate data type in operators with reduction 44 | hidet.torch.dynamo_config.use_fp16_reduction(True) 45 | # use tensorcore 46 | hidet.torch.dynamo_config.use_tensor_core() 47 | 48 | torch.backends.cudnn.benchmark = True 49 | 50 | from diffusers import DiffusionPipeline 51 | from diffusers import DPMSolverMultistepScheduler 52 | 53 | 54 | prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition." 55 | prompt = "a grey cat sitting on a chair in the kitchen, animated" 56 | 57 | model = "runwayml/stable-diffusion-v1-5" 58 | 59 | generator = torch.Generator(device="cuda").manual_seed(21) 60 | dpm = DPMSolverMultistepScheduler.from_pretrained(model, subfolder="scheduler") 61 | pipe_base = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.half, scheduler=dpm).to("cuda") 62 | # pipe_base = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.half).to("cuda") 63 | 64 | unet = pipe_base.unet 65 | unet.eval() 66 | # unet.to(memory_format=torch.channels_last) # use channels_last memory format 67 | # unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default 68 | 69 | pipe_base.unet = torch.compile(pipe_base.unet, backend='hidet') 70 | 71 | baseline_image = pipe_base(prompt, guidance_scale=7.5, generator=generator, num_inference_steps=120).images[0] 72 | baseline_image.save(f"baseline.png") 73 | 74 | prompt = "a photo of an astronaut riding a horse on mars" 75 | 76 | vanilla_results = measure_latency(pipe_base, prompt) 77 | 78 | print(f"Vanilla pipeline: {vanilla_results[0]}") 79 | 80 | if __name__ == "__main__": 81 | main() -------------------------------------------------------------------------------- /PyTorch/test_pt.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # os.environ["HF_HOME"] = "/workspace/.cache/huggingface" 4 | 5 | import time 6 | import numpy as np 7 | 8 | import torch 9 | from diffusers import DiffusionPipeline 10 | 11 | sd_args_v15 = { 12 | "_model_id_": "runwayml/stable-diffusion-v1-5", 13 | "width": 512, 14 | "height": 512, 15 | "guidance_scale": 7.5, 16 | "num_inference_steps": 50, 17 | } 18 | sd_args_v21 = { 19 | "_model_id_": "stabilityai/stable-diffusion-2-1", 20 | "width": 768, 21 | "height": 768, 22 | "guidance_scale": 7.5, 23 | "num_inference_steps": 50, 24 | } 25 | sd_args_xl = { 26 | "_model_id_": "stabilityai/stable-diffusion-xl-base-1.0", 27 | "width": 1024, 28 | "height": 1024, 29 | "guidance_scale": 7.5, 30 | "num_inference_steps": 50, 31 | } 32 | sd_args_v3 = { 33 | "_model_id_": "stabilityai/stable-diffusion-3-medium-diffusers", 34 | "width": 1024, 35 | "height": 1024, 36 | "guidance_scale": 7.0, 37 | "num_inference_steps": 28, 38 | } 39 | 40 | sd_args = sd_args_v3 41 | 42 | 43 | # @torch.inference_mode() 44 | def benchmark_func(pipe, compiled, prompt, sd_args_local): 45 | for _ in range(5): 46 | _ = pipe(prompt, **sd_args_local) 47 | # Start benchmark. 48 | torch.cuda.synchronize() 49 | 50 | # Timed run 51 | n_runs = 10 52 | latencies = [] 53 | for _ in range(n_runs): 54 | start = time.perf_counter_ns() 55 | if not compiled: 56 | with torch.inference_mode(): 57 | _ = pipe(prompt, **sd_args_local) 58 | else: 59 | _ = pipe(prompt, **sd_args_local) 60 | torch.cuda.synchronize() 61 | end = time.perf_counter_ns() - start 62 | latencies.append(end) 63 | 64 | time_avg_s = np.average(latencies) 65 | return int(time_avg_s / 1000000.0) 66 | 67 | 68 | run_compile = False # Set True / False 69 | 70 | 71 | def main(): 72 | batch_size = 1 73 | prompt = ["A cat holding a sign that says hello world"] * batch_size 74 | 75 | pipe_base = DiffusionPipeline.from_pretrained( 76 | sd_args["_model_id_"], 77 | torch_dtype=torch.float16, 78 | use_safetensors=True, 79 | # variant="fp16", 80 | ).to("cuda") 81 | 82 | if run_compile: 83 | print("Run torch compile") 84 | pipe_base.unet = torch.compile( 85 | pipe_base.unet, mode="reduce-overhead", fullgraph=True 86 | ) 87 | 88 | sd_args_copy = sd_args 89 | del sd_args_copy["_model_id_"] 90 | baseline_image = pipe_base(prompt, **sd_args_copy).images 91 | for idx, im in enumerate(baseline_image): 92 | im.save(f"{idx:06}.jpg") 93 | 94 | # A cat holding a sign that says hello world 95 | # A majestic lion jumping from a big stone at night 96 | prompt = ["A cat holding a sign that says hello world"] * batch_size 97 | 98 | latency_ms = benchmark_func(pipe_base, run_compile, prompt, sd_args_copy) 99 | 100 | print("Pipeline latency:", latency_ms, "ms") 101 | print(sd_args) 102 | 103 | 104 | if __name__ == "__main__": 105 | main() 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Stable Diffusion XL inference benchmarks 2 | 3 | model: 4 | 5 | https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 6 | 7 | batch size = 1, image size 1024x1024, 50 iterations 8 | 9 | #### A100-SXM, 40GB 10 | | Engine | Time | 11 | | :--- | :--- | 12 | | PT2.0,fp16 + compile | 5.35 s | 13 | | Onnxruntime,fp16,ORT_CUDA | 4.28 s | 14 | 15 | #### RTX 4090, 24GB 16 | | Engine | Time | 17 | | :--- | :--- | 18 | | PT2.0,fp16 + compile | 6.02 s | 19 | 20 | #### RTX 6000 Ada, 48GB 21 | | Engine | Time | 22 | | :--- | :--- | 23 | | PT2.0,fp16 + compile | 9.07 s | 24 | 25 | ## Stable Diffusion inference benchmarks 26 | 27 | model: 28 | 29 | https://huggingface.co/runwayml/stable-diffusion-v1-5 - sd1.5 30 | 31 | https://huggingface.co/stabilityai/stable-diffusion-2-1 - sd2.1 32 | 33 | batch size = 1, image size 512x512, 50 iterations 34 | 35 | #### A100-SXM, 40GB 36 | | Engine | Time, sd1.5 | Time, sd2.1, 512x512 | Time, sd2.1, 768x768 | 37 | | :--- | :--- | :--- | :--- | 38 | | PT2.0,fp16 | 1.96 s (4.54gb VRAM) | | | 39 | | PT2.0,fp16 + compile | 1.36 s (5.96 gb) | | 2.37 s | 40 | | AITemplate,fp16 | 1.01 s (4.06 gb) | | | 41 | | DeepSpeed,fp16 | 1.18 s | | 2.28 s | 42 | | Oneflow,fp16 | 0.98 s (5.62 gb) | | | 43 | | TensorRT 8.6.1, fp16 | 0.98 s | 0.81 s | 1.88 s | 44 | | TensorRT 10.0, fp16 | | | 1.57 s | 45 | | Onnxruntime,fp16,ORT_CUDA | 0.85 s | 0.76 s | 1.63 s | 46 | | Jax,XLA,bf16 | 1.58 s | 1.35 s | 3.61 s | 47 | 48 | #### H100-PCIe, 80GB 49 | 50 | | Engine | Time, sd1.5 | Time, sd2.1 | 51 | | :--- | :--- | :--- | 52 | | PT2.0,fp16 | 1.44 s | | 53 | | PT2.0,fp16,compile | 1.11 s | | 54 | | TensorRT 8.6.1,fp16 | 0.75 s | 0.68 s | 55 | | Jax,XLA,bf16 | 1.18 s | | 56 | 57 | #### H100-SXM, 80GB 58 | | Engine | Time, sd1.5 | Time, sd2.1 | Time, sd2.1, 768x768 | 59 | | :--- | :--- | :--- | :--- | 60 | | PT2.0,fp16,compile | 0.83 s | 0.70 s | 1.39 s | 61 | | TensorRT 8.6.1,fp16 | 0.49 s | 0.48 s | 1.05 s | 62 | | Jax,XLA,bf16 | 1.00 s | 0.79 s | | 63 | 64 | #### RTX 4090, 24GB 65 | | Engine | Time, sd1.5 | Time, sd2.1 | Time, sd2.1, 768x768 | 66 | | :--- | :--- | :--- | :--- | 67 | | PT2.0,fp16,compile | 1.17 s | | 2.26 s | 68 | | TensorRT 8.6.1, fp16 | 0.74 s | 0.68 s | 1.52 s | 69 | 70 | #### L40, 48GB 71 | | Engine | Time, sd1.5 | Time, sd2.1 | Time, sd2.1, 768x768 | 72 | | :--- | :--- | :--- | :--- | 73 | | PT2.0,fp16,compile | 2.09 s | | 3.08 s | 74 | | TensorRT 8.6.2,fp16 | 0.91 s | | 2.19 s | 75 | 76 | #### RTX 6000 Ada, 48GB 77 | | Engine | Time, sd1.5 | Time, sd2.1 | Time, sd2.1, 768x768 | 78 | | :--- | :--- | :--- | :--- | 79 | | PT2.0,fp16,compile | 1.28 s | | 2.77 s | 80 | | TensorRT 8.6.2,fp16 | 0.90 s | | 2.25 s | 81 | 82 | #### V100, T4 83 | | GPU | PT2.0,fp16,xformers | 84 | | :--- | :--- | 85 | | V100, 16gb | 2.96 s | 86 | | T4, 16gb | 7.83 s | 87 | 88 | ## How to run 89 | Ubuntu, Debian VM setup https://gist.github.com/alexeigor/b4c21b5e1fe62d670c433d4ac8c9fd83 90 | ```bash 91 | docker build -f ./Dockerfile --network=host --build-arg HF_TOKEN=xxxxx -t test_pt . 92 | ``` 93 | 94 | ```bash 95 | docker run -it --network=host -v ${PWD}/workspace:/workspace -w /workspace --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 test_engine 96 | ``` 97 | 98 | ## References: 99 | - https://github.com/microsoft/DeepSpeed-MII/tree/main/examples/benchmark/txt2img 100 | - https://github.com/Oneflow-Inc/oneflow 101 | - https://github.com/Oneflow-Inc/diffusers 102 | - https://github.com/facebookincubator/AITemplate 103 | - https://arxiv.org/abs/2304.11267 104 | - https://github.com/dbolya/tomesd 105 | - https://huggingface.co/docs/diffusers/main/en/optimization/fp16 106 | - https://github.com/hidet-org/hidet 107 | - https://github.com/stochasticai/x-stable-diffusion 108 | - https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/stable_diffusion 109 | - https://github.com/microsoft/Olive 110 | - https://github.com/NVIDIA/TensorRT/tree/main/demo/Diffusion (kudos to Denis Timonin https://www.linkedin.com/in/denistimonin/) 111 | - https://medium.com/microsoftazure/accelerating-stable-diffusion-inference-with-onnx-runtime-203bd7728540 112 | - https://huggingface.co/blog/sdxl_jax 113 | - https://huggingface.co/blog/simple_sdxl_optimizations 114 | - https://pytorch.org/blog/accelerating-generative-ai-3/ 115 | 116 | 117 | --------------------------------------------------------------------------------