├── .gitignore
├── LICENSE.txt
├── README.md
├── assets
├── idea.jpg
├── quality.jpg
├── speedups.jpg
└── teaser.jpg
├── distrifuser
├── __init__.py
├── __version__.py
├── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── distri_sdxl_unet_pp.py
│ ├── distri_sdxl_unet_tp.py
│ └── naive_patch_sdxl.py
├── modules
│ ├── __init__.py
│ ├── base_module.py
│ ├── pp
│ │ ├── __init__.py
│ │ ├── attn.py
│ │ ├── conv2d.py
│ │ └── groupnorm.py
│ └── tp
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── conv2d.py
│ │ ├── feed_forward.py
│ │ └── resnet.py
├── pipelines.py
└── utils.py
├── scripts
├── compute_metrics.py
├── dump_coco.py
├── export_html.py
├── generate_coco.py
├── profile_macs.py
├── run_sdxl.py
├── sd_example.py
└── sdxl_example.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .DS_Store
3 | build
4 | dist
5 | *.egg-info
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 MIT HAN Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models
2 |
3 | ### [Paper](http://arxiv.org/abs/2402.19481) | [Project](https://hanlab.mit.edu/projects/distrifusion) | [Blog](https://hanlab.mit.edu/blog/distrifusion) | [Slides](https://www.dropbox.com/scl/fi/yv98hi2kdoh27ej4jqlbp/slides.key?rlkey=3rmfxpezqt3co5x2hgqvxv09i&st=ve4z9w6t&dl=0) | [Youtube](https://www.youtube.com/watch?v=EZX7srDDmW0&list=PL80kAHvQbh-pKRxcSS6xjds7U7Yc0gDQI&index=1) | [Poster](https://www.dropbox.com/scl/fi/labhefjwi9r01e3o9eob0/poster.pdf?rlkey=rjj1jj179enln92h8kygrftmg&st=0ddego10&dl=0)
4 |
5 | **[Dec 1, 2024]** DistriFusion is integrated in NVIDIA's [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/sdxl/README.md) for distributed inference on high-resolution image generation.
6 |
7 | **[Jul 29, 2024]** DistriFusion is supported in [ColossalAI](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/inference/README.md)!
8 |
9 | **[Apr 4, 2024]** DistriFusion is selected as a **highlight** poster in CVPR 2024!
10 |
11 | **[Feb 29, 2024]** DistriFusion is accepted by CVPR 2024! Our code is publicly available!
12 |
13 | 
14 | *We introduce DistriFusion, a training-free algorithm to harness multiple GPUs to accelerate diffusion model inference without sacrificing image quality. Naïve Patch (Overview (b)) suffers from the fragmentation issue due to the lack of patch interaction. The presented examples are generated with SDXL using a 50-step Euler sampler at 1280×1920 resolution, and latency is measured on A100 GPUs.*
15 |
16 | DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models
17 | [Muyang Li](https://lmxyy.me/)\*, [Tianle Cai](https://www.tianle.website/)\*, [Jiaxin Cao](https://www.linkedin.com/in/jiaxin-cao-2166081b3/), [Qinsheng Zhang](https://qsh-zh.github.io), [Han Cai](https://han-cai.github.io), [Junjie Bai](https://www.linkedin.com/in/junjiebai/), [Yangqing Jia](https://daggerfs.com), [Ming-Yu Liu](https://mingyuliu.net), [Kai Li](https://www.cs.princeton.edu/~li/), and [Song Han](https://hanlab.mit.edu/songhan)
18 | MIT, Princeton, Lepton AI, and NVIDIA
19 | In CVPR 2024.
20 |
21 | ## Overview
22 | 
23 | **(a)** Original diffusion model running on a single device. **(b)** Naïvely splitting the image into 2 patches across 2 GPUs has an evident seam at the boundary due to the absence of interaction across patches. **(c)** Our DistriFusion employs synchronous communication for patch interaction at the first step. After that, we reuse the activations from the previous step via asynchronous communication. In this way, the communication overhead can be hidden into the computation pipeline.
24 |
25 | ## Performance
26 | ### Speedups
27 |
28 |
29 |
30 |
Measured total latency of DistriFusion with SDXL using a 50-step DDIM sampler for generating a single image across on NVIDIA A100 GPUs. When scaling up the resolution, the GPU devices are better utilized. Remarkably, when generating 3840×3840 images, DistriFusion achieves 1.8×, 3.4× and 6.1× speedups with 2, 4, and 8 A100s, respectively.
31 |
32 |
33 |
34 | ### Quality
35 |
36 | 
37 | Qualitative results of SDXL. FID is computed against the ground-truth images. Our DistriFusion can reduce the latency according to the number of used devices while preserving visual fidelity.
38 |
39 | References:
40 |
41 | * Denoising Diffusion Implicit Model (DDIM), Song *et al.*, ICLR 2021
42 | * Elucidating the Design Space of Diffusion-Based Generative Models, Karras *et al.*, NeurIPS 2022
43 | * Parallel Sampling of Diffusion Models, Shih *et al.*, NeurIPS 2023
44 | * SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis, Podell *et al.*, ICLR 2024
45 |
46 | ## Prerequisites
47 |
48 | * Python3
49 | * NVIDIA GPU + CUDA >= 12.0 and corresponding CuDNN
50 | * [PyTorch](https://pytorch.org) = 2.2.
51 |
52 | ## Getting Started
53 |
54 | ### Installation
55 |
56 | After installing [PyTorch](https://pytorch.org), you should be able to install `distrifuser` with PyPI
57 |
58 | ```shell
59 | pip install distrifuser
60 | ```
61 |
62 | or via GitHub:
63 |
64 | ```shell
65 | pip install git+https://github.com/mit-han-lab/distrifuser.git
66 | ```
67 |
68 | or locally for development
69 |
70 | ```shell
71 | git clone git@github.com:mit-han-lab/distrifuser.git
72 | cd distrifuser
73 | pip install -e .
74 | ```
75 |
76 | ### Usage Example
77 |
78 | In [`scripts/sdxl_example.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/sdxl_example.py), we provide a minimal script for running [SDXL](https://huggingface.co/docs/diffusers/en/using-diffusers/sdxl) with DistriFusion.
79 |
80 | ```python
81 | import torch
82 |
83 | from distrifuser.pipelines import DistriSDXLPipeline
84 | from distrifuser.utils import DistriConfig
85 |
86 | distri_config = DistriConfig(height=1024, width=1024, warmup_steps=4)
87 | pipeline = DistriSDXLPipeline.from_pretrained(
88 | distri_config=distri_config,
89 | pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0",
90 | variant="fp16",
91 | use_safetensors=True,
92 | )
93 |
94 | pipeline.set_progress_bar_config(disable=distri_config.rank != 0)
95 | image = pipeline(
96 | prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
97 | generator=torch.Generator(device="cuda").manual_seed(233),
98 | ).images[0]
99 | if distri_config.rank == 0:
100 | image.save("astronaut.png")
101 | ```
102 |
103 | Specifically, our `distrifuser` shares the same APIs as [diffusers](https://github.com/huggingface/diffusers) and can be used in a similar way. You just need to define a `DistriConfig` and use our wrapped `DistriSDXLPipeline` to load the pretrained SDXL model. Then, we can generate the image like the `StableDiffusionXLPipeline` in [diffusers](https://github.com/huggingface/diffusers). The running command is
104 |
105 | ```shell
106 | torchrun --nproc_per_node=$N_GPUS scripts/sdxl_example.py
107 | ```
108 |
109 | where `$N_GPUS` is the number GPUs you want to use.
110 |
111 | We also provide a minimal script for running SD1.4/2 with DistriFusion in [`scripts/sd_example.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/sd_example.py). The usage is the same.
112 |
113 | ### Benchmark
114 |
115 | Our benchmark results are using [PyTorch](https://pytorch.org) 2.2 and [diffusers](https://github.com/huggingface/diffusers) 0.24.0. First, you may need to install some additional dependencies:
116 |
117 | ```shell
118 | pip install git+https://github.com/zhijian-liu/torchprofile datasets torchmetrics dominate clean-fid
119 | ```
120 |
121 | #### COCO Quality
122 |
123 | You can use [`scripts/generate_coco.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/generate_coco.py) to generate images with COCO captions. The command is
124 |
125 | ```
126 | torchrun --nproc_per_node=$N_GPUS scripts/generate_coco.py --no_split_batch
127 | ```
128 |
129 | where `$N_GPUS` is the number GPUs you want to use. By default, the generated results will be stored in `results/coco`. You can also customize it with `--output_root`. Some additional arguments that you may want to tune:
130 |
131 | * `--num_inference_steps`: The number of inference steps. We use 50 by default.
132 | * `--guidance_scale`: The classifier-free guidance scale. We use 5 by default.
133 | * `--scheduler`: The diffusion sampler. We use [DDIM sampler](https://huggingface.co/docs/diffusers/v0.26.3/en/api/schedulers/ddim#ddimscheduler) by default. You can also use `euler` for [Euler sampler](https://huggingface.co/docs/diffusers/v0.26.3/en/api/schedulers/euler#eulerdiscretescheduler) and `dpm-solver` for [DPM solver](https://huggingface.co/docs/diffusers/en/api/schedulers/multistep_dpm_solver).
134 | * `--warmup_steps`: The number of additional warmup steps (4 by default).
135 | * `--sync_mode`: Different GroupNorm synchronization modes. By default, it is using our corrected asynchronous GroupNorm.
136 | * `--parallelism`: The parallelism paradigm you use. By default, it is patch parallelism. You can use `tensor` for tensor parallelism and `naive_patch` for naïve patch.
137 |
138 | After you generate all the images, you can use our script [`scripts/compute_metrics.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py) to calculate PSNR, LPIPS and FID. The usage is
139 |
140 | ```shell
141 | python scripts/compute_metrics.py --input_root0 $IMAGE_ROOT0 --input_root1 $IMAGE_ROOT1
142 | ```
143 |
144 | where `$IMAGE_ROOT0` and `$IMAGE_ROOT1` are paths to the image folders you are trying to compare. If `IMAGE_ROOT0` is the ground-truth foler, please add a `--is_gt` flag for resizing. We also provide a script [`scripts/dump_coco.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/dump_coco.py) to dump the ground-truth images.
145 |
146 | #### Latency
147 |
148 | You can use [`scripts/run_sdxl.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/run_sdxl.py) to benchmark the latency our different methods. The command is
149 |
150 | ```shell
151 | torchrun --nproc_per_node=$N_GPUS scripts/run_sdxl.py --mode benchmark --output_type latent
152 | ```
153 |
154 | where `$N_GPUS` is the number GPUs you want to use. Similar to [`scripts/generate_coco.py`](https://github.com/mit-han-lab/distrifuser/blob/main/scripts/generate_coco.py), you can also change some arguments:
155 |
156 | * `--num_inference_steps`: The number of inference steps. We use 50 by default.
157 | * `--image_size`: The generated image size. By default, it is 1024×1024.
158 | * `--no_split_batch`: Disable the batch splitting for classifier-free guidance.
159 | * `--warmup_steps`: The number of additional warmup steps (4 by default).
160 | * `--sync_mode`: Different GroupNorm synchronization modes. By default, it is using our corrected asynchronous GroupNorm.
161 | * `--parallelism`: The parallelism paradigm you use. By default, it is patch parallelism. You can use `tensor` for tensor parallelism and `naive_patch` for naïve patch.
162 | * `--warmup_times`/`--test_times`: The number of warmup/test runs. By default, they are 5 and 20, respectively.
163 |
164 |
165 | ## Citation
166 |
167 | If you use this code for your research, please cite our paper.
168 |
169 | ```bibtex
170 | @inproceedings{li2023distrifusion,
171 | title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
172 | author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Liu, Ming-Yu and Li, Kai and Han, Song},
173 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
174 | year={2024}
175 | }
176 | ```
177 |
178 | ## Acknowledgments
179 |
180 | Our code is developed based on [huggingface/diffusers](https://github.com/huggingface/diffusers) and [lmxyy/sige](https://github.com/lmxyy/sige). We thank [torchprofile](https://github.com/zhijian-liu/torchprofile) for MACs measurement, [clean-fid](https://github.com/GaParmar/clean-fid) for FID computation and [Lightning-AI/torchmetrics](https://github.com/Lightning-AI/torchmetrics) for PSNR and LPIPS.
181 |
182 | We thank Jun-Yan Zhu and Ligeng Zhu for their helpful discussion and valuable feedback. The project is supported by MIT-IBM Watson AI Lab, Amazon, MIT Science Hub, and National Science Foundation.
183 |
--------------------------------------------------------------------------------
/assets/idea.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/assets/idea.jpg
--------------------------------------------------------------------------------
/assets/quality.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/assets/quality.jpg
--------------------------------------------------------------------------------
/assets/speedups.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/assets/speedups.jpg
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/assets/teaser.jpg
--------------------------------------------------------------------------------
/distrifuser/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/distrifuser/__init__.py
--------------------------------------------------------------------------------
/distrifuser/__version__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.1beta1"
2 |
--------------------------------------------------------------------------------
/distrifuser/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/distrifuser/models/__init__.py
--------------------------------------------------------------------------------
/distrifuser/models/base_model.py:
--------------------------------------------------------------------------------
1 | from diffusers import ConfigMixin, ModelMixin
2 | from torch import nn
3 |
4 | from distrifuser.modules.base_module import BaseModule
5 | from ..utils import PatchParallelismCommManager, DistriConfig
6 |
7 |
8 | class BaseModel(ModelMixin, ConfigMixin):
9 | def __init__(self, model: nn.Module, distri_config: DistriConfig):
10 | super(BaseModel, self).__init__()
11 | self.model = model
12 | self.distri_config = distri_config
13 | self.comm_manager = None
14 |
15 | self.buffer_list = None
16 | self.output_buffer = None
17 | self.counter = 0
18 |
19 | # for cuda graph
20 | self.static_inputs = None
21 | self.static_outputs = None
22 | self.cuda_graphs = None
23 |
24 | def forward(self, *args, **kwargs):
25 | raise NotImplementedError
26 |
27 | def set_counter(self, counter: int = 0):
28 | self.counter = counter
29 | for module in self.model.modules():
30 | if isinstance(module, BaseModule):
31 | module.set_counter(counter)
32 |
33 | def set_comm_manager(self, comm_manager: PatchParallelismCommManager):
34 | self.comm_manager = comm_manager
35 | for module in self.model.modules():
36 | if isinstance(module, BaseModule):
37 | module.set_comm_manager(comm_manager)
38 |
39 | def setup_cuda_graph(self, static_outputs, cuda_graphs):
40 | self.static_outputs = static_outputs
41 | self.cuda_graphs = cuda_graphs
42 |
43 | @property
44 | def config(self):
45 | return self.model.config
46 |
47 | def synchronize(self):
48 | if self.comm_manager is not None and self.comm_manager.handles is not None:
49 | for i in range(len(self.comm_manager.handles)):
50 | if self.comm_manager.handles[i] is not None:
51 | self.comm_manager.handles[i].wait()
52 | self.comm_manager.handles[i] = None
53 |
--------------------------------------------------------------------------------
/distrifuser/models/distri_sdxl_unet_pp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import UNet2DConditionModel
3 | from diffusers.models.attention_processor import Attention
4 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput
5 | from torch import distributed as dist, nn
6 |
7 | from .base_model import BaseModel
8 | from distrifuser.modules.pp.attn import DistriCrossAttentionPP, DistriSelfAttentionPP
9 | from distrifuser.modules.base_module import BaseModule
10 | from distrifuser.modules.pp.conv2d import DistriConv2dPP
11 | from distrifuser.modules.pp.groupnorm import DistriGroupNorm
12 | from ..utils import DistriConfig
13 |
14 |
15 | class DistriUNetPP(BaseModel): # for Patch Parallelism
16 | def __init__(self, model: UNet2DConditionModel, distri_config: DistriConfig):
17 | assert isinstance(model, UNet2DConditionModel)
18 | if distri_config.world_size > 1 and distri_config.n_device_per_batch > 1:
19 | for name, module in model.named_modules():
20 | if isinstance(module, BaseModule):
21 | continue
22 | for subname, submodule in module.named_children():
23 | if isinstance(submodule, nn.Conv2d):
24 | kernel_size = submodule.kernel_size
25 | if kernel_size == (1, 1) or kernel_size == 1:
26 | continue
27 | wrapped_submodule = DistriConv2dPP(
28 | submodule, distri_config, is_first_layer=subname == "conv_in"
29 | )
30 | setattr(module, subname, wrapped_submodule)
31 | elif isinstance(submodule, Attention):
32 | if subname == "attn1": # self attention
33 | wrapped_submodule = DistriSelfAttentionPP(submodule, distri_config)
34 | else: # cross attention
35 | assert subname == "attn2"
36 | wrapped_submodule = DistriCrossAttentionPP(submodule, distri_config)
37 | setattr(module, subname, wrapped_submodule)
38 | elif isinstance(submodule, nn.GroupNorm):
39 | wrapped_submodule = DistriGroupNorm(submodule, distri_config)
40 | setattr(module, subname, wrapped_submodule)
41 |
42 | super(DistriUNetPP, self).__init__(model, distri_config)
43 |
44 | def forward(
45 | self,
46 | sample: torch.FloatTensor,
47 | timestep: torch.Tensor or float or int,
48 | encoder_hidden_states: torch.Tensor,
49 | class_labels: torch.Tensor or None = None,
50 | timestep_cond: torch.Tensor or None = None,
51 | attention_mask: torch.Tensor or None = None,
52 | cross_attention_kwargs: dict[str, any] or None = None,
53 | added_cond_kwargs: dict[str, torch.Tensor] or None = None,
54 | down_block_additional_residuals: tuple[torch.Tensor] or None = None,
55 | mid_block_additional_residual: torch.Tensor or None = None,
56 | down_intrablock_additional_residuals: tuple[torch.Tensor] or None = None,
57 | encoder_attention_mask: torch.Tensor or None = None,
58 | return_dict: bool = True,
59 | record: bool = False,
60 | ):
61 | distri_config = self.distri_config
62 | b, c, h, w = sample.shape
63 | assert (
64 | class_labels is None
65 | and timestep_cond is None
66 | and attention_mask is None
67 | and cross_attention_kwargs is None
68 | and down_block_additional_residuals is None
69 | and mid_block_additional_residual is None
70 | and down_intrablock_additional_residuals is None
71 | and encoder_attention_mask is None
72 | )
73 |
74 | if distri_config.use_cuda_graph and not record:
75 | static_inputs = self.static_inputs
76 |
77 | if distri_config.world_size > 1 and distri_config.do_classifier_free_guidance and distri_config.split_batch:
78 | assert b == 2
79 | batch_idx = distri_config.batch_idx()
80 | sample = sample[batch_idx : batch_idx + 1]
81 | timestep = (
82 | timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
83 | )
84 | encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
85 | if added_cond_kwargs is not None:
86 | for k in added_cond_kwargs:
87 | added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]
88 |
89 | assert static_inputs["sample"].shape == sample.shape
90 | static_inputs["sample"].copy_(sample)
91 | if torch.is_tensor(timestep):
92 | if timestep.ndim == 0:
93 | for b in range(static_inputs["timestep"].shape[0]):
94 | static_inputs["timestep"][b] = timestep.item()
95 | else:
96 | assert static_inputs["timestep"].shape == timestep.shape
97 | static_inputs["timestep"].copy_(timestep)
98 | else:
99 | for b in range(static_inputs["timestep"].shape[0]):
100 | static_inputs["timestep"][b] = timestep
101 | assert static_inputs["encoder_hidden_states"].shape == encoder_hidden_states.shape
102 | static_inputs["encoder_hidden_states"].copy_(encoder_hidden_states)
103 | if added_cond_kwargs is not None:
104 | for k in added_cond_kwargs:
105 | assert static_inputs["added_cond_kwargs"][k].shape == added_cond_kwargs[k].shape
106 | static_inputs["added_cond_kwargs"][k].copy_(added_cond_kwargs[k])
107 |
108 | if self.counter <= distri_config.warmup_steps:
109 | graph_idx = 0
110 | elif self.counter == distri_config.warmup_steps + 1:
111 | graph_idx = 1
112 | else:
113 | graph_idx = 2
114 |
115 | self.cuda_graphs[graph_idx].replay()
116 | output = self.static_outputs[graph_idx]
117 | else:
118 | if distri_config.world_size == 1:
119 | output = self.model(
120 | sample,
121 | timestep,
122 | encoder_hidden_states,
123 | class_labels=class_labels,
124 | timestep_cond=timestep_cond,
125 | attention_mask=attention_mask,
126 | cross_attention_kwargs=cross_attention_kwargs,
127 | added_cond_kwargs=added_cond_kwargs,
128 | down_block_additional_residuals=down_block_additional_residuals,
129 | mid_block_additional_residual=mid_block_additional_residual,
130 | down_intrablock_additional_residuals=down_intrablock_additional_residuals,
131 | encoder_attention_mask=encoder_attention_mask,
132 | return_dict=False,
133 | )[0]
134 | elif distri_config.do_classifier_free_guidance and distri_config.split_batch:
135 | assert b == 2
136 | batch_idx = distri_config.batch_idx()
137 | sample = sample[batch_idx : batch_idx + 1]
138 | timestep = (
139 | timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
140 | )
141 | encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
142 | if added_cond_kwargs is not None:
143 | new_added_cond_kwargs = {}
144 | for k in added_cond_kwargs:
145 | new_added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]
146 | added_cond_kwargs = new_added_cond_kwargs
147 | output = self.model(
148 | sample,
149 | timestep,
150 | encoder_hidden_states,
151 | class_labels=class_labels,
152 | timestep_cond=timestep_cond,
153 | attention_mask=attention_mask,
154 | cross_attention_kwargs=cross_attention_kwargs,
155 | added_cond_kwargs=added_cond_kwargs,
156 | down_block_additional_residuals=down_block_additional_residuals,
157 | mid_block_additional_residual=mid_block_additional_residual,
158 | down_intrablock_additional_residuals=down_intrablock_additional_residuals,
159 | encoder_attention_mask=encoder_attention_mask,
160 | return_dict=False,
161 | )[0]
162 | if self.output_buffer is None:
163 | self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype)
164 | if self.buffer_list is None:
165 | self.buffer_list = [torch.empty_like(output) for _ in range(distri_config.world_size)]
166 | dist.all_gather(self.buffer_list, output.contiguous(), async_op=False)
167 | torch.cat(self.buffer_list[: distri_config.n_device_per_batch], dim=2, out=self.output_buffer[0:1])
168 | torch.cat(self.buffer_list[distri_config.n_device_per_batch :], dim=2, out=self.output_buffer[1:2])
169 | output = self.output_buffer
170 | else:
171 | output = self.model(
172 | sample,
173 | timestep,
174 | encoder_hidden_states,
175 | class_labels=class_labels,
176 | timestep_cond=timestep_cond,
177 | attention_mask=attention_mask,
178 | cross_attention_kwargs=cross_attention_kwargs,
179 | added_cond_kwargs=added_cond_kwargs,
180 | down_block_additional_residuals=down_block_additional_residuals,
181 | mid_block_additional_residual=mid_block_additional_residual,
182 | down_intrablock_additional_residuals=down_intrablock_additional_residuals,
183 | encoder_attention_mask=encoder_attention_mask,
184 | return_dict=False,
185 | )[0]
186 | if self.output_buffer is None:
187 | self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype)
188 | if self.buffer_list is None:
189 | self.buffer_list = [torch.empty_like(output) for _ in range(distri_config.world_size)]
190 | output = output.contiguous()
191 | dist.all_gather(self.buffer_list, output, async_op=False)
192 | torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
193 | output = self.output_buffer
194 | if record:
195 | if self.static_inputs is None:
196 | self.static_inputs = {
197 | "sample": sample,
198 | "timestep": timestep,
199 | "encoder_hidden_states": encoder_hidden_states,
200 | "added_cond_kwargs": added_cond_kwargs,
201 | }
202 | self.synchronize()
203 |
204 | if return_dict:
205 | output = UNet2DConditionOutput(sample=output)
206 | else:
207 | output = (output,)
208 |
209 | self.counter += 1
210 | return output
211 |
212 | @property
213 | def add_embedding(self):
214 | return self.model.add_embedding
215 |
--------------------------------------------------------------------------------
/distrifuser/models/distri_sdxl_unet_tp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import UNet2DConditionModel
3 | from diffusers.models.attention import Attention, FeedForward
4 | from diffusers.models.resnet import ResnetBlock2D
5 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput
6 | from torch import distributed as dist, nn
7 |
8 | from distrifuser.modules.base_module import BaseModule
9 | from .base_model import BaseModel
10 | from ..modules.tp.attention import DistriAttentionTP
11 | from ..modules.tp.conv2d import DistriConv2dTP
12 | from ..modules.tp.feed_forward import DistriFeedForwardTP
13 | from ..modules.tp.resnet import DistriResnetBlock2DTP
14 | from ..utils import DistriConfig
15 |
16 |
17 | class DistriUNetTP(BaseModel): # for Patch Parallelism
18 | def __init__(self, model: UNet2DConditionModel, distri_config: DistriConfig):
19 | assert isinstance(model, UNet2DConditionModel)
20 | if distri_config.world_size > 1 and distri_config.n_device_per_batch > 1:
21 | for name, module in model.named_modules():
22 | if isinstance(module, BaseModule):
23 | continue
24 | for subname, submodule in module.named_children():
25 | if isinstance(submodule, Attention):
26 | wrapped_submodule = DistriAttentionTP(submodule, distri_config)
27 | setattr(module, subname, wrapped_submodule)
28 | elif isinstance(submodule, FeedForward):
29 | wrapped_submodule = DistriFeedForwardTP(submodule, distri_config)
30 | setattr(module, subname, wrapped_submodule)
31 | elif isinstance(submodule, ResnetBlock2D):
32 | wrapped_submodule = DistriResnetBlock2DTP(submodule, distri_config)
33 | setattr(module, subname, wrapped_submodule)
34 | elif isinstance(submodule, nn.Conv2d) and (
35 | subname == "conv_out" or "downsamplers" in name or "upsamplers" in name
36 | ):
37 | wrapped_submodule = DistriConv2dTP(submodule, distri_config)
38 | setattr(module, subname, wrapped_submodule)
39 |
40 | super(DistriUNetTP, self).__init__(model, distri_config)
41 |
42 | def forward(
43 | self,
44 | sample: torch.FloatTensor,
45 | timestep: torch.Tensor or float or int,
46 | encoder_hidden_states: torch.Tensor,
47 | class_labels: torch.Tensor or None = None,
48 | timestep_cond: torch.Tensor or None = None,
49 | attention_mask: torch.Tensor or None = None,
50 | cross_attention_kwargs: dict[str, any] or None = None,
51 | added_cond_kwargs: dict[str, torch.Tensor] or None = None,
52 | down_block_additional_residuals: tuple[torch.Tensor] or None = None,
53 | mid_block_additional_residual: torch.Tensor or None = None,
54 | down_intrablock_additional_residuals: tuple[torch.Tensor] or None = None,
55 | encoder_attention_mask: torch.Tensor or None = None,
56 | return_dict: bool = True,
57 | record: bool = False,
58 | ):
59 | distri_config = self.distri_config
60 | b, c, h, w = sample.shape
61 | assert (
62 | class_labels is None
63 | and timestep_cond is None
64 | and attention_mask is None
65 | and cross_attention_kwargs is None
66 | and down_block_additional_residuals is None
67 | and mid_block_additional_residual is None
68 | and down_intrablock_additional_residuals is None
69 | and encoder_attention_mask is None
70 | )
71 |
72 | if distri_config.use_cuda_graph and not record:
73 | static_inputs = self.static_inputs
74 |
75 | if distri_config.world_size > 1 and distri_config.do_classifier_free_guidance and distri_config.split_batch:
76 | assert b == 2
77 | batch_idx = distri_config.batch_idx()
78 | sample = sample[batch_idx : batch_idx + 1]
79 | timestep = (
80 | timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
81 | )
82 | encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
83 | if added_cond_kwargs is not None:
84 | for k in added_cond_kwargs:
85 | added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]
86 |
87 | assert static_inputs["sample"].shape == sample.shape
88 | static_inputs["sample"].copy_(sample)
89 | if torch.is_tensor(timestep):
90 | if timestep.ndim == 0:
91 | for b in range(static_inputs["timestep"].shape[0]):
92 | static_inputs["timestep"][b] = timestep.item()
93 | else:
94 | assert static_inputs["timestep"].shape == timestep.shape
95 | static_inputs["timestep"].copy_(timestep)
96 | else:
97 | for b in range(static_inputs["timestep"].shape[0]):
98 | static_inputs["timestep"][b] = timestep
99 | assert static_inputs["encoder_hidden_states"].shape == encoder_hidden_states.shape
100 | static_inputs["encoder_hidden_states"].copy_(encoder_hidden_states)
101 | if added_cond_kwargs is not None:
102 | for k in added_cond_kwargs:
103 | assert static_inputs["added_cond_kwargs"][k].shape == added_cond_kwargs[k].shape
104 | static_inputs["added_cond_kwargs"][k].copy_(added_cond_kwargs[k])
105 |
106 | graph_idx = 0
107 |
108 | self.cuda_graphs[graph_idx].replay()
109 | output = self.static_outputs[graph_idx]
110 | else:
111 | if distri_config.world_size == 1:
112 | output = self.model(
113 | sample,
114 | timestep,
115 | encoder_hidden_states,
116 | class_labels=class_labels,
117 | timestep_cond=timestep_cond,
118 | attention_mask=attention_mask,
119 | cross_attention_kwargs=cross_attention_kwargs,
120 | added_cond_kwargs=added_cond_kwargs,
121 | down_block_additional_residuals=down_block_additional_residuals,
122 | mid_block_additional_residual=mid_block_additional_residual,
123 | down_intrablock_additional_residuals=down_intrablock_additional_residuals,
124 | encoder_attention_mask=encoder_attention_mask,
125 | return_dict=False,
126 | )[0]
127 | elif distri_config.do_classifier_free_guidance and distri_config.split_batch:
128 | assert b == 2
129 | batch_idx = distri_config.batch_idx()
130 | sample = sample[batch_idx : batch_idx + 1]
131 | timestep = (
132 | timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
133 | )
134 | encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
135 | if added_cond_kwargs is not None:
136 | new_added_cond_kwargs = {}
137 | for k in added_cond_kwargs:
138 | new_added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]
139 | added_cond_kwargs = new_added_cond_kwargs
140 | output = self.model(
141 | sample,
142 | timestep,
143 | encoder_hidden_states,
144 | class_labels=class_labels,
145 | timestep_cond=timestep_cond,
146 | attention_mask=attention_mask,
147 | cross_attention_kwargs=cross_attention_kwargs,
148 | added_cond_kwargs=added_cond_kwargs,
149 | down_block_additional_residuals=down_block_additional_residuals,
150 | mid_block_additional_residual=mid_block_additional_residual,
151 | down_intrablock_additional_residuals=down_intrablock_additional_residuals,
152 | encoder_attention_mask=encoder_attention_mask,
153 | return_dict=False,
154 | )[0]
155 | if self.output_buffer is None:
156 | self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype)
157 | if self.buffer_list is None:
158 | self.buffer_list = [torch.empty_like(output) for _ in range(2)]
159 | dist.all_gather(
160 | self.buffer_list, output.contiguous(), group=distri_config.split_group(), async_op=False
161 | )
162 | torch.cat(self.buffer_list, dim=0, out=self.output_buffer)
163 | output = self.output_buffer
164 | else:
165 | output = self.model(
166 | sample,
167 | timestep,
168 | encoder_hidden_states,
169 | class_labels=class_labels,
170 | timestep_cond=timestep_cond,
171 | attention_mask=attention_mask,
172 | cross_attention_kwargs=cross_attention_kwargs,
173 | added_cond_kwargs=added_cond_kwargs,
174 | down_block_additional_residuals=down_block_additional_residuals,
175 | mid_block_additional_residual=mid_block_additional_residual,
176 | down_intrablock_additional_residuals=down_intrablock_additional_residuals,
177 | encoder_attention_mask=encoder_attention_mask,
178 | return_dict=False,
179 | )[0]
180 | if self.output_buffer is None:
181 | self.output_buffer = torch.empty_like(output)
182 | self.output_buffer.copy_(output)
183 | output = self.output_buffer
184 | if record:
185 | if self.static_inputs is None:
186 | self.static_inputs = {
187 | "sample": sample,
188 | "timestep": timestep,
189 | "encoder_hidden_states": encoder_hidden_states,
190 | "added_cond_kwargs": added_cond_kwargs,
191 | }
192 | self.synchronize()
193 |
194 | if return_dict:
195 | output = UNet2DConditionOutput(sample=output)
196 | else:
197 | output = (output,)
198 |
199 | self.counter += 1
200 | return output
201 |
202 | @property
203 | def add_embedding(self):
204 | return self.model.add_embedding
205 |
--------------------------------------------------------------------------------
/distrifuser/models/naive_patch_sdxl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import UNet2DConditionModel
3 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput
4 | from torch import distributed as dist
5 |
6 | from .base_model import BaseModel
7 | from ..utils import DistriConfig
8 |
9 |
10 | class NaivePatchUNet(BaseModel): # for Patch Parallelism
11 | def __init__(self, model: UNet2DConditionModel, distri_config: DistriConfig):
12 | assert isinstance(model, UNet2DConditionModel)
13 | super(NaivePatchUNet, self).__init__(model, distri_config)
14 |
15 | def forward(
16 | self,
17 | sample: torch.FloatTensor,
18 | timestep: torch.Tensor or float or int,
19 | encoder_hidden_states: torch.Tensor,
20 | class_labels: torch.Tensor or None = None,
21 | timestep_cond: torch.Tensor or None = None,
22 | attention_mask: torch.Tensor or None = None,
23 | cross_attention_kwargs: dict[str, any] or None = None,
24 | added_cond_kwargs: dict[str, torch.Tensor] or None = None,
25 | down_block_additional_residuals: tuple[torch.Tensor] or None = None,
26 | mid_block_additional_residual: torch.Tensor or None = None,
27 | down_intrablock_additional_residuals: tuple[torch.Tensor] or None = None,
28 | encoder_attention_mask: torch.Tensor or None = None,
29 | return_dict: bool = True,
30 | record: bool = False,
31 | ):
32 | distri_config = self.distri_config
33 | b, c, h, w = sample.shape
34 | assert (
35 | class_labels is None
36 | and timestep_cond is None
37 | and attention_mask is None
38 | and cross_attention_kwargs is None
39 | and down_block_additional_residuals is None
40 | and mid_block_additional_residual is None
41 | and down_intrablock_additional_residuals is None
42 | and encoder_attention_mask is None
43 | )
44 |
45 | if distri_config.use_cuda_graph and not record:
46 | static_inputs = self.static_inputs
47 |
48 | if distri_config.world_size > 1 and distri_config.do_classifier_free_guidance and distri_config.split_batch:
49 | assert b == 2
50 | batch_idx = distri_config.batch_idx()
51 | sample = sample[batch_idx : batch_idx + 1]
52 | timestep = (
53 | timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
54 | )
55 | encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
56 | if added_cond_kwargs is not None:
57 | for k in added_cond_kwargs:
58 | added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]
59 |
60 | assert static_inputs["sample"].shape == sample.shape
61 | static_inputs["sample"].copy_(sample)
62 | if torch.is_tensor(timestep):
63 | if timestep.ndim == 0:
64 | for b in range(static_inputs["timestep"].shape[0]):
65 | static_inputs["timestep"][b] = timestep.item()
66 | else:
67 | assert static_inputs["timestep"].shape == timestep.shape
68 | static_inputs["timestep"].copy_(timestep)
69 | else:
70 | for b in range(static_inputs["timestep"].shape[0]):
71 | static_inputs["timestep"][b] = timestep
72 | assert static_inputs["encoder_hidden_states"].shape == encoder_hidden_states.shape
73 | static_inputs["encoder_hidden_states"].copy_(encoder_hidden_states)
74 | if added_cond_kwargs is not None:
75 | for k in added_cond_kwargs:
76 | assert static_inputs["added_cond_kwargs"][k].shape == added_cond_kwargs[k].shape
77 | static_inputs["added_cond_kwargs"][k].copy_(added_cond_kwargs[k])
78 |
79 | graph_idx = 0
80 | if distri_config.split_scheme == "alternate":
81 | graph_idx = self.counter % 2
82 | self.cuda_graphs[graph_idx].replay()
83 | output = self.static_outputs[graph_idx]
84 | else:
85 | if distri_config.world_size == 1:
86 | output = self.model(
87 | sample,
88 | timestep,
89 | encoder_hidden_states,
90 | class_labels=class_labels,
91 | timestep_cond=timestep_cond,
92 | attention_mask=attention_mask,
93 | cross_attention_kwargs=cross_attention_kwargs,
94 | added_cond_kwargs=added_cond_kwargs,
95 | down_block_additional_residuals=down_block_additional_residuals,
96 | mid_block_additional_residual=mid_block_additional_residual,
97 | down_intrablock_additional_residuals=down_intrablock_additional_residuals,
98 | encoder_attention_mask=encoder_attention_mask,
99 | return_dict=False,
100 | )[0]
101 | elif distri_config.do_classifier_free_guidance and distri_config.split_batch:
102 | assert b == 2
103 | batch_idx = distri_config.batch_idx()
104 | sample = sample[batch_idx : batch_idx + 1]
105 | timestep = (
106 | timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
107 | )
108 | encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
109 | if added_cond_kwargs is not None:
110 | new_added_cond_kwargs = {}
111 | for k in added_cond_kwargs:
112 | new_added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]
113 | added_cond_kwargs = new_added_cond_kwargs
114 |
115 | if distri_config.split_scheme == "row":
116 | split_dim = 2
117 | elif distri_config.split_scheme == "col":
118 | split_dim = 3
119 | elif distri_config.split_scheme == "alternate":
120 | split_dim = 2 if self.counter % 2 == 0 else 3
121 | else:
122 | raise NotImplementedError
123 |
124 | if split_dim == 2:
125 | sample = sample.view(1, c, distri_config.n_device_per_batch, -1, w)[:, :, distri_config.split_idx()]
126 | else:
127 | assert split_dim == 3
128 | sample = sample.view(1, c, h, distri_config.n_device_per_batch, -1)[
129 | ..., distri_config.split_idx(), :
130 | ]
131 |
132 | output = self.model(
133 | sample,
134 | timestep,
135 | encoder_hidden_states,
136 | class_labels=class_labels,
137 | timestep_cond=timestep_cond,
138 | attention_mask=attention_mask,
139 | cross_attention_kwargs=cross_attention_kwargs,
140 | added_cond_kwargs=added_cond_kwargs,
141 | down_block_additional_residuals=down_block_additional_residuals,
142 | mid_block_additional_residual=mid_block_additional_residual,
143 | down_intrablock_additional_residuals=down_intrablock_additional_residuals,
144 | encoder_attention_mask=encoder_attention_mask,
145 | return_dict=False,
146 | )[0]
147 | if self.output_buffer is None:
148 | self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype)
149 | if self.buffer_list is None:
150 | self.buffer_list = [torch.empty_like(output.view(-1)) for _ in range(distri_config.world_size)]
151 | dist.all_gather(self.buffer_list, output.contiguous().view(-1), async_op=False)
152 | buffer_list = [buffer.view(output.shape) for buffer in self.buffer_list]
153 | torch.cat(buffer_list[: distri_config.n_device_per_batch], dim=split_dim, out=self.output_buffer[0:1])
154 | torch.cat(buffer_list[distri_config.n_device_per_batch :], dim=split_dim, out=self.output_buffer[1:2])
155 | output = self.output_buffer
156 | else:
157 | if distri_config.split_scheme == "row":
158 | split_dim = 2
159 | elif distri_config.split_scheme == "col":
160 | split_dim = 3
161 | elif distri_config.split_scheme == "alternate":
162 | split_dim = 2 if self.counter % 2 == 0 else 3
163 | else:
164 | raise NotImplementedError
165 |
166 | if split_dim == 2:
167 | sliced_sample = sample.view(b, c, distri_config.n_device_per_batch, -1, w)[
168 | :, :, distri_config.split_idx()
169 | ]
170 | else:
171 | assert split_dim == 3
172 | sliced_sample = sample.view(b, c, h, distri_config.n_device_per_batch, -1)[
173 | ..., distri_config.split_idx(), :
174 | ]
175 |
176 | output = self.model(
177 | sliced_sample,
178 | timestep,
179 | encoder_hidden_states,
180 | class_labels=class_labels,
181 | timestep_cond=timestep_cond,
182 | attention_mask=attention_mask,
183 | cross_attention_kwargs=cross_attention_kwargs,
184 | added_cond_kwargs=added_cond_kwargs,
185 | down_block_additional_residuals=down_block_additional_residuals,
186 | mid_block_additional_residual=mid_block_additional_residual,
187 | down_intrablock_additional_residuals=down_intrablock_additional_residuals,
188 | encoder_attention_mask=encoder_attention_mask,
189 | return_dict=False,
190 | )[0]
191 | if self.output_buffer is None:
192 | self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype)
193 | if self.buffer_list is None:
194 | self.buffer_list = [torch.empty_like(output.view(-1)) for _ in range(distri_config.world_size)]
195 | dist.all_gather(self.buffer_list, output.contiguous().view(-1), async_op=False)
196 | buffer_list = [buffer.view(output.shape) for buffer in self.buffer_list]
197 | torch.cat(buffer_list, dim=split_dim, out=self.output_buffer)
198 | output = self.output_buffer
199 | if record:
200 | if self.static_inputs is None:
201 | self.static_inputs = {
202 | "sample": sample,
203 | "timestep": timestep,
204 | "encoder_hidden_states": encoder_hidden_states,
205 | "added_cond_kwargs": added_cond_kwargs,
206 | }
207 | self.synchronize()
208 |
209 | if return_dict:
210 | output = UNet2DConditionOutput(sample=output)
211 | else:
212 | output = (output,)
213 |
214 | self.counter += 1
215 | return output
216 |
217 | @property
218 | def add_embedding(self):
219 | return self.model.add_embedding
220 |
--------------------------------------------------------------------------------
/distrifuser/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/distrifuser/modules/__init__.py
--------------------------------------------------------------------------------
/distrifuser/modules/base_module.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from distrifuser.utils import DistriConfig
4 |
5 |
6 | class BaseModule(nn.Module):
7 | def __init__(
8 | self,
9 | module: nn.Module,
10 | distri_config: DistriConfig,
11 | ):
12 | super(BaseModule, self).__init__()
13 | self.module = module
14 | self.distri_config = distri_config
15 | self.comm_manager = None
16 |
17 | self.counter = 0
18 |
19 | self.buffer_list = None
20 | self.idx = None
21 |
22 | def forward(self, *args, **kwargs):
23 | raise NotImplementedError
24 |
25 | def set_counter(self, counter: int = 0):
26 | self.counter = counter
27 |
28 | def set_comm_manager(self, comm_manager):
29 | self.comm_manager = comm_manager
30 |
--------------------------------------------------------------------------------
/distrifuser/modules/pp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/distrifuser/modules/pp/__init__.py
--------------------------------------------------------------------------------
/distrifuser/modules/pp/attn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers.models.attention import Attention
3 | from diffusers.utils import USE_PEFT_BACKEND
4 | from torch import distributed as dist
5 | from torch import nn
6 | from torch.nn import functional as F
7 |
8 | from distrifuser.modules.base_module import BaseModule
9 | from distrifuser.utils import DistriConfig
10 |
11 |
12 | class DistriAttentionPP(BaseModule):
13 | def __init__(self, module: Attention, distri_config: DistriConfig):
14 | super(DistriAttentionPP, self).__init__(module, distri_config)
15 |
16 | to_k = module.to_k
17 | to_v = module.to_v
18 | assert isinstance(to_k, nn.Linear)
19 | assert isinstance(to_v, nn.Linear)
20 | assert (to_k.bias is None) == (to_v.bias is None)
21 | assert to_k.weight.shape == to_v.weight.shape
22 |
23 | in_size, out_size = to_k.in_features, to_k.out_features
24 | to_kv = nn.Linear(
25 | in_size,
26 | out_size * 2,
27 | bias=to_k.bias is not None,
28 | device=to_k.weight.device,
29 | dtype=to_k.weight.dtype,
30 | )
31 | to_kv.weight.data[:out_size].copy_(to_k.weight.data)
32 | to_kv.weight.data[out_size:].copy_(to_v.weight.data)
33 |
34 | if to_k.bias is not None:
35 | assert to_v.bias is not None
36 | to_kv.bias.data[:out_size].copy_(to_k.bias.data)
37 | to_kv.bias.data[out_size:].copy_(to_v.bias.data)
38 |
39 | self.to_kv = to_kv
40 |
41 |
42 | class DistriCrossAttentionPP(DistriAttentionPP):
43 | def __init__(self, module: Attention, distri_config: DistriConfig):
44 | super(DistriCrossAttentionPP, self).__init__(module, distri_config)
45 | self.kv_cache = None
46 |
47 | def forward(
48 | self,
49 | hidden_states: torch.FloatTensor,
50 | encoder_hidden_states: torch.FloatTensor or None = None,
51 | scale: float = 1.0,
52 | *args,
53 | **kwargs,
54 | ):
55 | assert encoder_hidden_states is not None
56 | recompute_kv = self.counter == 0
57 |
58 | attn = self.module
59 | assert isinstance(attn, Attention)
60 |
61 | residual = hidden_states
62 |
63 | batch_size, sequence_length, _ = (
64 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
65 | )
66 |
67 | args = () if USE_PEFT_BACKEND else (scale,)
68 | query = attn.to_q(hidden_states, *args)
69 |
70 | if encoder_hidden_states is None:
71 | encoder_hidden_states = hidden_states
72 |
73 | if recompute_kv or self.kv_cache is None:
74 | kv = self.to_kv(encoder_hidden_states)
75 | self.kv_cache = kv
76 | else:
77 | kv = self.kv_cache
78 | key, value = torch.split(kv, kv.shape[-1] // 2, dim=-1)
79 |
80 | inner_dim = key.shape[-1]
81 | head_dim = inner_dim // attn.heads
82 |
83 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
84 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
85 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
86 |
87 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
88 |
89 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
90 | hidden_states = hidden_states.to(query.dtype)
91 |
92 | # linear proj
93 | hidden_states = attn.to_out[0](hidden_states, *args)
94 | # dropout
95 | hidden_states = attn.to_out[1](hidden_states)
96 |
97 | if attn.residual_connection:
98 | hidden_states = hidden_states + residual
99 |
100 | hidden_states = hidden_states / attn.rescale_output_factor
101 |
102 | self.counter += 1
103 |
104 | return hidden_states
105 |
106 |
107 | class DistriSelfAttentionPP(DistriAttentionPP):
108 | def __init__(self, module: Attention, distri_config: DistriConfig):
109 | super(DistriSelfAttentionPP, self).__init__(module, distri_config)
110 |
111 | def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
112 | attn = self.module
113 | distri_config = self.distri_config
114 | assert isinstance(attn, Attention)
115 |
116 | residual = hidden_states
117 |
118 | batch_size, sequence_length, _ = hidden_states.shape
119 |
120 | args = () if USE_PEFT_BACKEND else (scale,)
121 | query = attn.to_q(hidden_states, *args)
122 |
123 | encoder_hidden_states = hidden_states
124 |
125 | kv = self.to_kv(encoder_hidden_states)
126 |
127 | if distri_config.n_device_per_batch == 1:
128 | full_kv = kv
129 | else:
130 | if self.buffer_list is None: # buffer not created
131 | full_kv = torch.cat([kv for _ in range(distri_config.n_device_per_batch)], dim=1)
132 | elif distri_config.mode == "full_sync" or self.counter <= distri_config.warmup_steps:
133 | dist.all_gather(self.buffer_list, kv, group=distri_config.batch_group, async_op=False)
134 | full_kv = torch.cat(self.buffer_list, dim=1)
135 | else:
136 | new_buffer_list = [buffer for buffer in self.buffer_list]
137 | new_buffer_list[distri_config.split_idx()] = kv
138 | full_kv = torch.cat(new_buffer_list, dim=1)
139 | if distri_config.mode != "no_sync":
140 | self.comm_manager.enqueue(self.idx, kv)
141 |
142 | key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
143 |
144 | inner_dim = key.shape[-1]
145 | head_dim = inner_dim // attn.heads
146 |
147 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
148 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
149 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
150 |
151 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
152 | # TODO: add support for attn.scale when we move to Torch 2.1
153 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
154 |
155 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
156 | hidden_states = hidden_states.to(query.dtype)
157 |
158 | # linear proj
159 | hidden_states = attn.to_out[0](hidden_states, *args)
160 | # dropout
161 | hidden_states = attn.to_out[1](hidden_states)
162 |
163 | if attn.residual_connection:
164 | hidden_states = hidden_states + residual
165 |
166 | hidden_states = hidden_states / attn.rescale_output_factor
167 |
168 | return hidden_states
169 |
170 | def forward(
171 | self,
172 | hidden_states: torch.FloatTensor,
173 | encoder_hidden_states: torch.FloatTensor or None = None,
174 | scale: float = 1.0,
175 | *args,
176 | **kwargs,
177 | ) -> torch.FloatTensor:
178 | distri_config = self.distri_config
179 | if self.comm_manager is not None and self.comm_manager.handles is not None and self.idx is not None:
180 | if self.comm_manager.handles[self.idx] is not None:
181 | self.comm_manager.handles[self.idx].wait()
182 | self.comm_manager.handles[self.idx] = None
183 |
184 | b, l, c = hidden_states.shape
185 | if distri_config.n_device_per_batch > 1 and self.buffer_list is None:
186 | if self.comm_manager.buffer_list is None:
187 | self.idx = self.comm_manager.register_tensor(
188 | shape=(b, l, self.to_kv.out_features), torch_dtype=hidden_states.dtype, layer_type="attn"
189 | )
190 | else:
191 | self.buffer_list = self.comm_manager.get_buffer_list(self.idx)
192 | output = self._forward(hidden_states, scale=scale)
193 |
194 | self.counter += 1
195 | return output
196 |
--------------------------------------------------------------------------------
/distrifuser/modules/pp/conv2d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import distributed as dist
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | from distrifuser.modules.base_module import BaseModule
7 | from distrifuser.utils import DistriConfig
8 |
9 |
10 | class DistriConv2dPP(BaseModule):
11 | def __init__(self, module: nn.Conv2d, distri_config: DistriConfig, is_first_layer: bool = False):
12 | super(DistriConv2dPP, self).__init__(module, distri_config)
13 | self.is_first_layer = is_first_layer
14 |
15 | def naive_forward(self, x: torch.Tensor) -> torch.Tensor:
16 | # x: [B, C, H, W]
17 | output = self.module(x)
18 | return output
19 |
20 | def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
21 | config = self.distri_config
22 | b, c, h, w = x.shape
23 | assert h % config.n_device_per_batch == 0
24 |
25 | stride = self.module.stride[0]
26 | padding = self.module.padding[0]
27 |
28 | output_h = x.shape[2] // stride // config.n_device_per_batch
29 | idx = config.split_idx()
30 | h_begin = output_h * idx * stride - padding
31 | h_end = output_h * (idx + 1) * stride + padding
32 | final_padding = [padding, padding, 0, 0]
33 | if h_begin < 0:
34 | h_begin = 0
35 | final_padding[2] = padding
36 | if h_end > h:
37 | h_end = h
38 | final_padding[3] = padding
39 | sliced_input = x[:, :, h_begin:h_end, :]
40 | padded_input = F.pad(sliced_input, final_padding, mode="constant")
41 | return F.conv2d(padded_input, self.module.weight, self.module.bias, stride=stride, padding="valid")
42 |
43 | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
44 | distri_config = self.distri_config
45 |
46 | if self.comm_manager is not None and self.comm_manager.handles is not None and self.idx is not None:
47 | if self.comm_manager.handles[self.idx] is not None:
48 | self.comm_manager.handles[self.idx].wait()
49 | self.comm_manager.handles[self.idx] = None
50 |
51 | if distri_config.n_device_per_batch == 1:
52 | output = self.naive_forward(x)
53 | else:
54 | if self.is_first_layer:
55 | full_x = x
56 | output = self.sliced_forward(full_x)
57 | else:
58 | boundary_size = self.module.padding[0]
59 | if self.buffer_list is None:
60 | if self.comm_manager.buffer_list is None:
61 | self.idx = self.comm_manager.register_tensor(
62 | shape=[2, x.shape[0], x.shape[1], boundary_size, x.shape[3]],
63 | torch_dtype=x.dtype,
64 | layer_type="conv2d",
65 | )
66 | else:
67 | self.buffer_list = self.comm_manager.get_buffer_list(self.idx)
68 | if self.buffer_list is None:
69 | output = self.naive_forward(x)
70 | else:
71 |
72 | def create_padded_x():
73 | if distri_config.split_idx() == 0:
74 | concat_x = torch.cat([x, self.buffer_list[distri_config.split_idx() + 1][0]], dim=2)
75 | padded_x = F.pad(concat_x, [0, 0, boundary_size, 0], mode="constant")
76 | elif distri_config.split_idx() == distri_config.n_device_per_batch - 1:
77 | concat_x = torch.cat([self.buffer_list[distri_config.split_idx() - 1][1], x], dim=2)
78 | padded_x = F.pad(concat_x, [0, 0, 0, boundary_size], mode="constant")
79 | else:
80 | padded_x = torch.cat(
81 | [
82 | self.buffer_list[distri_config.split_idx() - 1][1],
83 | x,
84 | self.buffer_list[distri_config.split_idx() + 1][0],
85 | ],
86 | dim=2,
87 | )
88 | return padded_x
89 |
90 | boundary = torch.stack([x[:, :, :boundary_size, :], x[:, :, -boundary_size:, :]], dim=0)
91 |
92 | if distri_config.mode == "full_sync" or self.counter <= distri_config.warmup_steps:
93 | dist.all_gather(self.buffer_list, boundary, group=distri_config.batch_group, async_op=False)
94 | padded_x = create_padded_x()
95 | output = F.conv2d(
96 | padded_x,
97 | self.module.weight,
98 | self.module.bias,
99 | stride=self.module.stride[0],
100 | padding=(0, self.module.padding[1]),
101 | )
102 | else:
103 | padded_x = create_padded_x()
104 | output = F.conv2d(
105 | padded_x,
106 | self.module.weight,
107 | self.module.bias,
108 | stride=self.module.stride[0],
109 | padding=(0, self.module.padding[1]),
110 | )
111 | if distri_config.mode != "no_sync":
112 | self.comm_manager.enqueue(self.idx, boundary)
113 |
114 | self.counter += 1
115 | return output
116 |
--------------------------------------------------------------------------------
/distrifuser/modules/pp/groupnorm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import distributed as dist
3 | from torch import nn
4 |
5 | from distrifuser.modules.base_module import BaseModule
6 | from distrifuser.utils import DistriConfig
7 |
8 |
9 | class DistriGroupNorm(BaseModule):
10 | def __init__(self, module: nn.GroupNorm, distri_config: DistriConfig):
11 | assert isinstance(module, nn.GroupNorm)
12 | super(DistriGroupNorm, self).__init__(module, distri_config)
13 |
14 | def forward(self, x: torch.Tensor) -> torch.Tensor:
15 | module = self.module
16 | assert isinstance(module, nn.GroupNorm)
17 | distri_config = self.distri_config
18 |
19 | if self.comm_manager is not None and self.comm_manager.handles is not None and self.idx is not None:
20 | if self.comm_manager.handles[self.idx] is not None:
21 | self.comm_manager.handles[self.idx].wait()
22 | self.comm_manager.handles[self.idx] = None
23 |
24 | assert x.ndim == 4
25 | n, c, h, w = x.shape
26 | num_groups = module.num_groups
27 | group_size = c // num_groups
28 |
29 | if distri_config.mode in ["stale_gn", "corrected_async_gn"]:
30 | if self.buffer_list is None:
31 | if self.comm_manager.buffer_list is None:
32 | n, c, h, w = x.shape
33 | self.idx = self.comm_manager.register_tensor(
34 | shape=[2, n, num_groups, 1, 1, 1], torch_dtype=x.dtype, layer_type="gn"
35 | )
36 | else:
37 | self.buffer_list = self.comm_manager.get_buffer_list(self.idx)
38 | x = x.view([n, num_groups, group_size, h, w])
39 | x_mean = x.mean(dim=[2, 3, 4], keepdim=True) # [1, num_groups, 1, 1, 1]
40 | x2_mean = (x**2).mean(dim=[2, 3, 4], keepdim=True) # [1, num_groups, 1, 1, 1]
41 | slice_mean = torch.stack([x_mean, x2_mean], dim=0)
42 |
43 | if self.buffer_list is None:
44 | full_mean = slice_mean
45 | elif self.counter <= distri_config.warmup_steps:
46 | dist.all_gather(self.buffer_list, slice_mean, group=distri_config.batch_group, async_op=False)
47 | full_mean = sum(self.buffer_list) / distri_config.n_device_per_batch
48 | else:
49 | if distri_config.mode == "corrected_async_gn":
50 | correction = slice_mean - self.buffer_list[distri_config.split_idx()]
51 | full_mean = sum(self.buffer_list) / distri_config.n_device_per_batch + correction
52 | else:
53 | new_buffer_list = [buffer for buffer in self.buffer_list]
54 | new_buffer_list[distri_config.split_idx()] = slice_mean
55 | full_mean = sum(new_buffer_list) / distri_config.n_device_per_batch
56 | self.comm_manager.enqueue(self.idx, slice_mean)
57 |
58 | full_x_mean, full_x2_mean = full_mean[0], full_mean[1]
59 | var = full_x2_mean - full_x_mean**2
60 | if distri_config.mode == "corrected_async_gn":
61 | slice_x_mean, slice_x2_mean = slice_mean[0], slice_mean[1]
62 | slice_var = slice_x2_mean - slice_x_mean**2
63 | var = torch.where(var < 0, slice_var, var) # Correct negative variance
64 |
65 | num_elements = group_size * h * w
66 | var = var * (num_elements / (num_elements - 1))
67 | std = (var + module.eps).sqrt()
68 | output = (x - full_x_mean) / std
69 | output = output.view([n, c, h, w])
70 | if module.affine:
71 | output = output * module.weight.view([1, -1, 1, 1])
72 | output = output + module.bias.view([1, -1, 1, 1])
73 | else:
74 | if self.counter <= distri_config.warmup_steps or distri_config.mode in ["sync_gn", "full_sync"]:
75 | x = x.view([n, num_groups, group_size, h, w])
76 | x_mean = x.mean(dim=[2, 3, 4], keepdim=True) # [1, num_groups, 1, 1, 1]
77 | x2_mean = (x**2).mean(dim=[2, 3, 4], keepdim=True) # [1, num_groups, 1, 1, 1]
78 | mean = torch.stack([x_mean, x2_mean], dim=0)
79 | dist.all_reduce(mean, op=dist.ReduceOp.SUM, group=distri_config.batch_group)
80 | mean = mean / distri_config.n_device_per_batch
81 | x_mean = mean[0]
82 | x2_mean = mean[1]
83 | var = x2_mean - x_mean**2
84 | num_elements = group_size * h * w
85 | var = var * (num_elements / (num_elements - 1))
86 | std = (var + module.eps).sqrt()
87 | output = (x - x_mean) / std
88 | output = output.view([n, c, h, w])
89 | if module.affine:
90 | output = output * module.weight.view([1, -1, 1, 1])
91 | output = output + module.bias.view([1, -1, 1, 1])
92 | elif distri_config.mode in ["separate_gn", "no_sync"]:
93 | output = module(x)
94 | else:
95 | raise NotImplementedError
96 | self.counter += 1
97 | return output
98 |
--------------------------------------------------------------------------------
/distrifuser/modules/tp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/distrifuser/8aebdd6d0db87884ba03dc596c98d28f3f38a1a3/distrifuser/modules/tp/__init__.py
--------------------------------------------------------------------------------
/distrifuser/modules/tp/attention.py:
--------------------------------------------------------------------------------
1 | import torch.cuda
2 | from diffusers.models.attention_processor import Attention
3 | from torch import distributed as dist
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from distrifuser.modules.base_module import BaseModule
8 | from distrifuser.utils import DistriConfig
9 |
10 |
11 | class DistriAttentionTP(BaseModule):
12 | def __init__(self, module: Attention, distri_config: DistriConfig):
13 | super(DistriAttentionTP, self).__init__(module, distri_config)
14 |
15 | heads = module.heads
16 | sliced_heads = heads // distri_config.n_device_per_batch
17 | remainder_heads = heads % distri_config.n_device_per_batch
18 | if distri_config.split_idx() < remainder_heads:
19 | sliced_heads += 1
20 | self.sliced_heads = sliced_heads
21 |
22 | if sliced_heads > 0:
23 | if distri_config.split_idx() < remainder_heads:
24 | start_head = distri_config.split_idx() * sliced_heads
25 | else:
26 | start_head = (
27 | remainder_heads * (sliced_heads + 1) + (distri_config.split_idx() - remainder_heads) * sliced_heads
28 | )
29 | end_head = start_head + sliced_heads
30 |
31 | dim = module.to_q.out_features // heads
32 |
33 | sharded_to_q = nn.Linear(
34 | module.to_q.in_features,
35 | sliced_heads * dim,
36 | bias=module.to_q.bias is not None,
37 | device=module.to_q.weight.device,
38 | dtype=module.to_q.weight.dtype,
39 | )
40 | sharded_to_q.weight.data.copy_(module.to_q.weight.data[start_head * dim : end_head * dim])
41 | if module.to_q.bias is not None:
42 | sharded_to_q.bias.data.copy_(module.to_q.bias.data[start_head * dim : end_head * dim])
43 |
44 | sharded_to_k = nn.Linear(
45 | module.to_k.in_features,
46 | sliced_heads * dim,
47 | bias=module.to_k.bias is not None,
48 | device=module.to_k.weight.device,
49 | dtype=module.to_k.weight.dtype,
50 | )
51 | sharded_to_k.weight.data.copy_(module.to_k.weight.data[start_head * dim : end_head * dim])
52 | if module.to_k.bias is not None:
53 | sharded_to_k.bias.data.copy_(module.to_k.bias.data[start_head * dim : end_head * dim])
54 |
55 | sharded_to_v = nn.Linear(
56 | module.to_v.in_features,
57 | sliced_heads * dim,
58 | bias=module.to_v.bias is not None,
59 | device=module.to_v.weight.device,
60 | dtype=module.to_v.weight.dtype,
61 | )
62 | sharded_to_v.weight.data.copy_(module.to_v.weight.data[start_head * dim : end_head * dim])
63 | if module.to_v.bias is not None:
64 | sharded_to_v.bias.data.copy_(module.to_v.bias.data[start_head * dim : end_head * dim])
65 |
66 | sharded_to_out = nn.Linear(
67 | sliced_heads * dim,
68 | module.to_out[0].out_features,
69 | bias=module.to_out[0].bias is not None,
70 | device=module.to_out[0].weight.device,
71 | dtype=module.to_out[0].weight.dtype,
72 | )
73 | sharded_to_out.weight.data.copy_(module.to_out[0].weight.data[:, start_head * dim : end_head * dim])
74 | if module.to_out[0].bias is not None:
75 | sharded_to_out.bias.data.copy_(module.to_out[0].bias.data)
76 |
77 | del module.to_q
78 | del module.to_k
79 | del module.to_v
80 |
81 | old_to_out = module.to_out[0]
82 |
83 | module.to_q = sharded_to_q
84 | module.to_k = sharded_to_k
85 | module.to_v = sharded_to_v
86 | module.to_out[0] = sharded_to_out
87 | module.heads = sliced_heads
88 |
89 | del old_to_out
90 |
91 | torch.cuda.empty_cache()
92 |
93 | def forward(
94 | self,
95 | hidden_states: torch.FloatTensor,
96 | encoder_hidden_states: torch.FloatTensor or None = None,
97 | attention_mask: torch.FloatTensor or None = None,
98 | **cross_attention_kwargs,
99 | ) -> torch.Tensor:
100 | distri_config = self.distri_config
101 | module = self.module
102 | residual = hidden_states
103 |
104 | if self.sliced_heads > 0:
105 | input_ndim = hidden_states.ndim
106 |
107 | assert input_ndim == 3
108 |
109 | batch_size, sequence_length, _ = (
110 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
111 | )
112 |
113 | if attention_mask is not None:
114 | attention_mask = module.prepare_attention_mask(attention_mask, sequence_length, batch_size)
115 | # scaled_dot_product_attention expects attention_mask shape to be
116 | # (batch, heads, source_length, target_length)
117 | attention_mask = attention_mask.view(batch_size, module.heads, -1, attention_mask.shape[-1])
118 |
119 | if module.group_norm is not None:
120 | hidden_states = module.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
121 |
122 | query = module.to_q(hidden_states)
123 |
124 | if encoder_hidden_states is None:
125 | encoder_hidden_states = hidden_states
126 | elif module.norm_cross:
127 | encoder_hidden_states = module.norm_encoder_hidden_states(encoder_hidden_states)
128 |
129 | key = module.to_k(encoder_hidden_states)
130 | value = module.to_v(encoder_hidden_states)
131 |
132 | inner_dim = key.shape[-1]
133 | head_dim = inner_dim // module.heads
134 |
135 | query = query.view(batch_size, -1, module.heads, head_dim).transpose(1, 2)
136 |
137 | key = key.view(batch_size, -1, module.heads, head_dim).transpose(1, 2)
138 | value = value.view(batch_size, -1, module.heads, head_dim).transpose(1, 2)
139 |
140 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
141 | # TODO: add support for attn.scale when we move to Torch 2.1
142 | hidden_states = F.scaled_dot_product_attention(
143 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
144 | )
145 |
146 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, module.heads * head_dim)
147 | hidden_states = hidden_states.to(query.dtype)
148 |
149 | # linear proj
150 | hidden_states = F.linear(hidden_states, module.to_out[0].weight, bias=None)
151 | # dropout
152 | hidden_states = module.to_out[1](hidden_states)
153 | else:
154 | hidden_states = torch.zeros(
155 | [hidden_states.shape[0], hidden_states.shape[1], module.to_out[0].out_features],
156 | device=hidden_states.device,
157 | dtype=hidden_states.dtype,
158 | )
159 | dist.all_reduce(hidden_states, op=dist.ReduceOp.SUM, group=distri_config.batch_group, async_op=False)
160 | if module.to_out[0].bias is not None:
161 | hidden_states = hidden_states + module.to_out[0].bias.view(1, 1, -1)
162 |
163 | if module.residual_connection:
164 | hidden_states = hidden_states + residual
165 |
166 | hidden_states = hidden_states / module.rescale_output_factor
167 |
168 | self.counter += 1
169 |
170 | return hidden_states
171 |
--------------------------------------------------------------------------------
/distrifuser/modules/tp/conv2d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import distributed as dist
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | from distrifuser.modules.base_module import BaseModule
7 | from distrifuser.utils import DistriConfig
8 |
9 |
10 | class DistriConv2dTP(BaseModule):
11 | def __init__(self, module: nn.Conv2d, distri_config: DistriConfig):
12 | super(DistriConv2dTP, self).__init__(module, distri_config)
13 | assert module.in_channels % distri_config.n_device_per_batch == 0
14 |
15 | sharded_module = nn.Conv2d(
16 | module.in_channels // distri_config.n_device_per_batch,
17 | module.out_channels,
18 | module.kernel_size,
19 | module.stride,
20 | module.padding,
21 | module.dilation,
22 | module.groups,
23 | module.bias is not None,
24 | module.padding_mode,
25 | device=module.weight.device,
26 | dtype=module.weight.dtype,
27 | )
28 | start_idx = distri_config.split_idx() * (module.in_channels // distri_config.n_device_per_batch)
29 | end_idx = (distri_config.split_idx() + 1) * (module.in_channels // distri_config.n_device_per_batch)
30 | sharded_module.weight.data.copy_(module.weight.data[:, start_idx:end_idx])
31 | if module.bias is not None:
32 | sharded_module.bias.data.copy_(module.bias.data)
33 |
34 | self.module = sharded_module
35 | del module
36 |
37 | def forward(self, x: torch.Tensor) -> torch.Tensor:
38 | distri_config = self.distri_config
39 |
40 | b, c, h, w = x.shape
41 | start_idx = distri_config.split_idx() * (c // distri_config.n_device_per_batch)
42 | end_idx = (distri_config.split_idx() + 1) * (c // distri_config.n_device_per_batch)
43 | output = F.conv2d(
44 | x[:, start_idx:end_idx],
45 | self.module.weight,
46 | bias=None,
47 | stride=self.module.stride,
48 | padding=self.module.padding,
49 | dilation=self.module.dilation,
50 | groups=self.module.groups,
51 | )
52 | dist.all_reduce(output, op=dist.ReduceOp.SUM, group=distri_config.batch_group, async_op=False)
53 | if self.module.bias is not None:
54 | output = output + self.module.bias.view(1, -1, 1, 1)
55 |
56 | self.counter += 1
57 | return output
58 |
--------------------------------------------------------------------------------
/distrifuser/modules/tp/feed_forward.py:
--------------------------------------------------------------------------------
1 | import torch.cuda
2 | from diffusers.models.attention import FeedForward, GEGLU
3 | from torch import distributed as dist
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from ..base_module import BaseModule
8 | from ...utils import DistriConfig
9 |
10 |
11 | class DistriFeedForwardTP(BaseModule):
12 | def __init__(self, module: FeedForward, distri_config: DistriConfig):
13 | super(DistriFeedForwardTP, self).__init__(module, distri_config)
14 | assert isinstance(module.net[0], GEGLU)
15 | assert module.net[0].proj.out_features % (distri_config.n_device_per_batch * 2) == 0
16 | assert module.net[2].in_features % distri_config.n_device_per_batch == 0
17 |
18 | mid_features = module.net[2].in_features // distri_config.n_device_per_batch
19 |
20 | sharded_fc1 = nn.Linear(
21 | module.net[0].proj.in_features,
22 | mid_features * 2,
23 | bias=module.net[0].proj.bias is not None,
24 | device=module.net[0].proj.weight.device,
25 | dtype=module.net[0].proj.weight.dtype,
26 | )
27 | start_idx = distri_config.split_idx() * mid_features
28 | end_idx = (distri_config.split_idx() + 1) * mid_features
29 | sharded_fc1.weight.data[:mid_features].copy_(module.net[0].proj.weight.data[start_idx:end_idx])
30 | if module.net[0].proj.bias is not None:
31 | sharded_fc1.bias.data[:mid_features].copy_(module.net[0].proj.bias.data[start_idx:end_idx])
32 | start_idx = (distri_config.n_device_per_batch + distri_config.split_idx()) * mid_features
33 | end_idx = (distri_config.n_device_per_batch + distri_config.split_idx() + 1) * mid_features
34 | sharded_fc1.weight.data[mid_features:].copy_(module.net[0].proj.weight.data[start_idx:end_idx])
35 | if module.net[0].proj.bias is not None:
36 | sharded_fc1.bias.data[mid_features:].copy_(module.net[0].proj.bias.data[start_idx:end_idx])
37 |
38 | sharded_fc2 = nn.Linear(
39 | mid_features,
40 | module.net[2].out_features,
41 | bias=module.net[2].bias is not None,
42 | device=module.net[2].weight.device,
43 | dtype=module.net[2].weight.dtype,
44 | )
45 | sharded_fc2.weight.data.copy_(
46 | module.net[2].weight.data[
47 | :, distri_config.split_idx() * mid_features : (distri_config.split_idx() + 1) * mid_features
48 | ]
49 | )
50 | if module.net[2].bias is not None:
51 | sharded_fc2.bias.data.copy_(module.net[2].bias.data)
52 |
53 | old_fc1 = module.net[0].proj
54 | old_fc2 = module.net[2]
55 |
56 | module.net[0].proj = sharded_fc1
57 | module.net[2] = sharded_fc2
58 |
59 | del old_fc1
60 | del old_fc2
61 | torch.cuda.empty_cache()
62 |
63 | def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
64 | distri_config = self.distri_config
65 | module = self.module
66 |
67 | assert scale == 1.0
68 | for i, submodule in enumerate(module.net):
69 | if i == 0:
70 | hidden_states, gate = submodule.proj(hidden_states).chunk(2, dim=-1)
71 | hidden_states = hidden_states * submodule.gelu(gate)
72 | elif i == 2:
73 | hidden_states = F.linear(hidden_states, submodule.weight, None)
74 | else:
75 | hidden_states = submodule(hidden_states)
76 |
77 | dist.all_reduce(hidden_states, op=dist.ReduceOp.SUM, group=distri_config.batch_group, async_op=False)
78 | if module.net[2].bias is not None:
79 | hidden_states = hidden_states + module.net[2].bias.view(1, 1, -1)
80 |
81 | self.counter += 1
82 |
83 | return hidden_states
84 |
--------------------------------------------------------------------------------
/distrifuser/modules/tp/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.cuda
2 | from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D, USE_PEFT_BACKEND
3 | from torch import distributed as dist
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from ..base_module import BaseModule
8 | from ...utils import DistriConfig
9 |
10 |
11 | class DistriResnetBlock2DTP(BaseModule):
12 | def __init__(self, module: ResnetBlock2D, distri_config: DistriConfig):
13 | super(DistriResnetBlock2DTP, self).__init__(module, distri_config)
14 | assert module.conv1.out_channels % distri_config.n_device_per_batch == 0
15 |
16 | mid_channels = module.conv1.out_channels // distri_config.n_device_per_batch
17 |
18 | sharded_conv1 = nn.Conv2d(
19 | module.conv1.in_channels,
20 | mid_channels,
21 | module.conv1.kernel_size,
22 | module.conv1.stride,
23 | module.conv1.padding,
24 | module.conv1.dilation,
25 | module.conv1.groups,
26 | module.conv1.bias is not None,
27 | module.conv1.padding_mode,
28 | device=module.conv1.weight.device,
29 | dtype=module.conv1.weight.dtype,
30 | )
31 | sharded_conv1.weight.data.copy_(
32 | module.conv1.weight.data[
33 | distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
34 | ]
35 | )
36 | if module.conv1.bias is not None:
37 | sharded_conv1.bias.data.copy_(
38 | module.conv1.bias.data[
39 | distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
40 | ]
41 | )
42 |
43 | sharded_conv2 = nn.Conv2d(
44 | mid_channels,
45 | module.conv2.out_channels,
46 | module.conv2.kernel_size,
47 | module.conv2.stride,
48 | module.conv2.padding,
49 | module.conv2.dilation,
50 | module.conv2.groups,
51 | module.conv2.bias is not None,
52 | module.conv2.padding_mode,
53 | device=module.conv2.weight.device,
54 | dtype=module.conv2.weight.dtype,
55 | )
56 | sharded_conv2.weight.data.copy_(
57 | module.conv2.weight.data[
58 | :, distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
59 | ]
60 | )
61 | if module.conv2.bias is not None:
62 | sharded_conv2.bias.data.copy_(module.conv2.bias.data)
63 |
64 | assert module.time_emb_proj is not None
65 | assert module.time_embedding_norm == "default"
66 |
67 | sharded_time_emb_proj = nn.Linear(
68 | module.time_emb_proj.in_features,
69 | mid_channels,
70 | bias=module.time_emb_proj.bias is not None,
71 | device=module.time_emb_proj.weight.device,
72 | dtype=module.time_emb_proj.weight.dtype,
73 | )
74 | sharded_time_emb_proj.weight.data.copy_(
75 | module.time_emb_proj.weight.data[
76 | distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
77 | ]
78 | )
79 | if module.time_emb_proj.bias is not None:
80 | sharded_time_emb_proj.bias.data.copy_(
81 | module.time_emb_proj.bias.data[
82 | distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
83 | ]
84 | )
85 |
86 | sharded_norm2 = nn.GroupNorm(
87 | module.norm2.num_groups // distri_config.n_device_per_batch,
88 | mid_channels,
89 | module.norm2.eps,
90 | module.norm2.affine,
91 | device=module.norm2.weight.device,
92 | dtype=module.norm2.weight.dtype,
93 | )
94 | if module.norm2.affine:
95 | sharded_norm2.weight.data.copy_(
96 | module.norm2.weight.data[
97 | distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
98 | ]
99 | )
100 | sharded_norm2.bias.data.copy_(
101 | module.norm2.bias.data[
102 | distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
103 | ]
104 | )
105 |
106 | del module.conv1
107 | del module.conv2
108 | del module.time_emb_proj
109 | del module.norm2
110 | module.conv1 = sharded_conv1
111 | module.conv2 = sharded_conv2
112 | module.time_emb_proj = sharded_time_emb_proj
113 | module.norm2 = sharded_norm2
114 |
115 | torch.cuda.empty_cache()
116 |
117 | def forward(
118 | self,
119 | input_tensor: torch.FloatTensor,
120 | temb: torch.FloatTensor,
121 | scale: float = 1.0,
122 | ) -> torch.FloatTensor:
123 | assert scale == 1.0
124 |
125 | distri_config = self.distri_config
126 | module = self.module
127 |
128 | hidden_states = input_tensor
129 | hidden_states = module.norm1(hidden_states)
130 |
131 | hidden_states = module.nonlinearity(hidden_states)
132 |
133 | if module.upsample is not None:
134 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
135 | if hidden_states.shape[0] >= 64:
136 | input_tensor = input_tensor.contiguous()
137 | hidden_states = hidden_states.contiguous()
138 | input_tensor = (
139 | module.upsample(input_tensor, scale=scale)
140 | if isinstance(module.upsample, Upsample2D)
141 | else module.upsample(input_tensor)
142 | )
143 | hidden_states = (
144 | module.upsample(hidden_states, scale=scale)
145 | if isinstance(module.upsample, Upsample2D)
146 | else module.upsample(hidden_states)
147 | )
148 | elif module.downsample is not None:
149 | input_tensor = (
150 | module.downsample(input_tensor, scale=scale)
151 | if isinstance(module.downsample, Downsample2D)
152 | else module.downsample(input_tensor)
153 | )
154 | hidden_states = (
155 | module.downsample(hidden_states, scale=scale)
156 | if isinstance(module.downsample, Downsample2D)
157 | else module.downsample(hidden_states)
158 | )
159 |
160 | hidden_states = module.conv1(hidden_states)
161 |
162 | if module.time_emb_proj is not None:
163 | if not module.skip_time_act:
164 | temb = module.nonlinearity(temb)
165 | temb = module.time_emb_proj(temb)[:, :, None, None]
166 |
167 | if temb is not None and module.time_embedding_norm == "default":
168 | hidden_states = hidden_states + temb
169 |
170 | hidden_states = module.norm2(hidden_states)
171 |
172 | if temb is not None and module.time_embedding_norm == "scale_shift":
173 | scale, shift = torch.chunk(temb, 2, dim=1)
174 | hidden_states = hidden_states * (1 + scale) + shift
175 |
176 | hidden_states = module.nonlinearity(hidden_states)
177 |
178 | hidden_states = module.dropout(hidden_states)
179 | hidden_states = F.conv2d(
180 | hidden_states,
181 | module.conv2.weight,
182 | bias=None,
183 | stride=module.conv2.stride,
184 | padding=module.conv2.padding,
185 | dilation=module.conv2.dilation,
186 | groups=module.conv2.groups,
187 | )
188 |
189 | dist.all_reduce(hidden_states, op=dist.ReduceOp.SUM, group=distri_config.batch_group, async_op=False)
190 | if module.conv2.bias is not None:
191 | hidden_states = hidden_states + module.conv2.bias.view(1, -1, 1, 1)
192 |
193 | if module.conv_shortcut is not None:
194 | input_tensor = (
195 | module.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
196 | )
197 |
198 | output_tensor = (input_tensor + hidden_states) / module.output_scale_factor
199 |
200 | self.counter += 1
201 |
202 | return output_tensor
203 |
--------------------------------------------------------------------------------
/distrifuser/pipelines.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel
3 |
4 | from .models.distri_sdxl_unet_pp import DistriUNetPP
5 | from .models.distri_sdxl_unet_tp import DistriUNetTP
6 | from .models.naive_patch_sdxl import NaivePatchUNet
7 | from .utils import DistriConfig, PatchParallelismCommManager
8 |
9 |
10 | class DistriSDXLPipeline:
11 | def __init__(self, pipeline: StableDiffusionXLPipeline, module_config: DistriConfig):
12 | self.pipeline = pipeline
13 | self.distri_config = module_config
14 |
15 | self.static_inputs = None
16 |
17 | self.prepare()
18 |
19 | @staticmethod
20 | def from_pretrained(distri_config: DistriConfig, **kwargs):
21 | device = distri_config.device
22 | pretrained_model_name_or_path = kwargs.pop(
23 | "pretrained_model_name_or_path", "stabilityai/stable-diffusion-xl-base-1.0"
24 | )
25 | torch_dtype = kwargs.pop("torch_dtype", torch.float16)
26 | unet = UNet2DConditionModel.from_pretrained(
27 | pretrained_model_name_or_path, torch_dtype=torch_dtype, subfolder="unet"
28 | ).to(device)
29 |
30 | if distri_config.parallelism == "patch":
31 | unet = DistriUNetPP(unet, distri_config)
32 | elif distri_config.parallelism == "tensor":
33 | unet = DistriUNetTP(unet, distri_config)
34 | elif distri_config.parallelism == "naive_patch":
35 | unet = NaivePatchUNet(unet, distri_config)
36 | else:
37 | raise ValueError(f"Unknown parallelism: {distri_config.parallelism}")
38 |
39 | pipeline = StableDiffusionXLPipeline.from_pretrained(
40 | pretrained_model_name_or_path, torch_dtype=torch_dtype, unet=unet, **kwargs
41 | ).to(device)
42 | return DistriSDXLPipeline(pipeline, distri_config)
43 |
44 | def set_progress_bar_config(self, **kwargs):
45 | self.pipeline.set_progress_bar_config(**kwargs)
46 |
47 | @torch.no_grad()
48 | def __call__(self, *args, **kwargs):
49 | assert "height" not in kwargs, "height should not be in kwargs"
50 | assert "width" not in kwargs, "width should not be in kwargs"
51 | config = self.distri_config
52 | if not config.do_classifier_free_guidance:
53 | if "guidance_scale" not in kwargs:
54 | kwargs["guidance_scale"] = 1
55 | else:
56 | assert kwargs["guidance_scale"] == 1
57 | self.pipeline.unet.set_counter(0)
58 | return self.pipeline(height=config.height, width=config.width, *args, **kwargs)
59 |
60 | @torch.no_grad()
61 | def prepare(self, **kwargs):
62 | distri_config = self.distri_config
63 |
64 | static_inputs = {}
65 | static_outputs = []
66 | cuda_graphs = []
67 | pipeline = self.pipeline
68 |
69 | height = distri_config.height
70 | width = distri_config.width
71 | assert height % 8 == 0 and width % 8 == 0
72 |
73 | original_size = (height, width)
74 | target_size = (height, width)
75 | crops_coords_top_left = (0, 0)
76 |
77 | device = distri_config.device
78 |
79 | prompt_embeds, _, pooled_prompt_embeds, _ = pipeline.encode_prompt(
80 | prompt="",
81 | prompt_2=None,
82 | device=device,
83 | num_images_per_prompt=1,
84 | do_classifier_free_guidance=False,
85 | negative_prompt=None,
86 | negative_prompt_2=None,
87 | prompt_embeds=None,
88 | negative_prompt_embeds=None,
89 | pooled_prompt_embeds=None,
90 | negative_pooled_prompt_embeds=None,
91 | )
92 | batch_size = 2 if distri_config.do_classifier_free_guidance else 1
93 |
94 | num_channels_latents = pipeline.unet.config.in_channels
95 | latents = pipeline.prepare_latents(
96 | batch_size, num_channels_latents, height, width, prompt_embeds.dtype, device, None
97 | )
98 |
99 | # 7. Prepare added time ids & embeddings
100 | add_text_embeds = pooled_prompt_embeds
101 | if pipeline.text_encoder_2 is None:
102 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
103 | else:
104 | text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim
105 |
106 | add_time_ids = pipeline._get_add_time_ids(
107 | original_size,
108 | crops_coords_top_left,
109 | target_size,
110 | dtype=prompt_embeds.dtype,
111 | text_encoder_projection_dim=text_encoder_projection_dim,
112 | )
113 |
114 | prompt_embeds = prompt_embeds.to(device)
115 | add_text_embeds = add_text_embeds.to(device)
116 | add_time_ids = add_time_ids.to(device).repeat(1, 1)
117 |
118 | if batch_size > 1:
119 | prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1)
120 | add_text_embeds = add_text_embeds.repeat(batch_size, 1)
121 | add_time_ids = add_time_ids.repeat(batch_size, 1)
122 |
123 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
124 | t = torch.zeros([batch_size], device=device, dtype=torch.long)
125 |
126 | static_inputs["sample"] = latents
127 | static_inputs["timestep"] = t
128 | static_inputs["encoder_hidden_states"] = prompt_embeds
129 | static_inputs["added_cond_kwargs"] = added_cond_kwargs
130 |
131 | # Used to create communication buffer
132 | comm_manager = None
133 | if distri_config.n_device_per_batch > 1:
134 | comm_manager = PatchParallelismCommManager(distri_config)
135 | pipeline.unet.set_comm_manager(comm_manager)
136 |
137 | # Only used for creating the communication buffer
138 | pipeline.unet.set_counter(0)
139 | pipeline.unet(**static_inputs, return_dict=False, record=True)
140 | if comm_manager.numel > 0:
141 | comm_manager.create_buffer()
142 |
143 | # Pre-run
144 | pipeline.unet.set_counter(0)
145 | pipeline.unet(**static_inputs, return_dict=False, record=True)
146 |
147 | if distri_config.use_cuda_graph:
148 | if comm_manager is not None:
149 | comm_manager.clear()
150 | if distri_config.parallelism == "naive_patch":
151 | counters = [0, 1]
152 | elif distri_config.parallelism == "patch":
153 | counters = [0, distri_config.warmup_steps + 1, distri_config.warmup_steps + 2]
154 | elif distri_config.parallelism == "tensor":
155 | counters = [0]
156 | else:
157 | raise ValueError(f"Unknown parallelism: {distri_config.parallelism}")
158 | for counter in counters:
159 | graph = torch.cuda.CUDAGraph()
160 | with torch.cuda.graph(graph):
161 | pipeline.unet.set_counter(counter)
162 | output = pipeline.unet(**static_inputs, return_dict=False, record=True)[0]
163 | static_outputs.append(output)
164 | cuda_graphs.append(graph)
165 | pipeline.unet.setup_cuda_graph(static_outputs, cuda_graphs)
166 |
167 | self.static_inputs = static_inputs
168 |
169 |
170 | class DistriSDPipeline:
171 | def __init__(self, pipeline: StableDiffusionPipeline, module_config: DistriConfig):
172 | self.pipeline = pipeline
173 | self.distri_config = module_config
174 |
175 | self.static_inputs = None
176 |
177 | self.prepare()
178 |
179 | @staticmethod
180 | def from_pretrained(distri_config: DistriConfig, **kwargs):
181 | device = distri_config.device
182 | pretrained_model_name_or_path = kwargs.pop("pretrained_model_name_or_path", "CompVis/stable-diffusion-v1-4")
183 | torch_dtype = kwargs.pop("torch_dtype", torch.float16)
184 | unet = UNet2DConditionModel.from_pretrained(
185 | pretrained_model_name_or_path, torch_dtype=torch_dtype, subfolder="unet"
186 | ).to(device)
187 |
188 | if distri_config.parallelism == "patch":
189 | unet = DistriUNetPP(unet, distri_config)
190 | elif distri_config.parallelism == "tensor":
191 | unet = DistriUNetTP(unet, distri_config)
192 | elif distri_config.parallelism == "naive_patch":
193 | unet = NaivePatchUNet(unet, distri_config)
194 | else:
195 | raise ValueError(f"Unknown parallelism: {distri_config.parallelism}")
196 |
197 | pipeline = StableDiffusionPipeline.from_pretrained(
198 | pretrained_model_name_or_path, torch_dtype=torch_dtype, unet=unet, **kwargs
199 | ).to(device)
200 | return DistriSDPipeline(pipeline, distri_config)
201 |
202 | def set_progress_bar_config(self, **kwargs):
203 | self.pipeline.set_progress_bar_config(**kwargs)
204 |
205 | @torch.no_grad()
206 | def __call__(self, *args, **kwargs):
207 | assert "height" not in kwargs, "height should not be in kwargs"
208 | assert "width" not in kwargs, "width should not be in kwargs"
209 | config = self.distri_config
210 | if not config.do_classifier_free_guidance:
211 | if not "guidance_scale" not in kwargs:
212 | kwargs["guidance_scale"] = 1
213 | else:
214 | assert kwargs["guidance_scale"] == 1
215 | self.pipeline.unet.set_counter(0)
216 | return self.pipeline(height=config.height, width=config.width, *args, **kwargs)
217 |
218 | @torch.no_grad()
219 | def prepare(self, **kwargs):
220 | distri_config = self.distri_config
221 |
222 | static_inputs = {}
223 | static_outputs = []
224 | cuda_graphs = []
225 | pipeline = self.pipeline
226 |
227 | height = distri_config.height
228 | width = distri_config.width
229 | assert height % 8 == 0 and width % 8 == 0
230 |
231 | device = distri_config.device
232 |
233 | prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
234 | "",
235 | device,
236 | num_images_per_prompt=1,
237 | do_classifier_free_guidance=False,
238 | negative_prompt=None,
239 | prompt_embeds=None,
240 | negative_prompt_embeds=None,
241 | lora_scale=None,
242 | clip_skip=kwargs.get("clip_skip", None),
243 | )
244 |
245 | batch_size = 2 if distri_config.do_classifier_free_guidance else 1
246 |
247 | num_channels_latents = pipeline.unet.config.in_channels
248 | latents = pipeline.prepare_latents(
249 | batch_size, num_channels_latents, height, width, prompt_embeds.dtype, device, None
250 | )
251 |
252 | prompt_embeds = prompt_embeds.to(device)
253 |
254 | if batch_size > 1:
255 | prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1)
256 |
257 | t = torch.zeros([batch_size], device=device, dtype=torch.long)
258 |
259 | static_inputs["sample"] = latents
260 | static_inputs["timestep"] = t
261 | static_inputs["encoder_hidden_states"] = prompt_embeds
262 |
263 | # Used to create communication buffer
264 | comm_manager = None
265 | if distri_config.n_device_per_batch > 1:
266 | comm_manager = PatchParallelismCommManager(distri_config)
267 | pipeline.unet.set_comm_manager(comm_manager)
268 |
269 | # Only used for creating the communication buffer
270 | pipeline.unet.set_counter(0)
271 | pipeline.unet(**static_inputs, return_dict=False, record=True)
272 | if comm_manager.numel > 0:
273 | comm_manager.create_buffer()
274 |
275 | # Pre-run
276 | pipeline.unet.set_counter(0)
277 | pipeline.unet(**static_inputs, return_dict=False, record=True)
278 |
279 | if distri_config.use_cuda_graph:
280 | if comm_manager is not None:
281 | comm_manager.clear()
282 | if distri_config.parallelism == "naive_patch":
283 | counters = [0, 1]
284 | elif distri_config.parallelism == "patch":
285 | counters = [0, distri_config.warmup_steps + 1, distri_config.warmup_steps + 2]
286 | elif distri_config.parallelism == "tensor":
287 | counters = [0]
288 | else:
289 | raise ValueError(f"Unknown parallelism: {distri_config.parallelism}")
290 | for counter in counters:
291 | graph = torch.cuda.CUDAGraph()
292 | with torch.cuda.graph(graph):
293 | pipeline.unet.set_counter(counter)
294 | output = pipeline.unet(**static_inputs, return_dict=False, record=True)[0]
295 | static_outputs.append(output)
296 | cuda_graphs.append(graph)
297 | pipeline.unet.setup_cuda_graph(static_outputs, cuda_graphs)
298 |
299 | self.static_inputs = static_inputs
300 |
--------------------------------------------------------------------------------
/distrifuser/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from packaging import version
3 | from torch import distributed as dist
4 |
5 |
6 | def check_env():
7 | if version.parse(torch.version.cuda) < version.parse("11.3"):
8 | # https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/cudagraph.html
9 | raise RuntimeError("NCCL CUDA Graph support requires CUDA 11.3 or above")
10 | if version.parse(version.parse(torch.__version__).base_version) < version.parse("2.2.0"):
11 | # https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/
12 | raise RuntimeError(
13 | "CUDAGraph with NCCL support requires PyTorch 2.2.0 or above. "
14 | "If it is not released yet, please install nightly built PyTorch with "
15 | "`pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121`"
16 | )
17 |
18 |
19 | def is_power_of_2(n: int) -> bool:
20 | return (n & (n - 1) == 0) and n != 0
21 |
22 |
23 | class DistriConfig:
24 | def __init__(
25 | self,
26 | height: int = 1024,
27 | width: int = 1024,
28 | do_classifier_free_guidance: bool = True,
29 | split_batch: bool = True,
30 | warmup_steps: int = 4,
31 | comm_checkpoint: int = 60,
32 | mode: str = "corrected_async_gn",
33 | use_cuda_graph: bool = True,
34 | parallelism: str = "patch",
35 | split_scheme: str = "row",
36 | verbose: bool = False,
37 | ):
38 | try:
39 | # Initialize the process group
40 | dist.init_process_group("nccl")
41 | # Get the rank and world_size
42 | rank = dist.get_rank()
43 | world_size = dist.get_world_size()
44 | except Exception as e:
45 | rank = 0
46 | world_size = 1
47 | print(f"Failed to initialize process group: {e}, falling back to single GPU")
48 |
49 | assert is_power_of_2(world_size)
50 | check_env()
51 |
52 | self.world_size = world_size
53 | self.rank = rank
54 | self.height = height
55 | self.width = width
56 | self.do_classifier_free_guidance = do_classifier_free_guidance
57 | self.split_batch = split_batch
58 | self.warmup_steps = warmup_steps
59 | self.comm_checkpoint = comm_checkpoint
60 | self.mode = mode
61 | self.use_cuda_graph = use_cuda_graph
62 |
63 | self.parallelism = parallelism
64 | self.split_scheme = split_scheme
65 |
66 | self.verbose = verbose
67 |
68 | if do_classifier_free_guidance and split_batch:
69 | n_device_per_batch = world_size // 2
70 | if n_device_per_batch == 0:
71 | n_device_per_batch = 1
72 | else:
73 | n_device_per_batch = world_size
74 |
75 | self.n_device_per_batch = n_device_per_batch
76 |
77 | self.height = height
78 | self.width = width
79 |
80 | device = torch.device(f"cuda:{rank}")
81 | torch.cuda.set_device(device)
82 | self.device = device
83 |
84 | batch_group = None
85 | split_group = None
86 | if do_classifier_free_guidance and split_batch and world_size >= 2:
87 | batch_groups = []
88 | for i in range(2):
89 | batch_groups.append(dist.new_group(list(range(i * (world_size // 2), (i + 1) * (world_size // 2)))))
90 | batch_group = batch_groups[self.batch_idx()]
91 | split_groups = []
92 | for i in range(world_size // 2):
93 | split_groups.append(dist.new_group([i, i + world_size // 2]))
94 | split_group = split_groups[self.split_idx()]
95 | self.batch_group = batch_group
96 | self.split_group = split_group
97 |
98 | def batch_idx(self, rank: int or None = None) -> int:
99 | if rank is None:
100 | rank = self.rank
101 | if self.do_classifier_free_guidance and self.split_batch:
102 | return 1 - int(rank < (self.world_size // 2))
103 | else:
104 | return 0 # raise NotImplementedError
105 |
106 | def split_idx(self, rank: int or None = None) -> int:
107 | if rank is None:
108 | rank = self.rank
109 | return rank % self.n_device_per_batch
110 |
111 |
112 | class PatchParallelismCommManager:
113 | def __init__(self, distri_config: DistriConfig):
114 | self.distri_config = distri_config
115 |
116 | self.torch_dtype = None
117 | self.numel = 0
118 | self.numel_dict = {}
119 |
120 | self.buffer_list = None
121 |
122 | self.starts = []
123 | self.ends = []
124 | self.shapes = []
125 |
126 | self.idx_queue = []
127 |
128 | self.handles = None
129 |
130 | def register_tensor(
131 | self, shape: tuple[int, ...] or list[int], torch_dtype: torch.dtype, layer_type: str = None
132 | ) -> int:
133 | if self.torch_dtype is None:
134 | self.torch_dtype = torch_dtype
135 | else:
136 | assert self.torch_dtype == torch_dtype
137 | self.starts.append(self.numel)
138 | numel = 1
139 | for dim in shape:
140 | numel *= dim
141 | self.numel += numel
142 | if layer_type is not None:
143 | if layer_type not in self.numel_dict:
144 | self.numel_dict[layer_type] = 0
145 | self.numel_dict[layer_type] += numel
146 |
147 | self.ends.append(self.numel)
148 | self.shapes.append(shape)
149 | return len(self.starts) - 1
150 |
151 | def create_buffer(self):
152 | distri_config = self.distri_config
153 | if distri_config.rank == 0 and distri_config.verbose:
154 | print(
155 | f"Create buffer with {self.numel / 1e6:.3f}M parameters for {len(self.starts)} tensors on each device."
156 | )
157 | for layer_type, numel in self.numel_dict.items():
158 | print(f" {layer_type}: {numel / 1e6:.3f}M parameters")
159 |
160 | self.buffer_list = [
161 | torch.empty(self.numel, dtype=self.torch_dtype, device=self.distri_config.device)
162 | for _ in range(self.distri_config.n_device_per_batch)
163 | ]
164 | self.handles = [None for _ in range(len(self.starts))]
165 |
166 | def get_buffer_list(self, idx: int) -> list[torch.Tensor]:
167 | buffer_list = [t[self.starts[idx] : self.ends[idx]].view(self.shapes[idx]) for t in self.buffer_list]
168 | return buffer_list
169 |
170 | def communicate(self):
171 | distri_config = self.distri_config
172 | start = self.starts[self.idx_queue[0]]
173 | end = self.ends[self.idx_queue[-1]]
174 | tensor = self.buffer_list[distri_config.split_idx()][start:end]
175 | buffer_list = [t[start:end] for t in self.buffer_list]
176 | handle = dist.all_gather(buffer_list, tensor, group=self.distri_config.batch_group, async_op=True)
177 | for i in self.idx_queue:
178 | self.handles[i] = handle
179 | self.idx_queue = []
180 |
181 | def enqueue(self, idx: int, tensor: torch.Tensor):
182 | distri_config = self.distri_config
183 | if idx == 0 and len(self.idx_queue) > 0:
184 | self.communicate()
185 | assert len(self.idx_queue) == 0 or self.idx_queue[-1] == idx - 1
186 | self.idx_queue.append(idx)
187 | self.buffer_list[distri_config.split_idx()][self.starts[idx] : self.ends[idx]].copy_(tensor.flatten())
188 |
189 | if len(self.idx_queue) == distri_config.comm_checkpoint:
190 | self.communicate()
191 |
192 | def clear(self):
193 | if len(self.idx_queue) > 0:
194 | self.communicate()
195 | if self.handles is not None:
196 | for i in range(len(self.handles)):
197 | if self.handles[i] is not None:
198 | self.handles[i].wait()
199 | self.handles[i] = None
200 |
--------------------------------------------------------------------------------
/scripts/compute_metrics.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 | from cleanfid import fid
7 | from PIL import Image
8 | from torch.utils.data import DataLoader, Dataset
9 | from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio
10 | from torchvision.transforms import Resize
11 | from tqdm import tqdm
12 |
13 |
14 | def read_image(path: str):
15 | """
16 | input: path
17 | output: tensor (C, H, W)
18 | """
19 | img = np.asarray(Image.open(path))
20 | if len(img.shape) == 2:
21 | img = np.repeat(img[:, :, None], 3, axis=2)
22 | img = torch.from_numpy(img).permute(2, 0, 1)
23 | return img
24 |
25 |
26 | class MultiImageDataset(Dataset):
27 | def __init__(self, root0, root1, is_gt=False):
28 | super().__init__()
29 | self.root0 = root0
30 | self.root1 = root1
31 | file_names0 = os.listdir(root0)
32 | file_names1 = os.listdir(root1)
33 |
34 | self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")])
35 | self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")])
36 | self.is_gt = is_gt
37 | assert len(self.image_names0) == len(self.image_names1)
38 |
39 | def __len__(self):
40 | return len(self.image_names0)
41 |
42 | def __getitem__(self, idx):
43 | img0 = read_image(os.path.join(self.root0, self.image_names0[idx]))
44 | if self.is_gt:
45 | # resize to 1024 x 1024
46 | img0 = Resize((1024, 1024))(img0)
47 | img1 = read_image(os.path.join(self.root1, self.image_names1[idx]))
48 |
49 | batch_list = [img0, img1]
50 | return batch_list
51 |
52 |
53 | if __name__ == "__main__":
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument("--batch_size", type=int, default=64)
56 | parser.add_argument("--num_workers", type=int, default=8)
57 | parser.add_argument("--is_gt", action="store_true")
58 | parser.add_argument("--input_root0", type=str, required=True)
59 | parser.add_argument("--input_root1", type=str, required=True)
60 | args = parser.parse_args()
61 |
62 | psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda")
63 | lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda")
64 |
65 | dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt)
66 | dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
67 |
68 | progress_bar = tqdm(dataloader)
69 | with torch.inference_mode():
70 | for i, batch in enumerate(progress_bar):
71 | batch = [img.to("cuda") / 255 for img in batch]
72 | batch_size = batch[0].shape[0]
73 | psnr.update(batch[0], batch[1])
74 | lpips.update(batch[0], batch[1])
75 | fid_score = fid.compute_fid(args.input_root0, args.input_root1)
76 |
77 | print("PSNR:", psnr.compute().item())
78 | print("LPIPS:", lpips.compute().item())
79 | print("FID:", fid_score)
80 |
--------------------------------------------------------------------------------
/scripts/dump_coco.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from argparse import ArgumentParser
4 |
5 | from datasets import load_dataset
6 | from tqdm import tqdm, trange
7 |
8 | if __name__ == "__main__":
9 | parser = ArgumentParser()
10 | parser.add_argument("--output_root", type=str, default="./coco")
11 | args = parser.parse_args()
12 |
13 | dataset = load_dataset("HuggingFaceM4/COCO", name="2014_captions", split="validation")
14 |
15 | prompt_list = []
16 | for i in trange(len(dataset["sentences_raw"])):
17 | prompt = dataset["sentences_raw"][i][i % len(dataset["sentences_raw"][i])]
18 | prompt_list.append(prompt)
19 |
20 | os.makedirs(args.output_root, exist_ok=True)
21 | prompt_path = os.path.join(args.output_root, "prompts.json")
22 | with open(prompt_path, "w") as f:
23 | json.dump(prompt_list, f, indent=4)
24 |
25 | os.makedirs(os.path.join(args.output_root, "images"), exist_ok=True)
26 |
27 | dataset = load_dataset("HuggingFaceM4/COCO", name="2014_captions", split="validation")
28 | for i, image in enumerate(tqdm(dataset["image"])):
29 | image.save(os.path.join(args.output_root, "images", f"{i:04}.png"))
30 |
--------------------------------------------------------------------------------
/scripts/export_html.py:
--------------------------------------------------------------------------------
1 | """
2 | An auxiliary script to generate HTML files for image visualization.
3 | """
4 |
5 | import argparse
6 | import json
7 | import os
8 | import random
9 | import shutil
10 |
11 | import dominate
12 | from dominate.tags import h3, img, table, td, tr
13 |
14 |
15 | def get_args():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--image_dirs", type=str, nargs="+", required=True)
18 | parser.add_argument("--caption_path", type=str, default=None)
19 | parser.add_argument("--output_root", type=str, required=True)
20 | parser.add_argument("--aliases", type=str, default=None, nargs="+")
21 | parser.add_argument("--title", type=str, default=None)
22 | parser.add_argument("--hard_copy", action="store_true")
23 | parser.add_argument("--max_images", type=int, default=None)
24 | parser.add_argument("--seed", type=int, default=0)
25 | args = parser.parse_args()
26 | if args.aliases is not None:
27 | assert len(args.image_dirs) == len(args.aliases)
28 | else:
29 | args.aliases = [str(i) for i in range(len(args.image_dirs))]
30 | return args
31 |
32 |
33 | def check_existence(image_dirs, filename):
34 | for image_dir in image_dirs:
35 | if not os.path.exists(os.path.join(image_dir, filename)):
36 | print(os.path.join(image_dir, filename))
37 | return False
38 | return True
39 |
40 |
41 | if __name__ == "__main__":
42 | args = get_args()
43 | filenames = sorted(os.listdir(args.image_dirs[0]))
44 | filenames = [
45 | filename
46 | for filename in filenames
47 | if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg")
48 | ]
49 | if args.max_images is not None:
50 | random.seed(args.seed)
51 | random.shuffle(filenames)
52 | filenames = filenames[: args.max_images]
53 | filenames = sorted(filenames)
54 | doc = dominate.document(title="Visualization" if args.title is None else args.title)
55 | if args.title:
56 | with doc:
57 | h3(args.title)
58 | t_main = table(border=1, style="table-layout: fixed;")
59 | prompts = json.load(open(args.caption_path, "r"))
60 | for i, filename in enumerate(filenames):
61 | bname = os.path.splitext(filename)[0]
62 | if not check_existence(args.image_dirs, filename):
63 | continue
64 | title_row = tr()
65 | _tr = tr()
66 | title_row.add(td(f"{bname}"))
67 | _tr.add(td(prompts[int(bname)]))
68 | for image_dir, alias in zip(args.image_dirs, args.aliases):
69 | title_row.add(td(f"{alias}"))
70 | _td = td(style="word-wrap: break-word;", halign="center", valign="top")
71 | source_path = os.path.abspath(os.path.join(image_dir, filename))
72 | target_path = os.path.abspath(os.path.join(args.output_root, "images", alias, filename))
73 | os.makedirs(os.path.dirname(os.path.abspath(target_path)), exist_ok=True)
74 | if args.hard_copy:
75 | shutil.copy(source_path, target_path)
76 | else:
77 | os.symlink(source_path, target_path)
78 | _td.add(img(style="width:256px", src=os.path.relpath(target_path, args.output_root)))
79 | _tr.add(_td)
80 | t_main.add(title_row)
81 | t_main.add(_tr)
82 | with open(os.path.join(args.output_root, "index.html"), "w") as f:
83 | f.write(t_main.render())
84 |
--------------------------------------------------------------------------------
/scripts/generate_coco.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | from datasets import load_dataset
6 | from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler
7 | from tqdm import trange
8 |
9 | from distrifuser.pipelines import DistriSDXLPipeline
10 | from distrifuser.utils import DistriConfig
11 |
12 |
13 | def get_args() -> argparse.Namespace:
14 | parser = argparse.ArgumentParser()
15 | # Diffuser specific arguments
16 | parser.add_argument("--output_root", type=str, default=None)
17 | parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
18 | parser.add_argument("--image_size", type=int, nargs="*", default=1024, help="Image size of generation")
19 | parser.add_argument("--guidance_scale", type=float, default=5.0)
20 | parser.add_argument("--scheduler", type=str, default="ddim", choices=["euler", "dpm-solver", "ddim"])
21 |
22 | # DistriFuser specific arguments
23 | parser.add_argument(
24 | "--no_split_batch", action="store_true", help="Disable the batch splitting for classifier-free guidance"
25 | )
26 | parser.add_argument("--warmup_steps", type=int, default=4, help="Number of warmup steps")
27 | parser.add_argument(
28 | "--sync_mode",
29 | type=str,
30 | default="corrected_async_gn",
31 | choices=["separate_gn", "stale_gn", "corrected_async_gn", "sync_gn", "full_sync", "no_sync"],
32 | help="Different GroupNorm synchronization modes",
33 | )
34 | parser.add_argument(
35 | "--parallelism",
36 | type=str,
37 | default="patch",
38 | choices=["patch", "tensor", "naive_patch"],
39 | help="patch parallelism, tensor parallelism or naive patch",
40 | )
41 | parser.add_argument(
42 | "--split_scheme",
43 | type=str,
44 | default="alternate",
45 | choices=["row", "col", "alternate"],
46 | help="Split scheme for naive patch",
47 | )
48 | parser.add_argument("--no_cuda_graph", action="store_true", help="Disable CUDA graph")
49 |
50 | parser.add_argument("--split", nargs=2, type=int, default=None, help="Split the dataset into chunks")
51 |
52 | args = parser.parse_args()
53 | return args
54 |
55 |
56 | def main():
57 | args = get_args()
58 |
59 | if isinstance(args.image_size, int):
60 | args.image_size = [args.image_size, args.image_size]
61 | else:
62 | if len(args.image_size) == 1:
63 | args.image_size = [args.image_size[0], args.image_size[0]]
64 | else:
65 | assert len(args.image_size) == 2
66 | distri_config = DistriConfig(
67 | height=args.image_size[0],
68 | width=args.image_size[1],
69 | do_classifier_free_guidance=args.guidance_scale > 1,
70 | split_batch=not args.no_split_batch,
71 | warmup_steps=args.warmup_steps,
72 | mode=args.sync_mode,
73 | use_cuda_graph=not args.no_cuda_graph,
74 | parallelism=args.parallelism,
75 | split_scheme=args.split_scheme,
76 | )
77 |
78 | pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
79 | if args.scheduler == "euler":
80 | scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
81 | elif args.scheduler == "dpm-solver":
82 | scheduler = DPMSolverMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
83 | elif args.scheduler == "ddim":
84 | scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
85 | else:
86 | raise NotImplementedError
87 | pipeline = DistriSDXLPipeline.from_pretrained(
88 | pretrained_model_name_or_path=pretrained_model_name_or_path,
89 | distri_config=distri_config,
90 | variant="fp16",
91 | use_safetensors=True,
92 | scheduler=scheduler,
93 | )
94 | pipeline.set_progress_bar_config(disable=distri_config.rank != 0, position=1, leave=False)
95 |
96 | if args.output_root is None:
97 | args.output_root = os.path.join(
98 | "results",
99 | "coco",
100 | f"{args.scheduler}-{args.num_inference_steps}",
101 | f"gpus{distri_config.world_size if args.no_split_batch else distri_config.world_size // 2}-"
102 | f"warmup{args.warmup_steps}-{args.sync_mode}",
103 | )
104 | if distri_config.rank == 0:
105 | os.makedirs(args.output_root, exist_ok=True)
106 |
107 | dataset = load_dataset("HuggingFaceM4/COCO", name="2014_captions", split="validation", trust_remote_code=True)
108 |
109 | if args.split is not None:
110 | assert args.split[0] < args.split[1]
111 | chunk_size = (5000 + args.split[1] - 1) // args.split[1]
112 | start_idx = args.split[0] * chunk_size
113 | end_idx = min((args.split[0] + 1) * chunk_size, 5000)
114 | else:
115 | start_idx = 0
116 | end_idx = 5000
117 |
118 | for i in trange(start_idx, end_idx, disable=distri_config.rank != 0, position=0, leave=False):
119 | prompt = dataset["sentences_raw"][i][i % len(dataset["sentences_raw"][i])]
120 | seed = i
121 |
122 | image = pipeline(
123 | prompt=prompt,
124 | generator=torch.Generator(device="cuda").manual_seed(seed),
125 | num_inference_steps=args.num_inference_steps,
126 | guidance_scale=args.guidance_scale,
127 | ).images[0]
128 | if distri_config.rank == 0:
129 | output_path = os.path.join(args.output_root, f"{i:04d}.png")
130 | image.save(output_path)
131 |
132 |
133 | if __name__ == "__main__":
134 | main()
135 |
--------------------------------------------------------------------------------
/scripts/profile_macs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from diffusers import StableDiffusionXLPipeline
5 | from torchprofile import profile_macs
6 |
7 | if __name__ == "__main__":
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--image_size", type=int, nargs="*", default=1024, help="Image size of generation")
10 | args = parser.parse_args()
11 |
12 | if isinstance(args.image_size, int):
13 | args.image_size = [args.image_size // 8, args.image_size // 8]
14 | elif len(args.image_size) == 1:
15 | args.image_size = [args.image_size[0] // 8, args.image_size[0] // 8]
16 | else:
17 | assert len(args.image_size) == 2
18 | args.image_size = [args.image_size[0] // 8, args.image_size[1] // 8]
19 |
20 | pipeline = StableDiffusionXLPipeline.from_pretrained(
21 | "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
22 | ).to("cuda")
23 |
24 | unet = pipeline.unet
25 |
26 | latent_model_input = torch.randn(2, 4, *args.image_size, dtype=unet.dtype).to("cuda")
27 | t = torch.randn(1).to("cuda")
28 | prompt_embeds = torch.randn(2, 77, 2048, dtype=unet.dtype).to("cuda")
29 | add_text_embeds = torch.randn(2, 1280, dtype=unet.dtype).to("cuda")
30 | add_time_ids = torch.randint(0, 1024, (2, 6)).to("cuda")
31 |
32 | with torch.no_grad():
33 | macs = profile_macs(
34 | unet,
35 | args=(
36 | latent_model_input,
37 | t,
38 | prompt_embeds,
39 | None,
40 | None,
41 | None,
42 | None,
43 | {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
44 | ),
45 | )
46 | print(f"MACs: {macs / 1e9:.3f}G")
47 |
--------------------------------------------------------------------------------
/scripts/run_sdxl.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 |
5 | import torch
6 | from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler
7 | from tqdm import trange
8 |
9 | from distrifuser.pipelines import DistriSDXLPipeline
10 | from distrifuser.utils import DistriConfig
11 |
12 |
13 | def get_args() -> argparse.Namespace:
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument(
16 | "--mode",
17 | type=str,
18 | default="generation",
19 | choices=["generation", "benchmark"],
20 | help="Purpose of running the script",
21 | )
22 |
23 | # Diffuser specific arguments
24 | parser.add_argument(
25 | "--prompt", type=str, default="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
26 | )
27 | parser.add_argument("--output_path", type=str, default=None)
28 | parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
29 | parser.add_argument("--image_size", type=int, nargs="*", default=1024, help="Image size of generation")
30 | parser.add_argument("--guidance_scale", type=float, default=5.0)
31 | parser.add_argument("--scheduler", type=str, default="ddim", choices=["euler", "dpm-solver", "ddim"])
32 | parser.add_argument("--seed", type=int, default=1234, help="Random seed")
33 |
34 | # DistriFuser specific arguments
35 | parser.add_argument(
36 | "--no_split_batch", action="store_true", help="Disable the batch splitting for classifier-free guidance"
37 | )
38 | parser.add_argument("--warmup_steps", type=int, default=4, help="Number of warmup steps")
39 | parser.add_argument(
40 | "--sync_mode",
41 | type=str,
42 | default="corrected_async_gn",
43 | choices=["separate_gn", "stale_gn", "corrected_async_gn", "sync_gn", "full_sync", "no_sync"],
44 | help="Different GroupNorm synchronization modes",
45 | )
46 | parser.add_argument(
47 | "--parallelism",
48 | type=str,
49 | default="patch",
50 | choices=["patch", "tensor", "naive_patch"],
51 | help="patch parallelism, tensor parallelism or naive patch",
52 | )
53 | parser.add_argument("--no_cuda_graph", action="store_true", help="Disable CUDA graph")
54 | parser.add_argument(
55 | "--split_scheme",
56 | type=str,
57 | default="alternate",
58 | choices=["row", "col", "alternate"],
59 | help="Split scheme for naive patch",
60 | )
61 |
62 | # Benchmark specific arguments
63 | parser.add_argument("--output_type", type=str, default="pil", choices=["latent", "pil"])
64 | parser.add_argument("--warmup_times", type=int, default=5, help="Number of warmup times")
65 | parser.add_argument("--test_times", type=int, default=20, help="Number of test times")
66 | parser.add_argument(
67 | "--ignore_ratio", type=float, default=0.2, help="Ignored ratio of the slowest and fastest steps"
68 | )
69 |
70 | args = parser.parse_args()
71 | return args
72 |
73 |
74 | def main():
75 | args = get_args()
76 |
77 | if isinstance(args.image_size, int):
78 | args.image_size = [args.image_size, args.image_size]
79 | else:
80 | if len(args.image_size) == 1:
81 | args.image_size = [args.image_size[0], args.image_size[0]]
82 | else:
83 | assert len(args.image_size) == 2
84 | distri_config = DistriConfig(
85 | height=args.image_size[0],
86 | width=args.image_size[1],
87 | do_classifier_free_guidance=args.guidance_scale > 1,
88 | split_batch=not args.no_split_batch,
89 | warmup_steps=args.warmup_steps,
90 | mode=args.sync_mode,
91 | use_cuda_graph=not args.no_cuda_graph,
92 | parallelism=args.parallelism,
93 | split_scheme=args.split_scheme,
94 | )
95 |
96 | pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
97 | if args.scheduler == "euler":
98 | scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
99 | elif args.scheduler == "dpm-solver":
100 | scheduler = DPMSolverMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
101 | elif args.scheduler == "ddim":
102 | scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
103 | else:
104 | raise NotImplementedError
105 | pipeline = DistriSDXLPipeline.from_pretrained(
106 | pretrained_model_name_or_path=pretrained_model_name_or_path,
107 | distri_config=distri_config,
108 | variant="fp16",
109 | use_safetensors=True,
110 | scheduler=scheduler,
111 | )
112 |
113 | if args.mode == "generation":
114 | assert args.output_path is not None
115 | pipeline.set_progress_bar_config(disable=distri_config.rank != 0)
116 | image = pipeline(
117 | prompt=args.prompt,
118 | generator=torch.Generator(device="cuda").manual_seed(args.seed),
119 | num_inference_steps=args.num_inference_steps,
120 | guidance_scale=args.guidance_scale,
121 | ).images[0]
122 | os.makedirs(os.path.dirname(os.path.abspath(args.output_path)), exist_ok=True)
123 | image.save(args.output_path)
124 | elif args.mode == "benchmark":
125 | pipeline.set_progress_bar_config(position=1, desc="Generation", leave=False, disable=distri_config.rank != 0)
126 | for i in trange(args.warmup_times, position=0, desc="Warmup", leave=False, disable=distri_config.rank != 0):
127 | pipeline(
128 | prompt=args.prompt,
129 | generator=torch.Generator(device="cuda").manual_seed(args.seed),
130 | num_inference_steps=args.num_inference_steps,
131 | guidance_scale=args.guidance_scale,
132 | output_type=args.output_type,
133 | )
134 | torch.cuda.synchronize()
135 | latency_list = []
136 | for i in trange(args.test_times, position=0, desc="Test", leave=False, disable=distri_config.rank != 0):
137 | start_time = time.time()
138 | pipeline(
139 | prompt=args.prompt,
140 | generator=torch.Generator(device="cuda").manual_seed(args.seed),
141 | num_inference_steps=args.num_inference_steps,
142 | guidance_scale=args.guidance_scale,
143 | output_type=args.output_type,
144 | )
145 | torch.cuda.synchronize()
146 | end_time = time.time()
147 | latency_list.append(end_time - start_time)
148 | latency_list = sorted(latency_list)
149 | ignored_count = int(args.ignore_ratio * len(latency_list) / 2)
150 | if ignored_count > 0:
151 | latency_list = latency_list[ignored_count:-ignored_count]
152 | if distri_config.rank == 0:
153 | print(f"Latency: {sum(latency_list) / len(latency_list):.5f} s")
154 | else:
155 | raise NotImplementedError
156 |
157 |
158 | if __name__ == "__main__":
159 | main()
160 |
--------------------------------------------------------------------------------
/scripts/sd_example.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from distrifuser.pipelines import DistriSDPipeline
4 | from distrifuser.utils import DistriConfig
5 |
6 | distri_config = DistriConfig(height=512, width=512, warmup_steps=4, mode="stale_gn")
7 | pipeline = DistriSDPipeline.from_pretrained(
8 | distri_config=distri_config,
9 | pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4",
10 | )
11 |
12 | pipeline.set_progress_bar_config(disable=distri_config.rank != 0)
13 | image = pipeline(
14 | prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
15 | generator=torch.Generator(device="cuda").manual_seed(233),
16 | ).images[0]
17 | if distri_config.rank == 0:
18 | image.save("astronaut.png")
19 |
--------------------------------------------------------------------------------
/scripts/sdxl_example.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from distrifuser.pipelines import DistriSDXLPipeline
4 | from distrifuser.utils import DistriConfig
5 |
6 | distri_config = DistriConfig(height=1024, width=1024, warmup_steps=4)
7 | pipeline = DistriSDXLPipeline.from_pretrained(
8 | distri_config=distri_config,
9 | pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0",
10 | variant="fp16",
11 | use_safetensors=True,
12 | )
13 |
14 | pipeline.set_progress_bar_config(disable=distri_config.rank != 0)
15 | image = pipeline(
16 | prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
17 | generator=torch.Generator(device="cuda").manual_seed(233),
18 | ).images[0]
19 | if distri_config.rank == 0:
20 | image.save("astronaut.png")
21 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 |
4 | if __name__ == "__main__":
5 | with open("README.md", "r") as f:
6 | long_description = f.read()
7 | fp = open("distrifuser/__version__.py", "r").read()
8 | version = eval(fp.strip().split()[-1])
9 |
10 | setup(
11 | name="distrifuser",
12 | author="Muyang Li, Tianle Cai, Jiaxin Cao, Qinsheng Zhang, Han Cai, Junjie Bai, Yangqing Jia, Ming-Yu Liu, Kai Li and Song Han",
13 | author_email="muyangli@mit.edu",
14 | packages=find_packages(),
15 | install_requires=["torch>=2.2", "diffusers==0.24.0", "transformers", "tqdm"],
16 | url="https://github.com/mit-han-lab/distrifuser",
17 | description="DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models",
18 | long_description=long_description,
19 | long_description_content_type="text/markdown",
20 | version=version,
21 | classifiers=[
22 | "Programming Language :: Python :: 3",
23 | "Operating System :: OS Independent",
24 | ],
25 | include_package_data=True,
26 | python_requires=">=3.10",
27 | )
28 |
--------------------------------------------------------------------------------