├── requirements.txt ├── __init__.py ├── uno ├── utils │ └── convert_yaml_to_args_file.py └── flux │ ├── math.py │ ├── modules │ ├── conditioner.py │ ├── autoencoder.py │ └── layers.py │ ├── sampling.py │ ├── model.py │ ├── pipeline.py │ └── util.py ├── example_workflows └── flux-red-uno.json ├── README.md └── uno_nodes └── comfy_nodes.py /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | transformers 3 | huggingface-hub 4 | diffusers 5 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .uno_nodes.comfy_nodes import REDUNOModelLoader, REDUNOGenerate 2 | 3 | 4 | # 注册节点 5 | NODE_CLASS_MAPPINGS = { 6 | "REDUNOModelLoader": REDUNOModelLoader, 7 | "REDUNOGenerate": REDUNOGenerate, 8 | } 9 | 10 | NODE_DISPLAY_NAME_MAPPINGS = { 11 | "REDUNOModelLoader": "UNO Model Loader @REDAIGC", 12 | "REDUNOGenerate": "UNO Generate @REDAIGC", 13 | } 14 | 15 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] -------------------------------------------------------------------------------- /uno/utils/convert_yaml_to_args_file.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import yaml 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--yaml", type=str, required=True) 20 | parser.add_argument("--arg", type=str, required=True) 21 | args = parser.parse_args() 22 | 23 | 24 | with open(args.yaml, "r") as f: 25 | data = yaml.safe_load(f) 26 | 27 | with open(args.arg, "w") as f: 28 | for k, v in data.items(): 29 | if isinstance(v, list): 30 | v = list(map(str, v)) 31 | v = " ".join(v) 32 | if v is None: 33 | continue 34 | print(f"--{k} {v}", end=" ", file=f) 35 | -------------------------------------------------------------------------------- /uno/flux/math.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from einops import rearrange 18 | from torch import Tensor 19 | 20 | 21 | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: 22 | q, k = apply_rope(q, k, pe) 23 | 24 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 25 | x = rearrange(x, "B H L D -> B L (H D)") 26 | 27 | return x 28 | 29 | 30 | def rope(pos: Tensor, dim: int, theta: int) -> Tensor: 31 | assert dim % 2 == 0 32 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim 33 | omega = 1.0 / (theta**scale) 34 | out = torch.einsum("...n,d->...nd", pos, omega) 35 | out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) 36 | out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) 37 | return out.float() 38 | 39 | 40 | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: 41 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 42 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 43 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 44 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 45 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) 46 | -------------------------------------------------------------------------------- /uno/flux/modules/conditioner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from torch import Tensor, nn 17 | 18 | 19 | class HFEmbedder(nn.Module): 20 | def __init__(self, version: str, max_length: int, is_clip=True, **hf_kwargs): 21 | super().__init__() 22 | self.is_clip = is_clip 23 | self.max_length = max_length 24 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" 25 | from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, 26 | T5Tokenizer) 27 | 28 | if self.is_clip: 29 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length, **hf_kwargs) 30 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) 31 | else: 32 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length, **hf_kwargs) 33 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) 34 | 35 | self.hf_module = self.hf_module.eval().requires_grad_(False) 36 | 37 | def forward(self, text: list[str]) -> Tensor: 38 | batch_encoding = self.tokenizer( 39 | text, 40 | truncation=True, 41 | max_length=self.max_length, 42 | return_length=False, 43 | return_overflowing_tokens=False, 44 | padding="max_length", 45 | return_tensors="pt", 46 | ) 47 | 48 | outputs = self.hf_module( 49 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device), 50 | attention_mask=None, 51 | output_hidden_states=False, 52 | ) 53 | return outputs[self.output_key] 54 | -------------------------------------------------------------------------------- /example_workflows/flux-red-uno.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "a2846121-208b-41df-a41d-099d6388aba8", 3 | "revision": 0, 4 | "last_node_id": 10, 5 | "last_link_id": 14, 6 | "nodes": [ 7 | { 8 | "id": 1, 9 | "type": "LoadImage", 10 | "pos": [ 11 | -66.44861602783203, 12 | -457.27520751953125 13 | ], 14 | "size": [ 15 | 474.272216796875, 16 | 398.81951904296875 17 | ], 18 | "flags": {}, 19 | "order": 0, 20 | "mode": 0, 21 | "inputs": [], 22 | "outputs": [ 23 | { 24 | "label": "图像", 25 | "name": "IMAGE", 26 | "type": "IMAGE", 27 | "links": [ 28 | 11 29 | ] 30 | }, 31 | { 32 | "label": "遮罩", 33 | "name": "MASK", 34 | "type": "MASK", 35 | "links": null 36 | } 37 | ], 38 | "properties": { 39 | "Node name for S&R": "LoadImage", 40 | "cnr_id": "comfy-core", 41 | "ver": "0.3.27" 42 | }, 43 | "widgets_values": [ 44 | "image.webp", 45 | "image", 46 | "" 47 | ] 48 | }, 49 | { 50 | "id": 4, 51 | "type": "LoadImage", 52 | "pos": [ 53 | 423.3421325683594, 54 | -454.0016784667969 55 | ], 56 | "size": [ 57 | 459.19317626953125, 58 | 388.45269775390625 59 | ], 60 | "flags": {}, 61 | "order": 1, 62 | "mode": 0, 63 | "inputs": [], 64 | "outputs": [ 65 | { 66 | "label": "图像", 67 | "name": "IMAGE", 68 | "type": "IMAGE", 69 | "links": [ 70 | 10 71 | ] 72 | }, 73 | { 74 | "label": "遮罩", 75 | "name": "MASK", 76 | "type": "MASK", 77 | "links": null 78 | } 79 | ], 80 | "properties": { 81 | "Node name for S&R": "LoadImage", 82 | "cnr_id": "comfy-core", 83 | "ver": "0.3.27" 84 | }, 85 | "widgets_values": [ 86 | "image (2).png", 87 | "image", 88 | "" 89 | ] 90 | }, 91 | { 92 | "id": 5, 93 | "type": "SaveImage", 94 | "pos": [ 95 | 1339.749267578125, 96 | -451.31793212890625 97 | ], 98 | "size": [ 99 | 428.37542724609375, 100 | 536.8045654296875 101 | ], 102 | "flags": {}, 103 | "order": 4, 104 | "mode": 0, 105 | "inputs": [ 106 | { 107 | "label": "图像", 108 | "name": "images", 109 | "type": "IMAGE", 110 | "link": 14 111 | } 112 | ], 113 | "outputs": [], 114 | "properties": { 115 | "Node name for S&R": "SaveImage", 116 | "cnr_id": "comfy-core", 117 | "ver": "0.3.27" 118 | }, 119 | "widgets_values": [ 120 | "REDAIGC" 121 | ] 122 | }, 123 | { 124 | "id": 9, 125 | "type": "REDUNOModelLoader", 126 | "pos": [ 127 | 902.960693359375, 128 | -446.33685302734375 129 | ], 130 | "size": [ 131 | 416.5949401855469, 132 | 154 133 | ], 134 | "flags": {}, 135 | "order": 2, 136 | "mode": 0, 137 | "inputs": [], 138 | "outputs": [ 139 | { 140 | "label": "uno_model", 141 | "name": "uno_model", 142 | "type": "UNO_MODEL", 143 | "links": [ 144 | 12 145 | ] 146 | } 147 | ], 148 | "properties": { 149 | "Node name for S&R": "REDUNOModelLoader" 150 | }, 151 | "widgets_values": [ 152 | "RED-UNO-Diffusers-FP8.safetensors", 153 | "ae.safetensors", 154 | true, 155 | true, 156 | "UNO_dit_lora.safetensors" 157 | ] 158 | }, 159 | { 160 | "id": 10, 161 | "type": "REDUNOGenerate", 162 | "pos": [ 163 | 907.1075439453125, 164 | -247.29359436035156 165 | ], 166 | "size": [ 167 | 409.33013916015625, 168 | 336 169 | ], 170 | "flags": {}, 171 | "order": 3, 172 | "mode": 0, 173 | "inputs": [ 174 | { 175 | "label": "uno_model", 176 | "name": "uno_model", 177 | "type": "UNO_MODEL", 178 | "link": 12 179 | }, 180 | { 181 | "label": "reference_image_1", 182 | "name": "reference_image_1", 183 | "shape": 7, 184 | "type": "IMAGE", 185 | "link": 11 186 | }, 187 | { 188 | "label": "reference_image_2", 189 | "name": "reference_image_2", 190 | "shape": 7, 191 | "type": "IMAGE", 192 | "link": 10 193 | }, 194 | { 195 | "label": "reference_image_3", 196 | "name": "reference_image_3", 197 | "shape": 7, 198 | "type": "IMAGE", 199 | "link": null 200 | }, 201 | { 202 | "label": "reference_image_4", 203 | "name": "reference_image_4", 204 | "shape": 7, 205 | "type": "IMAGE", 206 | "link": null 207 | } 208 | ], 209 | "outputs": [ 210 | { 211 | "label": "IMAGE", 212 | "name": "IMAGE", 213 | "type": "IMAGE", 214 | "links": [ 215 | 14 216 | ] 217 | } 218 | ], 219 | "properties": { 220 | "Node name for S&R": "REDUNOGenerate" 221 | }, 222 | "widgets_values": [ 223 | "Two cartoon men embracing each other", 224 | 512, 225 | 512, 226 | 4, 227 | 25, 228 | 1178, 229 | "randomize", 230 | "d" 231 | ] 232 | } 233 | ], 234 | "links": [ 235 | [ 236 | 10, 237 | 4, 238 | 0, 239 | 10, 240 | 2, 241 | "IMAGE" 242 | ], 243 | [ 244 | 11, 245 | 1, 246 | 0, 247 | 10, 248 | 1, 249 | "IMAGE" 250 | ], 251 | [ 252 | 12, 253 | 9, 254 | 0, 255 | 10, 256 | 0, 257 | "UNO_MODEL" 258 | ], 259 | [ 260 | 14, 261 | 10, 262 | 0, 263 | 5, 264 | 0, 265 | "IMAGE" 266 | ] 267 | ], 268 | "groups": [], 269 | "config": {}, 270 | "extra": { 271 | "ds": { 272 | "scale": 0.9646149645000006, 273 | "offset": [ 274 | 239.6949526197524, 275 | 702.1051557136478 276 | ] 277 | }, 278 | "node_versions": { 279 | "comfy-core": "0.3.27", 280 | "ComfyUI-RED-UNO": "a57080bee0b21b41d36531eac6d6273f4c4cdc4a" 281 | }, 282 | "ue_links": [], 283 | "VHS_latentpreview": false, 284 | "VHS_latentpreviewrate": 0 285 | }, 286 | "version": 0.4 287 | } -------------------------------------------------------------------------------- /uno/flux/sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import math 17 | from typing import Literal 18 | 19 | import torch 20 | from einops import rearrange, repeat 21 | from torch import Tensor 22 | from tqdm import tqdm 23 | 24 | from .model import Flux 25 | from .modules.conditioner import HFEmbedder 26 | 27 | 28 | def get_noise( 29 | num_samples: int, 30 | height: int, 31 | width: int, 32 | device: torch.device, 33 | dtype: torch.dtype, 34 | seed: int, 35 | ): 36 | return torch.randn( 37 | num_samples, 38 | 16, 39 | # allow for packing 40 | 2 * math.ceil(height / 16), 41 | 2 * math.ceil(width / 16), 42 | device=device, 43 | dtype=dtype, 44 | generator=torch.Generator(device=device).manual_seed(seed), 45 | ) 46 | 47 | 48 | def prepare( 49 | t5: HFEmbedder, 50 | clip: HFEmbedder, 51 | img: Tensor, 52 | prompt: str | list[str], 53 | ref_img: None | Tensor=None, 54 | pe: Literal['d', 'h', 'w', 'o'] ='d' 55 | ) -> dict[str, Tensor]: 56 | assert pe in ['d', 'h', 'w', 'o'] 57 | bs, c, h, w = img.shape 58 | if bs == 1 and not isinstance(prompt, str): 59 | bs = len(prompt) 60 | 61 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 62 | if img.shape[0] == 1 and bs > 1: 63 | img = repeat(img, "1 ... -> bs ...", bs=bs) 64 | 65 | img_ids = torch.zeros(h // 2, w // 2, 3) 66 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 67 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 68 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 69 | 70 | if ref_img is not None: 71 | _, _, ref_h, ref_w = ref_img.shape 72 | ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 73 | if ref_img.shape[0] == 1 and bs > 1: 74 | ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) 75 | ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3) 76 | # img id分别在宽高偏移各自最大值 77 | h_offset = h // 2 if pe in {'d', 'h'} else 0 78 | w_offset = w // 2 if pe in {'d', 'w'} else 0 79 | ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + h_offset 80 | ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + w_offset 81 | ref_img_ids = repeat(ref_img_ids, "h w c -> b (h w) c", b=bs) 82 | 83 | if isinstance(prompt, str): 84 | prompt = [prompt] 85 | txt = t5(prompt) 86 | if txt.shape[0] == 1 and bs > 1: 87 | txt = repeat(txt, "1 ... -> bs ...", bs=bs) 88 | txt_ids = torch.zeros(bs, txt.shape[1], 3) 89 | 90 | vec = clip(prompt) 91 | if vec.shape[0] == 1 and bs > 1: 92 | vec = repeat(vec, "1 ... -> bs ...", bs=bs) 93 | 94 | if ref_img is not None: 95 | return { 96 | "img": img, 97 | "img_ids": img_ids.to(img.device), 98 | "ref_img": ref_img, 99 | "ref_img_ids": ref_img_ids.to(img.device), 100 | "txt": txt.to(img.device), 101 | "txt_ids": txt_ids.to(img.device), 102 | "vec": vec.to(img.device), 103 | } 104 | else: 105 | return { 106 | "img": img, 107 | "img_ids": img_ids.to(img.device), 108 | "txt": txt.to(img.device), 109 | "txt_ids": txt_ids.to(img.device), 110 | "vec": vec.to(img.device), 111 | } 112 | 113 | def prepare_multi_ip( 114 | t5: HFEmbedder, 115 | clip: HFEmbedder, 116 | img: Tensor, 117 | prompt: str | list[str], 118 | ref_imgs: list[Tensor] | None = None, 119 | pe: Literal['d', 'h', 'w', 'o'] = 'd' 120 | ) -> dict[str, Tensor]: 121 | assert pe in ['d', 'h', 'w', 'o'] 122 | bs, c, h, w = img.shape 123 | if bs == 1 and not isinstance(prompt, str): 124 | bs = len(prompt) 125 | 126 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 127 | if img.shape[0] == 1 and bs > 1: 128 | img = repeat(img, "1 ... -> bs ...", bs=bs) 129 | 130 | img_ids = torch.zeros(h // 2, w // 2, 3) 131 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 132 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 133 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 134 | 135 | ref_img_ids = [] 136 | ref_imgs_list = [] 137 | pe_shift_w, pe_shift_h = w // 2, h // 2 138 | for ref_img in ref_imgs: 139 | _, _, ref_h1, ref_w1 = ref_img.shape 140 | ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 141 | if ref_img.shape[0] == 1 and bs > 1: 142 | ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) 143 | ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3) 144 | # img id分别在宽高偏移各自最大值 145 | h_offset = pe_shift_h if pe in {'d', 'h'} else 0 146 | w_offset = pe_shift_w if pe in {'d', 'w'} else 0 147 | ref_img_ids1[..., 1] = ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset 148 | ref_img_ids1[..., 2] = ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset 149 | ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs) 150 | ref_img_ids.append(ref_img_ids1) 151 | ref_imgs_list.append(ref_img) 152 | 153 | # 更新pe shift 154 | pe_shift_h += ref_h1 // 2 155 | pe_shift_w += ref_w1 // 2 156 | 157 | if isinstance(prompt, str): 158 | prompt = [prompt] 159 | txt = t5(prompt) 160 | if txt.shape[0] == 1 and bs > 1: 161 | txt = repeat(txt, "1 ... -> bs ...", bs=bs) 162 | txt_ids = torch.zeros(bs, txt.shape[1], 3) 163 | 164 | vec = clip(prompt) 165 | if vec.shape[0] == 1 and bs > 1: 166 | vec = repeat(vec, "1 ... -> bs ...", bs=bs) 167 | 168 | return { 169 | "img": img, 170 | "img_ids": img_ids.to(img.device), 171 | "ref_img": tuple(ref_imgs_list), 172 | "ref_img_ids": [ref_img_id.to(img.device) for ref_img_id in ref_img_ids], 173 | "txt": txt.to(img.device), 174 | "txt_ids": txt_ids.to(img.device), 175 | "vec": vec.to(img.device), 176 | } 177 | 178 | 179 | def time_shift(mu: float, sigma: float, t: Tensor): 180 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 181 | 182 | 183 | def get_lin_function( 184 | x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 185 | ): 186 | m = (y2 - y1) / (x2 - x1) 187 | b = y1 - m * x1 188 | return lambda x: m * x + b 189 | 190 | 191 | def get_schedule( 192 | num_steps: int, 193 | image_seq_len: int, 194 | base_shift: float = 0.5, 195 | max_shift: float = 1.15, 196 | shift: bool = True, 197 | ) -> list[float]: 198 | # extra step for zero 199 | timesteps = torch.linspace(1, 0, num_steps + 1) 200 | 201 | # shifting the schedule to favor high timesteps for higher signal images 202 | if shift: 203 | # eastimate mu based on linear estimation between two points 204 | mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) 205 | timesteps = time_shift(mu, 1.0, timesteps) 206 | 207 | return timesteps.tolist() 208 | 209 | 210 | def denoise( 211 | model: Flux, 212 | # model input 213 | img: Tensor, 214 | img_ids: Tensor, 215 | txt: Tensor, 216 | txt_ids: Tensor, 217 | vec: Tensor, 218 | # sampling parameters 219 | timesteps: list[float], 220 | guidance: float = 4.0, 221 | ref_img: Tensor=None, 222 | ref_img_ids: Tensor=None, 223 | ): 224 | i = 0 225 | guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) 226 | for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1): 227 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) 228 | pred = model( 229 | img=img, 230 | img_ids=img_ids, 231 | ref_img=ref_img, 232 | ref_img_ids=ref_img_ids, 233 | txt=txt, 234 | txt_ids=txt_ids, 235 | y=vec, 236 | timesteps=t_vec, 237 | guidance=guidance_vec 238 | ) 239 | img = img + (t_prev - t_curr) * pred 240 | i += 1 241 | return img 242 | 243 | 244 | def unpack(x: Tensor, height: int, width: int) -> Tensor: 245 | return rearrange( 246 | x, 247 | "b (h w) (c ph pw) -> b c (h ph) (w pw)", 248 | h=math.ceil(height / 16), 249 | w=math.ceil(width / 16), 250 | ph=2, 251 | pw=2, 252 | ) 253 | -------------------------------------------------------------------------------- /uno/flux/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass 17 | 18 | import torch 19 | from torch import Tensor, nn 20 | 21 | from .modules.layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding 22 | 23 | 24 | @dataclass 25 | class FluxParams: 26 | in_channels: int 27 | vec_in_dim: int 28 | context_in_dim: int 29 | hidden_size: int 30 | mlp_ratio: float 31 | num_heads: int 32 | depth: int 33 | depth_single_blocks: int 34 | axes_dim: list[int] 35 | theta: int 36 | qkv_bias: bool 37 | guidance_embed: bool 38 | 39 | 40 | class Flux(nn.Module): 41 | """ 42 | Transformer model for flow matching on sequences. 43 | """ 44 | _supports_gradient_checkpointing = True 45 | 46 | def __init__(self, params: FluxParams): 47 | super().__init__() 48 | 49 | self.params = params 50 | self.in_channels = params.in_channels 51 | self.out_channels = self.in_channels 52 | if params.hidden_size % params.num_heads != 0: 53 | raise ValueError( 54 | f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" 55 | ) 56 | pe_dim = params.hidden_size // params.num_heads 57 | if sum(params.axes_dim) != pe_dim: 58 | raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 59 | self.hidden_size = params.hidden_size 60 | self.num_heads = params.num_heads 61 | self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 62 | self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 63 | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 64 | self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 65 | self.guidance_in = ( 66 | MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() 67 | ) 68 | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 69 | 70 | self.double_blocks = nn.ModuleList( 71 | [ 72 | DoubleStreamBlock( 73 | self.hidden_size, 74 | self.num_heads, 75 | mlp_ratio=params.mlp_ratio, 76 | qkv_bias=params.qkv_bias, 77 | ) 78 | for _ in range(params.depth) 79 | ] 80 | ) 81 | 82 | self.single_blocks = nn.ModuleList( 83 | [ 84 | SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) 85 | for _ in range(params.depth_single_blocks) 86 | ] 87 | ) 88 | 89 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 90 | self.gradient_checkpointing = False 91 | 92 | def _set_gradient_checkpointing(self, module, value=False): 93 | if hasattr(module, "gradient_checkpointing"): 94 | module.gradient_checkpointing = value 95 | 96 | @property 97 | def attn_processors(self): 98 | # set recursively 99 | processors = {} # type: dict[str, nn.Module] 100 | 101 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): 102 | if hasattr(module, "set_processor"): 103 | processors[f"{name}.processor"] = module.processor 104 | 105 | for sub_name, child in module.named_children(): 106 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 107 | 108 | return processors 109 | 110 | for name, module in self.named_children(): 111 | fn_recursive_add_processors(name, module, processors) 112 | 113 | return processors 114 | 115 | def set_attn_processor(self, processor): 116 | r""" 117 | Sets the attention processor to use to compute attention. 118 | 119 | Parameters: 120 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 121 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 122 | for **all** `Attention` layers. 123 | 124 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 125 | processor. This is strongly recommended when setting trainable attention processors. 126 | 127 | """ 128 | count = len(self.attn_processors.keys()) 129 | 130 | if isinstance(processor, dict) and len(processor) != count: 131 | raise ValueError( 132 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 133 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 134 | ) 135 | 136 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 137 | if hasattr(module, "set_processor"): 138 | if not isinstance(processor, dict): 139 | module.set_processor(processor) 140 | else: 141 | module.set_processor(processor.pop(f"{name}.processor")) 142 | 143 | for sub_name, child in module.named_children(): 144 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 145 | 146 | for name, module in self.named_children(): 147 | fn_recursive_attn_processor(name, module, processor) 148 | 149 | def forward( 150 | self, 151 | img: Tensor, 152 | img_ids: Tensor, 153 | txt: Tensor, 154 | txt_ids: Tensor, 155 | timesteps: Tensor, 156 | y: Tensor, 157 | guidance: Tensor | None = None, 158 | ref_img: Tensor | None = None, 159 | ref_img_ids: Tensor | None = None, 160 | ) -> Tensor: 161 | if img.ndim != 3 or txt.ndim != 3: 162 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 163 | 164 | # running on sequences img 165 | img = self.img_in(img) 166 | vec = self.time_in(timestep_embedding(timesteps, 256)) 167 | if self.params.guidance_embed: 168 | if guidance is None: 169 | raise ValueError("Didn't get guidance strength for guidance distilled model.") 170 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 171 | vec = vec + self.vector_in(y) 172 | txt = self.txt_in(txt) 173 | 174 | ids = torch.cat((txt_ids, img_ids), dim=1) 175 | 176 | # concat ref_img/img 177 | img_end = img.shape[1] 178 | if ref_img is not None: 179 | if isinstance(ref_img, tuple) or isinstance(ref_img, list): 180 | img_in = [img] + [self.img_in(ref) for ref in ref_img] 181 | img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids] 182 | img = torch.cat(img_in, dim=1) 183 | ids = torch.cat(img_ids, dim=1) 184 | else: 185 | img = torch.cat((img, self.img_in(ref_img)), dim=1) 186 | ids = torch.cat((ids, ref_img_ids), dim=1) 187 | pe = self.pe_embedder(ids) 188 | 189 | for index_block, block in enumerate(self.double_blocks): 190 | if self.training and self.gradient_checkpointing: 191 | img, txt = torch.utils.checkpoint.checkpoint( 192 | block, 193 | img=img, 194 | txt=txt, 195 | vec=vec, 196 | pe=pe, 197 | use_reentrant=False, 198 | ) 199 | else: 200 | img, txt = block( 201 | img=img, 202 | txt=txt, 203 | vec=vec, 204 | pe=pe 205 | ) 206 | 207 | img = torch.cat((txt, img), 1) 208 | for block in self.single_blocks: 209 | if self.training and self.gradient_checkpointing: 210 | img = torch.utils.checkpoint.checkpoint( 211 | block, 212 | img, vec=vec, pe=pe, 213 | use_reentrant=False 214 | ) 215 | else: 216 | img = block(img, vec=vec, pe=pe) 217 | img = img[:, txt.shape[1] :, ...] 218 | # index img 219 | img = img[:, :img_end, ...] 220 | 221 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 222 | return img 223 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-RED-UNO 2 | Default 16GB VRAM UNO in context generation ComfyUI-node, using RED-UNO FT model 3 | 4 | 默认16G显存的UNO ComfyUI in-context生成组件,使用RED-UNO FT模型 5 | 6 | ## RED-UNO FLUX-dev-FT FP8: 7 | 8 | https://civitai.com/models/958009/redcraft-or-cads-or-updated-apr14-or-commercial-and-advertising-design-system 9 | 10 | ![image](https://github.com/user-attachments/assets/d354655e-a446-48f3-893b-edfb5e61d6f7) 11 | 12 | 13 | ## RED. UNO In-Context (FP8) 4/14/2025 14 | 15 | REDAIGC FT Model used to match UNO In-Context Generation 16 | 17 | (with improved quality compared to F.1 dev) 18 | 19 | 解决了FLUX FT底模无法适配UNO组件的问题,FP8权重(显存占用16GB),同时支持Diffusers以及ComfyUI 20 | 21 | 实测对比DEV生成质量明显提升,对比BF16版本没有明显质量损失,显存占用只有16G,生成时间大概30秒 22 | 23 | 24 | 25 | # 安装说明 26 | 27 | 由于ComfyUI-RED-UNO使用Diffusers库进行推理,所以首次使用,会自动下载Diffusers格式的T5和Clip(10G左右) 28 | 29 | 存放路径是ComfyUI虚拟环境的.cache缓存文件夹,如果是系统级安装的ComfyUI,则会存放至C盘User路径下的.cache(非常占空间) 30 | 31 | 建议是给ComfyUI的虚拟环境指定一个独立的缓存目录。 32 | 33 | FLUX FT底模正常存放在ComfyUI\models\diffusion_models目录下,VAE在ComfyUI\models\vae,UNO-Lora在ComfyUI\models\loras 34 | 35 | ![image](https://github.com/user-attachments/assets/572a5206-ab23-417f-b83e-61b2f06d0b8b) 36 | 37 | use_fp8和offload 默认勾选,否则24G显存就不够用了,VAE也要使用Diffusers格式的版本否则会报错(ComfyUI目前没有直接支持UNO所以才这么折腾) 38 | 39 | VAE版本: 40 | https://huggingface.co/GuangyuanSD/16C_vae_Diffusers 41 | 42 | 43 | --- 44 | 45 | Diffusers 脚本: 46 | 47 | https://github.com/bytedance/UNO 48 | 49 | Dit-LoRA 权重: 50 | 51 | https://huggingface.co/bytedance-research/UNO 52 | 53 | VAE版本: 54 | 55 | https://huggingface.co/GuangyuanSD/16C_vae_Diffusers 56 | 57 | 58 | # ComfyUI UNO Nodes 59 | 60 | ComfyUI UNO Nodes is a collection of nodes for ComfyUI that allows you to load and use UNO models. 61 | 62 | [https://github.com/bytedance/UNO comfyui](https://github.com/QijiTec/ComfyUI-RED-UNO.git) 63 | 64 | # 使用RED-UNO,因为UNO使用的模型格式和ComfyUI不一样,否则就会报错 65 | 66 | Missing keys: 236 67 | Unexpected keys: 236 68 | 69 | 这个错误是因为底模不匹配UNO算法,下载使用这个FLUX底模 https://civitai.com/models/958009 70 | 不可以把别的模型改名字替代,因为模型结构不一样!要用Diffusers格式的模型,才可以正确加载。 71 | 72 | 73 | FP8 support 74 | open offload and fp8 support 16GB VRAM 75 | 76 | flux model in unet directory and lora in lora directory 77 | 78 | clip and t5 will autodownload 79 | 80 | 81 | ![UNO0](https://github.com/user-attachments/assets/2d5e287e-73fc-4e95-ba3d-42f81ef920fc) 82 | 83 | 84 | 85 | ## Paper Analysis 86 | 87 | **UNO - Less-to-More Generalization Unlocking More Controllability by In-Context Generation** 88 | 89 | **Introduction** 90 | 91 | Subject-driven image generation, aiming to create new images based on textual descriptions and user-provided reference images, is a central challenge in the field of Artificial Intelligence Generated Content (AIGC). However, existing methods still face significant limitations, particularly regarding **data scalability** (especially acquiring high-quality multi-subject paired data) and **subject expansibility** (stably and controllably handling multiple subjects). Researchers from the Intelligent Creation Team at ByteDance address these issues in their paper, "Less-to-More Generalization: Unlocking More Controllability by In-Context Generation," by proposing a novel framework named **UNO (Universal aNd cOntrollable)**, based on an innovative "Less-to-More" generalization paradigm. 92 | 93 | **Core Challenges and Motivation** 94 | 95 | Traditional approaches grapple with data scarcity (especially for multi-subject pairs), difficulties in multi-subject control (identity confusion, layout issues), and the inherent trade-off between efficiency and fidelity. UNO's motivation stems from the need to systematically tackle these bottlenecks. 96 | 97 | **UNO's Core Contributions and Technical Breakdown** 98 | 99 | UNO presents a comprehensive solution encompassing data, model, and training strategies: 100 | 101 | 1. **Innovative "Model-Data Co-evolution" Paradigm:** This is UNO's most groundbreaking conceptual contribution. Instead of passively relying on existing data, it proposes a virtuous cycle where the model itself (even less capable versions) is leveraged to systematically synthesize higher-quality, more complex customized data (evolving from single-subject to multi-subject). This superior data, in turn, trains more powerful ("more controllable") model variants. 102 | 2. **High-Quality Synthetic Data Automation Pipeline:** Utilizes the in-context generation capability of Diffusion Transformers (DiTs), high-resolution output, and a sophisticated multi-stage filtering process involving DINOv2 + VLM (with Chain-of-Thought prompting) to automatically generate large-scale, high-fidelity, high-consistency single- and multi-subject paired data. 103 | 3. **UNO Model Architecture and Training Strategy:** Built upon a DiT architecture, it employs **Progressive Cross-Modal Alignment** (training first on single-subject, then multi-subject data) and introduces the crucial component: **UnoPE**. 104 | 105 | **In-Depth Comparison: UNO vs. OminiControl vs. In-Context LoRA** 106 | 107 | All three approaches recognize and leverage the powerful in-context learning capabilities of Diffusion Transformers (DiTs), but their methodologies and focuses differ significantly. Understanding these distinctions highlights UNO's unique value: 108 | 109 | * **Core Mechanism Utilization:** 110 | * **OminiControl:** Was among the earlier works demonstrating that DiTs can understand and replicate reference subjects through specific input formats (like side-by-side image templates) without explicit fine-tuning. It primarily relies on this "emergent" capability of the DiT itself. 111 | * **In-Context LoRA:** Focuses on using the DiT's context window for "few-shot learning" or "rapid adaptation." It posits that by providing a few (reference, target) example pairs within the context, the model can quickly grasp a new concept. This rapid adaptation is then potentially enhanced or guided by training specific LoRA modules designed to process this contextual information. 112 | * **UNO:** Does *not* rely on immediate contextual examples for learning. Instead, it utilizes large-scale **pre-synthesized**, high-quality data for **pre-training/fine-tuning**. This process internalizes the ability to understand and handle reference images (both single and multiple) directly into the model's weights (or attached modules like LoRA). It's more akin to "compiling" the context understanding capability into the model beforehand. 113 | 114 | * **Data Strategy Differences:** 115 | * **OminiControl:** Data generation appears relatively basic, potentially at lower resolutions, and with less sophisticated filtering. 116 | * **In-Context LoRA:** Likely relies on training with collections of high-quality (reference, target) example pairs. 117 | * **UNO:** Employs the most complex and systematic data strategy. It emphasizes high resolution, multiple aspect ratios, progressive generation from single to multi-subject, and rigorous filtering using VLM CoT. **Data quality is explicitly treated as crucial for pushing the model's performance ceiling**, embodying the "model-data co-evolution" philosophy. 118 | 119 | * **Multi-Subject Handling Capability:** 120 | * **OminiControl:** Lacks mechanisms explicitly designed for multi-subject scenarios. Its multi-subject performance likely depends heavily on the DiT's raw generalization ability, which might struggle with complex interactions or attribute distinction. 121 | * **In-Context LoRA:** The original paper doesn't heavily focus on specific solutions for multi-subject generation. Its effectiveness might be limited by the complexity of multi-subject examples provided in the context and the base DiT's generalization limits. 122 | * **UNO:** This is a standout advantage for UNO. It not only explicitly includes multi-subject scenarios in its data generation and training phases but, critically, introduces **UnoPE (Universal Rotary Position Embedding)**. UnoPE modifies positional encodings by assigning unique offsets to tokens from different reference images. This fundamentally helps the Transformer's attention mechanism **distinguish between different visual sources**, significantly **mitigating attribute confusion** between multiple subjects. It also encourages the model to prioritize the text prompt for layout instructions rather than simply replicating the spatial arrangement of reference images, leading to more robust and controllable multi-subject generation. 123 | 124 | * **Training and Inference:** 125 | * **OminiControl:** Training is likely simpler; inference requires providing input matching the expected template. 126 | * **In-Context LoRA:** Requires training specific LoRA modules. Inference might necessitate providing contextual example pairs along with loading the appropriate LoRA weights. 127 | * **UNO:** The training pipeline (including data generation and progressive stages) is more complex. However, the **inference-time user experience is simpler and aligns with standard tuning-free methods** like IP-Adapter. Users only need to provide the reference image(s) and the text prompt. 128 | 129 | **Comparison Summary (UNO vs. OminiControl vs. In-Context LoRA):** 130 | OminiControl represents an early exploration of DiT's in-context capabilities. In-Context LoRA investigates rapid concept adaptation using context, often coupled with LoRA. **UNO, however, offers a more comprehensive and systematic solution.** It leverages DiT's context abilities but significantly enhances the baseline capability through a superior data strategy (co-evolution, synthesis, filtering). Crucially, it **addresses the multi-subject challenge head-on with the dedicated UnoPE mechanism**, leading to robust and controllable generation in complex scenarios, all while maintaining a user-friendly, tuning-free inference process. 131 | 132 | **Experimental Results and Performance** 133 | 134 | UNO's effectiveness is strongly supported by its experimental results: 135 | 136 | * **Quantitative SOTA:** Achieves state-of-the-art scores on DreamBench (single-subject) and multi-subject benchmarks for subject similarity metrics (DINO, CLIP-I), while maintaining highly competitive text fidelity (CLIP-T). 137 | * **Qualitative Excellence:** Produces high-quality, natural-looking images. Demonstrates strong subject identity preservation, detail fidelity, stable multi-subject control, and effective attribute editing capabilities. 138 | * **Strong Generalization:** Showcases impressive performance across diverse applications like virtual try-on, product design, identity preservation, and stylized generation. 139 | * **Ablation Studies:** Thoroughly validate the necessity and effectiveness of each component: high-quality synthetic data, progressive training, and especially the UnoPE mechanism. 140 | 141 | **Conclusion** 142 | 143 | The UNO framework marks a significant advancement in controllable image generation. Through its unique "model-data co-evolution" concept, a superior synthetic data pipeline, and a well-designed model architecture featuring UnoPE and progressive training, UNO successfully overcomes critical bottlenecks in data scalability and multi-subject control. It not only achieves state-of-the-art performance but, more importantly, demonstrates a viable path towards unlocking greater controllability and generalization via systematic model self-improvement. As a powerful, unified, and tuning-free solution, UNO offers substantial technological support for personalized content creation, virtual reality, design, and beyond. 144 | 145 | Future work could further enhance UNO's versatility by expanding the types of synthetic data generated (e.g., including more editing or stylization pairs) to cover an even broader range of applications. 146 | 147 | --- 148 | -------------------------------------------------------------------------------- /uno/flux/pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from typing import Literal 18 | 19 | import torch 20 | from einops import rearrange 21 | from PIL import ExifTags, Image 22 | import torchvision.transforms.functional as TVF 23 | 24 | from uno.flux.modules.layers import ( 25 | DoubleStreamBlockLoraProcessor, 26 | DoubleStreamBlockProcessor, 27 | SingleStreamBlockLoraProcessor, 28 | SingleStreamBlockProcessor, 29 | ) 30 | from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack 31 | from uno.flux.util import ( 32 | get_lora_rank, 33 | load_ae, 34 | load_checkpoint, 35 | load_clip, 36 | load_flow_model, 37 | load_flow_model_only_lora, 38 | load_flow_model_quintized, 39 | load_t5, 40 | ) 41 | 42 | 43 | def find_nearest_scale(image_h, image_w, predefined_scales): 44 | """ 45 | 根据图片的高度和宽度,找到最近的预定义尺度。 46 | 47 | :param image_h: 图片的高度 48 | :param image_w: 图片的宽度 49 | :param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...] 50 | :return: 最近的预定义尺度 (h, w) 51 | """ 52 | # 计算输入图片的长宽比 53 | image_ratio = image_h / image_w 54 | 55 | # 初始化变量以存储最小差异和最近的尺度 56 | min_diff = float('inf') 57 | nearest_scale = None 58 | 59 | # 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度 60 | for scale_h, scale_w in predefined_scales: 61 | predefined_ratio = scale_h / scale_w 62 | diff = abs(predefined_ratio - image_ratio) 63 | 64 | if diff < min_diff: 65 | min_diff = diff 66 | nearest_scale = (scale_h, scale_w) 67 | 68 | return nearest_scale 69 | 70 | def preprocess_ref(raw_image: Image.Image, long_size: int = 512): 71 | # 获取原始图像的宽度和高度 72 | image_w, image_h = raw_image.size 73 | 74 | # 计算长边和短边 75 | if image_w >= image_h: 76 | new_w = long_size 77 | new_h = int((long_size / image_w) * image_h) 78 | else: 79 | new_h = long_size 80 | new_w = int((long_size / image_h) * image_w) 81 | 82 | # 按新的宽高进行等比例缩放 83 | raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS) 84 | target_w = new_w // 16 * 16 85 | target_h = new_h // 16 * 16 86 | 87 | # 计算裁剪的起始坐标以实现中心裁剪 88 | left = (new_w - target_w) // 2 89 | top = (new_h - target_h) // 2 90 | right = left + target_w 91 | bottom = top + target_h 92 | 93 | # 进行中心裁剪 94 | raw_image = raw_image.crop((left, top, right, bottom)) 95 | 96 | # 转换为 RGB 模式 97 | raw_image = raw_image.convert("RGB") 98 | return raw_image 99 | 100 | class UNOPipeline: 101 | def __init__( 102 | self, 103 | model_type: str, 104 | device: torch.device, 105 | offload: bool = False, 106 | only_lora: bool = False, 107 | lora_rank: int = 16 108 | ): 109 | self.device = device 110 | self.offload = offload 111 | self.model_type = model_type 112 | 113 | self.clip = load_clip(self.device) 114 | self.t5 = load_t5(self.device, max_length=512) 115 | self.ae = load_ae(model_type, device="cpu" if offload else self.device) 116 | self.use_fp8 = "fp8" in model_type 117 | if only_lora: 118 | self.model = load_flow_model_only_lora( 119 | model_type, 120 | device="cpu" if offload else self.device, 121 | lora_rank=lora_rank, 122 | use_fp8=self.use_fp8 123 | ) 124 | else: 125 | self.model = load_flow_model(model_type, device="cpu" if offload else self.device) 126 | 127 | 128 | def load_ckpt(self, ckpt_path): 129 | if ckpt_path is not None: 130 | from safetensors.torch import load_file as load_sft 131 | print("Loading checkpoint to replace old keys") 132 | # load_sft doesn't support torch.device 133 | if ckpt_path.endswith('safetensors'): 134 | sd = load_sft(ckpt_path, device='cpu') 135 | missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True) 136 | else: 137 | dit_state = torch.load(ckpt_path, map_location='cpu') 138 | sd = {} 139 | for k in dit_state.keys(): 140 | sd[k.replace('module.','')] = dit_state[k] 141 | missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True) 142 | self.model.to(str(self.device)) 143 | print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}") 144 | 145 | def set_lora(self, local_path: str = None, repo_id: str = None, 146 | name: str = None, lora_weight: int = 0.7): 147 | checkpoint = load_checkpoint(local_path, repo_id, name) 148 | self.update_model_with_lora(checkpoint, lora_weight) 149 | 150 | def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7): 151 | checkpoint = load_checkpoint( 152 | None, self.hf_lora_collection, self.lora_types_to_names[lora_type] 153 | ) 154 | self.update_model_with_lora(checkpoint, lora_weight) 155 | 156 | def update_model_with_lora(self, checkpoint, lora_weight): 157 | rank = get_lora_rank(checkpoint) 158 | lora_attn_procs = {} 159 | 160 | for name, _ in self.model.attn_processors.items(): 161 | lora_state_dict = {} 162 | for k in checkpoint.keys(): 163 | if name in k: 164 | lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight 165 | 166 | if len(lora_state_dict): 167 | if name.startswith("single_blocks"): 168 | lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank) 169 | else: 170 | lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) 171 | lora_attn_procs[name].load_state_dict(lora_state_dict) 172 | lora_attn_procs[name].to(self.device) 173 | else: 174 | if name.startswith("single_blocks"): 175 | lora_attn_procs[name] = SingleStreamBlockProcessor() 176 | else: 177 | lora_attn_procs[name] = DoubleStreamBlockProcessor() 178 | 179 | self.model.set_attn_processor(lora_attn_procs) 180 | 181 | 182 | def __call__( 183 | self, 184 | prompt: str, 185 | width: int = 512, 186 | height: int = 512, 187 | guidance: float = 4, 188 | num_steps: int = 50, 189 | seed: int = 123456789, 190 | **kwargs 191 | ): 192 | width = 16 * (width // 16) 193 | height = 16 * (height // 16) 194 | 195 | device_type = self.device if isinstance(self.device, str) else self.device.type 196 | with torch.autocast(enabled=self.use_fp8, device_type=device_type, dtype=torch.bfloat16): 197 | return self.forward( 198 | prompt, 199 | width, 200 | height, 201 | guidance, 202 | num_steps, 203 | seed, 204 | **kwargs 205 | ) 206 | 207 | @torch.inference_mode() 208 | def gradio_generate( 209 | self, 210 | prompt: str, 211 | width: int, 212 | height: int, 213 | guidance: float, 214 | num_steps: int, 215 | seed: int, 216 | image_prompt1: Image.Image, 217 | image_prompt2: Image.Image, 218 | image_prompt3: Image.Image, 219 | image_prompt4: Image.Image, 220 | ): 221 | ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4] 222 | ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)] 223 | ref_long_side = 512 if len(ref_imgs) <= 1 else 320 224 | ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs] 225 | 226 | seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item() 227 | 228 | img = self(prompt=prompt, width=width, height=height, guidance=guidance, 229 | num_steps=num_steps, seed=seed, ref_imgs=ref_imgs) 230 | 231 | filename = f"output/gradio/{seed}_{prompt[:20]}.png" 232 | os.makedirs(os.path.dirname(filename), exist_ok=True) 233 | exif_data = Image.Exif() 234 | exif_data[ExifTags.Base.Make] = "UNO" 235 | exif_data[ExifTags.Base.Model] = self.model_type 236 | info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}" 237 | exif_data[ExifTags.Base.ImageDescription] = info 238 | img.save(filename, format="png", exif=exif_data) 239 | return img, filename 240 | 241 | @torch.inference_mode 242 | def forward( 243 | self, 244 | prompt: str, 245 | width: int, 246 | height: int, 247 | guidance: float, 248 | num_steps: int, 249 | seed: int, 250 | ref_imgs: list[Image.Image] | None = None, 251 | pe: Literal['d', 'h', 'w', 'o'] = 'd', 252 | ): 253 | x = get_noise( 254 | 1, height, width, device=self.device, 255 | dtype=torch.bfloat16, seed=seed 256 | ) 257 | timesteps = get_schedule( 258 | num_steps, 259 | (width // 8) * (height // 8) // (16 * 16), 260 | shift=True, 261 | ) 262 | if self.offload: 263 | self.ae.encoder = self.ae.encoder.to(self.device) 264 | x_1_refs = [ 265 | self.ae.encode( 266 | (TVF.to_tensor(ref_img) * 2.0 - 1.0) 267 | .unsqueeze(0).to(self.device, torch.float32) 268 | ).to(torch.bfloat16) 269 | for ref_img in ref_imgs 270 | ] 271 | 272 | if self.offload: 273 | self.offload_model_to_cpu(self.ae.encoder) 274 | self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) 275 | inp_cond = prepare_multi_ip( 276 | t5=self.t5, clip=self.clip, 277 | img=x, 278 | prompt=prompt, ref_imgs=x_1_refs, pe=pe 279 | ) 280 | 281 | if self.offload: 282 | self.offload_model_to_cpu(self.t5, self.clip) 283 | self.model = self.model.to(self.device) 284 | 285 | x = denoise( 286 | self.model, 287 | **inp_cond, 288 | timesteps=timesteps, 289 | guidance=guidance, 290 | ) 291 | 292 | if self.offload: 293 | self.offload_model_to_cpu(self.model) 294 | self.ae.decoder.to(x.device) 295 | x = unpack(x.float(), height, width) 296 | x = self.ae.decode(x) 297 | self.offload_model_to_cpu(self.ae.decoder) 298 | 299 | x1 = x.clamp(-1, 1) 300 | x1 = rearrange(x1[-1], "c h w -> h w c") 301 | output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) 302 | return output_img 303 | 304 | def offload_model_to_cpu(self, *models): 305 | if not self.offload: return 306 | for model in models: 307 | model.cpu() 308 | torch.cuda.empty_cache() 309 | -------------------------------------------------------------------------------- /uno/flux/modules/autoencoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass 17 | 18 | import torch 19 | from einops import rearrange 20 | from torch import Tensor, nn 21 | 22 | 23 | @dataclass 24 | class AutoEncoderParams: 25 | resolution: int 26 | in_channels: int 27 | ch: int 28 | out_ch: int 29 | ch_mult: list[int] 30 | num_res_blocks: int 31 | z_channels: int 32 | scale_factor: float 33 | shift_factor: float 34 | 35 | 36 | def swish(x: Tensor) -> Tensor: 37 | return x * torch.sigmoid(x) 38 | 39 | 40 | class AttnBlock(nn.Module): 41 | def __init__(self, in_channels: int): 42 | super().__init__() 43 | self.in_channels = in_channels 44 | 45 | self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 46 | 47 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 48 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 49 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 50 | self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 51 | 52 | def attention(self, h_: Tensor) -> Tensor: 53 | h_ = self.norm(h_) 54 | q = self.q(h_) 55 | k = self.k(h_) 56 | v = self.v(h_) 57 | 58 | b, c, h, w = q.shape 59 | q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 60 | k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 61 | v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 62 | h_ = nn.functional.scaled_dot_product_attention(q, k, v) 63 | 64 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 65 | 66 | def forward(self, x: Tensor) -> Tensor: 67 | return x + self.proj_out(self.attention(x)) 68 | 69 | 70 | class ResnetBlock(nn.Module): 71 | def __init__(self, in_channels: int, out_channels: int): 72 | super().__init__() 73 | self.in_channels = in_channels 74 | out_channels = in_channels if out_channels is None else out_channels 75 | self.out_channels = out_channels 76 | 77 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 79 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 80 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 81 | if self.in_channels != self.out_channels: 82 | self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 83 | 84 | def forward(self, x): 85 | h = x 86 | h = self.norm1(h) 87 | h = swish(h) 88 | h = self.conv1(h) 89 | 90 | h = self.norm2(h) 91 | h = swish(h) 92 | h = self.conv2(h) 93 | 94 | if self.in_channels != self.out_channels: 95 | x = self.nin_shortcut(x) 96 | 97 | return x + h 98 | 99 | 100 | class Downsample(nn.Module): 101 | def __init__(self, in_channels: int): 102 | super().__init__() 103 | # no asymmetric padding in torch conv, must do it ourselves 104 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) 105 | 106 | def forward(self, x: Tensor): 107 | pad = (0, 1, 0, 1) 108 | x = nn.functional.pad(x, pad, mode="constant", value=0) 109 | x = self.conv(x) 110 | return x 111 | 112 | 113 | class Upsample(nn.Module): 114 | def __init__(self, in_channels: int): 115 | super().__init__() 116 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 117 | 118 | def forward(self, x: Tensor): 119 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 120 | x = self.conv(x) 121 | return x 122 | 123 | 124 | class Encoder(nn.Module): 125 | def __init__( 126 | self, 127 | resolution: int, 128 | in_channels: int, 129 | ch: int, 130 | ch_mult: list[int], 131 | num_res_blocks: int, 132 | z_channels: int, 133 | ): 134 | super().__init__() 135 | self.ch = ch 136 | self.num_resolutions = len(ch_mult) 137 | self.num_res_blocks = num_res_blocks 138 | self.resolution = resolution 139 | self.in_channels = in_channels 140 | # downsampling 141 | self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 142 | 143 | curr_res = resolution 144 | in_ch_mult = (1,) + tuple(ch_mult) 145 | self.in_ch_mult = in_ch_mult 146 | self.down = nn.ModuleList() 147 | block_in = self.ch 148 | for i_level in range(self.num_resolutions): 149 | block = nn.ModuleList() 150 | attn = nn.ModuleList() 151 | block_in = ch * in_ch_mult[i_level] 152 | block_out = ch * ch_mult[i_level] 153 | for _ in range(self.num_res_blocks): 154 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 155 | block_in = block_out 156 | down = nn.Module() 157 | down.block = block 158 | down.attn = attn 159 | if i_level != self.num_resolutions - 1: 160 | down.downsample = Downsample(block_in) 161 | curr_res = curr_res // 2 162 | self.down.append(down) 163 | 164 | # middle 165 | self.mid = nn.Module() 166 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 167 | self.mid.attn_1 = AttnBlock(block_in) 168 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 169 | 170 | # end 171 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 172 | self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) 173 | 174 | def forward(self, x: Tensor) -> Tensor: 175 | # downsampling 176 | hs = [self.conv_in(x)] 177 | for i_level in range(self.num_resolutions): 178 | for i_block in range(self.num_res_blocks): 179 | h = self.down[i_level].block[i_block](hs[-1]) 180 | if len(self.down[i_level].attn) > 0: 181 | h = self.down[i_level].attn[i_block](h) 182 | hs.append(h) 183 | if i_level != self.num_resolutions - 1: 184 | hs.append(self.down[i_level].downsample(hs[-1])) 185 | 186 | # middle 187 | h = hs[-1] 188 | h = self.mid.block_1(h) 189 | h = self.mid.attn_1(h) 190 | h = self.mid.block_2(h) 191 | # end 192 | h = self.norm_out(h) 193 | h = swish(h) 194 | h = self.conv_out(h) 195 | return h 196 | 197 | 198 | class Decoder(nn.Module): 199 | def __init__( 200 | self, 201 | ch: int, 202 | out_ch: int, 203 | ch_mult: list[int], 204 | num_res_blocks: int, 205 | in_channels: int, 206 | resolution: int, 207 | z_channels: int, 208 | ): 209 | super().__init__() 210 | self.ch = ch 211 | self.num_resolutions = len(ch_mult) 212 | self.num_res_blocks = num_res_blocks 213 | self.resolution = resolution 214 | self.in_channels = in_channels 215 | self.ffactor = 2 ** (self.num_resolutions - 1) 216 | 217 | # compute in_ch_mult, block_in and curr_res at lowest res 218 | block_in = ch * ch_mult[self.num_resolutions - 1] 219 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 220 | self.z_shape = (1, z_channels, curr_res, curr_res) 221 | 222 | # z to block_in 223 | self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 224 | 225 | # middle 226 | self.mid = nn.Module() 227 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 228 | self.mid.attn_1 = AttnBlock(block_in) 229 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 230 | 231 | # upsampling 232 | self.up = nn.ModuleList() 233 | for i_level in reversed(range(self.num_resolutions)): 234 | block = nn.ModuleList() 235 | attn = nn.ModuleList() 236 | block_out = ch * ch_mult[i_level] 237 | for _ in range(self.num_res_blocks + 1): 238 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 239 | block_in = block_out 240 | up = nn.Module() 241 | up.block = block 242 | up.attn = attn 243 | if i_level != 0: 244 | up.upsample = Upsample(block_in) 245 | curr_res = curr_res * 2 246 | self.up.insert(0, up) # prepend to get consistent order 247 | 248 | # end 249 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 250 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 251 | 252 | def forward(self, z: Tensor) -> Tensor: 253 | # z to block_in 254 | h = self.conv_in(z) 255 | 256 | # middle 257 | h = self.mid.block_1(h) 258 | h = self.mid.attn_1(h) 259 | h = self.mid.block_2(h) 260 | 261 | # upsampling 262 | for i_level in reversed(range(self.num_resolutions)): 263 | for i_block in range(self.num_res_blocks + 1): 264 | h = self.up[i_level].block[i_block](h) 265 | if len(self.up[i_level].attn) > 0: 266 | h = self.up[i_level].attn[i_block](h) 267 | if i_level != 0: 268 | h = self.up[i_level].upsample(h) 269 | 270 | # end 271 | h = self.norm_out(h) 272 | h = swish(h) 273 | h = self.conv_out(h) 274 | return h 275 | 276 | 277 | class DiagonalGaussian(nn.Module): 278 | def __init__(self, sample: bool = True, chunk_dim: int = 1): 279 | super().__init__() 280 | self.sample = sample 281 | self.chunk_dim = chunk_dim 282 | 283 | def forward(self, z: Tensor) -> Tensor: 284 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 285 | if self.sample: 286 | std = torch.exp(0.5 * logvar) 287 | return mean + std * torch.randn_like(mean) 288 | else: 289 | return mean 290 | 291 | 292 | class AutoEncoder(nn.Module): 293 | def __init__(self, params: AutoEncoderParams): 294 | super().__init__() 295 | self.encoder = Encoder( 296 | resolution=params.resolution, 297 | in_channels=params.in_channels, 298 | ch=params.ch, 299 | ch_mult=params.ch_mult, 300 | num_res_blocks=params.num_res_blocks, 301 | z_channels=params.z_channels, 302 | ) 303 | self.decoder = Decoder( 304 | resolution=params.resolution, 305 | in_channels=params.in_channels, 306 | ch=params.ch, 307 | out_ch=params.out_ch, 308 | ch_mult=params.ch_mult, 309 | num_res_blocks=params.num_res_blocks, 310 | z_channels=params.z_channels, 311 | ) 312 | self.reg = DiagonalGaussian() 313 | 314 | self.scale_factor = params.scale_factor 315 | self.shift_factor = params.shift_factor 316 | 317 | def encode(self, x: Tensor) -> Tensor: 318 | z = self.reg(self.encoder(x)) 319 | z = self.scale_factor * (z - self.shift_factor) 320 | return z 321 | 322 | def decode(self, z: Tensor) -> Tensor: 323 | z = z / self.scale_factor + self.shift_factor 324 | return self.decoder(z) 325 | 326 | def forward(self, x: Tensor) -> Tensor: 327 | return self.decode(self.encode(x)) 328 | -------------------------------------------------------------------------------- /uno_nodes/comfy_nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import re 5 | from PIL import Image 6 | from typing import Literal 7 | import sys 8 | 9 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 10 | 11 | from comfy.model_management import get_torch_device 12 | import folder_paths 13 | 14 | from uno.flux.modules.conditioner import HFEmbedder 15 | 16 | 17 | # 添加自定义加载模型的函数 18 | def custom_load_flux_model(model_path, device, use_fp8, lora_rank=512, lora_path=None): 19 | """ 20 | 从指定路径加载 Flux 模型 21 | """ 22 | from uno.flux.model import Flux 23 | from uno.flux.util import load_model 24 | from uno.flux.util import configs, print_load_warning, set_lora 25 | from safetensors.torch import load_file as load_sft 26 | 27 | if use_fp8: 28 | params = configs["flux-dev-fp8"].params 29 | else: 30 | params = configs["flux-dev"].params 31 | 32 | # 初始化模型 33 | with torch.device("meta" if model_path is not None else device): 34 | model = Flux(params) 35 | 36 | # 如果有lora,设置 LoRA 层 37 | if os.path.exists(lora_path): 38 | print(f"Using only_lora mode with rank: {lora_rank}") 39 | model = set_lora(model, lora_rank, device="meta" if model_path is not None else device) 40 | 41 | # 加载模型权重 42 | if model_path is not None: 43 | print(f"Loading Flux model from {model_path}") 44 | print("Loading lora") 45 | lora_sd = load_sft(lora_path, device=str(device)) if lora_path.endswith("safetensors")\ 46 | else torch.load(lora_path, map_location='cpu') 47 | print("Loading main checkpoint") 48 | if model_path.endswith('safetensors'): 49 | if use_fp8: 50 | print( 51 | "####\n" 52 | "We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n" 53 | "we convert the fp8 checkpoint on flight from bf16 checkpoint\n" 54 | "If your storage is constrained" 55 | "you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n" 56 | ) 57 | sd = load_sft(model_path, device="cpu") 58 | sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()} 59 | else: 60 | sd = load_sft(model_path, device=str(device)) 61 | 62 | sd.update(lora_sd) 63 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 64 | else: 65 | dit_state = torch.load(model_path, map_location='cpu') 66 | sd = {} 67 | for k in dit_state.keys(): 68 | sd[k.replace('module.','')] = dit_state[k] 69 | sd.update(lora_sd) 70 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 71 | model.to(str(device)) 72 | print_load_warning(missing, unexpected) 73 | 74 | return model 75 | 76 | def custom_load_ae(ae_path, device): 77 | """ 78 | 从指定路径加载自编码器 79 | """ 80 | from uno.flux.modules.autoencoder import AutoEncoder 81 | from uno.flux.util import load_model 82 | from uno.flux.util import configs 83 | from safetensors.torch import load_file as load_sft 84 | 85 | # 获取对应模型类型的自编码器参数 86 | ae_params = configs["flux-dev"].ae_params 87 | 88 | # 初始化自编码器 89 | with torch.device("meta" if ae_path is not None else device): 90 | ae = AutoEncoder(ae_params) 91 | 92 | # 加载自编码器权重 93 | if ae_path is not None: 94 | print(f"Loading AutoEncoder from {ae_path}") 95 | if ae_path.endswith('safetensors'): 96 | sd = load_sft(ae_path, device=str(device)) 97 | else: 98 | sd = torch.load(ae_path, map_location=str(device)) 99 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) 100 | if len(missing) > 0: 101 | print(f"Missing keys: {len(missing)}") 102 | if len(unexpected) > 0: 103 | print(f"Unexpected keys: {len(unexpected)}") 104 | 105 | # 转移到目标设备 106 | ae = ae.to(str(device)) 107 | return ae 108 | 109 | def custom_load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: 110 | # max length 64, 128, 256 and 512 should work (if your sequence is short enough) 111 | version = "xlabs-ai/xflux_text_encoders" 112 | if os.path.exists("/models/clip/xflux_text_encoders"): 113 | version = "/models/clip/xflux_text_encoders" 114 | return HFEmbedder(version, max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device) 115 | cache_dir = folder_paths.get_folder_paths("clip")[0] 116 | return HFEmbedder(version, max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16, cache_dir=cache_dir).to(device) 117 | 118 | def custom_load_clip(device: str | torch.device = "cuda") -> HFEmbedder: 119 | version = "openai/clip-vit-large-patch14" 120 | if os.path.exists("/models/clip/clip-vit-large-patch14"): 121 | version = "/models/clip/clip-vit-large-patch14" 122 | return HFEmbedder(version, max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device) 123 | cache_dir = folder_paths.get_folder_paths("clip")[0] 124 | return HFEmbedder(version, max_length=77, is_clip=True, torch_dtype=torch.bfloat16, cache_dir=cache_dir).to(device) 125 | 126 | 127 | 128 | class REDUNOModelLoader: 129 | def __init__(self): 130 | self.output_dir = folder_paths.get_output_directory() 131 | self.type = "UNO_MODEL" 132 | self.loaded_model = None 133 | 134 | @classmethod 135 | def INPUT_TYPES(cls): 136 | # 获取 unet 模型列表和 vae 模型列表 137 | model_paths = folder_paths.get_filename_list("unet") 138 | vae_paths = folder_paths.get_filename_list("vae") 139 | 140 | # 增加 LoRA 模型选项 141 | lora_paths = folder_paths.get_filename_list("loras") 142 | 143 | return { 144 | "required": { 145 | "flux_model": (model_paths, ), 146 | "ae_model": (vae_paths, ), 147 | "use_fp8": ("BOOLEAN", {"default": False}), 148 | "offload": ("BOOLEAN", {"default": False}), 149 | "lora_model": (["None"] + lora_paths, ), 150 | } 151 | } 152 | 153 | RETURN_TYPES = ("UNO_MODEL",) 154 | RETURN_NAMES = ("uno_model",) 155 | FUNCTION = "load_model" 156 | CATEGORY = "UNO" 157 | 158 | def load_model(self, flux_model, ae_model, use_fp8, offload, lora_model=None): 159 | device = get_torch_device() 160 | from uno.flux.pipeline import UNOPipeline 161 | 162 | try: 163 | # 获取模型文件的完整路径 164 | flux_model_path = folder_paths.get_full_path("unet", flux_model) 165 | ae_model_path = folder_paths.get_full_path("vae", ae_model) 166 | 167 | # 获取LoRA模型路径(如果有) 168 | lora_model_path = None 169 | if lora_model is not None and lora_model != "None": 170 | lora_model_path = folder_paths.get_full_path("loras", lora_model) 171 | 172 | print(f"Loading Flux model from: {flux_model_path}") 173 | print(f"Loading AE model from: {ae_model_path}") 174 | lora_rank = 512 175 | if lora_model_path: 176 | print(f"Loading LoRA model from: {lora_model_path}") 177 | 178 | # 创建自定义 UNO Pipeline 179 | class CustomUNOPipeline(UNOPipeline): 180 | def __init__(self, use_fp8, device, flux_path, ae_path, offload=False, 181 | lora_rank=512, lora_path=None): 182 | self.device = device 183 | self.offload = offload 184 | self.model_type = "flux-dev-fp8" if use_fp8 else "flux-dev" 185 | self.use_fp8 = use_fp8 186 | # 加载 CLIP 和 T5 编码器 187 | self.clip = custom_load_clip(device="cpu" if offload else self.device) 188 | self.t5 = custom_load_t5(device="cpu" if offload else self.device, max_length=512) 189 | 190 | # 加载自定义模型 191 | self.ae = custom_load_ae(ae_path, device="cpu" if offload else self.device) 192 | self.model = custom_load_flux_model( 193 | flux_path, 194 | device="cpu" if offload else self.device, 195 | use_fp8=use_fp8, 196 | lora_rank=lora_rank, 197 | lora_path=lora_path 198 | ) 199 | 200 | # 创建自定义 pipeline 201 | model = CustomUNOPipeline( 202 | use_fp8=use_fp8, 203 | device=device, 204 | flux_path=flux_model_path, 205 | ae_path=ae_model_path, 206 | offload=offload, 207 | lora_rank=lora_rank, 208 | lora_path=lora_model_path, 209 | ) 210 | 211 | self.loaded_model = model 212 | print(f"UNO model loaded successfully with custom models.") 213 | return (model,) 214 | except Exception as e: 215 | print(f"Error loading UNO model: {e}") 216 | import traceback 217 | traceback.print_exc() 218 | raise e 219 | 220 | 221 | class REDUNOGenerate: 222 | def __init__(self): 223 | self.output_dir = folder_paths.get_output_directory() 224 | os.makedirs(self.output_dir, exist_ok=True) 225 | 226 | @classmethod 227 | def INPUT_TYPES(cls): 228 | return { 229 | "required": { 230 | "uno_model": ("UNO_MODEL",), 231 | "prompt": ("STRING", {"multiline": True}), 232 | "width": ("INT", {"default": 512, "min": 256, "max": 2048, "step": 16}), 233 | "height": ("INT", {"default": 512, "min": 256, "max": 2048, "step": 16}), 234 | "guidance": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.1}), 235 | "num_steps": ("INT", {"default": 25, "min": 1, "max": 100}), 236 | "seed": ("INT", {"default": 3407}), 237 | "pe": (["d", "h", "w", "o"], {"default": "d"}), 238 | }, 239 | "optional": { 240 | "reference_image_1": ("IMAGE",), 241 | "reference_image_2": ("IMAGE",), 242 | "reference_image_3": ("IMAGE",), 243 | "reference_image_4": ("IMAGE",), 244 | } 245 | } 246 | 247 | RETURN_TYPES = ("IMAGE",) 248 | FUNCTION = "generate" 249 | CATEGORY = "UNO" 250 | 251 | def generate(self, uno_model, prompt, width, height, guidance, num_steps, seed, pe, 252 | reference_image_1=None, reference_image_2=None, reference_image_3=None, reference_image_4=None): 253 | # Make sure width and height are multiples of 16 254 | width = (width // 16) * 16 255 | height = (height // 16) * 16 256 | from uno.flux.pipeline import preprocess_ref 257 | 258 | # Process reference images if provided 259 | ref_imgs = [] 260 | ref_tensors = [reference_image_1, reference_image_2, reference_image_3, reference_image_4] 261 | for ref_tensor in ref_tensors: 262 | if ref_tensor is not None: 263 | # Convert from tensor to PIL 264 | if isinstance(ref_tensor, torch.Tensor): 265 | # Handle batch of images 266 | if ref_tensor.dim() == 4: # [batch, height, width, channels] 267 | for i in range(ref_tensor.shape[0]): 268 | img = ref_tensor[i].cpu().numpy() 269 | ref_image_pil = Image.fromarray((img * 255).astype(np.uint8)) 270 | # Determine reference size based on number of reference images 271 | ref_size = 512 if len([t for t in ref_tensors if t is not None]) <= 1 else 320 272 | ref_image_pil = preprocess_ref(ref_image_pil, ref_size) 273 | ref_imgs.append(ref_image_pil) 274 | else: # [height, width, channels] 275 | img = ref_tensor.cpu().numpy() 276 | ref_image_pil = Image.fromarray((img * 255).astype(np.uint8)) 277 | # Determine reference size based on number of reference images 278 | ref_size = 512 if len([t for t in ref_tensors if t is not None]) <= 1 else 320 279 | ref_image_pil = preprocess_ref(ref_image_pil, ref_size) 280 | ref_imgs.append(ref_image_pil) 281 | elif isinstance(ref_tensor, np.ndarray): 282 | # Assume ComfyUI range is [-1, 1], convert to [0, 1] 283 | ref_image_pil = Image.fromarray((img * 255).astype(np.uint8)) 284 | # Determine reference size based on number of reference images 285 | ref_size = 512 if len([t for t in ref_tensors if t is not None]) <= 1 else 320 286 | ref_image_pil = preprocess_ref(ref_image_pil, ref_size) 287 | ref_imgs.append(ref_image_pil) 288 | 289 | try: 290 | # Generate image 291 | output_img = uno_model( 292 | prompt=prompt, 293 | width=width, 294 | height=height, 295 | guidance=guidance, 296 | num_steps=num_steps, 297 | seed=seed, 298 | ref_imgs=ref_imgs, 299 | pe=pe 300 | ) 301 | 302 | # Save the generated image 303 | output_filename = f"uno_{seed}_{prompt[:20].replace(' ', '_')}.png" 304 | output_path = os.path.join(self.output_dir, output_filename) 305 | 306 | # Convert to ComfyUI-compatible tensor 307 | if hasattr(output_img, 'images') and len(output_img.images) > 0: 308 | # Handle FluxPipelineOutput 309 | output_img.images[0].save(output_path) 310 | print(f"Saved UNO generated image to {output_path}") 311 | image = np.array(output_img.images[0]) / 255.0 # Convert to [0, 1] 312 | else: 313 | # Handle PIL Image 314 | output_img.save(output_path) 315 | print(f"Saved UNO generated image to {output_path}") 316 | image = np.array(output_img) / 255.0 # Convert to [0, 1] 317 | 318 | # Convert numpy array to torch.Tensor 319 | image = torch.from_numpy(image).float() 320 | 321 | # Make sure it's in ComfyUI format [batch, height, width, channels] 322 | if image.dim() == 3: # [height, width, channels] 323 | image = image.unsqueeze(0) # Add batch dimension to make it [1, height, width, channels] 324 | 325 | 326 | return (image,) 327 | except Exception as e: 328 | print(f"Error generating image with UNO: {e}") 329 | raise e 330 | 331 | 332 | # Register our nodes to be used in ComfyUI 333 | NODE_CLASS_MAPPINGS = { 334 | "REDUNOModelLoader": REDUNOModelLoader, 335 | "REDUNOGenerate": REDUNOGenerate, 336 | } 337 | 338 | NODE_DISPLAY_NAME_MAPPINGS = { 339 | "REDUNOModelLoader": "UNO Model Loader @REDAIGC", 340 | "REDUNOGenerate": "UNO Generate @REDAIGC", 341 | } 342 | -------------------------------------------------------------------------------- /uno/flux/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from dataclasses import dataclass 18 | 19 | import torch 20 | import json 21 | import numpy as np 22 | from huggingface_hub import hf_hub_download 23 | from safetensors import safe_open 24 | from safetensors.torch import load_file as load_sft 25 | 26 | from .model import Flux, FluxParams 27 | from .modules.autoencoder import AutoEncoder, AutoEncoderParams 28 | from .modules.conditioner import HFEmbedder 29 | 30 | import re 31 | from uno.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor 32 | def load_model(ckpt, device='cpu'): 33 | if ckpt.endswith('safetensors'): 34 | from safetensors import safe_open 35 | pl_sd = {} 36 | with safe_open(ckpt, framework="pt", device=device) as f: 37 | for k in f.keys(): 38 | pl_sd[k] = f.get_tensor(k) 39 | else: 40 | pl_sd = torch.load(ckpt, map_location=device) 41 | return pl_sd 42 | 43 | def load_safetensors(path): 44 | tensors = {} 45 | with safe_open(path, framework="pt", device="cpu") as f: 46 | for key in f.keys(): 47 | tensors[key] = f.get_tensor(key) 48 | return tensors 49 | 50 | def get_lora_rank(checkpoint): 51 | for k in checkpoint.keys(): 52 | if k.endswith(".down.weight"): 53 | return checkpoint[k].shape[0] 54 | 55 | def load_checkpoint(local_path, repo_id, name): 56 | if local_path is not None: 57 | if '.safetensors' in local_path: 58 | print(f"Loading .safetensors checkpoint from {local_path}") 59 | checkpoint = load_safetensors(local_path) 60 | else: 61 | print(f"Loading checkpoint from {local_path}") 62 | checkpoint = torch.load(local_path, map_location='cpu') 63 | elif repo_id is not None and name is not None: 64 | print(f"Loading checkpoint {name} from repo id {repo_id}") 65 | checkpoint = load_from_repo_id(repo_id, name) 66 | else: 67 | raise ValueError( 68 | "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" 69 | ) 70 | return checkpoint 71 | 72 | 73 | def c_crop(image): 74 | width, height = image.size 75 | new_size = min(width, height) 76 | left = (width - new_size) / 2 77 | top = (height - new_size) / 2 78 | right = (width + new_size) / 2 79 | bottom = (height + new_size) / 2 80 | return image.crop((left, top, right, bottom)) 81 | 82 | def pad64(x): 83 | return int(np.ceil(float(x) / 64.0) * 64 - x) 84 | 85 | def HWC3(x): 86 | assert x.dtype == np.uint8 87 | if x.ndim == 2: 88 | x = x[:, :, None] 89 | assert x.ndim == 3 90 | H, W, C = x.shape 91 | assert C == 1 or C == 3 or C == 4 92 | if C == 3: 93 | return x 94 | if C == 1: 95 | return np.concatenate([x, x, x], axis=2) 96 | if C == 4: 97 | color = x[:, :, 0:3].astype(np.float32) 98 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 99 | y = color * alpha + 255.0 * (1.0 - alpha) 100 | y = y.clip(0, 255).astype(np.uint8) 101 | return y 102 | 103 | @dataclass 104 | class ModelSpec: 105 | params: FluxParams 106 | ae_params: AutoEncoderParams 107 | ckpt_path: str | None 108 | ae_path: str | None 109 | repo_id: str | None 110 | repo_flow: str | None 111 | repo_ae: str | None 112 | repo_id_ae: str | None 113 | 114 | 115 | configs = { 116 | "flux-dev": ModelSpec( 117 | repo_id="black-forest-labs/FLUX.1-dev", 118 | repo_id_ae="black-forest-labs/FLUX.1-dev", 119 | repo_flow="flux1-dev.safetensors", 120 | repo_ae="ae.safetensors", 121 | ckpt_path=os.getenv("FLUX_DEV"), 122 | params=FluxParams( 123 | in_channels=64, 124 | vec_in_dim=768, 125 | context_in_dim=4096, 126 | hidden_size=3072, 127 | mlp_ratio=4.0, 128 | num_heads=24, 129 | depth=19, 130 | depth_single_blocks=38, 131 | axes_dim=[16, 56, 56], 132 | theta=10_000, 133 | qkv_bias=True, 134 | guidance_embed=True, 135 | ), 136 | ae_path=os.getenv("AE"), 137 | ae_params=AutoEncoderParams( 138 | resolution=256, 139 | in_channels=3, 140 | ch=128, 141 | out_ch=3, 142 | ch_mult=[1, 2, 4, 4], 143 | num_res_blocks=2, 144 | z_channels=16, 145 | scale_factor=0.3611, 146 | shift_factor=0.1159, 147 | ), 148 | ), 149 | "flux-dev-fp8": ModelSpec( 150 | repo_id="black-forest-labs/FLUX.1-dev", 151 | repo_id_ae="black-forest-labs/FLUX.1-dev", 152 | repo_flow="flux1-dev.safetensors", 153 | repo_ae="ae.safetensors", 154 | ckpt_path=os.getenv("FLUX_DEV_FP8"), 155 | params=FluxParams( 156 | in_channels=64, 157 | vec_in_dim=768, 158 | context_in_dim=4096, 159 | hidden_size=3072, 160 | mlp_ratio=4.0, 161 | num_heads=24, 162 | depth=19, 163 | depth_single_blocks=38, 164 | axes_dim=[16, 56, 56], 165 | theta=10_000, 166 | qkv_bias=True, 167 | guidance_embed=True, 168 | ), 169 | ae_path=os.getenv("AE"), 170 | ae_params=AutoEncoderParams( 171 | resolution=256, 172 | in_channels=3, 173 | ch=128, 174 | out_ch=3, 175 | ch_mult=[1, 2, 4, 4], 176 | num_res_blocks=2, 177 | z_channels=16, 178 | scale_factor=0.3611, 179 | shift_factor=0.1159, 180 | ), 181 | ), 182 | "flux-schnell": ModelSpec( 183 | repo_id="black-forest-labs/FLUX.1-schnell", 184 | repo_id_ae="black-forest-labs/FLUX.1-dev", 185 | repo_flow="flux1-schnell.safetensors", 186 | repo_ae="ae.safetensors", 187 | ckpt_path=os.getenv("FLUX_SCHNELL"), 188 | params=FluxParams( 189 | in_channels=64, 190 | vec_in_dim=768, 191 | context_in_dim=4096, 192 | hidden_size=3072, 193 | mlp_ratio=4.0, 194 | num_heads=24, 195 | depth=19, 196 | depth_single_blocks=38, 197 | axes_dim=[16, 56, 56], 198 | theta=10_000, 199 | qkv_bias=True, 200 | guidance_embed=False, 201 | ), 202 | ae_path=os.getenv("AE"), 203 | ae_params=AutoEncoderParams( 204 | resolution=256, 205 | in_channels=3, 206 | ch=128, 207 | out_ch=3, 208 | ch_mult=[1, 2, 4, 4], 209 | num_res_blocks=2, 210 | z_channels=16, 211 | scale_factor=0.3611, 212 | shift_factor=0.1159, 213 | ), 214 | ), 215 | } 216 | 217 | 218 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None: 219 | if len(missing) > 0 and len(unexpected) > 0: 220 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 221 | print("\n" + "-" * 79 + "\n") 222 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 223 | elif len(missing) > 0: 224 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 225 | elif len(unexpected) > 0: 226 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 227 | 228 | def load_from_repo_id(repo_id, checkpoint_name): 229 | ckpt_path = hf_hub_download(repo_id, checkpoint_name) 230 | sd = load_sft(ckpt_path, device='cpu') 231 | return sd 232 | 233 | def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): 234 | # Loading Flux 235 | print("Init model") 236 | ckpt_path = configs[name].ckpt_path 237 | if ( 238 | ckpt_path is None 239 | and configs[name].repo_id is not None 240 | and configs[name].repo_flow is not None 241 | and hf_download 242 | ): 243 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) 244 | 245 | with torch.device("meta" if ckpt_path is not None else device): 246 | model = Flux(configs[name].params).to(torch.bfloat16) 247 | 248 | if ckpt_path is not None: 249 | print("Loading checkpoint") 250 | # load_sft doesn't support torch.device 251 | sd = load_model(ckpt_path, device=str(device)) 252 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 253 | print_load_warning(missing, unexpected) 254 | return model 255 | 256 | def load_flow_model_only_lora( 257 | name: str, 258 | device: str | torch.device = "cuda", 259 | hf_download: bool = True, 260 | lora_rank: int = 16, 261 | use_fp8: bool = False 262 | ): 263 | # Loading Flux 264 | print("Init model") 265 | ckpt_path = configs[name].ckpt_path 266 | if ( 267 | ckpt_path is None 268 | and configs[name].repo_id is not None 269 | and configs[name].repo_flow is not None 270 | and hf_download 271 | ): 272 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) 273 | 274 | if hf_download: 275 | try: 276 | lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors") 277 | except: 278 | lora_ckpt_path = os.environ.get("LORA", None) 279 | else: 280 | lora_ckpt_path = os.environ.get("LORA", None) 281 | 282 | with torch.device("meta" if ckpt_path is not None else device): 283 | model = Flux(configs[name].params) 284 | 285 | 286 | model = set_lora(model, lora_rank, device="meta" if lora_ckpt_path is not None else device) 287 | 288 | if ckpt_path is not None: 289 | print("Loading lora") 290 | lora_sd = load_sft(lora_ckpt_path, device=str(device)) if lora_ckpt_path.endswith("safetensors")\ 291 | else torch.load(lora_ckpt_path, map_location='cpu') 292 | 293 | print("Loading main checkpoint") 294 | # load_sft doesn't support torch.device 295 | 296 | if ckpt_path.endswith('safetensors'): 297 | if use_fp8: 298 | print( 299 | "####\n" 300 | "We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n" 301 | "we convert the fp8 checkpoint on flight from bf16 checkpoint\n" 302 | "If your storage is constrained" 303 | "you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n" 304 | ) 305 | sd = load_sft(ckpt_path, device="cpu") 306 | sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()} 307 | else: 308 | sd = load_sft(ckpt_path, device=str(device)) 309 | 310 | sd.update(lora_sd) 311 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 312 | else: 313 | dit_state = torch.load(ckpt_path, map_location='cpu') 314 | sd = {} 315 | for k in dit_state.keys(): 316 | sd[k.replace('module.','')] = dit_state[k] 317 | sd.update(lora_sd) 318 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 319 | model.to(str(device)) 320 | print_load_warning(missing, unexpected) 321 | return model 322 | 323 | 324 | def set_lora( 325 | model: Flux, 326 | lora_rank: int, 327 | double_blocks_indices: list[int] | None = None, 328 | single_blocks_indices: list[int] | None = None, 329 | device: str | torch.device = "cpu", 330 | ) -> Flux: 331 | double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices 332 | single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \ 333 | else single_blocks_indices 334 | 335 | lora_attn_procs = {} 336 | with torch.device(device): 337 | for name, attn_processor in model.attn_processors.items(): 338 | match = re.search(r'\.(\d+)\.', name) 339 | if match: 340 | layer_index = int(match.group(1)) 341 | 342 | if name.startswith("double_blocks") and layer_index in double_blocks_indices: 343 | lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank) 344 | elif name.startswith("single_blocks") and layer_index in single_blocks_indices: 345 | lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank) 346 | else: 347 | lora_attn_procs[name] = attn_processor 348 | model.set_attn_processor(lora_attn_procs) 349 | return model 350 | 351 | 352 | def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True): 353 | # Loading Flux 354 | from optimum.quanto import requantize 355 | print("Init model") 356 | ckpt_path = configs[name].ckpt_path 357 | if ( 358 | ckpt_path is None 359 | and configs[name].repo_id is not None 360 | and configs[name].repo_flow is not None 361 | and hf_download 362 | ): 363 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) 364 | # json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json') 365 | 366 | 367 | model = Flux(configs[name].params).to(torch.bfloat16) 368 | 369 | print("Loading checkpoint") 370 | # load_sft doesn't support torch.device 371 | sd = load_sft(ckpt_path, device='cpu') 372 | sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()} 373 | model.load_state_dict(sd, assign=True) 374 | return model 375 | with open(json_path, "r") as f: 376 | quantization_map = json.load(f) 377 | print("Start a quantization process...") 378 | requantize(model, sd, quantization_map, device=device) 379 | print("Model is quantized!") 380 | return model 381 | 382 | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: 383 | # max length 64, 128, 256 and 512 should work (if your sequence is short enough) 384 | version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders") 385 | return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device) 386 | 387 | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: 388 | version = os.environ.get("CLIP", "openai/clip-vit-large-patch14") 389 | return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device) 390 | 391 | 392 | def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: 393 | ckpt_path = configs[name].ae_path 394 | if ( 395 | ckpt_path is None 396 | and configs[name].repo_id is not None 397 | and configs[name].repo_ae is not None 398 | and hf_download 399 | ): 400 | ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) 401 | 402 | # Loading the autoencoder 403 | print("Init AE") 404 | with torch.device("meta" if ckpt_path is not None else device): 405 | ae = AutoEncoder(configs[name].ae_params) 406 | 407 | if ckpt_path is not None: 408 | sd = load_sft(ckpt_path, device=str(device)) 409 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) 410 | print_load_warning(missing, unexpected) 411 | return ae -------------------------------------------------------------------------------- /uno/flux/modules/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import math 17 | from dataclasses import dataclass 18 | 19 | import torch 20 | from einops import rearrange 21 | from torch import Tensor, nn 22 | 23 | from ..math import attention, rope 24 | import torch.nn.functional as F 25 | 26 | class EmbedND(nn.Module): 27 | def __init__(self, dim: int, theta: int, axes_dim: list[int]): 28 | super().__init__() 29 | self.dim = dim 30 | self.theta = theta 31 | self.axes_dim = axes_dim 32 | 33 | def forward(self, ids: Tensor) -> Tensor: 34 | n_axes = ids.shape[-1] 35 | emb = torch.cat( 36 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 37 | dim=-3, 38 | ) 39 | 40 | return emb.unsqueeze(1) 41 | 42 | 43 | def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): 44 | """ 45 | Create sinusoidal timestep embeddings. 46 | :param t: a 1-D Tensor of N indices, one per batch element. 47 | These may be fractional. 48 | :param dim: the dimension of the output. 49 | :param max_period: controls the minimum frequency of the embeddings. 50 | :return: an (N, D) Tensor of positional embeddings. 51 | """ 52 | t = time_factor * t 53 | half = dim // 2 54 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( 55 | t.device 56 | ) 57 | 58 | args = t[:, None].float() * freqs[None] 59 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 60 | if dim % 2: 61 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 62 | if torch.is_floating_point(t): 63 | embedding = embedding.to(t) 64 | return embedding 65 | 66 | 67 | class MLPEmbedder(nn.Module): 68 | def __init__(self, in_dim: int, hidden_dim: int): 69 | super().__init__() 70 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) 71 | self.silu = nn.SiLU() 72 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) 73 | 74 | def forward(self, x: Tensor) -> Tensor: 75 | return self.out_layer(self.silu(self.in_layer(x))) 76 | 77 | 78 | class RMSNorm(torch.nn.Module): 79 | def __init__(self, dim: int): 80 | super().__init__() 81 | self.scale = nn.Parameter(torch.ones(dim)) 82 | 83 | def forward(self, x: Tensor): 84 | x_dtype = x.dtype 85 | x = x.float() 86 | rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) 87 | return ((x * rrms) * self.scale.float()).to(dtype=x_dtype) 88 | 89 | 90 | class QKNorm(torch.nn.Module): 91 | def __init__(self, dim: int): 92 | super().__init__() 93 | self.query_norm = RMSNorm(dim) 94 | self.key_norm = RMSNorm(dim) 95 | 96 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 97 | q = self.query_norm(q) 98 | k = self.key_norm(k) 99 | return q.to(v), k.to(v) 100 | 101 | class LoRALinearLayer(nn.Module): 102 | def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): 103 | super().__init__() 104 | 105 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 106 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 107 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 108 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 109 | self.network_alpha = network_alpha 110 | self.rank = rank 111 | 112 | nn.init.normal_(self.down.weight, std=1 / rank) 113 | nn.init.zeros_(self.up.weight) 114 | 115 | def forward(self, hidden_states): 116 | orig_dtype = hidden_states.dtype 117 | dtype = self.down.weight.dtype 118 | 119 | down_hidden_states = self.down(hidden_states.to(dtype)) 120 | up_hidden_states = self.up(down_hidden_states) 121 | 122 | if self.network_alpha is not None: 123 | up_hidden_states *= self.network_alpha / self.rank 124 | 125 | return up_hidden_states.to(orig_dtype) 126 | 127 | class FLuxSelfAttnProcessor: 128 | def __call__(self, attn, x, pe, **attention_kwargs): 129 | qkv = attn.qkv(x) 130 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 131 | q, k = attn.norm(q, k, v) 132 | x = attention(q, k, v, pe=pe) 133 | x = attn.proj(x) 134 | return x 135 | 136 | class LoraFluxAttnProcessor(nn.Module): 137 | 138 | def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 139 | super().__init__() 140 | self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 141 | self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) 142 | self.lora_weight = lora_weight 143 | 144 | 145 | def __call__(self, attn, x, pe, **attention_kwargs): 146 | qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight 147 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 148 | q, k = attn.norm(q, k, v) 149 | x = attention(q, k, v, pe=pe) 150 | x = attn.proj(x) + self.proj_lora(x) * self.lora_weight 151 | return x 152 | 153 | class SelfAttention(nn.Module): 154 | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 155 | super().__init__() 156 | self.num_heads = num_heads 157 | head_dim = dim // num_heads 158 | 159 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 160 | self.norm = QKNorm(head_dim) 161 | self.proj = nn.Linear(dim, dim) 162 | def forward(): 163 | pass 164 | 165 | 166 | @dataclass 167 | class ModulationOut: 168 | shift: Tensor 169 | scale: Tensor 170 | gate: Tensor 171 | 172 | 173 | class Modulation(nn.Module): 174 | def __init__(self, dim: int, double: bool): 175 | super().__init__() 176 | self.is_double = double 177 | self.multiplier = 6 if double else 3 178 | self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) 179 | 180 | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: 181 | out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) 182 | 183 | return ( 184 | ModulationOut(*out[:3]), 185 | ModulationOut(*out[3:]) if self.is_double else None, 186 | ) 187 | 188 | class DoubleStreamBlockLoraProcessor(nn.Module): 189 | def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 190 | super().__init__() 191 | self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 192 | self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) 193 | self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 194 | self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) 195 | self.lora_weight = lora_weight 196 | 197 | def forward(self, attn, img, txt, vec, pe, **attention_kwargs): 198 | img_mod1, img_mod2 = attn.img_mod(vec) 199 | txt_mod1, txt_mod2 = attn.txt_mod(vec) 200 | 201 | # prepare image for attention 202 | img_modulated = attn.img_norm1(img) 203 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 204 | img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight 205 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 206 | img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 207 | 208 | # prepare txt for attention 209 | txt_modulated = attn.txt_norm1(txt) 210 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 211 | txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight 212 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 213 | txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 214 | 215 | # run actual attention 216 | q = torch.cat((txt_q, img_q), dim=2) 217 | k = torch.cat((txt_k, img_k), dim=2) 218 | v = torch.cat((txt_v, img_v), dim=2) 219 | 220 | attn1 = attention(q, k, v, pe=pe) 221 | txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 222 | 223 | # calculate the img bloks 224 | img = img + img_mod1.gate * (attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight) 225 | img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) 226 | 227 | # calculate the txt bloks 228 | txt = txt + txt_mod1.gate * (attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight) 229 | txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) 230 | return img, txt 231 | 232 | class DoubleStreamBlockProcessor: 233 | def __call__(self, attn, img, txt, vec, pe, **attention_kwargs): 234 | img_mod1, img_mod2 = attn.img_mod(vec) 235 | txt_mod1, txt_mod2 = attn.txt_mod(vec) 236 | 237 | # prepare image for attention 238 | img_modulated = attn.img_norm1(img) 239 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 240 | img_qkv = attn.img_attn.qkv(img_modulated) 241 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 242 | img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 243 | 244 | # prepare txt for attention 245 | txt_modulated = attn.txt_norm1(txt) 246 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 247 | txt_qkv = attn.txt_attn.qkv(txt_modulated) 248 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 249 | txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 250 | 251 | # run actual attention 252 | q = torch.cat((txt_q, img_q), dim=2) 253 | k = torch.cat((txt_k, img_k), dim=2) 254 | v = torch.cat((txt_v, img_v), dim=2) 255 | 256 | attn1 = attention(q, k, v, pe=pe) 257 | txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 258 | 259 | # calculate the img bloks 260 | img = img + img_mod1.gate * attn.img_attn.proj(img_attn) 261 | img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) 262 | 263 | # calculate the txt bloks 264 | txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) 265 | txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) 266 | return img, txt 267 | 268 | class DoubleStreamBlock(nn.Module): 269 | def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): 270 | super().__init__() 271 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 272 | self.num_heads = num_heads 273 | self.hidden_size = hidden_size 274 | self.head_dim = hidden_size // num_heads 275 | 276 | self.img_mod = Modulation(hidden_size, double=True) 277 | self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 278 | self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 279 | 280 | self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 281 | self.img_mlp = nn.Sequential( 282 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 283 | nn.GELU(approximate="tanh"), 284 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 285 | ) 286 | 287 | self.txt_mod = Modulation(hidden_size, double=True) 288 | self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 289 | self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 290 | 291 | self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 292 | self.txt_mlp = nn.Sequential( 293 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 294 | nn.GELU(approximate="tanh"), 295 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 296 | ) 297 | processor = DoubleStreamBlockProcessor() 298 | self.set_processor(processor) 299 | 300 | def set_processor(self, processor) -> None: 301 | self.processor = processor 302 | 303 | def get_processor(self): 304 | return self.processor 305 | 306 | def forward( 307 | self, 308 | img: Tensor, 309 | txt: Tensor, 310 | vec: Tensor, 311 | pe: Tensor, 312 | image_proj: Tensor = None, 313 | ip_scale: float =1.0, 314 | ) -> tuple[Tensor, Tensor]: 315 | if image_proj is None: 316 | return self.processor(self, img, txt, vec, pe) 317 | else: 318 | return self.processor(self, img, txt, vec, pe, image_proj, ip_scale) 319 | 320 | 321 | class SingleStreamBlockLoraProcessor(nn.Module): 322 | def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1): 323 | super().__init__() 324 | self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 325 | self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha) 326 | self.lora_weight = lora_weight 327 | 328 | def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: 329 | 330 | mod, _ = attn.modulation(vec) 331 | x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 332 | qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) 333 | qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight 334 | 335 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 336 | q, k = attn.norm(q, k, v) 337 | 338 | # compute attention 339 | attn_1 = attention(q, k, v, pe=pe) 340 | 341 | # compute activation in mlp stream, cat again and run second linear layer 342 | output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) 343 | output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight 344 | output = x + mod.gate * output 345 | return output 346 | 347 | 348 | class SingleStreamBlockProcessor: 349 | def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs) -> Tensor: 350 | 351 | mod, _ = attn.modulation(vec) 352 | x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 353 | qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) 354 | 355 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 356 | q, k = attn.norm(q, k, v) 357 | 358 | # compute attention 359 | attn_1 = attention(q, k, v, pe=pe) 360 | 361 | # compute activation in mlp stream, cat again and run second linear layer 362 | output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) 363 | output = x + mod.gate * output 364 | return output 365 | 366 | class SingleStreamBlock(nn.Module): 367 | """ 368 | A DiT block with parallel linear layers as described in 369 | https://arxiv.org/abs/2302.05442 and adapted modulation interface. 370 | """ 371 | 372 | def __init__( 373 | self, 374 | hidden_size: int, 375 | num_heads: int, 376 | mlp_ratio: float = 4.0, 377 | qk_scale: float | None = None, 378 | ): 379 | super().__init__() 380 | self.hidden_dim = hidden_size 381 | self.num_heads = num_heads 382 | self.head_dim = hidden_size // num_heads 383 | self.scale = qk_scale or self.head_dim**-0.5 384 | 385 | self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 386 | # qkv and mlp_in 387 | self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 388 | # proj and mlp_out 389 | self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 390 | 391 | self.norm = QKNorm(self.head_dim) 392 | 393 | self.hidden_size = hidden_size 394 | self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 395 | 396 | self.mlp_act = nn.GELU(approximate="tanh") 397 | self.modulation = Modulation(hidden_size, double=False) 398 | 399 | processor = SingleStreamBlockProcessor() 400 | self.set_processor(processor) 401 | 402 | 403 | def set_processor(self, processor) -> None: 404 | self.processor = processor 405 | 406 | def get_processor(self): 407 | return self.processor 408 | 409 | def forward( 410 | self, 411 | x: Tensor, 412 | vec: Tensor, 413 | pe: Tensor, 414 | image_proj: Tensor | None = None, 415 | ip_scale: float = 1.0, 416 | ) -> Tensor: 417 | if image_proj is None: 418 | return self.processor(self, x, vec, pe) 419 | else: 420 | return self.processor(self, x, vec, pe, image_proj, ip_scale) 421 | 422 | 423 | 424 | class LastLayer(nn.Module): 425 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 426 | super().__init__() 427 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 428 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 429 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) 430 | 431 | def forward(self, x: Tensor, vec: Tensor) -> Tensor: 432 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) 433 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] 434 | x = self.linear(x) 435 | return x 436 | --------------------------------------------------------------------------------