├── .gitignore ├── LICENSE.txt ├── README.md ├── assets ├── idea.jpg ├── quality.jpg ├── speedups.jpg └── teaser.jpg ├── distrifuser ├── __init__.py ├── __version__.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── distri_sdxl_unet_pp.py │ ├── distri_sdxl_unet_tp.py │ └── naive_patch_sdxl.py ├── modules │ ├── __init__.py │ ├── base_module.py │ ├── pp │ │ ├── __init__.py │ │ ├── attn.py │ │ ├── conv2d.py │ │ └── groupnorm.py │ └── tp │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── conv2d.py │ │ ├── feed_forward.py │ │ └── resnet.py ├── pipelines.py └── utils.py ├── scripts ├── compute_metrics.py ├── dump_coco.py ├── export_html.py ├── generate_coco.py ├── profile_macs.py ├── run_sdxl.py ├── sd_example.py └── sdxl_example.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | build 4 | dist 5 | *.egg-info -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MIT HAN Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models 2 | 3 | ### [Paper](http://arxiv.org/abs/2402.19481) | [Project](https://hanlab.mit.edu/projects/distrifusion) | [Blog](https://hanlab.mit.edu/blog/distrifusion) | [Slides](https://www.dropbox.com/scl/fi/yv98hi2kdoh27ej4jqlbp/slides.key?rlkey=3rmfxpezqt3co5x2hgqvxv09i&st=ve4z9w6t&dl=0) | [Youtube](https://www.youtube.com/watch?v=EZX7srDDmW0&list=PL80kAHvQbh-pKRxcSS6xjds7U7Yc0gDQI&index=1) | [Poster](https://www.dropbox.com/scl/fi/labhefjwi9r01e3o9eob0/poster.pdf?rlkey=rjj1jj179enln92h8kygrftmg&st=0ddego10&dl=0) 4 | 5 | **[Dec 1, 2024]** DistriFusion is integrated in NVIDIA's [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/sdxl/README.md) for distributed inference on high-resolution image generation. 6 | 7 | **[Jul 29, 2024]** DistriFusion is supported in [ColossalAI](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/inference/README.md)! 8 | 9 | **[Apr 4, 2024]** DistriFusion is selected as a **highlight** poster in CVPR 2024! 10 | 11 | **[Feb 29, 2024]** DistriFusion is accepted by CVPR 2024! Our code is publicly available! 12 | 13 | ![teaser](https://github.com/mit-han-lab/distrifuser/blob/main/assets/teaser.jpg) 14 | *We introduce DistriFusion, a training-free algorithm to harness multiple GPUs to accelerate diffusion model inference without sacrificing image quality. Naïve Patch (Overview (b)) suffers from the fragmentation issue due to the lack of patch interaction. The presented examples are generated with SDXL using a 50-step Euler sampler at 1280×1920 resolution, and latency is measured on A100 GPUs.* 15 | 16 | DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models
17 | [Muyang Li](https://lmxyy.me/)\*, [Tianle Cai](https://www.tianle.website/)\*, [Jiaxin Cao](https://www.linkedin.com/in/jiaxin-cao-2166081b3/), [Qinsheng Zhang](https://qsh-zh.github.io), [Han Cai](https://han-cai.github.io), [Junjie Bai](https://www.linkedin.com/in/junjiebai/), [Yangqing Jia](https://daggerfs.com), [Ming-Yu Liu](https://mingyuliu.net), [Kai Li](https://www.cs.princeton.edu/~li/), and [Song Han](https://hanlab.mit.edu/songhan)
18 | MIT, Princeton, Lepton AI, and NVIDIA
19 | In CVPR 2024. 20 | 21 | ## Overview 22 | ![idea](https://github.com/mit-han-lab/distrifuser/blob/main/assets/idea.jpg) 23 | **(a)** Original diffusion model running on a single device. **(b)** Naïvely splitting the image into 2 patches across 2 GPUs has an evident seam at the boundary due to the absence of interaction across patches. **(c)** Our DistriFusion employs synchronous communication for patch interaction at the first step. After that, we reuse the activations from the previous step via asynchronous communication. In this way, the communication overhead can be hidden into the computation pipeline. 24 | 25 | ## Performance 26 | ### Speedups 27 | 28 |

29 | 30 |

Measured total latency of DistriFusion with SDXL using a 50-step DDIM sampler for generating a single image across on NVIDIA A100 GPUs. When scaling up the resolution, the GPU devices are better utilized. Remarkably, when generating 3840×3840 images, DistriFusion achieves 1.8×, 3.4× and 6.1× speedups with 2, 4, and 8 A100s, respectively. 31 | 32 | 33 | 34 | ### Quality 35 | 36 | ![quality](https://github.com/mit-han-lab/distrifuser/blob/main/assets/quality.jpg) 37 | Qualitative results of SDXL. FID is computed against the ground-truth images. Our DistriFusion can reduce the latency according to the number of used devices while preserving visual fidelity. 38 | 39 | References: 40 | 41 | * Denoising Diffusion Implicit Model (DDIM), Song *et al.*, ICLR 2021 42 | * Elucidating the Design Space of Diffusion-Based Generative Models, Karras *et al.*, NeurIPS 2022 43 | * Parallel Sampling of Diffusion Models, Shih *et al.*, NeurIPS 2023 44 | * SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis, Podell *et al.*, ICLR 2024 45 | 46 | ## Prerequisites 47 | 48 | * Python3 49 | * NVIDIA GPU + CUDA >= 12.0 and corresponding CuDNN 50 | * [PyTorch](https://pytorch.org) = 2.2. 51 | 52 | ## Getting Started 53 | 54 | ### Installation 55 | 56 | After installing [PyTorch](https://pytorch.org), you should be able to install `distrifuser` with PyPI 57 | 58 | ```shell 59 | pip install distrifuser 60 | ``` 61 | 62 | or via GitHub: 63 | 64 | ```shell 65 | pip install git+https://github.com/mit-han-lab/distrifuser.git 66 | ``` 67 | 68 | or locally for development 69 | 70 | ```shell 71 | git clone git@github.com:mit-han-lab/distrifuser.git 72 | cd distrifuser 73 | pip install -e . 74 | ``` 75 | 76 | ### Usage Example 77 | 78 | In [`scripts/sdxl_example.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/sdxl_example.py), we provide a minimal script for running [SDXL](https://huggingface.co/docs/diffusers/en/using-diffusers/sdxl) with DistriFusion. 79 | 80 | ```python 81 | import torch 82 | 83 | from distrifuser.pipelines import DistriSDXLPipeline 84 | from distrifuser.utils import DistriConfig 85 | 86 | distri_config = DistriConfig(height=1024, width=1024, warmup_steps=4) 87 | pipeline = DistriSDXLPipeline.from_pretrained( 88 | distri_config=distri_config, 89 | pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0", 90 | variant="fp16", 91 | use_safetensors=True, 92 | ) 93 | 94 | pipeline.set_progress_bar_config(disable=distri_config.rank != 0) 95 | image = pipeline( 96 | prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", 97 | generator=torch.Generator(device="cuda").manual_seed(233), 98 | ).images[0] 99 | if distri_config.rank == 0: 100 | image.save("astronaut.png") 101 | ``` 102 | 103 | Specifically, our `distrifuser` shares the same APIs as [diffusers](https://github.com/huggingface/diffusers) and can be used in a similar way. You just need to define a `DistriConfig` and use our wrapped `DistriSDXLPipeline` to load the pretrained SDXL model. Then, we can generate the image like the `StableDiffusionXLPipeline` in [diffusers](https://github.com/huggingface/diffusers). The running command is 104 | 105 | ```shell 106 | torchrun --nproc_per_node=$N_GPUS scripts/sdxl_example.py 107 | ``` 108 | 109 | where `$N_GPUS` is the number GPUs you want to use. 110 | 111 | We also provide a minimal script for running SD1.4/2 with DistriFusion in [`scripts/sd_example.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/sd_example.py). The usage is the same. 112 | 113 | ### Benchmark 114 | 115 | Our benchmark results are using [PyTorch](https://pytorch.org) 2.2 and [diffusers](https://github.com/huggingface/diffusers) 0.24.0. First, you may need to install some additional dependencies: 116 | 117 | ```shell 118 | pip install git+https://github.com/zhijian-liu/torchprofile datasets torchmetrics dominate clean-fid 119 | ``` 120 | 121 | #### COCO Quality 122 | 123 | You can use [`scripts/generate_coco.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/generate_coco.py) to generate images with COCO captions. The command is 124 | 125 | ``` 126 | torchrun --nproc_per_node=$N_GPUS scripts/generate_coco.py --no_split_batch 127 | ``` 128 | 129 | where `$N_GPUS` is the number GPUs you want to use. By default, the generated results will be stored in `results/coco`. You can also customize it with `--output_root`. Some additional arguments that you may want to tune: 130 | 131 | * `--num_inference_steps`: The number of inference steps. We use 50 by default. 132 | * `--guidance_scale`: The classifier-free guidance scale. We use 5 by default. 133 | * `--scheduler`: The diffusion sampler. We use [DDIM sampler](https://huggingface.co/docs/diffusers/v0.26.3/en/api/schedulers/ddim#ddimscheduler) by default. You can also use `euler` for [Euler sampler](https://huggingface.co/docs/diffusers/v0.26.3/en/api/schedulers/euler#eulerdiscretescheduler) and `dpm-solver` for [DPM solver](https://huggingface.co/docs/diffusers/en/api/schedulers/multistep_dpm_solver). 134 | * `--warmup_steps`: The number of additional warmup steps (4 by default). 135 | * `--sync_mode`: Different GroupNorm synchronization modes. By default, it is using our corrected asynchronous GroupNorm. 136 | * `--parallelism`: The parallelism paradigm you use. By default, it is patch parallelism. You can use `tensor` for tensor parallelism and `naive_patch` for naïve patch. 137 | 138 | After you generate all the images, you can use our script [`scripts/compute_metrics.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py) to calculate PSNR, LPIPS and FID. The usage is 139 | 140 | ```shell 141 | python scripts/compute_metrics.py --input_root0 $IMAGE_ROOT0 --input_root1 $IMAGE_ROOT1 142 | ``` 143 | 144 | where `$IMAGE_ROOT0` and `$IMAGE_ROOT1` are paths to the image folders you are trying to compare. If `IMAGE_ROOT0` is the ground-truth foler, please add a `--is_gt` flag for resizing. We also provide a script [`scripts/dump_coco.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/dump_coco.py) to dump the ground-truth images. 145 | 146 | #### Latency 147 | 148 | You can use [`scripts/run_sdxl.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/run_sdxl.py) to benchmark the latency our different methods. The command is 149 | 150 | ```shell 151 | torchrun --nproc_per_node=$N_GPUS scripts/run_sdxl.py --mode benchmark --output_type latent 152 | ``` 153 | 154 | where `$N_GPUS` is the number GPUs you want to use. Similar to [`scripts/generate_coco.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/generate_coco.py), you can also change some arguments: 155 | 156 | * `--num_inference_steps`: The number of inference steps. We use 50 by default. 157 | * `--image_size`: The generated image size. By default, it is 1024×1024. 158 | * `--no_split_batch`: Disable the batch splitting for classifier-free guidance. 159 | * `--warmup_steps`: The number of additional warmup steps (4 by default). 160 | * `--sync_mode`: Different GroupNorm synchronization modes. By default, it is using our corrected asynchronous GroupNorm. 161 | * `--parallelism`: The parallelism paradigm you use. By default, it is patch parallelism. You can use `tensor` for tensor parallelism and `naive_patch` for naïve patch. 162 | * `--warmup_times`/`--test_times`: The number of warmup/test runs. By default, they are 5 and 20, respectively. 163 | 164 | 165 | ## Citation 166 | 167 | If you use this code for your research, please cite our paper. 168 | 169 | ```bibtex 170 | @inproceedings{li2023distrifusion, 171 | title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models}, 172 | author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Liu, Ming-Yu and Li, Kai and Han, Song}, 173 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 174 | year={2024} 175 | } 176 | ``` 177 | 178 | ## Acknowledgments 179 | 180 | Our code is developed based on [huggingface/diffusers](https://github.com/huggingface/diffusers) and [lmxyy/sige](https://github.com/lmxyy/sige). We thank [torchprofile](https://github.com/zhijian-liu/torchprofile) for MACs measurement, [clean-fid](https://github.com/GaParmar/clean-fid) for FID computation and [Lightning-AI/torchmetrics](https://github.com/Lightning-AI/torchmetrics) for PSNR and LPIPS. 181 | 182 | We thank Jun-Yan Zhu and Ligeng Zhu for their helpful discussion and valuable feedback. The project is supported by MIT-IBM Watson AI Lab, Amazon, MIT Science Hub, and National Science Foundation. 183 | -------------------------------------------------------------------------------- /assets/idea.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/assets/idea.jpg -------------------------------------------------------------------------------- /assets/quality.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/assets/quality.jpg -------------------------------------------------------------------------------- /assets/speedups.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/assets/speedups.jpg -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/assets/teaser.jpg -------------------------------------------------------------------------------- /distrifuser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/distrifuser/__init__.py -------------------------------------------------------------------------------- /distrifuser/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1beta1" 2 | -------------------------------------------------------------------------------- /distrifuser/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/distrifuser/models/__init__.py -------------------------------------------------------------------------------- /distrifuser/models/base_model.py: -------------------------------------------------------------------------------- 1 | from diffusers import ConfigMixin, ModelMixin 2 | from torch import nn 3 | 4 | from distrifuser.modules.base_module import BaseModule 5 | from ..utils import PatchParallelismCommManager, DistriConfig 6 | 7 | 8 | class BaseModel(ModelMixin, ConfigMixin): 9 | def __init__(self, model: nn.Module, distri_config: DistriConfig): 10 | super(BaseModel, self).__init__() 11 | self.model = model 12 | self.distri_config = distri_config 13 | self.comm_manager = None 14 | 15 | self.buffer_list = None 16 | self.output_buffer = None 17 | self.counter = 0 18 | 19 | # for cuda graph 20 | self.static_inputs = None 21 | self.static_outputs = None 22 | self.cuda_graphs = None 23 | 24 | def forward(self, *args, **kwargs): 25 | raise NotImplementedError 26 | 27 | def set_counter(self, counter: int = 0): 28 | self.counter = counter 29 | for module in self.model.modules(): 30 | if isinstance(module, BaseModule): 31 | module.set_counter(counter) 32 | 33 | def set_comm_manager(self, comm_manager: PatchParallelismCommManager): 34 | self.comm_manager = comm_manager 35 | for module in self.model.modules(): 36 | if isinstance(module, BaseModule): 37 | module.set_comm_manager(comm_manager) 38 | 39 | def setup_cuda_graph(self, static_outputs, cuda_graphs): 40 | self.static_outputs = static_outputs 41 | self.cuda_graphs = cuda_graphs 42 | 43 | @property 44 | def config(self): 45 | return self.model.config 46 | 47 | def synchronize(self): 48 | if self.comm_manager is not None and self.comm_manager.handles is not None: 49 | for i in range(len(self.comm_manager.handles)): 50 | if self.comm_manager.handles[i] is not None: 51 | self.comm_manager.handles[i].wait() 52 | self.comm_manager.handles[i] = None 53 | -------------------------------------------------------------------------------- /distrifuser/models/distri_sdxl_unet_pp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import UNet2DConditionModel 3 | from diffusers.models.attention_processor import Attention 4 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput 5 | from torch import distributed as dist, nn 6 | 7 | from .base_model import BaseModel 8 | from distrifuser.modules.pp.attn import DistriCrossAttentionPP, DistriSelfAttentionPP 9 | from distrifuser.modules.base_module import BaseModule 10 | from distrifuser.modules.pp.conv2d import DistriConv2dPP 11 | from distrifuser.modules.pp.groupnorm import DistriGroupNorm 12 | from ..utils import DistriConfig 13 | 14 | 15 | class DistriUNetPP(BaseModel): # for Patch Parallelism 16 | def __init__(self, model: UNet2DConditionModel, distri_config: DistriConfig): 17 | assert isinstance(model, UNet2DConditionModel) 18 | if distri_config.world_size > 1 and distri_config.n_device_per_batch > 1: 19 | for name, module in model.named_modules(): 20 | if isinstance(module, BaseModule): 21 | continue 22 | for subname, submodule in module.named_children(): 23 | if isinstance(submodule, nn.Conv2d): 24 | kernel_size = submodule.kernel_size 25 | if kernel_size == (1, 1) or kernel_size == 1: 26 | continue 27 | wrapped_submodule = DistriConv2dPP( 28 | submodule, distri_config, is_first_layer=subname == "conv_in" 29 | ) 30 | setattr(module, subname, wrapped_submodule) 31 | elif isinstance(submodule, Attention): 32 | if subname == "attn1": # self attention 33 | wrapped_submodule = DistriSelfAttentionPP(submodule, distri_config) 34 | else: # cross attention 35 | assert subname == "attn2" 36 | wrapped_submodule = DistriCrossAttentionPP(submodule, distri_config) 37 | setattr(module, subname, wrapped_submodule) 38 | elif isinstance(submodule, nn.GroupNorm): 39 | wrapped_submodule = DistriGroupNorm(submodule, distri_config) 40 | setattr(module, subname, wrapped_submodule) 41 | 42 | super(DistriUNetPP, self).__init__(model, distri_config) 43 | 44 | def forward( 45 | self, 46 | sample: torch.FloatTensor, 47 | timestep: torch.Tensor or float or int, 48 | encoder_hidden_states: torch.Tensor, 49 | class_labels: torch.Tensor or None = None, 50 | timestep_cond: torch.Tensor or None = None, 51 | attention_mask: torch.Tensor or None = None, 52 | cross_attention_kwargs: dict[str, any] or None = None, 53 | added_cond_kwargs: dict[str, torch.Tensor] or None = None, 54 | down_block_additional_residuals: tuple[torch.Tensor] or None = None, 55 | mid_block_additional_residual: torch.Tensor or None = None, 56 | down_intrablock_additional_residuals: tuple[torch.Tensor] or None = None, 57 | encoder_attention_mask: torch.Tensor or None = None, 58 | return_dict: bool = True, 59 | record: bool = False, 60 | ): 61 | distri_config = self.distri_config 62 | b, c, h, w = sample.shape 63 | assert ( 64 | class_labels is None 65 | and timestep_cond is None 66 | and attention_mask is None 67 | and cross_attention_kwargs is None 68 | and down_block_additional_residuals is None 69 | and mid_block_additional_residual is None 70 | and down_intrablock_additional_residuals is None 71 | and encoder_attention_mask is None 72 | ) 73 | 74 | if distri_config.use_cuda_graph and not record: 75 | static_inputs = self.static_inputs 76 | 77 | if distri_config.world_size > 1 and distri_config.do_classifier_free_guidance and distri_config.split_batch: 78 | assert b == 2 79 | batch_idx = distri_config.batch_idx() 80 | sample = sample[batch_idx : batch_idx + 1] 81 | timestep = ( 82 | timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep 83 | ) 84 | encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1] 85 | if added_cond_kwargs is not None: 86 | for k in added_cond_kwargs: 87 | added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1] 88 | 89 | assert static_inputs["sample"].shape == sample.shape 90 | static_inputs["sample"].copy_(sample) 91 | if torch.is_tensor(timestep): 92 | if timestep.ndim == 0: 93 | for b in range(static_inputs["timestep"].shape[0]): 94 | static_inputs["timestep"][b] = timestep.item() 95 | else: 96 | assert static_inputs["timestep"].shape == timestep.shape 97 | static_inputs["timestep"].copy_(timestep) 98 | else: 99 | for b in range(static_inputs["timestep"].shape[0]): 100 | static_inputs["timestep"][b] = timestep 101 | assert static_inputs["encoder_hidden_states"].shape == encoder_hidden_states.shape 102 | static_inputs["encoder_hidden_states"].copy_(encoder_hidden_states) 103 | if added_cond_kwargs is not None: 104 | for k in added_cond_kwargs: 105 | assert static_inputs["added_cond_kwargs"][k].shape == added_cond_kwargs[k].shape 106 | static_inputs["added_cond_kwargs"][k].copy_(added_cond_kwargs[k]) 107 | 108 | if self.counter <= distri_config.warmup_steps: 109 | graph_idx = 0 110 | elif self.counter == distri_config.warmup_steps + 1: 111 | graph_idx = 1 112 | else: 113 | graph_idx = 2 114 | 115 | self.cuda_graphs[graph_idx].replay() 116 | output = self.static_outputs[graph_idx] 117 | else: 118 | if distri_config.world_size == 1: 119 | output = self.model( 120 | sample, 121 | timestep, 122 | encoder_hidden_states, 123 | class_labels=class_labels, 124 | timestep_cond=timestep_cond, 125 | attention_mask=attention_mask, 126 | cross_attention_kwargs=cross_attention_kwargs, 127 | added_cond_kwargs=added_cond_kwargs, 128 | down_block_additional_residuals=down_block_additional_residuals, 129 | mid_block_additional_residual=mid_block_additional_residual, 130 | down_intrablock_additional_residuals=down_intrablock_additional_residuals, 131 | encoder_attention_mask=encoder_attention_mask, 132 | return_dict=False, 133 | )[0] 134 | elif distri_config.do_classifier_free_guidance and distri_config.split_batch: 135 | assert b == 2 136 | batch_idx = distri_config.batch_idx() 137 | sample = sample[batch_idx : batch_idx + 1] 138 | timestep = ( 139 | timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep 140 | ) 141 | encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1] 142 | if added_cond_kwargs is not None: 143 | new_added_cond_kwargs = {} 144 | for k in added_cond_kwargs: 145 | new_added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1] 146 | added_cond_kwargs = new_added_cond_kwargs 147 | output = self.model( 148 | sample, 149 | timestep, 150 | encoder_hidden_states, 151 | class_labels=class_labels, 152 | timestep_cond=timestep_cond, 153 | attention_mask=attention_mask, 154 | cross_attention_kwargs=cross_attention_kwargs, 155 | added_cond_kwargs=added_cond_kwargs, 156 | down_block_additional_residuals=down_block_additional_residuals, 157 | mid_block_additional_residual=mid_block_additional_residual, 158 | down_intrablock_additional_residuals=down_intrablock_additional_residuals, 159 | encoder_attention_mask=encoder_attention_mask, 160 | return_dict=False, 161 | )[0] 162 | if self.output_buffer is None: 163 | self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype) 164 | if self.buffer_list is None: 165 | self.buffer_list = [torch.empty_like(output) for _ in range(distri_config.world_size)] 166 | dist.all_gather(self.buffer_list, output.contiguous(), async_op=False) 167 | torch.cat(self.buffer_list[: distri_config.n_device_per_batch], dim=2, out=self.output_buffer[0:1]) 168 | torch.cat(self.buffer_list[distri_config.n_device_per_batch :], dim=2, out=self.output_buffer[1:2]) 169 | output = self.output_buffer 170 | else: 171 | output = self.model( 172 | sample, 173 | timestep, 174 | encoder_hidden_states, 175 | class_labels=class_labels, 176 | timestep_cond=timestep_cond, 177 | attention_mask=attention_mask, 178 | cross_attention_kwargs=cross_attention_kwargs, 179 | added_cond_kwargs=added_cond_kwargs, 180 | down_block_additional_residuals=down_block_additional_residuals, 181 | mid_block_additional_residual=mid_block_additional_residual, 182 | down_intrablock_additional_residuals=down_intrablock_additional_residuals, 183 | encoder_attention_mask=encoder_attention_mask, 184 | return_dict=False, 185 | )[0] 186 | if self.output_buffer is None: 187 | self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype) 188 | if self.buffer_list is None: 189 | self.buffer_list = [torch.empty_like(output) for _ in range(distri_config.world_size)] 190 | output = output.contiguous() 191 | dist.all_gather(self.buffer_list, output, async_op=False) 192 | torch.cat(self.buffer_list, dim=2, out=self.output_buffer) 193 | output = self.output_buffer 194 | if record: 195 | if self.static_inputs is None: 196 | self.static_inputs = { 197 | "sample": sample, 198 | "timestep": timestep, 199 | "encoder_hidden_states": encoder_hidden_states, 200 | "added_cond_kwargs": added_cond_kwargs, 201 | } 202 | self.synchronize() 203 | 204 | if return_dict: 205 | output = UNet2DConditionOutput(sample=output) 206 | else: 207 | output = (output,) 208 | 209 | self.counter += 1 210 | return output 211 | 212 | @property 213 | def add_embedding(self): 214 | return self.model.add_embedding 215 | -------------------------------------------------------------------------------- /distrifuser/models/distri_sdxl_unet_tp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import UNet2DConditionModel 3 | from diffusers.models.attention import Attention, FeedForward 4 | from diffusers.models.resnet import ResnetBlock2D 5 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput 6 | from torch import distributed as dist, nn 7 | 8 | from distrifuser.modules.base_module import BaseModule 9 | from .base_model import BaseModel 10 | from ..modules.tp.attention import DistriAttentionTP 11 | from ..modules.tp.conv2d import DistriConv2dTP 12 | from ..modules.tp.feed_forward import DistriFeedForwardTP 13 | from ..modules.tp.resnet import DistriResnetBlock2DTP 14 | from ..utils import DistriConfig 15 | 16 | 17 | class DistriUNetTP(BaseModel): # for Patch Parallelism 18 | def __init__(self, model: UNet2DConditionModel, distri_config: DistriConfig): 19 | assert isinstance(model, UNet2DConditionModel) 20 | if distri_config.world_size > 1 and distri_config.n_device_per_batch > 1: 21 | for name, module in model.named_modules(): 22 | if isinstance(module, BaseModule): 23 | continue 24 | for subname, submodule in module.named_children(): 25 | if isinstance(submodule, Attention): 26 | wrapped_submodule = DistriAttentionTP(submodule, distri_config) 27 | setattr(module, subname, wrapped_submodule) 28 | elif isinstance(submodule, FeedForward): 29 | wrapped_submodule = DistriFeedForwardTP(submodule, distri_config) 30 | setattr(module, subname, wrapped_submodule) 31 | elif isinstance(submodule, ResnetBlock2D): 32 | wrapped_submodule = DistriResnetBlock2DTP(submodule, distri_config) 33 | setattr(module, subname, wrapped_submodule) 34 | elif isinstance(submodule, nn.Conv2d) and ( 35 | subname == "conv_out" or "downsamplers" in name or "upsamplers" in name 36 | ): 37 | wrapped_submodule = DistriConv2dTP(submodule, distri_config) 38 | setattr(module, subname, wrapped_submodule) 39 | 40 | super(DistriUNetTP, self).__init__(model, distri_config) 41 | 42 | def forward( 43 | self, 44 | sample: torch.FloatTensor, 45 | timestep: torch.Tensor or float or int, 46 | encoder_hidden_states: torch.Tensor, 47 | class_labels: torch.Tensor or None = None, 48 | timestep_cond: torch.Tensor or None = None, 49 | attention_mask: torch.Tensor or None = None, 50 | cross_attention_kwargs: dict[str, any] or None = None, 51 | added_cond_kwargs: dict[str, torch.Tensor] or None = None, 52 | down_block_additional_residuals: tuple[torch.Tensor] or None = None, 53 | mid_block_additional_residual: torch.Tensor or None = None, 54 | down_intrablock_additional_residuals: tuple[torch.Tensor] or None = None, 55 | encoder_attention_mask: torch.Tensor or None = None, 56 | return_dict: bool = True, 57 | record: bool = False, 58 | ): 59 | distri_config = self.distri_config 60 | b, c, h, w = sample.shape 61 | assert ( 62 | class_labels is None 63 | and timestep_cond is None 64 | and attention_mask is None 65 | and cross_attention_kwargs is None 66 | and down_block_additional_residuals is None 67 | and mid_block_additional_residual is None 68 | and down_intrablock_additional_residuals is None 69 | and encoder_attention_mask is None 70 | ) 71 | 72 | if distri_config.use_cuda_graph and not record: 73 | static_inputs = self.static_inputs 74 | 75 | if distri_config.world_size > 1 and distri_config.do_classifier_free_guidance and distri_config.split_batch: 76 | assert b == 2 77 | batch_idx = distri_config.batch_idx() 78 | sample = sample[batch_idx : batch_idx + 1] 79 | timestep = ( 80 | timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep 81 | ) 82 | encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1] 83 | if added_cond_kwargs is not None: 84 | for k in added_cond_kwargs: 85 | added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1] 86 | 87 | assert static_inputs["sample"].shape == sample.shape 88 | static_inputs["sample"].copy_(sample) 89 | if torch.is_tensor(timestep): 90 | if timestep.ndim == 0: 91 | for b in range(static_inputs["timestep"].shape[0]): 92 | static_inputs["timestep"][b] = timestep.item() 93 | else: 94 | assert static_inputs["timestep"].shape == timestep.shape 95 | static_inputs["timestep"].copy_(timestep) 96 | else: 97 | for b in range(static_inputs["timestep"].shape[0]): 98 | static_inputs["timestep"][b] = timestep 99 | assert static_inputs["encoder_hidden_states"].shape == encoder_hidden_states.shape 100 | static_inputs["encoder_hidden_states"].copy_(encoder_hidden_states) 101 | if added_cond_kwargs is not None: 102 | for k in added_cond_kwargs: 103 | assert static_inputs["added_cond_kwargs"][k].shape == added_cond_kwargs[k].shape 104 | static_inputs["added_cond_kwargs"][k].copy_(added_cond_kwargs[k]) 105 | 106 | graph_idx = 0 107 | 108 | self.cuda_graphs[graph_idx].replay() 109 | output = self.static_outputs[graph_idx] 110 | else: 111 | if distri_config.world_size == 1: 112 | output = self.model( 113 | sample, 114 | timestep, 115 | encoder_hidden_states, 116 | class_labels=class_labels, 117 | timestep_cond=timestep_cond, 118 | attention_mask=attention_mask, 119 | cross_attention_kwargs=cross_attention_kwargs, 120 | added_cond_kwargs=added_cond_kwargs, 121 | down_block_additional_residuals=down_block_additional_residuals, 122 | mid_block_additional_residual=mid_block_additional_residual, 123 | down_intrablock_additional_residuals=down_intrablock_additional_residuals, 124 | encoder_attention_mask=encoder_attention_mask, 125 | return_dict=False, 126 | )[0] 127 | elif distri_config.do_classifier_free_guidance and distri_config.split_batch: 128 | assert b == 2 129 | batch_idx = distri_config.batch_idx() 130 | sample = sample[batch_idx : batch_idx + 1] 131 | timestep = ( 132 | timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep 133 | ) 134 | encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1] 135 | if added_cond_kwargs is not None: 136 | new_added_cond_kwargs = {} 137 | for k in added_cond_kwargs: 138 | new_added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1] 139 | added_cond_kwargs = new_added_cond_kwargs 140 | output = self.model( 141 | sample, 142 | timestep, 143 | encoder_hidden_states, 144 | class_labels=class_labels, 145 | timestep_cond=timestep_cond, 146 | attention_mask=attention_mask, 147 | cross_attention_kwargs=cross_attention_kwargs, 148 | added_cond_kwargs=added_cond_kwargs, 149 | down_block_additional_residuals=down_block_additional_residuals, 150 | mid_block_additional_residual=mid_block_additional_residual, 151 | down_intrablock_additional_residuals=down_intrablock_additional_residuals, 152 | encoder_attention_mask=encoder_attention_mask, 153 | return_dict=False, 154 | )[0] 155 | if self.output_buffer is None: 156 | self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype) 157 | if self.buffer_list is None: 158 | self.buffer_list = [torch.empty_like(output) for _ in range(2)] 159 | dist.all_gather( 160 | self.buffer_list, output.contiguous(), group=distri_config.split_group(), async_op=False 161 | ) 162 | torch.cat(self.buffer_list, dim=0, out=self.output_buffer) 163 | output = self.output_buffer 164 | else: 165 | output = self.model( 166 | sample, 167 | timestep, 168 | encoder_hidden_states, 169 | class_labels=class_labels, 170 | timestep_cond=timestep_cond, 171 | attention_mask=attention_mask, 172 | cross_attention_kwargs=cross_attention_kwargs, 173 | added_cond_kwargs=added_cond_kwargs, 174 | down_block_additional_residuals=down_block_additional_residuals, 175 | mid_block_additional_residual=mid_block_additional_residual, 176 | down_intrablock_additional_residuals=down_intrablock_additional_residuals, 177 | encoder_attention_mask=encoder_attention_mask, 178 | return_dict=False, 179 | )[0] 180 | if self.output_buffer is None: 181 | self.output_buffer = torch.empty_like(output) 182 | self.output_buffer.copy_(output) 183 | output = self.output_buffer 184 | if record: 185 | if self.static_inputs is None: 186 | self.static_inputs = { 187 | "sample": sample, 188 | "timestep": timestep, 189 | "encoder_hidden_states": encoder_hidden_states, 190 | "added_cond_kwargs": added_cond_kwargs, 191 | } 192 | self.synchronize() 193 | 194 | if return_dict: 195 | output = UNet2DConditionOutput(sample=output) 196 | else: 197 | output = (output,) 198 | 199 | self.counter += 1 200 | return output 201 | 202 | @property 203 | def add_embedding(self): 204 | return self.model.add_embedding 205 | -------------------------------------------------------------------------------- /distrifuser/models/naive_patch_sdxl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import UNet2DConditionModel 3 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput 4 | from torch import distributed as dist 5 | 6 | from .base_model import BaseModel 7 | from ..utils import DistriConfig 8 | 9 | 10 | class NaivePatchUNet(BaseModel): # for Patch Parallelism 11 | def __init__(self, model: UNet2DConditionModel, distri_config: DistriConfig): 12 | assert isinstance(model, UNet2DConditionModel) 13 | super(NaivePatchUNet, self).__init__(model, distri_config) 14 | 15 | def forward( 16 | self, 17 | sample: torch.FloatTensor, 18 | timestep: torch.Tensor or float or int, 19 | encoder_hidden_states: torch.Tensor, 20 | class_labels: torch.Tensor or None = None, 21 | timestep_cond: torch.Tensor or None = None, 22 | attention_mask: torch.Tensor or None = None, 23 | cross_attention_kwargs: dict[str, any] or None = None, 24 | added_cond_kwargs: dict[str, torch.Tensor] or None = None, 25 | down_block_additional_residuals: tuple[torch.Tensor] or None = None, 26 | mid_block_additional_residual: torch.Tensor or None = None, 27 | down_intrablock_additional_residuals: tuple[torch.Tensor] or None = None, 28 | encoder_attention_mask: torch.Tensor or None = None, 29 | return_dict: bool = True, 30 | record: bool = False, 31 | ): 32 | distri_config = self.distri_config 33 | b, c, h, w = sample.shape 34 | assert ( 35 | class_labels is None 36 | and timestep_cond is None 37 | and attention_mask is None 38 | and cross_attention_kwargs is None 39 | and down_block_additional_residuals is None 40 | and mid_block_additional_residual is None 41 | and down_intrablock_additional_residuals is None 42 | and encoder_attention_mask is None 43 | ) 44 | 45 | if distri_config.use_cuda_graph and not record: 46 | static_inputs = self.static_inputs 47 | 48 | if distri_config.world_size > 1 and distri_config.do_classifier_free_guidance and distri_config.split_batch: 49 | assert b == 2 50 | batch_idx = distri_config.batch_idx() 51 | sample = sample[batch_idx : batch_idx + 1] 52 | timestep = ( 53 | timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep 54 | ) 55 | encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1] 56 | if added_cond_kwargs is not None: 57 | for k in added_cond_kwargs: 58 | added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1] 59 | 60 | assert static_inputs["sample"].shape == sample.shape 61 | static_inputs["sample"].copy_(sample) 62 | if torch.is_tensor(timestep): 63 | if timestep.ndim == 0: 64 | for b in range(static_inputs["timestep"].shape[0]): 65 | static_inputs["timestep"][b] = timestep.item() 66 | else: 67 | assert static_inputs["timestep"].shape == timestep.shape 68 | static_inputs["timestep"].copy_(timestep) 69 | else: 70 | for b in range(static_inputs["timestep"].shape[0]): 71 | static_inputs["timestep"][b] = timestep 72 | assert static_inputs["encoder_hidden_states"].shape == encoder_hidden_states.shape 73 | static_inputs["encoder_hidden_states"].copy_(encoder_hidden_states) 74 | if added_cond_kwargs is not None: 75 | for k in added_cond_kwargs: 76 | assert static_inputs["added_cond_kwargs"][k].shape == added_cond_kwargs[k].shape 77 | static_inputs["added_cond_kwargs"][k].copy_(added_cond_kwargs[k]) 78 | 79 | graph_idx = 0 80 | if distri_config.split_scheme == "alternate": 81 | graph_idx = self.counter % 2 82 | self.cuda_graphs[graph_idx].replay() 83 | output = self.static_outputs[graph_idx] 84 | else: 85 | if distri_config.world_size == 1: 86 | output = self.model( 87 | sample, 88 | timestep, 89 | encoder_hidden_states, 90 | class_labels=class_labels, 91 | timestep_cond=timestep_cond, 92 | attention_mask=attention_mask, 93 | cross_attention_kwargs=cross_attention_kwargs, 94 | added_cond_kwargs=added_cond_kwargs, 95 | down_block_additional_residuals=down_block_additional_residuals, 96 | mid_block_additional_residual=mid_block_additional_residual, 97 | down_intrablock_additional_residuals=down_intrablock_additional_residuals, 98 | encoder_attention_mask=encoder_attention_mask, 99 | return_dict=False, 100 | )[0] 101 | elif distri_config.do_classifier_free_guidance and distri_config.split_batch: 102 | assert b == 2 103 | batch_idx = distri_config.batch_idx() 104 | sample = sample[batch_idx : batch_idx + 1] 105 | timestep = ( 106 | timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep 107 | ) 108 | encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1] 109 | if added_cond_kwargs is not None: 110 | new_added_cond_kwargs = {} 111 | for k in added_cond_kwargs: 112 | new_added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1] 113 | added_cond_kwargs = new_added_cond_kwargs 114 | 115 | if distri_config.split_scheme == "row": 116 | split_dim = 2 117 | elif distri_config.split_scheme == "col": 118 | split_dim = 3 119 | elif distri_config.split_scheme == "alternate": 120 | split_dim = 2 if self.counter % 2 == 0 else 3 121 | else: 122 | raise NotImplementedError 123 | 124 | if split_dim == 2: 125 | sample = sample.view(1, c, distri_config.n_device_per_batch, -1, w)[:, :, distri_config.split_idx()] 126 | else: 127 | assert split_dim == 3 128 | sample = sample.view(1, c, h, distri_config.n_device_per_batch, -1)[ 129 | ..., distri_config.split_idx(), : 130 | ] 131 | 132 | output = self.model( 133 | sample, 134 | timestep, 135 | encoder_hidden_states, 136 | class_labels=class_labels, 137 | timestep_cond=timestep_cond, 138 | attention_mask=attention_mask, 139 | cross_attention_kwargs=cross_attention_kwargs, 140 | added_cond_kwargs=added_cond_kwargs, 141 | down_block_additional_residuals=down_block_additional_residuals, 142 | mid_block_additional_residual=mid_block_additional_residual, 143 | down_intrablock_additional_residuals=down_intrablock_additional_residuals, 144 | encoder_attention_mask=encoder_attention_mask, 145 | return_dict=False, 146 | )[0] 147 | if self.output_buffer is None: 148 | self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype) 149 | if self.buffer_list is None: 150 | self.buffer_list = [torch.empty_like(output.view(-1)) for _ in range(distri_config.world_size)] 151 | dist.all_gather(self.buffer_list, output.contiguous().view(-1), async_op=False) 152 | buffer_list = [buffer.view(output.shape) for buffer in self.buffer_list] 153 | torch.cat(buffer_list[: distri_config.n_device_per_batch], dim=split_dim, out=self.output_buffer[0:1]) 154 | torch.cat(buffer_list[distri_config.n_device_per_batch :], dim=split_dim, out=self.output_buffer[1:2]) 155 | output = self.output_buffer 156 | else: 157 | if distri_config.split_scheme == "row": 158 | split_dim = 2 159 | elif distri_config.split_scheme == "col": 160 | split_dim = 3 161 | elif distri_config.split_scheme == "alternate": 162 | split_dim = 2 if self.counter % 2 == 0 else 3 163 | else: 164 | raise NotImplementedError 165 | 166 | if split_dim == 2: 167 | sliced_sample = sample.view(b, c, distri_config.n_device_per_batch, -1, w)[ 168 | :, :, distri_config.split_idx() 169 | ] 170 | else: 171 | assert split_dim == 3 172 | sliced_sample = sample.view(b, c, h, distri_config.n_device_per_batch, -1)[ 173 | ..., distri_config.split_idx(), : 174 | ] 175 | 176 | output = self.model( 177 | sliced_sample, 178 | timestep, 179 | encoder_hidden_states, 180 | class_labels=class_labels, 181 | timestep_cond=timestep_cond, 182 | attention_mask=attention_mask, 183 | cross_attention_kwargs=cross_attention_kwargs, 184 | added_cond_kwargs=added_cond_kwargs, 185 | down_block_additional_residuals=down_block_additional_residuals, 186 | mid_block_additional_residual=mid_block_additional_residual, 187 | down_intrablock_additional_residuals=down_intrablock_additional_residuals, 188 | encoder_attention_mask=encoder_attention_mask, 189 | return_dict=False, 190 | )[0] 191 | if self.output_buffer is None: 192 | self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype) 193 | if self.buffer_list is None: 194 | self.buffer_list = [torch.empty_like(output.view(-1)) for _ in range(distri_config.world_size)] 195 | dist.all_gather(self.buffer_list, output.contiguous().view(-1), async_op=False) 196 | buffer_list = [buffer.view(output.shape) for buffer in self.buffer_list] 197 | torch.cat(buffer_list, dim=split_dim, out=self.output_buffer) 198 | output = self.output_buffer 199 | if record: 200 | if self.static_inputs is None: 201 | self.static_inputs = { 202 | "sample": sample, 203 | "timestep": timestep, 204 | "encoder_hidden_states": encoder_hidden_states, 205 | "added_cond_kwargs": added_cond_kwargs, 206 | } 207 | self.synchronize() 208 | 209 | if return_dict: 210 | output = UNet2DConditionOutput(sample=output) 211 | else: 212 | output = (output,) 213 | 214 | self.counter += 1 215 | return output 216 | 217 | @property 218 | def add_embedding(self): 219 | return self.model.add_embedding 220 | -------------------------------------------------------------------------------- /distrifuser/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/distrifuser/modules/__init__.py -------------------------------------------------------------------------------- /distrifuser/modules/base_module.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from distrifuser.utils import DistriConfig 4 | 5 | 6 | class BaseModule(nn.Module): 7 | def __init__( 8 | self, 9 | module: nn.Module, 10 | distri_config: DistriConfig, 11 | ): 12 | super(BaseModule, self).__init__() 13 | self.module = module 14 | self.distri_config = distri_config 15 | self.comm_manager = None 16 | 17 | self.counter = 0 18 | 19 | self.buffer_list = None 20 | self.idx = None 21 | 22 | def forward(self, *args, **kwargs): 23 | raise NotImplementedError 24 | 25 | def set_counter(self, counter: int = 0): 26 | self.counter = counter 27 | 28 | def set_comm_manager(self, comm_manager): 29 | self.comm_manager = comm_manager 30 | -------------------------------------------------------------------------------- /distrifuser/modules/pp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/distrifuser/modules/pp/__init__.py -------------------------------------------------------------------------------- /distrifuser/modules/pp/attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.models.attention import Attention 3 | from diffusers.utils import USE_PEFT_BACKEND 4 | from torch import distributed as dist 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from distrifuser.modules.base_module import BaseModule 9 | from distrifuser.utils import DistriConfig 10 | 11 | 12 | class DistriAttentionPP(BaseModule): 13 | def __init__(self, module: Attention, distri_config: DistriConfig): 14 | super(DistriAttentionPP, self).__init__(module, distri_config) 15 | 16 | to_k = module.to_k 17 | to_v = module.to_v 18 | assert isinstance(to_k, nn.Linear) 19 | assert isinstance(to_v, nn.Linear) 20 | assert (to_k.bias is None) == (to_v.bias is None) 21 | assert to_k.weight.shape == to_v.weight.shape 22 | 23 | in_size, out_size = to_k.in_features, to_k.out_features 24 | to_kv = nn.Linear( 25 | in_size, 26 | out_size * 2, 27 | bias=to_k.bias is not None, 28 | device=to_k.weight.device, 29 | dtype=to_k.weight.dtype, 30 | ) 31 | to_kv.weight.data[:out_size].copy_(to_k.weight.data) 32 | to_kv.weight.data[out_size:].copy_(to_v.weight.data) 33 | 34 | if to_k.bias is not None: 35 | assert to_v.bias is not None 36 | to_kv.bias.data[:out_size].copy_(to_k.bias.data) 37 | to_kv.bias.data[out_size:].copy_(to_v.bias.data) 38 | 39 | self.to_kv = to_kv 40 | 41 | 42 | class DistriCrossAttentionPP(DistriAttentionPP): 43 | def __init__(self, module: Attention, distri_config: DistriConfig): 44 | super(DistriCrossAttentionPP, self).__init__(module, distri_config) 45 | self.kv_cache = None 46 | 47 | def forward( 48 | self, 49 | hidden_states: torch.FloatTensor, 50 | encoder_hidden_states: torch.FloatTensor or None = None, 51 | scale: float = 1.0, 52 | *args, 53 | **kwargs, 54 | ): 55 | assert encoder_hidden_states is not None 56 | recompute_kv = self.counter == 0 57 | 58 | attn = self.module 59 | assert isinstance(attn, Attention) 60 | 61 | residual = hidden_states 62 | 63 | batch_size, sequence_length, _ = ( 64 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 65 | ) 66 | 67 | args = () if USE_PEFT_BACKEND else (scale,) 68 | query = attn.to_q(hidden_states, *args) 69 | 70 | if encoder_hidden_states is None: 71 | encoder_hidden_states = hidden_states 72 | 73 | if recompute_kv or self.kv_cache is None: 74 | kv = self.to_kv(encoder_hidden_states) 75 | self.kv_cache = kv 76 | else: 77 | kv = self.kv_cache 78 | key, value = torch.split(kv, kv.shape[-1] // 2, dim=-1) 79 | 80 | inner_dim = key.shape[-1] 81 | head_dim = inner_dim // attn.heads 82 | 83 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 84 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 85 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 86 | 87 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) 88 | 89 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 90 | hidden_states = hidden_states.to(query.dtype) 91 | 92 | # linear proj 93 | hidden_states = attn.to_out[0](hidden_states, *args) 94 | # dropout 95 | hidden_states = attn.to_out[1](hidden_states) 96 | 97 | if attn.residual_connection: 98 | hidden_states = hidden_states + residual 99 | 100 | hidden_states = hidden_states / attn.rescale_output_factor 101 | 102 | self.counter += 1 103 | 104 | return hidden_states 105 | 106 | 107 | class DistriSelfAttentionPP(DistriAttentionPP): 108 | def __init__(self, module: Attention, distri_config: DistriConfig): 109 | super(DistriSelfAttentionPP, self).__init__(module, distri_config) 110 | 111 | def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0): 112 | attn = self.module 113 | distri_config = self.distri_config 114 | assert isinstance(attn, Attention) 115 | 116 | residual = hidden_states 117 | 118 | batch_size, sequence_length, _ = hidden_states.shape 119 | 120 | args = () if USE_PEFT_BACKEND else (scale,) 121 | query = attn.to_q(hidden_states, *args) 122 | 123 | encoder_hidden_states = hidden_states 124 | 125 | kv = self.to_kv(encoder_hidden_states) 126 | 127 | if distri_config.n_device_per_batch == 1: 128 | full_kv = kv 129 | else: 130 | if self.buffer_list is None: # buffer not created 131 | full_kv = torch.cat([kv for _ in range(distri_config.n_device_per_batch)], dim=1) 132 | elif distri_config.mode == "full_sync" or self.counter <= distri_config.warmup_steps: 133 | dist.all_gather(self.buffer_list, kv, group=distri_config.batch_group, async_op=False) 134 | full_kv = torch.cat(self.buffer_list, dim=1) 135 | else: 136 | new_buffer_list = [buffer for buffer in self.buffer_list] 137 | new_buffer_list[distri_config.split_idx()] = kv 138 | full_kv = torch.cat(new_buffer_list, dim=1) 139 | if distri_config.mode != "no_sync": 140 | self.comm_manager.enqueue(self.idx, kv) 141 | 142 | key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1) 143 | 144 | inner_dim = key.shape[-1] 145 | head_dim = inner_dim // attn.heads 146 | 147 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 148 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 149 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 150 | 151 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 152 | # TODO: add support for attn.scale when we move to Torch 2.1 153 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) 154 | 155 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 156 | hidden_states = hidden_states.to(query.dtype) 157 | 158 | # linear proj 159 | hidden_states = attn.to_out[0](hidden_states, *args) 160 | # dropout 161 | hidden_states = attn.to_out[1](hidden_states) 162 | 163 | if attn.residual_connection: 164 | hidden_states = hidden_states + residual 165 | 166 | hidden_states = hidden_states / attn.rescale_output_factor 167 | 168 | return hidden_states 169 | 170 | def forward( 171 | self, 172 | hidden_states: torch.FloatTensor, 173 | encoder_hidden_states: torch.FloatTensor or None = None, 174 | scale: float = 1.0, 175 | *args, 176 | **kwargs, 177 | ) -> torch.FloatTensor: 178 | distri_config = self.distri_config 179 | if self.comm_manager is not None and self.comm_manager.handles is not None and self.idx is not None: 180 | if self.comm_manager.handles[self.idx] is not None: 181 | self.comm_manager.handles[self.idx].wait() 182 | self.comm_manager.handles[self.idx] = None 183 | 184 | b, l, c = hidden_states.shape 185 | if distri_config.n_device_per_batch > 1 and self.buffer_list is None: 186 | if self.comm_manager.buffer_list is None: 187 | self.idx = self.comm_manager.register_tensor( 188 | shape=(b, l, self.to_kv.out_features), torch_dtype=hidden_states.dtype, layer_type="attn" 189 | ) 190 | else: 191 | self.buffer_list = self.comm_manager.get_buffer_list(self.idx) 192 | output = self._forward(hidden_states, scale=scale) 193 | 194 | self.counter += 1 195 | return output 196 | -------------------------------------------------------------------------------- /distrifuser/modules/pp/conv2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributed as dist 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from distrifuser.modules.base_module import BaseModule 7 | from distrifuser.utils import DistriConfig 8 | 9 | 10 | class DistriConv2dPP(BaseModule): 11 | def __init__(self, module: nn.Conv2d, distri_config: DistriConfig, is_first_layer: bool = False): 12 | super(DistriConv2dPP, self).__init__(module, distri_config) 13 | self.is_first_layer = is_first_layer 14 | 15 | def naive_forward(self, x: torch.Tensor) -> torch.Tensor: 16 | # x: [B, C, H, W] 17 | output = self.module(x) 18 | return output 19 | 20 | def sliced_forward(self, x: torch.Tensor) -> torch.Tensor: 21 | config = self.distri_config 22 | b, c, h, w = x.shape 23 | assert h % config.n_device_per_batch == 0 24 | 25 | stride = self.module.stride[0] 26 | padding = self.module.padding[0] 27 | 28 | output_h = x.shape[2] // stride // config.n_device_per_batch 29 | idx = config.split_idx() 30 | h_begin = output_h * idx * stride - padding 31 | h_end = output_h * (idx + 1) * stride + padding 32 | final_padding = [padding, padding, 0, 0] 33 | if h_begin < 0: 34 | h_begin = 0 35 | final_padding[2] = padding 36 | if h_end > h: 37 | h_end = h 38 | final_padding[3] = padding 39 | sliced_input = x[:, :, h_begin:h_end, :] 40 | padded_input = F.pad(sliced_input, final_padding, mode="constant") 41 | return F.conv2d(padded_input, self.module.weight, self.module.bias, stride=stride, padding="valid") 42 | 43 | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: 44 | distri_config = self.distri_config 45 | 46 | if self.comm_manager is not None and self.comm_manager.handles is not None and self.idx is not None: 47 | if self.comm_manager.handles[self.idx] is not None: 48 | self.comm_manager.handles[self.idx].wait() 49 | self.comm_manager.handles[self.idx] = None 50 | 51 | if distri_config.n_device_per_batch == 1: 52 | output = self.naive_forward(x) 53 | else: 54 | if self.is_first_layer: 55 | full_x = x 56 | output = self.sliced_forward(full_x) 57 | else: 58 | boundary_size = self.module.padding[0] 59 | if self.buffer_list is None: 60 | if self.comm_manager.buffer_list is None: 61 | self.idx = self.comm_manager.register_tensor( 62 | shape=[2, x.shape[0], x.shape[1], boundary_size, x.shape[3]], 63 | torch_dtype=x.dtype, 64 | layer_type="conv2d", 65 | ) 66 | else: 67 | self.buffer_list = self.comm_manager.get_buffer_list(self.idx) 68 | if self.buffer_list is None: 69 | output = self.naive_forward(x) 70 | else: 71 | 72 | def create_padded_x(): 73 | if distri_config.split_idx() == 0: 74 | concat_x = torch.cat([x, self.buffer_list[distri_config.split_idx() + 1][0]], dim=2) 75 | padded_x = F.pad(concat_x, [0, 0, boundary_size, 0], mode="constant") 76 | elif distri_config.split_idx() == distri_config.n_device_per_batch - 1: 77 | concat_x = torch.cat([self.buffer_list[distri_config.split_idx() - 1][1], x], dim=2) 78 | padded_x = F.pad(concat_x, [0, 0, 0, boundary_size], mode="constant") 79 | else: 80 | padded_x = torch.cat( 81 | [ 82 | self.buffer_list[distri_config.split_idx() - 1][1], 83 | x, 84 | self.buffer_list[distri_config.split_idx() + 1][0], 85 | ], 86 | dim=2, 87 | ) 88 | return padded_x 89 | 90 | boundary = torch.stack([x[:, :, :boundary_size, :], x[:, :, -boundary_size:, :]], dim=0) 91 | 92 | if distri_config.mode == "full_sync" or self.counter <= distri_config.warmup_steps: 93 | dist.all_gather(self.buffer_list, boundary, group=distri_config.batch_group, async_op=False) 94 | padded_x = create_padded_x() 95 | output = F.conv2d( 96 | padded_x, 97 | self.module.weight, 98 | self.module.bias, 99 | stride=self.module.stride[0], 100 | padding=(0, self.module.padding[1]), 101 | ) 102 | else: 103 | padded_x = create_padded_x() 104 | output = F.conv2d( 105 | padded_x, 106 | self.module.weight, 107 | self.module.bias, 108 | stride=self.module.stride[0], 109 | padding=(0, self.module.padding[1]), 110 | ) 111 | if distri_config.mode != "no_sync": 112 | self.comm_manager.enqueue(self.idx, boundary) 113 | 114 | self.counter += 1 115 | return output 116 | -------------------------------------------------------------------------------- /distrifuser/modules/pp/groupnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributed as dist 3 | from torch import nn 4 | 5 | from distrifuser.modules.base_module import BaseModule 6 | from distrifuser.utils import DistriConfig 7 | 8 | 9 | class DistriGroupNorm(BaseModule): 10 | def __init__(self, module: nn.GroupNorm, distri_config: DistriConfig): 11 | assert isinstance(module, nn.GroupNorm) 12 | super(DistriGroupNorm, self).__init__(module, distri_config) 13 | 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | module = self.module 16 | assert isinstance(module, nn.GroupNorm) 17 | distri_config = self.distri_config 18 | 19 | if self.comm_manager is not None and self.comm_manager.handles is not None and self.idx is not None: 20 | if self.comm_manager.handles[self.idx] is not None: 21 | self.comm_manager.handles[self.idx].wait() 22 | self.comm_manager.handles[self.idx] = None 23 | 24 | assert x.ndim == 4 25 | n, c, h, w = x.shape 26 | num_groups = module.num_groups 27 | group_size = c // num_groups 28 | 29 | if distri_config.mode in ["stale_gn", "corrected_async_gn"]: 30 | if self.buffer_list is None: 31 | if self.comm_manager.buffer_list is None: 32 | n, c, h, w = x.shape 33 | self.idx = self.comm_manager.register_tensor( 34 | shape=[2, n, num_groups, 1, 1, 1], torch_dtype=x.dtype, layer_type="gn" 35 | ) 36 | else: 37 | self.buffer_list = self.comm_manager.get_buffer_list(self.idx) 38 | x = x.view([n, num_groups, group_size, h, w]) 39 | x_mean = x.mean(dim=[2, 3, 4], keepdim=True) # [1, num_groups, 1, 1, 1] 40 | x2_mean = (x**2).mean(dim=[2, 3, 4], keepdim=True) # [1, num_groups, 1, 1, 1] 41 | slice_mean = torch.stack([x_mean, x2_mean], dim=0) 42 | 43 | if self.buffer_list is None: 44 | full_mean = slice_mean 45 | elif self.counter <= distri_config.warmup_steps: 46 | dist.all_gather(self.buffer_list, slice_mean, group=distri_config.batch_group, async_op=False) 47 | full_mean = sum(self.buffer_list) / distri_config.n_device_per_batch 48 | else: 49 | if distri_config.mode == "corrected_async_gn": 50 | correction = slice_mean - self.buffer_list[distri_config.split_idx()] 51 | full_mean = sum(self.buffer_list) / distri_config.n_device_per_batch + correction 52 | else: 53 | new_buffer_list = [buffer for buffer in self.buffer_list] 54 | new_buffer_list[distri_config.split_idx()] = slice_mean 55 | full_mean = sum(new_buffer_list) / distri_config.n_device_per_batch 56 | self.comm_manager.enqueue(self.idx, slice_mean) 57 | 58 | full_x_mean, full_x2_mean = full_mean[0], full_mean[1] 59 | var = full_x2_mean - full_x_mean**2 60 | if distri_config.mode == "corrected_async_gn": 61 | slice_x_mean, slice_x2_mean = slice_mean[0], slice_mean[1] 62 | slice_var = slice_x2_mean - slice_x_mean**2 63 | var = torch.where(var < 0, slice_var, var) # Correct negative variance 64 | 65 | num_elements = group_size * h * w 66 | var = var * (num_elements / (num_elements - 1)) 67 | std = (var + module.eps).sqrt() 68 | output = (x - full_x_mean) / std 69 | output = output.view([n, c, h, w]) 70 | if module.affine: 71 | output = output * module.weight.view([1, -1, 1, 1]) 72 | output = output + module.bias.view([1, -1, 1, 1]) 73 | else: 74 | if self.counter <= distri_config.warmup_steps or distri_config.mode in ["sync_gn", "full_sync"]: 75 | x = x.view([n, num_groups, group_size, h, w]) 76 | x_mean = x.mean(dim=[2, 3, 4], keepdim=True) # [1, num_groups, 1, 1, 1] 77 | x2_mean = (x**2).mean(dim=[2, 3, 4], keepdim=True) # [1, num_groups, 1, 1, 1] 78 | mean = torch.stack([x_mean, x2_mean], dim=0) 79 | dist.all_reduce(mean, op=dist.ReduceOp.SUM, group=distri_config.batch_group) 80 | mean = mean / distri_config.n_device_per_batch 81 | x_mean = mean[0] 82 | x2_mean = mean[1] 83 | var = x2_mean - x_mean**2 84 | num_elements = group_size * h * w 85 | var = var * (num_elements / (num_elements - 1)) 86 | std = (var + module.eps).sqrt() 87 | output = (x - x_mean) / std 88 | output = output.view([n, c, h, w]) 89 | if module.affine: 90 | output = output * module.weight.view([1, -1, 1, 1]) 91 | output = output + module.bias.view([1, -1, 1, 1]) 92 | elif distri_config.mode in ["separate_gn", "no_sync"]: 93 | output = module(x) 94 | else: 95 | raise NotImplementedError 96 | self.counter += 1 97 | return output 98 | -------------------------------------------------------------------------------- /distrifuser/modules/tp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/distrifuser/modules/tp/__init__.py -------------------------------------------------------------------------------- /distrifuser/modules/tp/attention.py: -------------------------------------------------------------------------------- 1 | import torch.cuda 2 | from diffusers.models.attention_processor import Attention 3 | from torch import distributed as dist 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from distrifuser.modules.base_module import BaseModule 8 | from distrifuser.utils import DistriConfig 9 | 10 | 11 | class DistriAttentionTP(BaseModule): 12 | def __init__(self, module: Attention, distri_config: DistriConfig): 13 | super(DistriAttentionTP, self).__init__(module, distri_config) 14 | 15 | heads = module.heads 16 | sliced_heads = heads // distri_config.n_device_per_batch 17 | remainder_heads = heads % distri_config.n_device_per_batch 18 | if distri_config.split_idx() < remainder_heads: 19 | sliced_heads += 1 20 | self.sliced_heads = sliced_heads 21 | 22 | if sliced_heads > 0: 23 | if distri_config.split_idx() < remainder_heads: 24 | start_head = distri_config.split_idx() * sliced_heads 25 | else: 26 | start_head = ( 27 | remainder_heads * (sliced_heads + 1) + (distri_config.split_idx() - remainder_heads) * sliced_heads 28 | ) 29 | end_head = start_head + sliced_heads 30 | 31 | dim = module.to_q.out_features // heads 32 | 33 | sharded_to_q = nn.Linear( 34 | module.to_q.in_features, 35 | sliced_heads * dim, 36 | bias=module.to_q.bias is not None, 37 | device=module.to_q.weight.device, 38 | dtype=module.to_q.weight.dtype, 39 | ) 40 | sharded_to_q.weight.data.copy_(module.to_q.weight.data[start_head * dim : end_head * dim]) 41 | if module.to_q.bias is not None: 42 | sharded_to_q.bias.data.copy_(module.to_q.bias.data[start_head * dim : end_head * dim]) 43 | 44 | sharded_to_k = nn.Linear( 45 | module.to_k.in_features, 46 | sliced_heads * dim, 47 | bias=module.to_k.bias is not None, 48 | device=module.to_k.weight.device, 49 | dtype=module.to_k.weight.dtype, 50 | ) 51 | sharded_to_k.weight.data.copy_(module.to_k.weight.data[start_head * dim : end_head * dim]) 52 | if module.to_k.bias is not None: 53 | sharded_to_k.bias.data.copy_(module.to_k.bias.data[start_head * dim : end_head * dim]) 54 | 55 | sharded_to_v = nn.Linear( 56 | module.to_v.in_features, 57 | sliced_heads * dim, 58 | bias=module.to_v.bias is not None, 59 | device=module.to_v.weight.device, 60 | dtype=module.to_v.weight.dtype, 61 | ) 62 | sharded_to_v.weight.data.copy_(module.to_v.weight.data[start_head * dim : end_head * dim]) 63 | if module.to_v.bias is not None: 64 | sharded_to_v.bias.data.copy_(module.to_v.bias.data[start_head * dim : end_head * dim]) 65 | 66 | sharded_to_out = nn.Linear( 67 | sliced_heads * dim, 68 | module.to_out[0].out_features, 69 | bias=module.to_out[0].bias is not None, 70 | device=module.to_out[0].weight.device, 71 | dtype=module.to_out[0].weight.dtype, 72 | ) 73 | sharded_to_out.weight.data.copy_(module.to_out[0].weight.data[:, start_head * dim : end_head * dim]) 74 | if module.to_out[0].bias is not None: 75 | sharded_to_out.bias.data.copy_(module.to_out[0].bias.data) 76 | 77 | del module.to_q 78 | del module.to_k 79 | del module.to_v 80 | 81 | old_to_out = module.to_out[0] 82 | 83 | module.to_q = sharded_to_q 84 | module.to_k = sharded_to_k 85 | module.to_v = sharded_to_v 86 | module.to_out[0] = sharded_to_out 87 | module.heads = sliced_heads 88 | 89 | del old_to_out 90 | 91 | torch.cuda.empty_cache() 92 | 93 | def forward( 94 | self, 95 | hidden_states: torch.FloatTensor, 96 | encoder_hidden_states: torch.FloatTensor or None = None, 97 | attention_mask: torch.FloatTensor or None = None, 98 | **cross_attention_kwargs, 99 | ) -> torch.Tensor: 100 | distri_config = self.distri_config 101 | module = self.module 102 | residual = hidden_states 103 | 104 | if self.sliced_heads > 0: 105 | input_ndim = hidden_states.ndim 106 | 107 | assert input_ndim == 3 108 | 109 | batch_size, sequence_length, _ = ( 110 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 111 | ) 112 | 113 | if attention_mask is not None: 114 | attention_mask = module.prepare_attention_mask(attention_mask, sequence_length, batch_size) 115 | # scaled_dot_product_attention expects attention_mask shape to be 116 | # (batch, heads, source_length, target_length) 117 | attention_mask = attention_mask.view(batch_size, module.heads, -1, attention_mask.shape[-1]) 118 | 119 | if module.group_norm is not None: 120 | hidden_states = module.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 121 | 122 | query = module.to_q(hidden_states) 123 | 124 | if encoder_hidden_states is None: 125 | encoder_hidden_states = hidden_states 126 | elif module.norm_cross: 127 | encoder_hidden_states = module.norm_encoder_hidden_states(encoder_hidden_states) 128 | 129 | key = module.to_k(encoder_hidden_states) 130 | value = module.to_v(encoder_hidden_states) 131 | 132 | inner_dim = key.shape[-1] 133 | head_dim = inner_dim // module.heads 134 | 135 | query = query.view(batch_size, -1, module.heads, head_dim).transpose(1, 2) 136 | 137 | key = key.view(batch_size, -1, module.heads, head_dim).transpose(1, 2) 138 | value = value.view(batch_size, -1, module.heads, head_dim).transpose(1, 2) 139 | 140 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 141 | # TODO: add support for attn.scale when we move to Torch 2.1 142 | hidden_states = F.scaled_dot_product_attention( 143 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 144 | ) 145 | 146 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, module.heads * head_dim) 147 | hidden_states = hidden_states.to(query.dtype) 148 | 149 | # linear proj 150 | hidden_states = F.linear(hidden_states, module.to_out[0].weight, bias=None) 151 | # dropout 152 | hidden_states = module.to_out[1](hidden_states) 153 | else: 154 | hidden_states = torch.zeros( 155 | [hidden_states.shape[0], hidden_states.shape[1], module.to_out[0].out_features], 156 | device=hidden_states.device, 157 | dtype=hidden_states.dtype, 158 | ) 159 | dist.all_reduce(hidden_states, op=dist.ReduceOp.SUM, group=distri_config.batch_group, async_op=False) 160 | if module.to_out[0].bias is not None: 161 | hidden_states = hidden_states + module.to_out[0].bias.view(1, 1, -1) 162 | 163 | if module.residual_connection: 164 | hidden_states = hidden_states + residual 165 | 166 | hidden_states = hidden_states / module.rescale_output_factor 167 | 168 | self.counter += 1 169 | 170 | return hidden_states 171 | -------------------------------------------------------------------------------- /distrifuser/modules/tp/conv2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributed as dist 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from distrifuser.modules.base_module import BaseModule 7 | from distrifuser.utils import DistriConfig 8 | 9 | 10 | class DistriConv2dTP(BaseModule): 11 | def __init__(self, module: nn.Conv2d, distri_config: DistriConfig): 12 | super(DistriConv2dTP, self).__init__(module, distri_config) 13 | assert module.in_channels % distri_config.n_device_per_batch == 0 14 | 15 | sharded_module = nn.Conv2d( 16 | module.in_channels // distri_config.n_device_per_batch, 17 | module.out_channels, 18 | module.kernel_size, 19 | module.stride, 20 | module.padding, 21 | module.dilation, 22 | module.groups, 23 | module.bias is not None, 24 | module.padding_mode, 25 | device=module.weight.device, 26 | dtype=module.weight.dtype, 27 | ) 28 | start_idx = distri_config.split_idx() * (module.in_channels // distri_config.n_device_per_batch) 29 | end_idx = (distri_config.split_idx() + 1) * (module.in_channels // distri_config.n_device_per_batch) 30 | sharded_module.weight.data.copy_(module.weight.data[:, start_idx:end_idx]) 31 | if module.bias is not None: 32 | sharded_module.bias.data.copy_(module.bias.data) 33 | 34 | self.module = sharded_module 35 | del module 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | distri_config = self.distri_config 39 | 40 | b, c, h, w = x.shape 41 | start_idx = distri_config.split_idx() * (c // distri_config.n_device_per_batch) 42 | end_idx = (distri_config.split_idx() + 1) * (c // distri_config.n_device_per_batch) 43 | output = F.conv2d( 44 | x[:, start_idx:end_idx], 45 | self.module.weight, 46 | bias=None, 47 | stride=self.module.stride, 48 | padding=self.module.padding, 49 | dilation=self.module.dilation, 50 | groups=self.module.groups, 51 | ) 52 | dist.all_reduce(output, op=dist.ReduceOp.SUM, group=distri_config.batch_group, async_op=False) 53 | if self.module.bias is not None: 54 | output = output + self.module.bias.view(1, -1, 1, 1) 55 | 56 | self.counter += 1 57 | return output 58 | -------------------------------------------------------------------------------- /distrifuser/modules/tp/feed_forward.py: -------------------------------------------------------------------------------- 1 | import torch.cuda 2 | from diffusers.models.attention import FeedForward, GEGLU 3 | from torch import distributed as dist 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from ..base_module import BaseModule 8 | from ...utils import DistriConfig 9 | 10 | 11 | class DistriFeedForwardTP(BaseModule): 12 | def __init__(self, module: FeedForward, distri_config: DistriConfig): 13 | super(DistriFeedForwardTP, self).__init__(module, distri_config) 14 | assert isinstance(module.net[0], GEGLU) 15 | assert module.net[0].proj.out_features % (distri_config.n_device_per_batch * 2) == 0 16 | assert module.net[2].in_features % distri_config.n_device_per_batch == 0 17 | 18 | mid_features = module.net[2].in_features // distri_config.n_device_per_batch 19 | 20 | sharded_fc1 = nn.Linear( 21 | module.net[0].proj.in_features, 22 | mid_features * 2, 23 | bias=module.net[0].proj.bias is not None, 24 | device=module.net[0].proj.weight.device, 25 | dtype=module.net[0].proj.weight.dtype, 26 | ) 27 | start_idx = distri_config.split_idx() * mid_features 28 | end_idx = (distri_config.split_idx() + 1) * mid_features 29 | sharded_fc1.weight.data[:mid_features].copy_(module.net[0].proj.weight.data[start_idx:end_idx]) 30 | if module.net[0].proj.bias is not None: 31 | sharded_fc1.bias.data[:mid_features].copy_(module.net[0].proj.bias.data[start_idx:end_idx]) 32 | start_idx = (distri_config.n_device_per_batch + distri_config.split_idx()) * mid_features 33 | end_idx = (distri_config.n_device_per_batch + distri_config.split_idx() + 1) * mid_features 34 | sharded_fc1.weight.data[mid_features:].copy_(module.net[0].proj.weight.data[start_idx:end_idx]) 35 | if module.net[0].proj.bias is not None: 36 | sharded_fc1.bias.data[mid_features:].copy_(module.net[0].proj.bias.data[start_idx:end_idx]) 37 | 38 | sharded_fc2 = nn.Linear( 39 | mid_features, 40 | module.net[2].out_features, 41 | bias=module.net[2].bias is not None, 42 | device=module.net[2].weight.device, 43 | dtype=module.net[2].weight.dtype, 44 | ) 45 | sharded_fc2.weight.data.copy_( 46 | module.net[2].weight.data[ 47 | :, distri_config.split_idx() * mid_features : (distri_config.split_idx() + 1) * mid_features 48 | ] 49 | ) 50 | if module.net[2].bias is not None: 51 | sharded_fc2.bias.data.copy_(module.net[2].bias.data) 52 | 53 | old_fc1 = module.net[0].proj 54 | old_fc2 = module.net[2] 55 | 56 | module.net[0].proj = sharded_fc1 57 | module.net[2] = sharded_fc2 58 | 59 | del old_fc1 60 | del old_fc2 61 | torch.cuda.empty_cache() 62 | 63 | def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: 64 | distri_config = self.distri_config 65 | module = self.module 66 | 67 | assert scale == 1.0 68 | for i, submodule in enumerate(module.net): 69 | if i == 0: 70 | hidden_states, gate = submodule.proj(hidden_states).chunk(2, dim=-1) 71 | hidden_states = hidden_states * submodule.gelu(gate) 72 | elif i == 2: 73 | hidden_states = F.linear(hidden_states, submodule.weight, None) 74 | else: 75 | hidden_states = submodule(hidden_states) 76 | 77 | dist.all_reduce(hidden_states, op=dist.ReduceOp.SUM, group=distri_config.batch_group, async_op=False) 78 | if module.net[2].bias is not None: 79 | hidden_states = hidden_states + module.net[2].bias.view(1, 1, -1) 80 | 81 | self.counter += 1 82 | 83 | return hidden_states 84 | -------------------------------------------------------------------------------- /distrifuser/modules/tp/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.cuda 2 | from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D, USE_PEFT_BACKEND 3 | from torch import distributed as dist 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from ..base_module import BaseModule 8 | from ...utils import DistriConfig 9 | 10 | 11 | class DistriResnetBlock2DTP(BaseModule): 12 | def __init__(self, module: ResnetBlock2D, distri_config: DistriConfig): 13 | super(DistriResnetBlock2DTP, self).__init__(module, distri_config) 14 | assert module.conv1.out_channels % distri_config.n_device_per_batch == 0 15 | 16 | mid_channels = module.conv1.out_channels // distri_config.n_device_per_batch 17 | 18 | sharded_conv1 = nn.Conv2d( 19 | module.conv1.in_channels, 20 | mid_channels, 21 | module.conv1.kernel_size, 22 | module.conv1.stride, 23 | module.conv1.padding, 24 | module.conv1.dilation, 25 | module.conv1.groups, 26 | module.conv1.bias is not None, 27 | module.conv1.padding_mode, 28 | device=module.conv1.weight.device, 29 | dtype=module.conv1.weight.dtype, 30 | ) 31 | sharded_conv1.weight.data.copy_( 32 | module.conv1.weight.data[ 33 | distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels 34 | ] 35 | ) 36 | if module.conv1.bias is not None: 37 | sharded_conv1.bias.data.copy_( 38 | module.conv1.bias.data[ 39 | distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels 40 | ] 41 | ) 42 | 43 | sharded_conv2 = nn.Conv2d( 44 | mid_channels, 45 | module.conv2.out_channels, 46 | module.conv2.kernel_size, 47 | module.conv2.stride, 48 | module.conv2.padding, 49 | module.conv2.dilation, 50 | module.conv2.groups, 51 | module.conv2.bias is not None, 52 | module.conv2.padding_mode, 53 | device=module.conv2.weight.device, 54 | dtype=module.conv2.weight.dtype, 55 | ) 56 | sharded_conv2.weight.data.copy_( 57 | module.conv2.weight.data[ 58 | :, distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels 59 | ] 60 | ) 61 | if module.conv2.bias is not None: 62 | sharded_conv2.bias.data.copy_(module.conv2.bias.data) 63 | 64 | assert module.time_emb_proj is not None 65 | assert module.time_embedding_norm == "default" 66 | 67 | sharded_time_emb_proj = nn.Linear( 68 | module.time_emb_proj.in_features, 69 | mid_channels, 70 | bias=module.time_emb_proj.bias is not None, 71 | device=module.time_emb_proj.weight.device, 72 | dtype=module.time_emb_proj.weight.dtype, 73 | ) 74 | sharded_time_emb_proj.weight.data.copy_( 75 | module.time_emb_proj.weight.data[ 76 | distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels 77 | ] 78 | ) 79 | if module.time_emb_proj.bias is not None: 80 | sharded_time_emb_proj.bias.data.copy_( 81 | module.time_emb_proj.bias.data[ 82 | distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels 83 | ] 84 | ) 85 | 86 | sharded_norm2 = nn.GroupNorm( 87 | module.norm2.num_groups // distri_config.n_device_per_batch, 88 | mid_channels, 89 | module.norm2.eps, 90 | module.norm2.affine, 91 | device=module.norm2.weight.device, 92 | dtype=module.norm2.weight.dtype, 93 | ) 94 | if module.norm2.affine: 95 | sharded_norm2.weight.data.copy_( 96 | module.norm2.weight.data[ 97 | distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels 98 | ] 99 | ) 100 | sharded_norm2.bias.data.copy_( 101 | module.norm2.bias.data[ 102 | distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels 103 | ] 104 | ) 105 | 106 | del module.conv1 107 | del module.conv2 108 | del module.time_emb_proj 109 | del module.norm2 110 | module.conv1 = sharded_conv1 111 | module.conv2 = sharded_conv2 112 | module.time_emb_proj = sharded_time_emb_proj 113 | module.norm2 = sharded_norm2 114 | 115 | torch.cuda.empty_cache() 116 | 117 | def forward( 118 | self, 119 | input_tensor: torch.FloatTensor, 120 | temb: torch.FloatTensor, 121 | scale: float = 1.0, 122 | ) -> torch.FloatTensor: 123 | assert scale == 1.0 124 | 125 | distri_config = self.distri_config 126 | module = self.module 127 | 128 | hidden_states = input_tensor 129 | hidden_states = module.norm1(hidden_states) 130 | 131 | hidden_states = module.nonlinearity(hidden_states) 132 | 133 | if module.upsample is not None: 134 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 135 | if hidden_states.shape[0] >= 64: 136 | input_tensor = input_tensor.contiguous() 137 | hidden_states = hidden_states.contiguous() 138 | input_tensor = ( 139 | module.upsample(input_tensor, scale=scale) 140 | if isinstance(module.upsample, Upsample2D) 141 | else module.upsample(input_tensor) 142 | ) 143 | hidden_states = ( 144 | module.upsample(hidden_states, scale=scale) 145 | if isinstance(module.upsample, Upsample2D) 146 | else module.upsample(hidden_states) 147 | ) 148 | elif module.downsample is not None: 149 | input_tensor = ( 150 | module.downsample(input_tensor, scale=scale) 151 | if isinstance(module.downsample, Downsample2D) 152 | else module.downsample(input_tensor) 153 | ) 154 | hidden_states = ( 155 | module.downsample(hidden_states, scale=scale) 156 | if isinstance(module.downsample, Downsample2D) 157 | else module.downsample(hidden_states) 158 | ) 159 | 160 | hidden_states = module.conv1(hidden_states) 161 | 162 | if module.time_emb_proj is not None: 163 | if not module.skip_time_act: 164 | temb = module.nonlinearity(temb) 165 | temb = module.time_emb_proj(temb)[:, :, None, None] 166 | 167 | if temb is not None and module.time_embedding_norm == "default": 168 | hidden_states = hidden_states + temb 169 | 170 | hidden_states = module.norm2(hidden_states) 171 | 172 | if temb is not None and module.time_embedding_norm == "scale_shift": 173 | scale, shift = torch.chunk(temb, 2, dim=1) 174 | hidden_states = hidden_states * (1 + scale) + shift 175 | 176 | hidden_states = module.nonlinearity(hidden_states) 177 | 178 | hidden_states = module.dropout(hidden_states) 179 | hidden_states = F.conv2d( 180 | hidden_states, 181 | module.conv2.weight, 182 | bias=None, 183 | stride=module.conv2.stride, 184 | padding=module.conv2.padding, 185 | dilation=module.conv2.dilation, 186 | groups=module.conv2.groups, 187 | ) 188 | 189 | dist.all_reduce(hidden_states, op=dist.ReduceOp.SUM, group=distri_config.batch_group, async_op=False) 190 | if module.conv2.bias is not None: 191 | hidden_states = hidden_states + module.conv2.bias.view(1, -1, 1, 1) 192 | 193 | if module.conv_shortcut is not None: 194 | input_tensor = ( 195 | module.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) 196 | ) 197 | 198 | output_tensor = (input_tensor + hidden_states) / module.output_scale_factor 199 | 200 | self.counter += 1 201 | 202 | return output_tensor 203 | -------------------------------------------------------------------------------- /distrifuser/pipelines.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel 3 | 4 | from .models.distri_sdxl_unet_pp import DistriUNetPP 5 | from .models.distri_sdxl_unet_tp import DistriUNetTP 6 | from .models.naive_patch_sdxl import NaivePatchUNet 7 | from .utils import DistriConfig, PatchParallelismCommManager 8 | 9 | 10 | class DistriSDXLPipeline: 11 | def __init__(self, pipeline: StableDiffusionXLPipeline, module_config: DistriConfig): 12 | self.pipeline = pipeline 13 | self.distri_config = module_config 14 | 15 | self.static_inputs = None 16 | 17 | self.prepare() 18 | 19 | @staticmethod 20 | def from_pretrained(distri_config: DistriConfig, **kwargs): 21 | device = distri_config.device 22 | pretrained_model_name_or_path = kwargs.pop( 23 | "pretrained_model_name_or_path", "stabilityai/stable-diffusion-xl-base-1.0" 24 | ) 25 | torch_dtype = kwargs.pop("torch_dtype", torch.float16) 26 | unet = UNet2DConditionModel.from_pretrained( 27 | pretrained_model_name_or_path, torch_dtype=torch_dtype, subfolder="unet" 28 | ).to(device) 29 | 30 | if distri_config.parallelism == "patch": 31 | unet = DistriUNetPP(unet, distri_config) 32 | elif distri_config.parallelism == "tensor": 33 | unet = DistriUNetTP(unet, distri_config) 34 | elif distri_config.parallelism == "naive_patch": 35 | unet = NaivePatchUNet(unet, distri_config) 36 | else: 37 | raise ValueError(f"Unknown parallelism: {distri_config.parallelism}") 38 | 39 | pipeline = StableDiffusionXLPipeline.from_pretrained( 40 | pretrained_model_name_or_path, torch_dtype=torch_dtype, unet=unet, **kwargs 41 | ).to(device) 42 | return DistriSDXLPipeline(pipeline, distri_config) 43 | 44 | def set_progress_bar_config(self, **kwargs): 45 | self.pipeline.set_progress_bar_config(**kwargs) 46 | 47 | @torch.no_grad() 48 | def __call__(self, *args, **kwargs): 49 | assert "height" not in kwargs, "height should not be in kwargs" 50 | assert "width" not in kwargs, "width should not be in kwargs" 51 | config = self.distri_config 52 | if not config.do_classifier_free_guidance: 53 | if "guidance_scale" not in kwargs: 54 | kwargs["guidance_scale"] = 1 55 | else: 56 | assert kwargs["guidance_scale"] == 1 57 | self.pipeline.unet.set_counter(0) 58 | return self.pipeline(height=config.height, width=config.width, *args, **kwargs) 59 | 60 | @torch.no_grad() 61 | def prepare(self, **kwargs): 62 | distri_config = self.distri_config 63 | 64 | static_inputs = {} 65 | static_outputs = [] 66 | cuda_graphs = [] 67 | pipeline = self.pipeline 68 | 69 | height = distri_config.height 70 | width = distri_config.width 71 | assert height % 8 == 0 and width % 8 == 0 72 | 73 | original_size = (height, width) 74 | target_size = (height, width) 75 | crops_coords_top_left = (0, 0) 76 | 77 | device = distri_config.device 78 | 79 | prompt_embeds, _, pooled_prompt_embeds, _ = pipeline.encode_prompt( 80 | prompt="", 81 | prompt_2=None, 82 | device=device, 83 | num_images_per_prompt=1, 84 | do_classifier_free_guidance=False, 85 | negative_prompt=None, 86 | negative_prompt_2=None, 87 | prompt_embeds=None, 88 | negative_prompt_embeds=None, 89 | pooled_prompt_embeds=None, 90 | negative_pooled_prompt_embeds=None, 91 | ) 92 | batch_size = 2 if distri_config.do_classifier_free_guidance else 1 93 | 94 | num_channels_latents = pipeline.unet.config.in_channels 95 | latents = pipeline.prepare_latents( 96 | batch_size, num_channels_latents, height, width, prompt_embeds.dtype, device, None 97 | ) 98 | 99 | # 7. Prepare added time ids & embeddings 100 | add_text_embeds = pooled_prompt_embeds 101 | if pipeline.text_encoder_2 is None: 102 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) 103 | else: 104 | text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim 105 | 106 | add_time_ids = pipeline._get_add_time_ids( 107 | original_size, 108 | crops_coords_top_left, 109 | target_size, 110 | dtype=prompt_embeds.dtype, 111 | text_encoder_projection_dim=text_encoder_projection_dim, 112 | ) 113 | 114 | prompt_embeds = prompt_embeds.to(device) 115 | add_text_embeds = add_text_embeds.to(device) 116 | add_time_ids = add_time_ids.to(device).repeat(1, 1) 117 | 118 | if batch_size > 1: 119 | prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1) 120 | add_text_embeds = add_text_embeds.repeat(batch_size, 1) 121 | add_time_ids = add_time_ids.repeat(batch_size, 1) 122 | 123 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 124 | t = torch.zeros([batch_size], device=device, dtype=torch.long) 125 | 126 | static_inputs["sample"] = latents 127 | static_inputs["timestep"] = t 128 | static_inputs["encoder_hidden_states"] = prompt_embeds 129 | static_inputs["added_cond_kwargs"] = added_cond_kwargs 130 | 131 | # Used to create communication buffer 132 | comm_manager = None 133 | if distri_config.n_device_per_batch > 1: 134 | comm_manager = PatchParallelismCommManager(distri_config) 135 | pipeline.unet.set_comm_manager(comm_manager) 136 | 137 | # Only used for creating the communication buffer 138 | pipeline.unet.set_counter(0) 139 | pipeline.unet(**static_inputs, return_dict=False, record=True) 140 | if comm_manager.numel > 0: 141 | comm_manager.create_buffer() 142 | 143 | # Pre-run 144 | pipeline.unet.set_counter(0) 145 | pipeline.unet(**static_inputs, return_dict=False, record=True) 146 | 147 | if distri_config.use_cuda_graph: 148 | if comm_manager is not None: 149 | comm_manager.clear() 150 | if distri_config.parallelism == "naive_patch": 151 | counters = [0, 1] 152 | elif distri_config.parallelism == "patch": 153 | counters = [0, distri_config.warmup_steps + 1, distri_config.warmup_steps + 2] 154 | elif distri_config.parallelism == "tensor": 155 | counters = [0] 156 | else: 157 | raise ValueError(f"Unknown parallelism: {distri_config.parallelism}") 158 | for counter in counters: 159 | graph = torch.cuda.CUDAGraph() 160 | with torch.cuda.graph(graph): 161 | pipeline.unet.set_counter(counter) 162 | output = pipeline.unet(**static_inputs, return_dict=False, record=True)[0] 163 | static_outputs.append(output) 164 | cuda_graphs.append(graph) 165 | pipeline.unet.setup_cuda_graph(static_outputs, cuda_graphs) 166 | 167 | self.static_inputs = static_inputs 168 | 169 | 170 | class DistriSDPipeline: 171 | def __init__(self, pipeline: StableDiffusionPipeline, module_config: DistriConfig): 172 | self.pipeline = pipeline 173 | self.distri_config = module_config 174 | 175 | self.static_inputs = None 176 | 177 | self.prepare() 178 | 179 | @staticmethod 180 | def from_pretrained(distri_config: DistriConfig, **kwargs): 181 | device = distri_config.device 182 | pretrained_model_name_or_path = kwargs.pop("pretrained_model_name_or_path", "CompVis/stable-diffusion-v1-4") 183 | torch_dtype = kwargs.pop("torch_dtype", torch.float16) 184 | unet = UNet2DConditionModel.from_pretrained( 185 | pretrained_model_name_or_path, torch_dtype=torch_dtype, subfolder="unet" 186 | ).to(device) 187 | 188 | if distri_config.parallelism == "patch": 189 | unet = DistriUNetPP(unet, distri_config) 190 | elif distri_config.parallelism == "tensor": 191 | unet = DistriUNetTP(unet, distri_config) 192 | elif distri_config.parallelism == "naive_patch": 193 | unet = NaivePatchUNet(unet, distri_config) 194 | else: 195 | raise ValueError(f"Unknown parallelism: {distri_config.parallelism}") 196 | 197 | pipeline = StableDiffusionPipeline.from_pretrained( 198 | pretrained_model_name_or_path, torch_dtype=torch_dtype, unet=unet, **kwargs 199 | ).to(device) 200 | return DistriSDPipeline(pipeline, distri_config) 201 | 202 | def set_progress_bar_config(self, **kwargs): 203 | self.pipeline.set_progress_bar_config(**kwargs) 204 | 205 | @torch.no_grad() 206 | def __call__(self, *args, **kwargs): 207 | assert "height" not in kwargs, "height should not be in kwargs" 208 | assert "width" not in kwargs, "width should not be in kwargs" 209 | config = self.distri_config 210 | if not config.do_classifier_free_guidance: 211 | if not "guidance_scale" not in kwargs: 212 | kwargs["guidance_scale"] = 1 213 | else: 214 | assert kwargs["guidance_scale"] == 1 215 | self.pipeline.unet.set_counter(0) 216 | return self.pipeline(height=config.height, width=config.width, *args, **kwargs) 217 | 218 | @torch.no_grad() 219 | def prepare(self, **kwargs): 220 | distri_config = self.distri_config 221 | 222 | static_inputs = {} 223 | static_outputs = [] 224 | cuda_graphs = [] 225 | pipeline = self.pipeline 226 | 227 | height = distri_config.height 228 | width = distri_config.width 229 | assert height % 8 == 0 and width % 8 == 0 230 | 231 | device = distri_config.device 232 | 233 | prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt( 234 | "", 235 | device, 236 | num_images_per_prompt=1, 237 | do_classifier_free_guidance=False, 238 | negative_prompt=None, 239 | prompt_embeds=None, 240 | negative_prompt_embeds=None, 241 | lora_scale=None, 242 | clip_skip=kwargs.get("clip_skip", None), 243 | ) 244 | 245 | batch_size = 2 if distri_config.do_classifier_free_guidance else 1 246 | 247 | num_channels_latents = pipeline.unet.config.in_channels 248 | latents = pipeline.prepare_latents( 249 | batch_size, num_channels_latents, height, width, prompt_embeds.dtype, device, None 250 | ) 251 | 252 | prompt_embeds = prompt_embeds.to(device) 253 | 254 | if batch_size > 1: 255 | prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1) 256 | 257 | t = torch.zeros([batch_size], device=device, dtype=torch.long) 258 | 259 | static_inputs["sample"] = latents 260 | static_inputs["timestep"] = t 261 | static_inputs["encoder_hidden_states"] = prompt_embeds 262 | 263 | # Used to create communication buffer 264 | comm_manager = None 265 | if distri_config.n_device_per_batch > 1: 266 | comm_manager = PatchParallelismCommManager(distri_config) 267 | pipeline.unet.set_comm_manager(comm_manager) 268 | 269 | # Only used for creating the communication buffer 270 | pipeline.unet.set_counter(0) 271 | pipeline.unet(**static_inputs, return_dict=False, record=True) 272 | if comm_manager.numel > 0: 273 | comm_manager.create_buffer() 274 | 275 | # Pre-run 276 | pipeline.unet.set_counter(0) 277 | pipeline.unet(**static_inputs, return_dict=False, record=True) 278 | 279 | if distri_config.use_cuda_graph: 280 | if comm_manager is not None: 281 | comm_manager.clear() 282 | if distri_config.parallelism == "naive_patch": 283 | counters = [0, 1] 284 | elif distri_config.parallelism == "patch": 285 | counters = [0, distri_config.warmup_steps + 1, distri_config.warmup_steps + 2] 286 | elif distri_config.parallelism == "tensor": 287 | counters = [0] 288 | else: 289 | raise ValueError(f"Unknown parallelism: {distri_config.parallelism}") 290 | for counter in counters: 291 | graph = torch.cuda.CUDAGraph() 292 | with torch.cuda.graph(graph): 293 | pipeline.unet.set_counter(counter) 294 | output = pipeline.unet(**static_inputs, return_dict=False, record=True)[0] 295 | static_outputs.append(output) 296 | cuda_graphs.append(graph) 297 | pipeline.unet.setup_cuda_graph(static_outputs, cuda_graphs) 298 | 299 | self.static_inputs = static_inputs 300 | -------------------------------------------------------------------------------- /distrifuser/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from packaging import version 3 | from torch import distributed as dist 4 | 5 | 6 | def check_env(): 7 | if version.parse(torch.version.cuda) < version.parse("11.3"): 8 | # https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/cudagraph.html 9 | raise RuntimeError("NCCL CUDA Graph support requires CUDA 11.3 or above") 10 | if version.parse(version.parse(torch.__version__).base_version) < version.parse("2.2.0"): 11 | # https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/ 12 | raise RuntimeError( 13 | "CUDAGraph with NCCL support requires PyTorch 2.2.0 or above. " 14 | "If it is not released yet, please install nightly built PyTorch with " 15 | "`pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121`" 16 | ) 17 | 18 | 19 | def is_power_of_2(n: int) -> bool: 20 | return (n & (n - 1) == 0) and n != 0 21 | 22 | 23 | class DistriConfig: 24 | def __init__( 25 | self, 26 | height: int = 1024, 27 | width: int = 1024, 28 | do_classifier_free_guidance: bool = True, 29 | split_batch: bool = True, 30 | warmup_steps: int = 4, 31 | comm_checkpoint: int = 60, 32 | mode: str = "corrected_async_gn", 33 | use_cuda_graph: bool = True, 34 | parallelism: str = "patch", 35 | split_scheme: str = "row", 36 | verbose: bool = False, 37 | ): 38 | try: 39 | # Initialize the process group 40 | dist.init_process_group("nccl") 41 | # Get the rank and world_size 42 | rank = dist.get_rank() 43 | world_size = dist.get_world_size() 44 | except Exception as e: 45 | rank = 0 46 | world_size = 1 47 | print(f"Failed to initialize process group: {e}, falling back to single GPU") 48 | 49 | assert is_power_of_2(world_size) 50 | check_env() 51 | 52 | self.world_size = world_size 53 | self.rank = rank 54 | self.height = height 55 | self.width = width 56 | self.do_classifier_free_guidance = do_classifier_free_guidance 57 | self.split_batch = split_batch 58 | self.warmup_steps = warmup_steps 59 | self.comm_checkpoint = comm_checkpoint 60 | self.mode = mode 61 | self.use_cuda_graph = use_cuda_graph 62 | 63 | self.parallelism = parallelism 64 | self.split_scheme = split_scheme 65 | 66 | self.verbose = verbose 67 | 68 | if do_classifier_free_guidance and split_batch: 69 | n_device_per_batch = world_size // 2 70 | if n_device_per_batch == 0: 71 | n_device_per_batch = 1 72 | else: 73 | n_device_per_batch = world_size 74 | 75 | self.n_device_per_batch = n_device_per_batch 76 | 77 | self.height = height 78 | self.width = width 79 | 80 | device = torch.device(f"cuda:{rank}") 81 | torch.cuda.set_device(device) 82 | self.device = device 83 | 84 | batch_group = None 85 | split_group = None 86 | if do_classifier_free_guidance and split_batch and world_size >= 2: 87 | batch_groups = [] 88 | for i in range(2): 89 | batch_groups.append(dist.new_group(list(range(i * (world_size // 2), (i + 1) * (world_size // 2))))) 90 | batch_group = batch_groups[self.batch_idx()] 91 | split_groups = [] 92 | for i in range(world_size // 2): 93 | split_groups.append(dist.new_group([i, i + world_size // 2])) 94 | split_group = split_groups[self.split_idx()] 95 | self.batch_group = batch_group 96 | self.split_group = split_group 97 | 98 | def batch_idx(self, rank: int or None = None) -> int: 99 | if rank is None: 100 | rank = self.rank 101 | if self.do_classifier_free_guidance and self.split_batch: 102 | return 1 - int(rank < (self.world_size // 2)) 103 | else: 104 | return 0 # raise NotImplementedError 105 | 106 | def split_idx(self, rank: int or None = None) -> int: 107 | if rank is None: 108 | rank = self.rank 109 | return rank % self.n_device_per_batch 110 | 111 | 112 | class PatchParallelismCommManager: 113 | def __init__(self, distri_config: DistriConfig): 114 | self.distri_config = distri_config 115 | 116 | self.torch_dtype = None 117 | self.numel = 0 118 | self.numel_dict = {} 119 | 120 | self.buffer_list = None 121 | 122 | self.starts = [] 123 | self.ends = [] 124 | self.shapes = [] 125 | 126 | self.idx_queue = [] 127 | 128 | self.handles = None 129 | 130 | def register_tensor( 131 | self, shape: tuple[int, ...] or list[int], torch_dtype: torch.dtype, layer_type: str = None 132 | ) -> int: 133 | if self.torch_dtype is None: 134 | self.torch_dtype = torch_dtype 135 | else: 136 | assert self.torch_dtype == torch_dtype 137 | self.starts.append(self.numel) 138 | numel = 1 139 | for dim in shape: 140 | numel *= dim 141 | self.numel += numel 142 | if layer_type is not None: 143 | if layer_type not in self.numel_dict: 144 | self.numel_dict[layer_type] = 0 145 | self.numel_dict[layer_type] += numel 146 | 147 | self.ends.append(self.numel) 148 | self.shapes.append(shape) 149 | return len(self.starts) - 1 150 | 151 | def create_buffer(self): 152 | distri_config = self.distri_config 153 | if distri_config.rank == 0 and distri_config.verbose: 154 | print( 155 | f"Create buffer with {self.numel / 1e6:.3f}M parameters for {len(self.starts)} tensors on each device." 156 | ) 157 | for layer_type, numel in self.numel_dict.items(): 158 | print(f" {layer_type}: {numel / 1e6:.3f}M parameters") 159 | 160 | self.buffer_list = [ 161 | torch.empty(self.numel, dtype=self.torch_dtype, device=self.distri_config.device) 162 | for _ in range(self.distri_config.n_device_per_batch) 163 | ] 164 | self.handles = [None for _ in range(len(self.starts))] 165 | 166 | def get_buffer_list(self, idx: int) -> list[torch.Tensor]: 167 | buffer_list = [t[self.starts[idx] : self.ends[idx]].view(self.shapes[idx]) for t in self.buffer_list] 168 | return buffer_list 169 | 170 | def communicate(self): 171 | distri_config = self.distri_config 172 | start = self.starts[self.idx_queue[0]] 173 | end = self.ends[self.idx_queue[-1]] 174 | tensor = self.buffer_list[distri_config.split_idx()][start:end] 175 | buffer_list = [t[start:end] for t in self.buffer_list] 176 | handle = dist.all_gather(buffer_list, tensor, group=self.distri_config.batch_group, async_op=True) 177 | for i in self.idx_queue: 178 | self.handles[i] = handle 179 | self.idx_queue = [] 180 | 181 | def enqueue(self, idx: int, tensor: torch.Tensor): 182 | distri_config = self.distri_config 183 | if idx == 0 and len(self.idx_queue) > 0: 184 | self.communicate() 185 | assert len(self.idx_queue) == 0 or self.idx_queue[-1] == idx - 1 186 | self.idx_queue.append(idx) 187 | self.buffer_list[distri_config.split_idx()][self.starts[idx] : self.ends[idx]].copy_(tensor.flatten()) 188 | 189 | if len(self.idx_queue) == distri_config.comm_checkpoint: 190 | self.communicate() 191 | 192 | def clear(self): 193 | if len(self.idx_queue) > 0: 194 | self.communicate() 195 | if self.handles is not None: 196 | for i in range(len(self.handles)): 197 | if self.handles[i] is not None: 198 | self.handles[i].wait() 199 | self.handles[i] = None 200 | -------------------------------------------------------------------------------- /scripts/compute_metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from cleanfid import fid 7 | from PIL import Image 8 | from torch.utils.data import DataLoader, Dataset 9 | from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio 10 | from torchvision.transforms import Resize 11 | from tqdm import tqdm 12 | 13 | 14 | def read_image(path: str): 15 | """ 16 | input: path 17 | output: tensor (C, H, W) 18 | """ 19 | img = np.asarray(Image.open(path)) 20 | if len(img.shape) == 2: 21 | img = np.repeat(img[:, :, None], 3, axis=2) 22 | img = torch.from_numpy(img).permute(2, 0, 1) 23 | return img 24 | 25 | 26 | class MultiImageDataset(Dataset): 27 | def __init__(self, root0, root1, is_gt=False): 28 | super().__init__() 29 | self.root0 = root0 30 | self.root1 = root1 31 | file_names0 = os.listdir(root0) 32 | file_names1 = os.listdir(root1) 33 | 34 | self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")]) 35 | self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")]) 36 | self.is_gt = is_gt 37 | assert len(self.image_names0) == len(self.image_names1) 38 | 39 | def __len__(self): 40 | return len(self.image_names0) 41 | 42 | def __getitem__(self, idx): 43 | img0 = read_image(os.path.join(self.root0, self.image_names0[idx])) 44 | if self.is_gt: 45 | # resize to 1024 x 1024 46 | img0 = Resize((1024, 1024))(img0) 47 | img1 = read_image(os.path.join(self.root1, self.image_names1[idx])) 48 | 49 | batch_list = [img0, img1] 50 | return batch_list 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--batch_size", type=int, default=64) 56 | parser.add_argument("--num_workers", type=int, default=8) 57 | parser.add_argument("--is_gt", action="store_true") 58 | parser.add_argument("--input_root0", type=str, required=True) 59 | parser.add_argument("--input_root1", type=str, required=True) 60 | args = parser.parse_args() 61 | 62 | psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda") 63 | lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda") 64 | 65 | dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt) 66 | dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) 67 | 68 | progress_bar = tqdm(dataloader) 69 | with torch.inference_mode(): 70 | for i, batch in enumerate(progress_bar): 71 | batch = [img.to("cuda") / 255 for img in batch] 72 | batch_size = batch[0].shape[0] 73 | psnr.update(batch[0], batch[1]) 74 | lpips.update(batch[0], batch[1]) 75 | fid_score = fid.compute_fid(args.input_root0, args.input_root1) 76 | 77 | print("PSNR:", psnr.compute().item()) 78 | print("LPIPS:", lpips.compute().item()) 79 | print("FID:", fid_score) 80 | -------------------------------------------------------------------------------- /scripts/dump_coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from argparse import ArgumentParser 4 | 5 | from datasets import load_dataset 6 | from tqdm import tqdm, trange 7 | 8 | if __name__ == "__main__": 9 | parser = ArgumentParser() 10 | parser.add_argument("--output_root", type=str, default="./coco") 11 | args = parser.parse_args() 12 | 13 | dataset = load_dataset("HuggingFaceM4/COCO", name="2014_captions", split="validation") 14 | 15 | prompt_list = [] 16 | for i in trange(len(dataset["sentences_raw"])): 17 | prompt = dataset["sentences_raw"][i][i % len(dataset["sentences_raw"][i])] 18 | prompt_list.append(prompt) 19 | 20 | os.makedirs(args.output_root, exist_ok=True) 21 | prompt_path = os.path.join(args.output_root, "prompts.json") 22 | with open(prompt_path, "w") as f: 23 | json.dump(prompt_list, f, indent=4) 24 | 25 | os.makedirs(os.path.join(args.output_root, "images"), exist_ok=True) 26 | 27 | dataset = load_dataset("HuggingFaceM4/COCO", name="2014_captions", split="validation") 28 | for i, image in enumerate(tqdm(dataset["image"])): 29 | image.save(os.path.join(args.output_root, "images", f"{i:04}.png")) 30 | -------------------------------------------------------------------------------- /scripts/export_html.py: -------------------------------------------------------------------------------- 1 | """ 2 | An auxiliary script to generate HTML files for image visualization. 3 | """ 4 | 5 | import argparse 6 | import json 7 | import os 8 | import random 9 | import shutil 10 | 11 | import dominate 12 | from dominate.tags import h3, img, table, td, tr 13 | 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--image_dirs", type=str, nargs="+", required=True) 18 | parser.add_argument("--caption_path", type=str, default=None) 19 | parser.add_argument("--output_root", type=str, required=True) 20 | parser.add_argument("--aliases", type=str, default=None, nargs="+") 21 | parser.add_argument("--title", type=str, default=None) 22 | parser.add_argument("--hard_copy", action="store_true") 23 | parser.add_argument("--max_images", type=int, default=None) 24 | parser.add_argument("--seed", type=int, default=0) 25 | args = parser.parse_args() 26 | if args.aliases is not None: 27 | assert len(args.image_dirs) == len(args.aliases) 28 | else: 29 | args.aliases = [str(i) for i in range(len(args.image_dirs))] 30 | return args 31 | 32 | 33 | def check_existence(image_dirs, filename): 34 | for image_dir in image_dirs: 35 | if not os.path.exists(os.path.join(image_dir, filename)): 36 | print(os.path.join(image_dir, filename)) 37 | return False 38 | return True 39 | 40 | 41 | if __name__ == "__main__": 42 | args = get_args() 43 | filenames = sorted(os.listdir(args.image_dirs[0])) 44 | filenames = [ 45 | filename 46 | for filename in filenames 47 | if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg") 48 | ] 49 | if args.max_images is not None: 50 | random.seed(args.seed) 51 | random.shuffle(filenames) 52 | filenames = filenames[: args.max_images] 53 | filenames = sorted(filenames) 54 | doc = dominate.document(title="Visualization" if args.title is None else args.title) 55 | if args.title: 56 | with doc: 57 | h3(args.title) 58 | t_main = table(border=1, style="table-layout: fixed;") 59 | prompts = json.load(open(args.caption_path, "r")) 60 | for i, filename in enumerate(filenames): 61 | bname = os.path.splitext(filename)[0] 62 | if not check_existence(args.image_dirs, filename): 63 | continue 64 | title_row = tr() 65 | _tr = tr() 66 | title_row.add(td(f"{bname}")) 67 | _tr.add(td(prompts[int(bname)])) 68 | for image_dir, alias in zip(args.image_dirs, args.aliases): 69 | title_row.add(td(f"{alias}")) 70 | _td = td(style="word-wrap: break-word;", halign="center", valign="top") 71 | source_path = os.path.abspath(os.path.join(image_dir, filename)) 72 | target_path = os.path.abspath(os.path.join(args.output_root, "images", alias, filename)) 73 | os.makedirs(os.path.dirname(os.path.abspath(target_path)), exist_ok=True) 74 | if args.hard_copy: 75 | shutil.copy(source_path, target_path) 76 | else: 77 | os.symlink(source_path, target_path) 78 | _td.add(img(style="width:256px", src=os.path.relpath(target_path, args.output_root))) 79 | _tr.add(_td) 80 | t_main.add(title_row) 81 | t_main.add(_tr) 82 | with open(os.path.join(args.output_root, "index.html"), "w") as f: 83 | f.write(t_main.render()) 84 | -------------------------------------------------------------------------------- /scripts/generate_coco.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from datasets import load_dataset 6 | from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler 7 | from tqdm import trange 8 | 9 | from distrifuser.pipelines import DistriSDXLPipeline 10 | from distrifuser.utils import DistriConfig 11 | 12 | 13 | def get_args() -> argparse.Namespace: 14 | parser = argparse.ArgumentParser() 15 | # Diffuser specific arguments 16 | parser.add_argument("--output_root", type=str, default=None) 17 | parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps") 18 | parser.add_argument("--image_size", type=int, nargs="*", default=1024, help="Image size of generation") 19 | parser.add_argument("--guidance_scale", type=float, default=5.0) 20 | parser.add_argument("--scheduler", type=str, default="ddim", choices=["euler", "dpm-solver", "ddim"]) 21 | 22 | # DistriFuser specific arguments 23 | parser.add_argument( 24 | "--no_split_batch", action="store_true", help="Disable the batch splitting for classifier-free guidance" 25 | ) 26 | parser.add_argument("--warmup_steps", type=int, default=4, help="Number of warmup steps") 27 | parser.add_argument( 28 | "--sync_mode", 29 | type=str, 30 | default="corrected_async_gn", 31 | choices=["separate_gn", "stale_gn", "corrected_async_gn", "sync_gn", "full_sync", "no_sync"], 32 | help="Different GroupNorm synchronization modes", 33 | ) 34 | parser.add_argument( 35 | "--parallelism", 36 | type=str, 37 | default="patch", 38 | choices=["patch", "tensor", "naive_patch"], 39 | help="patch parallelism, tensor parallelism or naive patch", 40 | ) 41 | parser.add_argument( 42 | "--split_scheme", 43 | type=str, 44 | default="alternate", 45 | choices=["row", "col", "alternate"], 46 | help="Split scheme for naive patch", 47 | ) 48 | parser.add_argument("--no_cuda_graph", action="store_true", help="Disable CUDA graph") 49 | 50 | parser.add_argument("--split", nargs=2, type=int, default=None, help="Split the dataset into chunks") 51 | 52 | args = parser.parse_args() 53 | return args 54 | 55 | 56 | def main(): 57 | args = get_args() 58 | 59 | if isinstance(args.image_size, int): 60 | args.image_size = [args.image_size, args.image_size] 61 | else: 62 | if len(args.image_size) == 1: 63 | args.image_size = [args.image_size[0], args.image_size[0]] 64 | else: 65 | assert len(args.image_size) == 2 66 | distri_config = DistriConfig( 67 | height=args.image_size[0], 68 | width=args.image_size[1], 69 | do_classifier_free_guidance=args.guidance_scale > 1, 70 | split_batch=not args.no_split_batch, 71 | warmup_steps=args.warmup_steps, 72 | mode=args.sync_mode, 73 | use_cuda_graph=not args.no_cuda_graph, 74 | parallelism=args.parallelism, 75 | split_scheme=args.split_scheme, 76 | ) 77 | 78 | pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" 79 | if args.scheduler == "euler": 80 | scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") 81 | elif args.scheduler == "dpm-solver": 82 | scheduler = DPMSolverMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") 83 | elif args.scheduler == "ddim": 84 | scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") 85 | else: 86 | raise NotImplementedError 87 | pipeline = DistriSDXLPipeline.from_pretrained( 88 | pretrained_model_name_or_path=pretrained_model_name_or_path, 89 | distri_config=distri_config, 90 | variant="fp16", 91 | use_safetensors=True, 92 | scheduler=scheduler, 93 | ) 94 | pipeline.set_progress_bar_config(disable=distri_config.rank != 0, position=1, leave=False) 95 | 96 | if args.output_root is None: 97 | args.output_root = os.path.join( 98 | "results", 99 | "coco", 100 | f"{args.scheduler}-{args.num_inference_steps}", 101 | f"gpus{distri_config.world_size if args.no_split_batch else distri_config.world_size // 2}-" 102 | f"warmup{args.warmup_steps}-{args.sync_mode}", 103 | ) 104 | if distri_config.rank == 0: 105 | os.makedirs(args.output_root, exist_ok=True) 106 | 107 | dataset = load_dataset("HuggingFaceM4/COCO", name="2014_captions", split="validation", trust_remote_code=True) 108 | 109 | if args.split is not None: 110 | assert args.split[0] < args.split[1] 111 | chunk_size = (5000 + args.split[1] - 1) // args.split[1] 112 | start_idx = args.split[0] * chunk_size 113 | end_idx = min((args.split[0] + 1) * chunk_size, 5000) 114 | else: 115 | start_idx = 0 116 | end_idx = 5000 117 | 118 | for i in trange(start_idx, end_idx, disable=distri_config.rank != 0, position=0, leave=False): 119 | prompt = dataset["sentences_raw"][i][i % len(dataset["sentences_raw"][i])] 120 | seed = i 121 | 122 | image = pipeline( 123 | prompt=prompt, 124 | generator=torch.Generator(device="cuda").manual_seed(seed), 125 | num_inference_steps=args.num_inference_steps, 126 | guidance_scale=args.guidance_scale, 127 | ).images[0] 128 | if distri_config.rank == 0: 129 | output_path = os.path.join(args.output_root, f"{i:04d}.png") 130 | image.save(output_path) 131 | 132 | 133 | if __name__ == "__main__": 134 | main() 135 | -------------------------------------------------------------------------------- /scripts/profile_macs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from diffusers import StableDiffusionXLPipeline 5 | from torchprofile import profile_macs 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--image_size", type=int, nargs="*", default=1024, help="Image size of generation") 10 | args = parser.parse_args() 11 | 12 | if isinstance(args.image_size, int): 13 | args.image_size = [args.image_size // 8, args.image_size // 8] 14 | elif len(args.image_size) == 1: 15 | args.image_size = [args.image_size[0] // 8, args.image_size[0] // 8] 16 | else: 17 | assert len(args.image_size) == 2 18 | args.image_size = [args.image_size[0] // 8, args.image_size[1] // 8] 19 | 20 | pipeline = StableDiffusionXLPipeline.from_pretrained( 21 | "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True 22 | ).to("cuda") 23 | 24 | unet = pipeline.unet 25 | 26 | latent_model_input = torch.randn(2, 4, *args.image_size, dtype=unet.dtype).to("cuda") 27 | t = torch.randn(1).to("cuda") 28 | prompt_embeds = torch.randn(2, 77, 2048, dtype=unet.dtype).to("cuda") 29 | add_text_embeds = torch.randn(2, 1280, dtype=unet.dtype).to("cuda") 30 | add_time_ids = torch.randint(0, 1024, (2, 6)).to("cuda") 31 | 32 | with torch.no_grad(): 33 | macs = profile_macs( 34 | unet, 35 | args=( 36 | latent_model_input, 37 | t, 38 | prompt_embeds, 39 | None, 40 | None, 41 | None, 42 | None, 43 | {"text_embeds": add_text_embeds, "time_ids": add_time_ids}, 44 | ), 45 | ) 46 | print(f"MACs: {macs / 1e9:.3f}G") 47 | -------------------------------------------------------------------------------- /scripts/run_sdxl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler 7 | from tqdm import trange 8 | 9 | from distrifuser.pipelines import DistriSDXLPipeline 10 | from distrifuser.utils import DistriConfig 11 | 12 | 13 | def get_args() -> argparse.Namespace: 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--mode", 17 | type=str, 18 | default="generation", 19 | choices=["generation", "benchmark"], 20 | help="Purpose of running the script", 21 | ) 22 | 23 | # Diffuser specific arguments 24 | parser.add_argument( 25 | "--prompt", type=str, default="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" 26 | ) 27 | parser.add_argument("--output_path", type=str, default=None) 28 | parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps") 29 | parser.add_argument("--image_size", type=int, nargs="*", default=1024, help="Image size of generation") 30 | parser.add_argument("--guidance_scale", type=float, default=5.0) 31 | parser.add_argument("--scheduler", type=str, default="ddim", choices=["euler", "dpm-solver", "ddim"]) 32 | parser.add_argument("--seed", type=int, default=1234, help="Random seed") 33 | 34 | # DistriFuser specific arguments 35 | parser.add_argument( 36 | "--no_split_batch", action="store_true", help="Disable the batch splitting for classifier-free guidance" 37 | ) 38 | parser.add_argument("--warmup_steps", type=int, default=4, help="Number of warmup steps") 39 | parser.add_argument( 40 | "--sync_mode", 41 | type=str, 42 | default="corrected_async_gn", 43 | choices=["separate_gn", "stale_gn", "corrected_async_gn", "sync_gn", "full_sync", "no_sync"], 44 | help="Different GroupNorm synchronization modes", 45 | ) 46 | parser.add_argument( 47 | "--parallelism", 48 | type=str, 49 | default="patch", 50 | choices=["patch", "tensor", "naive_patch"], 51 | help="patch parallelism, tensor parallelism or naive patch", 52 | ) 53 | parser.add_argument("--no_cuda_graph", action="store_true", help="Disable CUDA graph") 54 | parser.add_argument( 55 | "--split_scheme", 56 | type=str, 57 | default="alternate", 58 | choices=["row", "col", "alternate"], 59 | help="Split scheme for naive patch", 60 | ) 61 | 62 | # Benchmark specific arguments 63 | parser.add_argument("--output_type", type=str, default="pil", choices=["latent", "pil"]) 64 | parser.add_argument("--warmup_times", type=int, default=5, help="Number of warmup times") 65 | parser.add_argument("--test_times", type=int, default=20, help="Number of test times") 66 | parser.add_argument( 67 | "--ignore_ratio", type=float, default=0.2, help="Ignored ratio of the slowest and fastest steps" 68 | ) 69 | 70 | args = parser.parse_args() 71 | return args 72 | 73 | 74 | def main(): 75 | args = get_args() 76 | 77 | if isinstance(args.image_size, int): 78 | args.image_size = [args.image_size, args.image_size] 79 | else: 80 | if len(args.image_size) == 1: 81 | args.image_size = [args.image_size[0], args.image_size[0]] 82 | else: 83 | assert len(args.image_size) == 2 84 | distri_config = DistriConfig( 85 | height=args.image_size[0], 86 | width=args.image_size[1], 87 | do_classifier_free_guidance=args.guidance_scale > 1, 88 | split_batch=not args.no_split_batch, 89 | warmup_steps=args.warmup_steps, 90 | mode=args.sync_mode, 91 | use_cuda_graph=not args.no_cuda_graph, 92 | parallelism=args.parallelism, 93 | split_scheme=args.split_scheme, 94 | ) 95 | 96 | pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" 97 | if args.scheduler == "euler": 98 | scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") 99 | elif args.scheduler == "dpm-solver": 100 | scheduler = DPMSolverMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") 101 | elif args.scheduler == "ddim": 102 | scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") 103 | else: 104 | raise NotImplementedError 105 | pipeline = DistriSDXLPipeline.from_pretrained( 106 | pretrained_model_name_or_path=pretrained_model_name_or_path, 107 | distri_config=distri_config, 108 | variant="fp16", 109 | use_safetensors=True, 110 | scheduler=scheduler, 111 | ) 112 | 113 | if args.mode == "generation": 114 | assert args.output_path is not None 115 | pipeline.set_progress_bar_config(disable=distri_config.rank != 0) 116 | image = pipeline( 117 | prompt=args.prompt, 118 | generator=torch.Generator(device="cuda").manual_seed(args.seed), 119 | num_inference_steps=args.num_inference_steps, 120 | guidance_scale=args.guidance_scale, 121 | ).images[0] 122 | os.makedirs(os.path.dirname(os.path.abspath(args.output_path)), exist_ok=True) 123 | image.save(args.output_path) 124 | elif args.mode == "benchmark": 125 | pipeline.set_progress_bar_config(position=1, desc="Generation", leave=False, disable=distri_config.rank != 0) 126 | for i in trange(args.warmup_times, position=0, desc="Warmup", leave=False, disable=distri_config.rank != 0): 127 | pipeline( 128 | prompt=args.prompt, 129 | generator=torch.Generator(device="cuda").manual_seed(args.seed), 130 | num_inference_steps=args.num_inference_steps, 131 | guidance_scale=args.guidance_scale, 132 | output_type=args.output_type, 133 | ) 134 | torch.cuda.synchronize() 135 | latency_list = [] 136 | for i in trange(args.test_times, position=0, desc="Test", leave=False, disable=distri_config.rank != 0): 137 | start_time = time.time() 138 | pipeline( 139 | prompt=args.prompt, 140 | generator=torch.Generator(device="cuda").manual_seed(args.seed), 141 | num_inference_steps=args.num_inference_steps, 142 | guidance_scale=args.guidance_scale, 143 | output_type=args.output_type, 144 | ) 145 | torch.cuda.synchronize() 146 | end_time = time.time() 147 | latency_list.append(end_time - start_time) 148 | latency_list = sorted(latency_list) 149 | ignored_count = int(args.ignore_ratio * len(latency_list) / 2) 150 | if ignored_count > 0: 151 | latency_list = latency_list[ignored_count:-ignored_count] 152 | if distri_config.rank == 0: 153 | print(f"Latency: {sum(latency_list) / len(latency_list):.5f} s") 154 | else: 155 | raise NotImplementedError 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /scripts/sd_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from distrifuser.pipelines import DistriSDPipeline 4 | from distrifuser.utils import DistriConfig 5 | 6 | distri_config = DistriConfig(height=512, width=512, warmup_steps=4, mode="stale_gn") 7 | pipeline = DistriSDPipeline.from_pretrained( 8 | distri_config=distri_config, 9 | pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4", 10 | ) 11 | 12 | pipeline.set_progress_bar_config(disable=distri_config.rank != 0) 13 | image = pipeline( 14 | prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", 15 | generator=torch.Generator(device="cuda").manual_seed(233), 16 | ).images[0] 17 | if distri_config.rank == 0: 18 | image.save("astronaut.png") 19 | -------------------------------------------------------------------------------- /scripts/sdxl_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from distrifuser.pipelines import DistriSDXLPipeline 4 | from distrifuser.utils import DistriConfig 5 | 6 | distri_config = DistriConfig(height=1024, width=1024, warmup_steps=4) 7 | pipeline = DistriSDXLPipeline.from_pretrained( 8 | distri_config=distri_config, 9 | pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0", 10 | variant="fp16", 11 | use_safetensors=True, 12 | ) 13 | 14 | pipeline.set_progress_bar_config(disable=distri_config.rank != 0) 15 | image = pipeline( 16 | prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", 17 | generator=torch.Generator(device="cuda").manual_seed(233), 18 | ).images[0] 19 | if distri_config.rank == 0: 20 | image.save("astronaut.png") 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | if __name__ == "__main__": 5 | with open("README.md", "r") as f: 6 | long_description = f.read() 7 | fp = open("distrifuser/__version__.py", "r").read() 8 | version = eval(fp.strip().split()[-1]) 9 | 10 | setup( 11 | name="distrifuser", 12 | author="Muyang Li, Tianle Cai, Jiaxin Cao, Qinsheng Zhang, Han Cai, Junjie Bai, Yangqing Jia, Ming-Yu Liu, Kai Li and Song Han", 13 | author_email="muyangli@mit.edu", 14 | packages=find_packages(), 15 | install_requires=["torch>=2.2", "diffusers==0.24.0", "transformers", "tqdm"], 16 | url="https://github.com/mit-han-lab/distrifuser", 17 | description="DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models", 18 | long_description=long_description, 19 | long_description_content_type="text/markdown", 20 | version=version, 21 | classifiers=[ 22 | "Programming Language :: Python :: 3", 23 | "Operating System :: OS Independent", 24 | ], 25 | include_package_data=True, 26 | python_requires=">=3.10", 27 | ) 28 | --------------------------------------------------------------------------------