├── rh_config.json
├── __init__.py
├── requirements.txt
├── .gitignore
├── uso
└── flux
│ ├── math.py
│ ├── modules
│ ├── conditioner.py
│ ├── autoencoder.py
│ └── layers.py
│ ├── sampling.py
│ ├── model.py
│ ├── pipeline.py
│ └── util.py
├── README_CN.md
├── README.md
├── rh_uso_nodes.py
├── inference.py
├── app.py
└── LICENSE
/rh_config.json:
--------------------------------------------------------------------------------
1 | {"enable": true, "untracked_paths": []}
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .rh_uso_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
2 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate>=1.1.1
2 | deepspeed>=0.14.4
3 | einops>=0.8.0
4 | transformers>=4.43.3
5 | huggingface-hub
6 | diffusers>=0.30.1
7 | sentencepiece>=0.2.0
8 | gradio>=5.22.0
9 | opencv-python
10 | matplotlib
11 | safetensors>=0.4.5
12 | scipy>=1.10.1
13 | numpy>=1.24.4
14 | onnxruntime-gpu
15 | # httpx==0.23.3
16 | git+https://github.com/openai/CLIP.git
17 | --extra-index-url https://download.pytorch.org/whl/cu124
18 | torch>=2.4.0
19 | torchvision>=0.19.0
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | *.egg-info/
20 | .installed.cfg
21 | *.egg
22 |
23 | # PyInstaller
24 | *.manifest
25 | *.spec
26 |
27 | # Virtual environments
28 | .env
29 | .venv
30 | env/
31 | venv/
32 | ENV/
33 | env.bak/
34 | venv.bak/
35 |
36 | # IDE
37 | .vscode/
38 | .idea/
39 | *.swp
40 | *.swo
41 | *~
42 |
43 | # OS
44 | .DS_Store
45 | Thumbs.db
46 |
47 | # Temporary files
48 | temp/
49 | tmp/
50 | *.tmp
51 | *.temp
52 |
53 | # Logs
54 | *.log
--------------------------------------------------------------------------------
/uso/flux/math.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | from einops import rearrange
18 | from torch import Tensor
19 |
20 |
21 | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
22 | q, k = apply_rope(q, k, pe)
23 |
24 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
25 | x = rearrange(x, "B H L D -> B L (H D)")
26 |
27 | return x
28 |
29 |
30 | def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
31 | assert dim % 2 == 0
32 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
33 | omega = 1.0 / (theta**scale)
34 | out = torch.einsum("...n,d->...nd", pos, omega)
35 | out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
36 | out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
37 | return out.float()
38 |
39 |
40 | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
41 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
42 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
43 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
44 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
45 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
46 |
--------------------------------------------------------------------------------
/uso/flux/modules/conditioner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from torch import Tensor, nn
17 | from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
18 | T5Tokenizer)
19 |
20 |
21 | class HFEmbedder(nn.Module):
22 | def __init__(self, version: str, max_length: int, **hf_kwargs):
23 | super().__init__()
24 | # self.is_clip = "clip" in version.lower()
25 | #kiki
26 | self.is_clip = hf_kwargs.pop('is_clip', False)
27 | self.max_length = max_length
28 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
29 |
30 | if self.is_clip:
31 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
32 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
33 | else:
34 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
35 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
36 |
37 | self.hf_module = self.hf_module.eval().requires_grad_(False)
38 |
39 | def forward(self, text: list[str]) -> Tensor:
40 | batch_encoding = self.tokenizer(
41 | text,
42 | truncation=True,
43 | max_length=self.max_length,
44 | return_length=False,
45 | return_overflowing_tokens=False,
46 | padding="max_length",
47 | return_tensors="pt",
48 | )
49 |
50 | outputs = self.hf_module(
51 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
52 | attention_mask=None,
53 | output_hidden_states=False,
54 | )
55 | return outputs[self.output_key]
56 |
--------------------------------------------------------------------------------
/README_CN.md:
--------------------------------------------------------------------------------
1 | # ComfyUI USO 节点
2 |
3 | 一个用于ComfyUI的自定义节点,集成USO(统一风格和主题驱动生成)模型,实现高质量的风格和主题控制图像生成。
4 |
5 | ## ✨ 特性
6 |
7 | - 🎨 **统一风格与主题生成**: 基于FLUX架构的USO模型
8 | - 🎯 **风格驱动生成**: 根据特定艺术风格生成图像
9 | - 👤 **主题驱动生成**: 在生成过程中保持主题一致性
10 | - 🔄 **多风格支持**: 在单次生成中结合多种风格
11 | - ⚙️ **内存优化**: 支持FP8精度,适用于消费级GPU(约16GB显存)
12 | - 🚀 **灵活控制**: 高级参数控制,精细调节生成结果
13 |
14 | ## 🔧 节点列表
15 |
16 | ### 核心节点
17 | - **RH_USO_Loader**: 加载和初始化USO模型,包含优化选项
18 | - **RH_USO_Generator**: 具有风格和主题控制的图像生成器
19 |
20 | ## 🚀 快速安装
21 |
22 | ### 步骤1: 安装节点
23 | ```bash
24 | # 进入ComfyUI自定义节点目录
25 | cd ComfyUI/custom_nodes
26 |
27 | # 克隆仓库
28 | git clone https://github.com/HM-RunningHub/ComfyUI_RH_USO
29 |
30 | # 安装依赖
31 | cd ComfyUI_RH_USO
32 | pip install -r requirements.txt
33 | ```
34 |
35 | ### 步骤2: 下载所需模型
36 | ```bash
37 | # 下载FLUX.1-dev模型(必需的基础模型)
38 | huggingface-cli download black-forest-labs/FLUX.1-dev flux1-dev.safetensors --local-dir models/diffusers/FLUX.1-dev
39 | huggingface-cli download black-forest-labs/FLUX.1-dev ae.safetensors --local-dir models/diffusers/FLUX.1-dev
40 |
41 | # 下载USO模型
42 | huggingface-cli download bytedance-research/USO --local-dir models/uso
43 |
44 | # 下载SigLIP模型
45 | huggingface-cli download google/siglip-so400m-patch14-384 --local-dir models/clip/siglip-so400m-patch14-384
46 |
47 | # 最终模型结构应该如下:
48 | models/
49 | ├── diffusers/
50 | │ └── FLUX.1-dev/
51 | │ ├── flux1-dev.safetensors
52 | │ └── ae.safetensors
53 | ├── uso/
54 | │ ├── assets/
55 | │ │ └── uso.webp
56 | │ ├── config.json
57 | │ ├── download_repo_enhanced.py
58 | │ ├── README.md
59 | │ └── uso_flux_v1.0/
60 | │ ├── dit_lora.safetensors
61 | │ └── projector.safetensors
62 | └── clip/
63 | └── siglip-so400m-patch14-384/
64 |
65 | # 重启ComfyUI
66 | ```
67 |
68 | ## 📖 使用方法
69 |
70 | ### 基础工作流
71 | ```
72 | [RH_USO_Loader] → [RH_USO_Generator] → [Save Image]
73 | ```
74 |
75 | ### 生成类型
76 |
77 | #### 风格驱动生成
78 | - 加载风格参考图像
79 | - 输入描述内容的文本提示
80 | - 生成指定风格的图像
81 |
82 | #### 主题驱动生成
83 | - 加载主题参考图像
84 | - 输入包含场景描述的文本提示
85 | - 生成保持主题身份的图像
86 |
87 | #### 风格+主题生成
88 | - 同时加载风格和主题参考图像
89 | - 结合风格转换与主题一致性
90 | - 生成具有统一风格且保持主题的图像
91 |
92 | ## 🛠️ 技术要求
93 |
94 | - **GPU**: 16GB+显存(使用FP8优化)
95 | - **内存**: 推荐32GB+
96 | - **存储**: 约35GB用于所有模型
97 | - FLUX.1-dev: ~24GB (flux1-dev.safetensors + ae.safetensors)
98 | - USO模型: ~6GB
99 | - SigLIP: ~1.5GB
100 | - **CUDA**: 优化性能需要CUDA支持
101 |
102 | ## ⚠️ 重要提示
103 |
104 | - **模型路径**: 模型必须放置在特定目录:
105 | - FLUX.1-dev → `models/diffusers/FLUX.1-dev/`
106 | - USO模型 → `models/uso/`
107 | - SigLIP → `models/clip/siglip-so400m-patch14-384/`
108 | - 推荐消费级GPU使用FP8模式(减少显存占用)
109 | - 所有模型文件必须在首次使用前下载完成
110 |
111 | ## 📄 许可证
112 |
113 | 本项目采用Apache 2.0许可证。
114 |
115 | ## 🔗 参考链接
116 |
117 | - [USO项目页面](https://bytedance.github.io/USO/)
118 | - [USO论文](https://arxiv.org/abs/2508.18966)
119 | - [USO HuggingFace](https://huggingface.co/bytedance-research/USO)
120 | - [ComfyUI](https://github.com/comfyanonymous/ComfyUI)
121 |
122 | ## 🤝 贡献
123 |
124 | 欢迎贡献!请随时提交问题和拉取请求。
125 |
126 | ## ⭐ 引用
127 |
128 | 如果您觉得这个项目有用,请考虑引用原始USO论文:
129 |
130 | ```bibtex
131 | @article{wu2025uso,
132 | title={USO: Unified Style and Subject-Driven Generation via Disentangled and Reward Learning},
133 | author={Shaojin Wu and Mengqi Huang and Yufeng Cheng and Wenxu Wu and Jiahe Tian and Yiming Luo and Fei Ding and Qian He},
134 | year={2025},
135 | eprint={2508.18966},
136 | archivePrefix={arXiv},
137 | primaryClass={cs.CV},
138 | }
139 | ```
140 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI USO Node
2 |
3 | A custom node for ComfyUI that integrates USO (Unified Style and Subject-Driven Generation) for high-quality image generation with style and subject control.
4 |
5 | ## ✨ Features
6 |
7 | - 🎨 **Unified Style & Subject Generation**: Powered by USO model based on FLUX architecture
8 | - 🎯 **Style-Driven Generation**: Generate images with specific artistic styles
9 | - 👤 **Subject-Driven Generation**: Maintain subject consistency across generations
10 | - 🔄 **Multi-Style Support**: Combine multiple styles in a single generation
11 | - ⚙️ **Memory Optimization**: FP8 precision support for consumer-grade GPUs (~16GB VRAM)
12 | - 🚀 **Flexible Control**: Advanced parameter control for fine-tuning results
13 |
14 | ## 🔧 Node List
15 |
16 | ### Core Nodes
17 | - **RH_USO_Loader**: Load and initialize USO models with optimization options
18 | - **RH_USO_Generator**: Generate images with style and subject control
19 |
20 | ## 🚀 Quick Installation
21 |
22 | ### Step 1: Install the Node
23 | ```bash
24 | # Navigate to ComfyUI custom_nodes directory
25 | cd ComfyUI/custom_nodes
26 |
27 | # Clone the repository
28 | git clone https://github.com/HM-RunningHub/ComfyUI_RH_USO
29 |
30 | # Install dependencies
31 | cd ComfyUI_RH_USO
32 | pip install -r requirements.txt
33 | ```
34 |
35 | ### Step 2: Download Required Models
36 | ```bash
37 | # Download FLUX.1-dev model (Required base model)
38 | huggingface-cli download black-forest-labs/FLUX.1-dev flux1-dev.safetensors --local-dir models/diffusers/FLUX.1-dev
39 | huggingface-cli download black-forest-labs/FLUX.1-dev ae.safetensors --local-dir models/diffusers/FLUX.1-dev
40 |
41 | # Download USO model
42 | huggingface-cli download bytedance-research/USO --local-dir models/uso
43 |
44 | # Download SigLIP model
45 | huggingface-cli download google/siglip-so400m-patch14-384 --local-dir models/clip/siglip-so400m-patch14-384
46 |
47 | # Final model structure should look like:
48 | models/
49 | ├── diffusers/
50 | │ └── FLUX.1-dev/
51 | │ ├── flux1-dev.safetensors
52 | │ └── ae.safetensors
53 | │ └── ....
54 | ├── uso/
55 | │ ├── assets/
56 | │ │ └── uso.webp
57 | │ ├── config.json
58 | │ ├── download_repo_enhanced.py
59 | │ ├── README.md
60 | │ └── uso_flux_v1.0/
61 | │ ├── dit_lora.safetensors
62 | │ └── projector.safetensors
63 | └── clip/
64 | └── siglip-so400m-patch14-384/
65 |
66 | # Restart ComfyUI
67 | ```
68 |
69 | ## 📖 Usage
70 |
71 | ### Basic Workflow
72 | ```
73 | [RH_USO_Loader] → [RH_USO_Generator] → [Save Image]
74 | ```
75 |
76 | ### Generation Types
77 |
78 | #### Style-Driven Generation
79 | - Load style reference images
80 | - Input text prompt describing the content
81 | - Generate images in the specified style
82 |
83 | #### Subject-Driven Generation
84 | - Load subject reference image
85 | - Input text prompt with scene description
86 | - Generate images maintaining subject identity
87 |
88 | #### Style + Subject Generation
89 | - Load both style and subject reference images
90 | - Combine style transfer with subject consistency
91 | - Generate images with unified style and preserved subjects
92 |
93 | ## 🛠️ Technical Requirements
94 |
95 | - **GPU**: 16GB+ VRAM (with FP8 optimization)
96 | - **RAM**: 32GB+ recommended
97 | - **Storage**: ~35GB for all models
98 | - FLUX.1-dev: ~24GB (flux1-dev.safetensors + ae.safetensors)
99 | - USO models: ~6GB
100 | - SigLIP: ~1.5GB
101 | - **CUDA**: Required for optimal performance
102 |
103 | ## ⚠️ Important Notes
104 |
105 | - **Model Paths**: Models must be placed in specific directories:
106 | - FLUX.1-dev → `models/diffusers/FLUX.1-dev/`
107 | - USO models → `models/uso/`
108 | - SigLIP → `models/clip/siglip-so400m-patch14-384/`
109 | - FP8 mode recommended for consumer GPUs (reduces VRAM usage)
110 | - All model files must be downloaded before first use
111 |
112 | ## 📄 License
113 |
114 | This project is licensed under Apache 2.0 License.
115 |
116 | ## 🔗 References
117 |
118 | - [USO Project Page](https://bytedance.github.io/USO/)
119 | - [USO Paper](https://arxiv.org/abs/2508.18966)
120 | - [USO HuggingFace](https://huggingface.co/bytedance-research/USO)
121 | - [ComfyUI](https://github.com/comfyanonymous/ComfyUI)
122 |
123 | ## 🔗 Example
124 |
125 |
126 |
127 |
128 |
129 | ## 🤝 Contributing
130 |
131 | Contributions are welcome! Please feel free to submit issues and pull requests.
132 |
133 | ## ⭐ Citation
134 |
135 | If you find this project useful, please consider citing the original USO paper:
136 |
137 | ```bibtex
138 | @article{wu2025uso,
139 | title={USO: Unified Style and Subject-Driven Generation via Disentangled and Reward Learning},
140 | author={Shaojin Wu and Mengqi Huang and Yufeng Cheng and Wenxu Wu and Jiahe Tian and Yiming Luo and Fei Ding and Qian He},
141 | year={2025},
142 | eprint={2508.18966},
143 | archivePrefix={arXiv},
144 | primaryClass={cs.CV},
145 | }
146 | ```
147 |
--------------------------------------------------------------------------------
/rh_uso_nodes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import dataclasses
3 | from typing import Literal
4 |
5 | from accelerate import Accelerator
6 | from transformers import HfArgumentParser
7 | from PIL import Image
8 | import json
9 | import itertools
10 | import torch
11 |
12 | from .uso.flux.pipeline import USOPipeline, preprocess_ref
13 | from transformers import SiglipVisionModel, SiglipImageProcessor
14 | from tqdm import tqdm
15 | import folder_paths
16 | import numpy as np
17 | import comfy.utils
18 |
19 | class RH_USO_Loader:
20 | @classmethod
21 | def INPUT_TYPES(s):
22 | return {
23 | "required": {
24 |
25 | },
26 | }
27 |
28 | RETURN_TYPES = ("RHUSOMudules",)
29 | RETURN_NAMES = ("USO Modules",)
30 | FUNCTION = "load"
31 |
32 | CATEGORY = "Runninghub/USO"
33 |
34 | def load(self, **kwargs):
35 | # accelerator = Accelerator()
36 | device = 'cuda'
37 | siglip_path = os.path.join(folder_paths.models_dir, 'clip', 'siglip-so400m-patch14-384')
38 | siglip_processor = SiglipImageProcessor.from_pretrained(
39 | siglip_path
40 | )
41 | siglip_model = SiglipVisionModel.from_pretrained(
42 | siglip_path
43 | )
44 | siglip_model.eval()
45 | siglip_model.to(device)
46 | print("SigLIP model loaded successfully")
47 |
48 | # hardcode hyperparamters -kiki
49 | model_type = 'flux-dev-fp8'
50 | lora_rank = 128
51 |
52 | pipeline = USOPipeline(
53 | model_type,
54 | device,
55 | True, #args.offload,
56 | only_lora=True,
57 | lora_rank=lora_rank,
58 | hf_download=False,
59 | )
60 | if siglip_model is not None:
61 | pipeline.model.vision_encoder = siglip_model
62 | print('-----> hook siglip encoder')
63 | return ({'siglip_processor':siglip_processor, 'pipeline':pipeline}, )
64 |
65 | class RH_USO_Sampler:
66 |
67 | @classmethod
68 | def INPUT_TYPES(s):
69 | return {
70 | "required": {
71 | "uso": ("RHUSOMudules", ),
72 | "prompt": ("STRING", {"multiline": True,
73 | 'default': ''}),
74 | "width": ("INT", {"default": 1024}),
75 | "height": ("INT", {"default": 1024}),
76 | "num_inference_steps": ("INT", {"default": 25}),
77 | "guidance": ("FLOAT", {"default": 4.0}),
78 | "seed": ("INT", {"default": 20, "min": 0, "max": 0xffffffffffffffff,
79 | "tooltip": "The random seed used for creating the noise."}),
80 | },
81 | "optional": {
82 | "content_image": ("IMAGE", ),
83 | "style_image": ("IMAGE", ),
84 | "style2_image": ("IMAGE", ),
85 | }
86 | }
87 |
88 | RETURN_TYPES = ("IMAGE",)
89 | RETURN_NAMES = ("image",)
90 | FUNCTION = "sample"
91 |
92 | CATEGORY = "Runninghub/USO"
93 |
94 | def tensor_2_pil(self, img_tensor):
95 | if img_tensor is not None:
96 | i = 255. * img_tensor.squeeze().cpu().numpy()
97 | img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
98 | return img
99 | else:
100 | return None
101 |
102 | def preprocess_ref(self, raw_image: Image.Image, long_size: int = 512, scale_ratio: int = 1):
103 | # 获取原始图像的宽度和高度
104 | image_w, image_h = raw_image.size
105 | if image_w == image_h and image_w == 16:
106 | return raw_image
107 |
108 | # 计算长边和短边
109 | if image_w >= image_h:
110 | new_w = long_size
111 | new_h = int((long_size / image_w) * image_h)
112 | else:
113 | new_h = long_size
114 | new_w = int((long_size / image_h) * image_w)
115 |
116 | # 按新的宽高进行等比例缩放
117 | raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
118 |
119 | # 为了能让canny img进行scale
120 | scale_ratio = int(scale_ratio)
121 | target_w = new_w // (16 * scale_ratio) * (16 * scale_ratio)
122 | target_h = new_h // (16 * scale_ratio) * (16 * scale_ratio)
123 |
124 | # 计算裁剪的起始坐标以实现中心裁剪
125 | left = (new_w - target_w) // 2
126 | top = (new_h - target_h) // 2
127 | right = left + target_w
128 | bottom = top + target_h
129 |
130 | # 进行中心裁剪
131 | raw_image = raw_image.crop((left, top, right, bottom))
132 |
133 | # 转换为 RGB 模式
134 | raw_image = raw_image.convert("RGB")
135 | return raw_image
136 |
137 | def sample(self, **kwargs):
138 | ref_imgs = []
139 | content_image = self.tensor_2_pil(kwargs.get('content_image', None))
140 | style_image = self.tensor_2_pil(kwargs.get('style_image', None))
141 | style2_image = self.tensor_2_pil(kwargs.get('style2_image', None))
142 | print(f'conds-c/s1/s2:{content_image is not None} {style_image is not None} {style2_image is not None}')
143 | ref_imgs.append(content_image)
144 | if style_image is not None:
145 | ref_imgs.append(style_image)
146 | if style2_image is not None:
147 | ref_imgs.append(style2_image)
148 | siglip_inputs = None
149 |
150 | width = kwargs.get('width')
151 | height = kwargs.get('height')
152 | prompt = kwargs.get('prompt')
153 | guidance = kwargs.get('guidance')
154 | num_steps = kwargs.get('num_inference_steps')
155 | seed = kwargs.get('seed') % (2 ** 32)
156 |
157 | # hardcode hyperparameters -kiki
158 | content_ref = 512
159 | pe = 'd'
160 |
161 | uso = kwargs.get('uso')
162 | siglip_processor = uso['siglip_processor']
163 | pipeline = uso['pipeline']
164 | with torch.no_grad():
165 | siglip_inputs = [
166 | siglip_processor(img, return_tensors="pt").to(pipeline.device)
167 | for img in ref_imgs[1:] if isinstance(img, Image.Image)
168 | ]
169 |
170 | ref_imgs_pil = [
171 | self.preprocess_ref(img, content_ref) for img in ref_imgs[:1] if isinstance(img, Image.Image)
172 | ]
173 | self.pbar = comfy.utils.ProgressBar(num_steps)
174 |
175 | image_gen = pipeline(
176 | prompt=prompt,
177 | width=width,
178 | height=height,
179 | guidance=guidance,
180 | num_steps=num_steps,
181 | seed=seed,
182 | ref_imgs=ref_imgs_pil,
183 | pe=pe,
184 | siglip_inputs=siglip_inputs,
185 | update_func=self.update,
186 | )
187 |
188 | image = np.array(image_gen).astype(np.float32) / 255.0
189 | image = torch.from_numpy(image)[None,]
190 |
191 | return (image, )
192 |
193 | def update(self):
194 | self.pbar.update(1)
195 |
196 |
197 | NODE_CLASS_MAPPINGS = {
198 | "RunningHub USO Loader": RH_USO_Loader,
199 | "RunningHub USO Sampler":RH_USO_Sampler,
200 | }
201 |
202 | NODE_DISPLAY_NAME_MAPPINGS = {
203 | "RunningHub USO Loader": "RunningHub USO Loader",
204 | "RunningHub USO Sampler": "RunningHub USO Sampler",
205 | }
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import dataclasses
17 | from typing import Literal
18 |
19 | from accelerate import Accelerator
20 | from transformers import HfArgumentParser
21 | from PIL import Image
22 | import json
23 | import itertools
24 | import torch
25 |
26 | from uso.flux.pipeline import USOPipeline, preprocess_ref
27 | from transformers import SiglipVisionModel, SiglipImageProcessor
28 | from tqdm import tqdm
29 |
30 |
31 | def horizontal_concat(images):
32 | widths, heights = zip(*(img.size for img in images))
33 |
34 | total_width = sum(widths)
35 | max_height = max(heights)
36 |
37 | new_im = Image.new("RGB", (total_width, max_height))
38 |
39 | x_offset = 0
40 | for img in images:
41 | new_im.paste(img, (x_offset, 0))
42 | x_offset += img.size[0]
43 |
44 | return new_im
45 |
46 |
47 | @dataclasses.dataclass
48 | class InferenceArgs:
49 | prompt: str | None = None
50 | image_paths: list[str] | None = None
51 | eval_json_path: str | None = None
52 | # offload: bool = False
53 | offload: bool = True
54 | num_images_per_prompt: int = 1
55 | model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev-fp8"
56 | width: int = 1024
57 | height: int = 1024
58 | num_steps: int = 25
59 | guidance: float = 4
60 | seed: int = 3407
61 | save_path: str = "output/inference"
62 | only_lora: bool = True
63 | concat_refs: bool = False
64 | lora_rank: int = 128
65 | pe: Literal["d", "h", "w", "o"] = "d"
66 | content_ref: int = 512
67 | ckpt_path: str | None = None
68 | use_siglip: bool = True
69 | instruct_edit: bool = False
70 | hf_download: bool = True
71 |
72 |
73 | def main(args: InferenceArgs):
74 | accelerator = Accelerator()
75 |
76 | # init SigLIP model
77 | siglip_processor = None
78 | siglip_model = None
79 |
80 | siglip_path = '/workspace/comfyui/models/clip/siglip-so400m-patch14-384'
81 | if args.use_siglip:
82 | siglip_processor = SiglipImageProcessor.from_pretrained(
83 | # "google/siglip-so400m-patch14-384"
84 | siglip_path
85 | )
86 | siglip_model = SiglipVisionModel.from_pretrained(
87 | # "google/siglip-so400m-patch14-384"
88 | siglip_path
89 | )
90 | siglip_model.eval()
91 | siglip_model.to(accelerator.device)
92 | print("SigLIP model loaded successfully")
93 |
94 | pipeline = USOPipeline(
95 | args.model_type,
96 | accelerator.device,
97 | args.offload,
98 | only_lora=args.only_lora,
99 | lora_rank=args.lora_rank,
100 | hf_download=args.hf_download,
101 | )
102 | if args.use_siglip and siglip_model is not None:
103 | pipeline.model.vision_encoder = siglip_model
104 | print('-----> hook siglip encoder')
105 |
106 | assert (
107 | args.prompt is not None or args.eval_json_path is not None
108 | ), "Please provide either prompt or eval_json_path"
109 |
110 | if args.eval_json_path is not None:
111 | with open(args.eval_json_path, "rt") as f:
112 | data_dicts = json.load(f)
113 | data_root = os.path.dirname(args.eval_json_path)
114 | else:
115 | data_root = ""
116 | data_dicts = [{"prompt": args.prompt, "image_paths": args.image_paths}]
117 |
118 | print(
119 | f"process: {accelerator.num_processes}/{accelerator.process_index}, \
120 | process images: {len(data_dicts)}/{len(data_dicts[accelerator.process_index::accelerator.num_processes])}"
121 | )
122 |
123 | data_dicts = data_dicts[accelerator.process_index :: accelerator.num_processes]
124 |
125 | accelerator.wait_for_everyone()
126 | local_task_count = len(data_dicts) * args.num_images_per_prompt
127 | if accelerator.is_main_process:
128 | progress_bar = tqdm(total=local_task_count, desc="Generating Images")
129 |
130 | for (i, data_dict), j in itertools.product(
131 | enumerate(data_dicts), range(args.num_images_per_prompt)
132 | ):
133 | ref_imgs = []
134 | for _, img_path in enumerate(data_dict["image_paths"]):
135 | if img_path != "":
136 | img = Image.open(os.path.join(data_root, img_path)).convert("RGB")
137 | ref_imgs.append(img)
138 | else:
139 | ref_imgs.append(None)
140 | siglip_inputs = None
141 | if args.use_siglip and siglip_processor is not None:
142 | with torch.no_grad():
143 | siglip_inputs = [
144 | siglip_processor(img, return_tensors="pt").to(pipeline.device)
145 | for img in ref_imgs[1:] if isinstance(img, Image.Image)
146 | ]
147 |
148 | ref_imgs_pil = [
149 | preprocess_ref(img, args.content_ref) for img in ref_imgs[:1] if isinstance(img, Image.Image)
150 | ]
151 |
152 | if args.instruct_edit:
153 | args.width, args.height = ref_imgs_pil[0].size
154 | args.width, args.height = args.width * (1024 / args.content_ref), args.height * (1024 / args.content_ref)
155 | image_gen = pipeline(
156 | prompt=data_dict["prompt"],
157 | width=args.width,
158 | height=args.height,
159 | guidance=args.guidance,
160 | num_steps=args.num_steps,
161 | seed=args.seed + j,
162 | ref_imgs=ref_imgs_pil,
163 | pe=args.pe,
164 | siglip_inputs=siglip_inputs,
165 | )
166 | if args.concat_refs:
167 | image_gen = horizontal_concat([image_gen, *ref_imgs])
168 |
169 | if "save_dir" in data_dict:
170 | config_save_path = os.path.join(args.save_path, data_dict["save_dir"] + f"_{j}.json")
171 | image_save_path = os.path.join(args.save_path, data_dict["save_dir"] + f"_{j}.png")
172 | else:
173 | os.makedirs(args.save_path, exist_ok=True)
174 | config_save_path = os.path.join(args.save_path, f"{i}_{j}.json")
175 | image_save_path = os.path.join(args.save_path, f"{i}_{j}.png")
176 |
177 | # save config and image
178 | os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
179 | image_gen.save(image_save_path)
180 | # ensure the prompt and image_paths are saved in the config file
181 | args.prompt = data_dict["prompt"]
182 | args.image_paths = data_dict["image_paths"]
183 | args_dict = vars(args)
184 | with open(config_save_path, "w") as f:
185 | json.dump(args_dict, f, indent=4)
186 |
187 | if accelerator.is_main_process:
188 | progress_bar.update(1)
189 | if accelerator.is_main_process:
190 | progress_bar.close()
191 |
192 |
193 | if __name__ == "__main__":
194 | parser = HfArgumentParser([InferenceArgs])
195 | args = parser.parse_args_into_dataclasses()[0]
196 | main(args)
197 |
--------------------------------------------------------------------------------
/uso/flux/sampling.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import math
17 | from typing import Literal
18 |
19 | import torch
20 | from einops import rearrange, repeat
21 | from torch import Tensor
22 | from tqdm import tqdm
23 |
24 | from .model import Flux
25 | from .modules.conditioner import HFEmbedder
26 |
27 |
28 | def get_noise(
29 | num_samples: int,
30 | height: int,
31 | width: int,
32 | device: torch.device,
33 | dtype: torch.dtype,
34 | seed: int,
35 | ):
36 | return torch.randn(
37 | num_samples,
38 | 16,
39 | # allow for packing
40 | 2 * math.ceil(height / 16),
41 | 2 * math.ceil(width / 16),
42 | device=device,
43 | dtype=dtype,
44 | generator=torch.Generator(device=device).manual_seed(seed),
45 | )
46 |
47 |
48 | def prepare(
49 | t5: HFEmbedder,
50 | clip: HFEmbedder,
51 | img: Tensor,
52 | prompt: str | list[str],
53 | ref_img: None | Tensor = None,
54 | pe: Literal["d", "h", "w", "o"] = "d",
55 | ) -> dict[str, Tensor]:
56 | assert pe in ["d", "h", "w", "o"]
57 | bs, c, h, w = img.shape
58 | if bs == 1 and not isinstance(prompt, str):
59 | bs = len(prompt)
60 |
61 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
62 | if img.shape[0] == 1 and bs > 1:
63 | img = repeat(img, "1 ... -> bs ...", bs=bs)
64 |
65 | img_ids = torch.zeros(h // 2, w // 2, 3)
66 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
67 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
68 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
69 |
70 | if ref_img is not None:
71 | _, _, ref_h, ref_w = ref_img.shape
72 | ref_img = rearrange(
73 | ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2
74 | )
75 | if ref_img.shape[0] == 1 and bs > 1:
76 | ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
77 | ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
78 | # img id分别在宽高偏移各自最大值
79 | h_offset = h // 2 if pe in {"d", "h"} else 0
80 | w_offset = w // 2 if pe in {"d", "w"} else 0
81 | ref_img_ids[..., 1] = (
82 | ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + h_offset
83 | )
84 | ref_img_ids[..., 2] = (
85 | ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + w_offset
86 | )
87 | ref_img_ids = repeat(ref_img_ids, "h w c -> b (h w) c", b=bs)
88 |
89 | if isinstance(prompt, str):
90 | prompt = [prompt]
91 | txt = t5(prompt)
92 | if txt.shape[0] == 1 and bs > 1:
93 | txt = repeat(txt, "1 ... -> bs ...", bs=bs)
94 | txt_ids = torch.zeros(bs, txt.shape[1], 3)
95 |
96 | vec = clip(prompt)
97 | if vec.shape[0] == 1 and bs > 1:
98 | vec = repeat(vec, "1 ... -> bs ...", bs=bs)
99 |
100 | if ref_img is not None:
101 | return {
102 | "img": img,
103 | "img_ids": img_ids.to(img.device),
104 | "ref_img": ref_img,
105 | "ref_img_ids": ref_img_ids.to(img.device),
106 | "txt": txt.to(img.device),
107 | "txt_ids": txt_ids.to(img.device),
108 | "vec": vec.to(img.device),
109 | }
110 | else:
111 | return {
112 | "img": img,
113 | "img_ids": img_ids.to(img.device),
114 | "txt": txt.to(img.device),
115 | "txt_ids": txt_ids.to(img.device),
116 | "vec": vec.to(img.device),
117 | }
118 |
119 |
120 | def prepare_multi_ip(
121 | t5: HFEmbedder,
122 | clip: HFEmbedder,
123 | img: Tensor,
124 | prompt: str | list[str],
125 | ref_imgs: list[Tensor] | None = None,
126 | pe: Literal["d", "h", "w", "o"] = "d",
127 | ) -> dict[str, Tensor]:
128 | assert pe in ["d", "h", "w", "o"]
129 | bs, c, h, w = img.shape
130 | if bs == 1 and not isinstance(prompt, str):
131 | bs = len(prompt)
132 |
133 | # tgt img
134 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
135 | if img.shape[0] == 1 and bs > 1:
136 | img = repeat(img, "1 ... -> bs ...", bs=bs)
137 |
138 | img_ids = torch.zeros(h // 2, w // 2, 3)
139 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
140 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
141 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
142 |
143 | ref_img_ids = []
144 | ref_imgs_list = []
145 |
146 | pe_shift_w, pe_shift_h = w // 2, h // 2
147 | for ref_img in ref_imgs:
148 | _, _, ref_h1, ref_w1 = ref_img.shape
149 | ref_img = rearrange(
150 | ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2
151 | )
152 | if ref_img.shape[0] == 1 and bs > 1:
153 | ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
154 | ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3)
155 | # img id分别在宽高偏移各自最大值
156 | h_offset = pe_shift_h if pe in {"d", "h"} else 0
157 | w_offset = pe_shift_w if pe in {"d", "w"} else 0
158 | ref_img_ids1[..., 1] = (
159 | ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset
160 | )
161 | ref_img_ids1[..., 2] = (
162 | ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset
163 | )
164 | ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs)
165 | ref_img_ids.append(ref_img_ids1)
166 | ref_imgs_list.append(ref_img)
167 |
168 | # 更新pe shift
169 | pe_shift_h += ref_h1 // 2
170 | pe_shift_w += ref_w1 // 2
171 |
172 | if isinstance(prompt, str):
173 | prompt = [prompt]
174 | txt = t5(prompt)
175 | if txt.shape[0] == 1 and bs > 1:
176 | txt = repeat(txt, "1 ... -> bs ...", bs=bs)
177 | txt_ids = torch.zeros(bs, txt.shape[1], 3)
178 |
179 | vec = clip(prompt)
180 | if vec.shape[0] == 1 and bs > 1:
181 | vec = repeat(vec, "1 ... -> bs ...", bs=bs)
182 |
183 | return {
184 | "img": img,
185 | "img_ids": img_ids.to(img.device),
186 | "ref_img": tuple(ref_imgs_list),
187 | "ref_img_ids": [ref_img_id.to(img.device) for ref_img_id in ref_img_ids],
188 | "txt": txt.to(img.device),
189 | "txt_ids": txt_ids.to(img.device),
190 | "vec": vec.to(img.device),
191 | }
192 |
193 |
194 | def time_shift(mu: float, sigma: float, t: Tensor):
195 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
196 |
197 |
198 | def get_lin_function(
199 | x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
200 | ):
201 | m = (y2 - y1) / (x2 - x1)
202 | b = y1 - m * x1
203 | return lambda x: m * x + b
204 |
205 |
206 | def get_schedule(
207 | num_steps: int,
208 | image_seq_len: int,
209 | base_shift: float = 0.5,
210 | max_shift: float = 1.15,
211 | shift: bool = True,
212 | ) -> list[float]:
213 | # extra step for zero
214 | timesteps = torch.linspace(1, 0, num_steps + 1)
215 |
216 | # shifting the schedule to favor high timesteps for higher signal images
217 | if shift:
218 | # eastimate mu based on linear estimation between two points
219 | mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
220 | timesteps = time_shift(mu, 1.0, timesteps)
221 |
222 | return timesteps.tolist()
223 |
224 |
225 | def denoise(
226 | model: Flux,
227 | # model input
228 | img: Tensor,
229 | img_ids: Tensor,
230 | txt: Tensor,
231 | txt_ids: Tensor,
232 | vec: Tensor,
233 | # sampling parameters
234 | timesteps: list[float],
235 | guidance: float = 4.0,
236 | ref_img: Tensor = None,
237 | ref_img_ids: Tensor = None,
238 | siglip_inputs: list[Tensor] | None = None,
239 | #kiki
240 | update_func = None,
241 | ):
242 | i = 0
243 | guidance_vec = torch.full(
244 | (img.shape[0],), guidance, device=img.device, dtype=img.dtype
245 | )
246 | for t_curr, t_prev in tqdm(
247 | zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1
248 | ):
249 | if update_func is not None:
250 | update_func()
251 | # for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
252 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
253 | pred = model(
254 | img=img,
255 | img_ids=img_ids,
256 | ref_img=ref_img,
257 | ref_img_ids=ref_img_ids,
258 | txt=txt,
259 | txt_ids=txt_ids,
260 | y=vec,
261 | timesteps=t_vec,
262 | guidance=guidance_vec,
263 | siglip_inputs=siglip_inputs,
264 | )
265 | img = img + (t_prev - t_curr) * pred
266 | i += 1
267 | return img
268 |
269 |
270 | def unpack(x: Tensor, height: int, width: int) -> Tensor:
271 | return rearrange(
272 | x,
273 | "b (h w) (c ph pw) -> b c (h ph) (w pw)",
274 | h=math.ceil(height / 16),
275 | w=math.ceil(width / 16),
276 | ph=2,
277 | pw=2,
278 | )
279 |
--------------------------------------------------------------------------------
/uso/flux/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from dataclasses import dataclass
17 |
18 | import torch
19 | from torch import Tensor, nn
20 |
21 | from .modules.layers import (
22 | DoubleStreamBlock,
23 | EmbedND,
24 | LastLayer,
25 | MLPEmbedder,
26 | SingleStreamBlock,
27 | timestep_embedding,
28 | SigLIPMultiFeatProjModel,
29 | )
30 | import os
31 |
32 |
33 | @dataclass
34 | class FluxParams:
35 | in_channels: int
36 | vec_in_dim: int
37 | context_in_dim: int
38 | hidden_size: int
39 | mlp_ratio: float
40 | num_heads: int
41 | depth: int
42 | depth_single_blocks: int
43 | axes_dim: list[int]
44 | theta: int
45 | qkv_bias: bool
46 | guidance_embed: bool
47 |
48 |
49 | class Flux(nn.Module):
50 | """
51 | Transformer model for flow matching on sequences.
52 | """
53 |
54 | _supports_gradient_checkpointing = True
55 |
56 | def __init__(self, params: FluxParams):
57 | super().__init__()
58 |
59 | self.params = params
60 | self.in_channels = params.in_channels
61 | self.out_channels = self.in_channels
62 | if params.hidden_size % params.num_heads != 0:
63 | raise ValueError(
64 | f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
65 | )
66 | pe_dim = params.hidden_size // params.num_heads
67 | if sum(params.axes_dim) != pe_dim:
68 | raise ValueError(
69 | f"Got {params.axes_dim} but expected positional dim {pe_dim}"
70 | )
71 | self.hidden_size = params.hidden_size
72 | self.num_heads = params.num_heads
73 | self.pe_embedder = EmbedND(
74 | dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
75 | )
76 | self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
77 | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
78 | self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
79 | self.guidance_in = (
80 | MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
81 | if params.guidance_embed
82 | else nn.Identity()
83 | )
84 | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
85 |
86 | self.double_blocks = nn.ModuleList(
87 | [
88 | DoubleStreamBlock(
89 | self.hidden_size,
90 | self.num_heads,
91 | mlp_ratio=params.mlp_ratio,
92 | qkv_bias=params.qkv_bias,
93 | )
94 | for _ in range(params.depth)
95 | ]
96 | )
97 |
98 | self.single_blocks = nn.ModuleList(
99 | [
100 | SingleStreamBlock(
101 | self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
102 | )
103 | for _ in range(params.depth_single_blocks)
104 | ]
105 | )
106 |
107 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
108 | self.gradient_checkpointing = False
109 |
110 | # feature embedder for siglip multi-feat inputs
111 | self.feature_embedder = SigLIPMultiFeatProjModel(
112 | siglip_token_nums=729,
113 | style_token_nums=64,
114 | siglip_token_dims=1152,
115 | hidden_size=self.hidden_size,
116 | context_layer_norm=True,
117 | )
118 | print("use semantic encoder siglip multi-feat to encode style image")
119 |
120 | self.vision_encoder = None
121 |
122 | def _set_gradient_checkpointing(self, module, value=False):
123 | if hasattr(module, "gradient_checkpointing"):
124 | module.gradient_checkpointing = value
125 |
126 | @property
127 | def attn_processors(self):
128 | # set recursively
129 | processors = {} # type: dict[str, nn.Module]
130 |
131 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
132 | if hasattr(module, "set_processor"):
133 | processors[f"{name}.processor"] = module.processor
134 |
135 | for sub_name, child in module.named_children():
136 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
137 |
138 | return processors
139 |
140 | for name, module in self.named_children():
141 | fn_recursive_add_processors(name, module, processors)
142 |
143 | return processors
144 |
145 | def set_attn_processor(self, processor):
146 | r"""
147 | Sets the attention processor to use to compute attention.
148 |
149 | Parameters:
150 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
151 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
152 | for **all** `Attention` layers.
153 |
154 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
155 | processor. This is strongly recommended when setting trainable attention processors.
156 |
157 | """
158 | count = len(self.attn_processors.keys())
159 |
160 | if isinstance(processor, dict) and len(processor) != count:
161 | raise ValueError(
162 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
163 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
164 | )
165 |
166 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
167 | if hasattr(module, "set_processor"):
168 | if not isinstance(processor, dict):
169 | module.set_processor(processor)
170 | else:
171 | module.set_processor(processor.pop(f"{name}.processor"))
172 |
173 | for sub_name, child in module.named_children():
174 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
175 |
176 | for name, module in self.named_children():
177 | fn_recursive_attn_processor(name, module, processor)
178 |
179 | def forward(
180 | self,
181 | img: Tensor,
182 | img_ids: Tensor,
183 | txt: Tensor,
184 | txt_ids: Tensor,
185 | timesteps: Tensor,
186 | y: Tensor,
187 | guidance: Tensor | None = None,
188 | ref_img: Tensor | None = None,
189 | ref_img_ids: Tensor | None = None,
190 | siglip_inputs: list[Tensor] | None = None,
191 | ) -> Tensor:
192 | if img.ndim != 3 or txt.ndim != 3:
193 | raise ValueError("Input img and txt tensors must have 3 dimensions.")
194 |
195 | # running on sequences img
196 | img = self.img_in(img)
197 | vec = self.time_in(timestep_embedding(timesteps, 256))
198 | if self.params.guidance_embed:
199 | if guidance is None:
200 | raise ValueError(
201 | "Didn't get guidance strength for guidance distilled model."
202 | )
203 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
204 | vec = vec + self.vector_in(y)
205 | txt = self.txt_in(txt)
206 | if self.feature_embedder is not None and siglip_inputs is not None and len(siglip_inputs) > 0 and self.vision_encoder is not None:
207 | # processing style feat into textural hidden space
208 | siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in siglip_inputs]
209 | # siglip_embedding = [self.vision_encoder(**(emb.to(torch.bfloat16)), output_hidden_states=True) for emb in siglip_inputs]
210 | siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1)
211 | txt = torch.cat((siglip_embedding, txt), dim=1)
212 | siglip_embedding_ids = torch.zeros(
213 | siglip_embedding.shape[0], siglip_embedding.shape[1], 3
214 | ).to(txt_ids.device)
215 | txt_ids = torch.cat((siglip_embedding_ids, txt_ids), dim=1)
216 |
217 | ids = torch.cat((txt_ids, img_ids), dim=1)
218 |
219 | # concat ref_img/img
220 | img_end = img.shape[1]
221 | if ref_img is not None:
222 | if isinstance(ref_img, tuple) or isinstance(ref_img, list):
223 | img_in = [img] + [self.img_in(ref) for ref in ref_img]
224 | img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids]
225 | img = torch.cat(img_in, dim=1)
226 | ids = torch.cat(img_ids, dim=1)
227 | else:
228 | img = torch.cat((img, self.img_in(ref_img)), dim=1)
229 | ids = torch.cat((ids, ref_img_ids), dim=1)
230 | pe = self.pe_embedder(ids)
231 |
232 | for index_block, block in enumerate(self.double_blocks):
233 | if self.training and self.gradient_checkpointing:
234 | img, txt = torch.utils.checkpoint.checkpoint(
235 | block,
236 | img=img,
237 | txt=txt,
238 | vec=vec,
239 | pe=pe,
240 | use_reentrant=False,
241 | )
242 | else:
243 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
244 |
245 | img = torch.cat((txt, img), 1)
246 | for block in self.single_blocks:
247 | if self.training and self.gradient_checkpointing:
248 | img = torch.utils.checkpoint.checkpoint(
249 | block, img, vec=vec, pe=pe, use_reentrant=False
250 | )
251 | else:
252 | img = block(img, vec=vec, pe=pe)
253 | img = img[:, txt.shape[1] :, ...]
254 | # index img
255 | img = img[:, :img_end, ...]
256 |
257 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
258 | return img
259 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import dataclasses
16 | import json
17 | import os
18 | from pathlib import Path
19 |
20 | import gradio as gr
21 | import torch
22 |
23 | from uso.flux.pipeline import USOPipeline
24 | from transformers import SiglipVisionModel, SiglipImageProcessor
25 |
26 |
27 | with open("assets/uso_text.svg", "r", encoding="utf-8") as svg_file:
28 | text_content = svg_file.read()
29 |
30 | with open("assets/uso_logo.svg", "r", encoding="utf-8") as svg_file:
31 | logo_content = svg_file.read()
32 |
33 | title = f"""
34 |
35 | {text_content}
36 | by UXO Team
37 | {logo_content}
38 |
39 | """.strip()
40 |
41 | badges_text = r"""
42 |
48 | """.strip()
49 |
50 | tips = """
51 | **What is USO?** 🎨
52 | USO is a unified style-subject optimized customization model and the latest addition to the UXO family ( USO and UNO).
53 | It can freely combine any subjects with any styles in any scenarios.
54 |
55 | **How to use?** 💡
56 | We provide step-by-step instructions in our Github Repo.
57 | Additionally, try the examples provided below the demo to quickly get familiar with USO and spark your creativity!
58 |
59 |
60 | The model is trained on 1024x1024 resolution and supports 3 types of usage. 📌 Tips:
61 |
62 | * **Only content img**: support following types:
63 | * Subject/Identity-driven (supports natural prompt, e.g., *A clock on the table.* *The woman near the sea.*, excels in producing **photorealistic portraits**)
64 | * Style edit (layout-preserved): *Transform the image into Ghibli style/Pixel style/Retro comic style/Watercolor painting style...*.
65 | * Style edit (layout-shift): *Ghibli style, the man on the beach.*.
66 | * **Only style img**: Reference input style and generate anything following prompt. Excelling in this and further support multiple style references (in beta).
67 | * **Content img + style img**: Place the content into the desired style.
68 | * Layout-preserved: set prompt to **empty**.
69 | * Layout-shift: using natural prompt. """
70 |
71 | star = r"""
72 | If USO is helpful, please help to ⭐ our Github Repo. Thanks a lot!"""
73 |
74 | def get_examples(examples_dir: str = "assets/examples") -> list:
75 | examples = Path(examples_dir)
76 | ans = []
77 | for example in examples.iterdir():
78 | if not example.is_dir() or len(os.listdir(example)) == 0:
79 | continue
80 | with open(example / "config.json") as f:
81 | example_dict = json.load(f)
82 |
83 |
84 | example_list = []
85 | example_list.append(example_dict["prompt"]) # prompt
86 |
87 | for key in ["image_ref1", "image_ref2", "image_ref3"]:
88 | if key in example_dict:
89 | example_list.append(str(example / example_dict[key]))
90 | else:
91 | example_list.append(None)
92 |
93 | example_list.append(example_dict["seed"])
94 | ans.append(example_list)
95 | return ans
96 |
97 |
98 | def create_demo(
99 | model_type: str,
100 | device: str = "cuda" if torch.cuda.is_available() else "cpu",
101 | offload: bool = False,
102 | ):
103 | pipeline = USOPipeline(
104 | model_type, device, offload, only_lora=True, lora_rank=128, hf_download=True
105 | )
106 | print("USOPipeline loaded successfully")
107 |
108 | siglip_processor = SiglipImageProcessor.from_pretrained(
109 | "google/siglip-so400m-patch14-384"
110 | )
111 | siglip_model = SiglipVisionModel.from_pretrained(
112 | "google/siglip-so400m-patch14-384"
113 | )
114 | siglip_model.eval()
115 | siglip_model.to(device)
116 | pipeline.model.vision_encoder = siglip_model
117 | pipeline.model.vision_encoder_processor = siglip_processor
118 | print("SigLIP model loaded successfully")
119 |
120 | with gr.Blocks() as demo:
121 | gr.Markdown(title)
122 | gr.Markdown(badges_text)
123 | gr.Markdown(tips)
124 | with gr.Row():
125 | with gr.Column():
126 | prompt = gr.Textbox(label="Prompt", value="A beautiful woman.")
127 | with gr.Row():
128 | image_prompt1 = gr.Image(
129 | label="Content Reference Img", visible=True, interactive=True, type="pil"
130 | )
131 | image_prompt2 = gr.Image(
132 | label="Style Reference Img", visible=True, interactive=True, type="pil"
133 | )
134 | image_prompt3 = gr.Image(
135 | label="Extra Style Reference Img (Beta)", visible=True, interactive=True, type="pil"
136 | )
137 |
138 | with gr.Row():
139 | with gr.Row():
140 | width = gr.Slider(
141 | 512, 1536, 1024, step=16, label="Generation Width"
142 | )
143 | height = gr.Slider(
144 | 512, 1536, 1024, step=16, label="Generation Height"
145 | )
146 | with gr.Row():
147 | with gr.Row():
148 | keep_size = gr.Checkbox(
149 | label="Keep input size",
150 | value=False,
151 | interactive=True
152 | )
153 | with gr.Column():
154 | gr.Markdown("Set it to True if you only need style editing or want to keep the layout.")
155 |
156 | with gr.Accordion("Advanced Options", open=True):
157 | with gr.Row():
158 | num_steps = gr.Slider(
159 | 1, 50, 25, step=1, label="Number of steps"
160 | )
161 | guidance = gr.Slider(
162 | 1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True
163 | )
164 | content_long_size = gr.Slider(
165 | 0, 1024, 512, step=16, label="Content reference size"
166 | )
167 | seed = gr.Number(-1, label="Seed (-1 for random)")
168 |
169 | generate_btn = gr.Button("Generate")
170 | gr.Markdown(star)
171 |
172 | with gr.Column():
173 | output_image = gr.Image(label="Generated Image")
174 | download_btn = gr.File(
175 | label="Download full-resolution", type="filepath", interactive=False
176 | )
177 |
178 | inputs = [
179 | prompt,
180 | image_prompt1,
181 | image_prompt2,
182 | image_prompt3,
183 | seed,
184 | width,
185 | height,
186 | guidance,
187 | num_steps,
188 | keep_size,
189 | content_long_size,
190 | ]
191 | generate_btn.click(
192 | fn=pipeline.gradio_generate,
193 | inputs=inputs,
194 | outputs=[output_image, download_btn],
195 | )
196 |
197 | # example_text = gr.Text("", visible=False, label="Case For:")
198 | examples = get_examples("./assets/gradio_examples")
199 |
200 | gr.Examples(
201 | examples=examples,
202 | inputs=[
203 | prompt,
204 | image_prompt1,
205 | image_prompt2,
206 | image_prompt3,
207 | seed,
208 | ],
209 | # cache_examples='lazy',
210 | outputs=[output_image, download_btn],
211 | fn=pipeline.gradio_generate,
212 | label='row 1-4: identity/subject-driven; row 5-7: style-subject-driven; row 8-9: style-driven; row 10-12: multi-style-driven task; row 13: txt2img',
213 | examples_per_page=15
214 | )
215 |
216 | return demo
217 |
218 |
219 | if __name__ == "__main__":
220 | from typing import Literal
221 |
222 | from transformers import HfArgumentParser
223 |
224 | @dataclasses.dataclass
225 | class AppArgs:
226 | name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell", "flux-krea-dev"] = "flux-dev"
227 | device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu"
228 | offload: bool = dataclasses.field(
229 | default=False,
230 | metadata={
231 | "help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."
232 | },
233 | )
234 | port: int = 7860
235 |
236 | parser = HfArgumentParser([AppArgs])
237 | args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
238 | args = args_tuple[0]
239 |
240 | demo = create_demo(args.name, args.device, args.offload)
241 | demo.launch(server_port=args.port)
242 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/uso/flux/modules/autoencoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from dataclasses import dataclass
17 |
18 | import torch
19 | from einops import rearrange
20 | from torch import Tensor, nn
21 |
22 |
23 | @dataclass
24 | class AutoEncoderParams:
25 | resolution: int
26 | in_channels: int
27 | ch: int
28 | out_ch: int
29 | ch_mult: list[int]
30 | num_res_blocks: int
31 | z_channels: int
32 | scale_factor: float
33 | shift_factor: float
34 |
35 |
36 | def swish(x: Tensor) -> Tensor:
37 | return x * torch.sigmoid(x)
38 |
39 |
40 | class AttnBlock(nn.Module):
41 | def __init__(self, in_channels: int):
42 | super().__init__()
43 | self.in_channels = in_channels
44 |
45 | self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
46 |
47 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
48 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
49 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
50 | self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
51 |
52 | def attention(self, h_: Tensor) -> Tensor:
53 | h_ = self.norm(h_)
54 | q = self.q(h_)
55 | k = self.k(h_)
56 | v = self.v(h_)
57 |
58 | b, c, h, w = q.shape
59 | q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
60 | k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
61 | v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
62 | h_ = nn.functional.scaled_dot_product_attention(q, k, v)
63 |
64 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
65 |
66 | def forward(self, x: Tensor) -> Tensor:
67 | return x + self.proj_out(self.attention(x))
68 |
69 |
70 | class ResnetBlock(nn.Module):
71 | def __init__(self, in_channels: int, out_channels: int):
72 | super().__init__()
73 | self.in_channels = in_channels
74 | out_channels = in_channels if out_channels is None else out_channels
75 | self.out_channels = out_channels
76 |
77 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
79 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
80 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
81 | if self.in_channels != self.out_channels:
82 | self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
83 |
84 | def forward(self, x):
85 | h = x
86 | h = self.norm1(h)
87 | h = swish(h)
88 | h = self.conv1(h)
89 |
90 | h = self.norm2(h)
91 | h = swish(h)
92 | h = self.conv2(h)
93 |
94 | if self.in_channels != self.out_channels:
95 | x = self.nin_shortcut(x)
96 |
97 | return x + h
98 |
99 |
100 | class Downsample(nn.Module):
101 | def __init__(self, in_channels: int):
102 | super().__init__()
103 | # no asymmetric padding in torch conv, must do it ourselves
104 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
105 |
106 | def forward(self, x: Tensor):
107 | pad = (0, 1, 0, 1)
108 | x = nn.functional.pad(x, pad, mode="constant", value=0)
109 | x = self.conv(x)
110 | return x
111 |
112 |
113 | class Upsample(nn.Module):
114 | def __init__(self, in_channels: int):
115 | super().__init__()
116 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
117 |
118 | def forward(self, x: Tensor):
119 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
120 | x = self.conv(x)
121 | return x
122 |
123 |
124 | class Encoder(nn.Module):
125 | def __init__(
126 | self,
127 | resolution: int,
128 | in_channels: int,
129 | ch: int,
130 | ch_mult: list[int],
131 | num_res_blocks: int,
132 | z_channels: int,
133 | ):
134 | super().__init__()
135 | self.ch = ch
136 | self.num_resolutions = len(ch_mult)
137 | self.num_res_blocks = num_res_blocks
138 | self.resolution = resolution
139 | self.in_channels = in_channels
140 | # downsampling
141 | self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
142 |
143 | curr_res = resolution
144 | in_ch_mult = (1,) + tuple(ch_mult)
145 | self.in_ch_mult = in_ch_mult
146 | self.down = nn.ModuleList()
147 | block_in = self.ch
148 | for i_level in range(self.num_resolutions):
149 | block = nn.ModuleList()
150 | attn = nn.ModuleList()
151 | block_in = ch * in_ch_mult[i_level]
152 | block_out = ch * ch_mult[i_level]
153 | for _ in range(self.num_res_blocks):
154 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
155 | block_in = block_out
156 | down = nn.Module()
157 | down.block = block
158 | down.attn = attn
159 | if i_level != self.num_resolutions - 1:
160 | down.downsample = Downsample(block_in)
161 | curr_res = curr_res // 2
162 | self.down.append(down)
163 |
164 | # middle
165 | self.mid = nn.Module()
166 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
167 | self.mid.attn_1 = AttnBlock(block_in)
168 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
169 |
170 | # end
171 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
172 | self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
173 |
174 | def forward(self, x: Tensor) -> Tensor:
175 | # downsampling
176 | hs = [self.conv_in(x)]
177 | for i_level in range(self.num_resolutions):
178 | for i_block in range(self.num_res_blocks):
179 | h = self.down[i_level].block[i_block](hs[-1])
180 | if len(self.down[i_level].attn) > 0:
181 | h = self.down[i_level].attn[i_block](h)
182 | hs.append(h)
183 | if i_level != self.num_resolutions - 1:
184 | hs.append(self.down[i_level].downsample(hs[-1]))
185 |
186 | # middle
187 | h = hs[-1]
188 | h = self.mid.block_1(h)
189 | h = self.mid.attn_1(h)
190 | h = self.mid.block_2(h)
191 | # end
192 | h = self.norm_out(h)
193 | h = swish(h)
194 | h = self.conv_out(h)
195 | return h
196 |
197 |
198 | class Decoder(nn.Module):
199 | def __init__(
200 | self,
201 | ch: int,
202 | out_ch: int,
203 | ch_mult: list[int],
204 | num_res_blocks: int,
205 | in_channels: int,
206 | resolution: int,
207 | z_channels: int,
208 | ):
209 | super().__init__()
210 | self.ch = ch
211 | self.num_resolutions = len(ch_mult)
212 | self.num_res_blocks = num_res_blocks
213 | self.resolution = resolution
214 | self.in_channels = in_channels
215 | self.ffactor = 2 ** (self.num_resolutions - 1)
216 |
217 | # compute in_ch_mult, block_in and curr_res at lowest res
218 | block_in = ch * ch_mult[self.num_resolutions - 1]
219 | curr_res = resolution // 2 ** (self.num_resolutions - 1)
220 | self.z_shape = (1, z_channels, curr_res, curr_res)
221 |
222 | # z to block_in
223 | self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
224 |
225 | # middle
226 | self.mid = nn.Module()
227 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
228 | self.mid.attn_1 = AttnBlock(block_in)
229 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
230 |
231 | # upsampling
232 | self.up = nn.ModuleList()
233 | for i_level in reversed(range(self.num_resolutions)):
234 | block = nn.ModuleList()
235 | attn = nn.ModuleList()
236 | block_out = ch * ch_mult[i_level]
237 | for _ in range(self.num_res_blocks + 1):
238 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
239 | block_in = block_out
240 | up = nn.Module()
241 | up.block = block
242 | up.attn = attn
243 | if i_level != 0:
244 | up.upsample = Upsample(block_in)
245 | curr_res = curr_res * 2
246 | self.up.insert(0, up) # prepend to get consistent order
247 |
248 | # end
249 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
250 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
251 |
252 | def forward(self, z: Tensor) -> Tensor:
253 | # z to block_in
254 | h = self.conv_in(z)
255 |
256 | # middle
257 | h = self.mid.block_1(h)
258 | h = self.mid.attn_1(h)
259 | h = self.mid.block_2(h)
260 |
261 | # upsampling
262 | for i_level in reversed(range(self.num_resolutions)):
263 | for i_block in range(self.num_res_blocks + 1):
264 | h = self.up[i_level].block[i_block](h)
265 | if len(self.up[i_level].attn) > 0:
266 | h = self.up[i_level].attn[i_block](h)
267 | if i_level != 0:
268 | h = self.up[i_level].upsample(h)
269 |
270 | # end
271 | h = self.norm_out(h)
272 | h = swish(h)
273 | h = self.conv_out(h)
274 | return h
275 |
276 |
277 | class DiagonalGaussian(nn.Module):
278 | def __init__(self, sample: bool = True, chunk_dim: int = 1):
279 | super().__init__()
280 | self.sample = sample
281 | self.chunk_dim = chunk_dim
282 |
283 | def forward(self, z: Tensor) -> Tensor:
284 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
285 | if self.sample:
286 | std = torch.exp(0.5 * logvar)
287 | return mean + std * torch.randn_like(mean)
288 | else:
289 | return mean
290 |
291 |
292 | class AutoEncoder(nn.Module):
293 | def __init__(self, params: AutoEncoderParams):
294 | super().__init__()
295 | self.encoder = Encoder(
296 | resolution=params.resolution,
297 | in_channels=params.in_channels,
298 | ch=params.ch,
299 | ch_mult=params.ch_mult,
300 | num_res_blocks=params.num_res_blocks,
301 | z_channels=params.z_channels,
302 | )
303 | self.decoder = Decoder(
304 | resolution=params.resolution,
305 | in_channels=params.in_channels,
306 | ch=params.ch,
307 | out_ch=params.out_ch,
308 | ch_mult=params.ch_mult,
309 | num_res_blocks=params.num_res_blocks,
310 | z_channels=params.z_channels,
311 | )
312 | self.reg = DiagonalGaussian()
313 |
314 | self.scale_factor = params.scale_factor
315 | self.shift_factor = params.shift_factor
316 |
317 | def encode(self, x: Tensor) -> Tensor:
318 | z = self.reg(self.encoder(x))
319 | z = self.scale_factor * (z - self.shift_factor)
320 | return z
321 |
322 | def decode(self, z: Tensor) -> Tensor:
323 | z = z / self.scale_factor + self.shift_factor
324 | return self.decoder(z)
325 |
326 | def forward(self, x: Tensor) -> Tensor:
327 | return self.decode(self.encode(x))
328 |
--------------------------------------------------------------------------------
/uso/flux/pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import math
18 | from typing import Literal, Optional
19 | from torch import Tensor
20 |
21 | import torch
22 | from einops import rearrange
23 | from PIL import ExifTags, Image
24 | import torchvision.transforms.functional as TVF
25 |
26 | from .modules.layers import (
27 | DoubleStreamBlockLoraProcessor,
28 | DoubleStreamBlockProcessor,
29 | SingleStreamBlockLoraProcessor,
30 | SingleStreamBlockProcessor,
31 | )
32 | from .sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack
33 | from .util import (
34 | get_lora_rank,
35 | load_ae,
36 | load_checkpoint,
37 | load_clip,
38 | load_flow_model,
39 | load_flow_model_only_lora,
40 | load_t5,
41 | )
42 |
43 |
44 | def find_nearest_scale(image_h, image_w, predefined_scales):
45 | """
46 | 根据图片的高度和宽度,找到最近的预定义尺度。
47 |
48 | :param image_h: 图片的高度
49 | :param image_w: 图片的宽度
50 | :param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...]
51 | :return: 最近的预定义尺度 (h, w)
52 | """
53 | # 计算输入图片的长宽比
54 | image_ratio = image_h / image_w
55 |
56 | # 初始化变量以存储最小差异和最近的尺度
57 | min_diff = float("inf")
58 | nearest_scale = None
59 |
60 | # 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度
61 | for scale_h, scale_w in predefined_scales:
62 | predefined_ratio = scale_h / scale_w
63 | diff = abs(predefined_ratio - image_ratio)
64 |
65 | if diff < min_diff:
66 | min_diff = diff
67 | nearest_scale = (scale_h, scale_w)
68 |
69 | return nearest_scale
70 |
71 |
72 | def preprocess_ref(raw_image: Image.Image, long_size: int = 512, scale_ratio: int = 1):
73 | # 获取原始图像的宽度和高度
74 | image_w, image_h = raw_image.size
75 | if image_w == image_h and image_w == 16:
76 | return raw_image
77 |
78 | # 计算长边和短边
79 | if image_w >= image_h:
80 | new_w = long_size
81 | new_h = int((long_size / image_w) * image_h)
82 | else:
83 | new_h = long_size
84 | new_w = int((long_size / image_h) * image_w)
85 |
86 | # 按新的宽高进行等比例缩放
87 | raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
88 |
89 | # 为了能让canny img进行scale
90 | scale_ratio = int(scale_ratio)
91 | target_w = new_w // (16 * scale_ratio) * (16 * scale_ratio)
92 | target_h = new_h // (16 * scale_ratio) * (16 * scale_ratio)
93 |
94 | # 计算裁剪的起始坐标以实现中心裁剪
95 | left = (new_w - target_w) // 2
96 | top = (new_h - target_h) // 2
97 | right = left + target_w
98 | bottom = top + target_h
99 |
100 | # 进行中心裁剪
101 | raw_image = raw_image.crop((left, top, right, bottom))
102 |
103 | # 转换为 RGB 模式
104 | raw_image = raw_image.convert("RGB")
105 | return raw_image
106 |
107 |
108 | def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1):
109 | target_height_ref1 = int(target_height_ref1 // 64 * 64)
110 | target_width_ref1 = int(target_width_ref1 // 64 * 64)
111 | h, w = image.shape[-2:]
112 | if h < target_height_ref1 or w < target_width_ref1:
113 | # 计算长宽比
114 | aspect_ratio = w / h
115 | if h < target_height_ref1:
116 | new_h = target_height_ref1
117 | new_w = new_h * aspect_ratio
118 | if new_w < target_width_ref1:
119 | new_w = target_width_ref1
120 | new_h = new_w / aspect_ratio
121 | else:
122 | new_w = target_width_ref1
123 | new_h = new_w / aspect_ratio
124 | if new_h < target_height_ref1:
125 | new_h = target_height_ref1
126 | new_w = new_h * aspect_ratio
127 | else:
128 | aspect_ratio = w / h
129 | tgt_aspect_ratio = target_width_ref1 / target_height_ref1
130 | if aspect_ratio > tgt_aspect_ratio:
131 | new_h = target_height_ref1
132 | new_w = new_h * aspect_ratio
133 | else:
134 | new_w = target_width_ref1
135 | new_h = new_w / aspect_ratio
136 | # 使用 TVF.resize 进行图像缩放
137 | image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w)))
138 | # 计算中心裁剪的参数
139 | top = (image.shape[-2] - target_height_ref1) // 2
140 | left = (image.shape[-1] - target_width_ref1) // 2
141 | # 使用 TVF.crop 进行中心裁剪
142 | image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1)
143 | return image
144 |
145 |
146 | class USOPipeline:
147 | def __init__(
148 | self,
149 | model_type: str,
150 | device: torch.device,
151 | offload: bool = False,
152 | only_lora: bool = False,
153 | lora_rank: int = 16,
154 | hf_download: bool = True,
155 | ):
156 | self.device = device
157 | self.offload = offload
158 | self.model_type = model_type
159 |
160 | print(f'----> model type is {model_type}({only_lora})')
161 | self.clip = load_clip(self.device)
162 | print('----> load clip completely')
163 | self.t5 = load_t5(self.device, max_length=512)
164 | print('----> load t5 completely')
165 | self.ae = load_ae(model_type, device="cpu" if offload else self.device)
166 | print('----> load ae completely')
167 | self.use_fp8 = "fp8" in model_type
168 | if only_lora:
169 | self.model = load_flow_model_only_lora(
170 | model_type,
171 | device="cpu" if offload else self.device,
172 | lora_rank=lora_rank,
173 | use_fp8=self.use_fp8,
174 | hf_download=hf_download,
175 | )
176 | else:
177 | self.model = load_flow_model(
178 | model_type, device="cpu" if offload else self.device
179 | )
180 |
181 | def load_ckpt(self, ckpt_path):
182 | if ckpt_path is not None:
183 | from safetensors.torch import load_file as load_sft
184 |
185 | print("Loading checkpoint to replace old keys")
186 | # load_sft doesn't support torch.device
187 | if ckpt_path.endswith("safetensors"):
188 | sd = load_sft(ckpt_path, device="cpu")
189 | missing, unexpected = self.model.load_state_dict(
190 | sd, strict=False, assign=True
191 | )
192 | else:
193 | dit_state = torch.load(ckpt_path, map_location="cpu")
194 | sd = {}
195 | for k in dit_state.keys():
196 | sd[k.replace("module.", "")] = dit_state[k]
197 | missing, unexpected = self.model.load_state_dict(
198 | sd, strict=False, assign=True
199 | )
200 | self.model.to(str(self.device))
201 | print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}")
202 |
203 | def set_lora(
204 | self,
205 | local_path: str = None,
206 | repo_id: str = None,
207 | name: str = None,
208 | lora_weight: int = 0.7,
209 | ):
210 | checkpoint = load_checkpoint(local_path, repo_id, name)
211 | self.update_model_with_lora(checkpoint, lora_weight)
212 |
213 | def set_lora_from_collection(
214 | self, lora_type: str = "realism", lora_weight: int = 0.7
215 | ):
216 | checkpoint = load_checkpoint(
217 | None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
218 | )
219 | self.update_model_with_lora(checkpoint, lora_weight)
220 |
221 | def update_model_with_lora(self, checkpoint, lora_weight):
222 | rank = get_lora_rank(checkpoint)
223 | lora_attn_procs = {}
224 |
225 | for name, _ in self.model.attn_processors.items():
226 | lora_state_dict = {}
227 | for k in checkpoint.keys():
228 | if name in k:
229 | lora_state_dict[k[len(name) + 1 :]] = checkpoint[k] * lora_weight
230 |
231 | if len(lora_state_dict):
232 | if name.startswith("single_blocks"):
233 | lora_attn_procs[name] = SingleStreamBlockLoraProcessor(
234 | dim=3072, rank=rank
235 | )
236 | else:
237 | lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(
238 | dim=3072, rank=rank
239 | )
240 | lora_attn_procs[name].load_state_dict(lora_state_dict)
241 | lora_attn_procs[name].to(self.device)
242 | else:
243 | if name.startswith("single_blocks"):
244 | lora_attn_procs[name] = SingleStreamBlockProcessor()
245 | else:
246 | lora_attn_procs[name] = DoubleStreamBlockProcessor()
247 |
248 | self.model.set_attn_processor(lora_attn_procs)
249 |
250 | def __call__(
251 | self,
252 | prompt: str,
253 | width: int = 512,
254 | height: int = 512,
255 | guidance: float = 4,
256 | num_steps: int = 50,
257 | seed: int = 123456789,
258 | **kwargs,
259 | ):
260 | width = 16 * (width // 16)
261 | height = 16 * (height // 16)
262 |
263 | device_type = self.device if isinstance(self.device, str) else self.device.type
264 | with torch.autocast(
265 | enabled=self.use_fp8, device_type=device_type, dtype=torch.bfloat16
266 | ):
267 | return self.forward(
268 | prompt, width, height, guidance, num_steps, seed, **kwargs
269 | )
270 |
271 | @torch.inference_mode()
272 | def gradio_generate(
273 | self,
274 | prompt: str,
275 | image_prompt1: Image.Image,
276 | image_prompt2: Image.Image,
277 | image_prompt3: Image.Image,
278 | seed: int,
279 | width: int = 1024,
280 | height: int = 1024,
281 | guidance: float = 4,
282 | num_steps: int = 25,
283 | keep_size: bool = False,
284 | content_long_size: int = 512,
285 | ):
286 | ref_content_imgs = [image_prompt1]
287 | ref_content_imgs = [img for img in ref_content_imgs if isinstance(img, Image.Image)]
288 | ref_content_imgs = [preprocess_ref(img, content_long_size) for img in ref_content_imgs]
289 |
290 | ref_style_imgs = [image_prompt2, image_prompt3]
291 | ref_style_imgs = [img for img in ref_style_imgs if isinstance(img, Image.Image)]
292 | ref_style_imgs = [self.model.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs]
293 |
294 | seed = seed if seed != -1 else torch.randint(0, 10**8, (1,)).item()
295 |
296 | # whether keep input image size
297 | if keep_size and len(ref_content_imgs)>0:
298 | width, height = ref_content_imgs[0].size
299 | width, height = int(width * (1024 / content_long_size)), int(height * (1024 / content_long_size))
300 | img = self(
301 | prompt=prompt,
302 | width=width,
303 | height=height,
304 | guidance=guidance,
305 | num_steps=num_steps,
306 | seed=seed,
307 | ref_imgs=ref_content_imgs,
308 | siglip_inputs=ref_style_imgs,
309 | )
310 |
311 | filename = f"output/gradio/{seed}_{prompt[:20]}.png"
312 | os.makedirs(os.path.dirname(filename), exist_ok=True)
313 | exif_data = Image.Exif()
314 | exif_data[ExifTags.Base.Make] = "USO"
315 | exif_data[ExifTags.Base.Model] = self.model_type
316 | info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}"
317 | exif_data[ExifTags.Base.ImageDescription] = info
318 | img.save(filename, format="png", exif=exif_data)
319 | return img, filename
320 |
321 | @torch.inference_mode
322 | def forward(
323 | self,
324 | prompt: str,
325 | width: int,
326 | height: int,
327 | guidance: float,
328 | num_steps: int,
329 | seed: int,
330 | ref_imgs: list[Image.Image] | None = None,
331 | pe: Literal["d", "h", "w", "o"] = "d",
332 | siglip_inputs: list[Tensor] | None = None,
333 | **kwargs
334 | ):
335 |
336 | update_func = kwargs.get('update_func', lambda *args, **kwargs: None)
337 | x = get_noise(
338 | 1, height, width, device=self.device, dtype=torch.bfloat16, seed=seed
339 | )
340 | timesteps = get_schedule(
341 | num_steps,
342 | (width // 8) * (height // 8) // (16 * 16),
343 | shift=True,
344 | )
345 | if self.offload:
346 | self.ae.encoder = self.ae.encoder.to(self.device)
347 | x_1_refs = [
348 | self.ae.encode(
349 | (TVF.to_tensor(ref_img) * 2.0 - 1.0)
350 | .unsqueeze(0)
351 | .to(self.device, torch.float32)
352 | ).to(torch.bfloat16)
353 | for ref_img in ref_imgs
354 | ]
355 |
356 | if self.offload:
357 | self.offload_model_to_cpu(self.ae.encoder)
358 | self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
359 | inp_cond = prepare_multi_ip(
360 | t5=self.t5,
361 | clip=self.clip,
362 | img=x,
363 | prompt=prompt,
364 | ref_imgs=x_1_refs,
365 | pe=pe,
366 | )
367 |
368 | if self.offload:
369 | self.offload_model_to_cpu(self.t5, self.clip)
370 | self.model = self.model.to(self.device)
371 |
372 | x = denoise(
373 | self.model,
374 | **inp_cond,
375 | timesteps=timesteps,
376 | guidance=guidance,
377 | siglip_inputs=siglip_inputs,
378 | update_func=update_func,
379 | )
380 |
381 | if self.offload:
382 | self.offload_model_to_cpu(self.model)
383 | self.ae.decoder.to(x.device)
384 | x = unpack(x.float(), height, width)
385 | x = self.ae.decode(x)
386 | self.offload_model_to_cpu(self.ae.decoder)
387 |
388 | x1 = x.clamp(-1, 1)
389 | x1 = rearrange(x1[-1], "c h w -> h w c")
390 | output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
391 | return output_img
392 |
393 | def offload_model_to_cpu(self, *models):
394 | if not self.offload:
395 | return
396 | for model in models:
397 | model.cpu()
398 | torch.cuda.empty_cache()
399 |
--------------------------------------------------------------------------------
/uso/flux/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | from dataclasses import dataclass
18 |
19 | import torch
20 | import json
21 | import numpy as np
22 | from huggingface_hub import hf_hub_download
23 | from safetensors import safe_open
24 | from safetensors.torch import load_file as load_sft
25 |
26 | from .model import Flux, FluxParams
27 | from .modules.autoencoder import AutoEncoder, AutoEncoderParams
28 | from .modules.conditioner import HFEmbedder
29 |
30 | import re
31 | from .modules.layers import (
32 | DoubleStreamBlockLoraProcessor,
33 | SingleStreamBlockLoraProcessor,
34 | )
35 |
36 | import os
37 | try:
38 | import folder_paths
39 | print('run in comfyui')
40 | except:
41 | print('not run in comfyui')
42 | from types import SimpleNamespace
43 | folder_paths = SimpleNamespace()
44 | folder_paths.models_dir = '/workspace/comfyui/models/'
45 |
46 |
47 | def load_model(ckpt, device="cpu"):
48 | if ckpt.endswith("safetensors"):
49 | from safetensors import safe_open
50 |
51 | pl_sd = {}
52 | with safe_open(ckpt, framework="pt", device=device) as f:
53 | for k in f.keys():
54 | pl_sd[k] = f.get_tensor(k)
55 | else:
56 | pl_sd = torch.load(ckpt, map_location=device)
57 | return pl_sd
58 |
59 |
60 | def load_safetensors(path):
61 | tensors = {}
62 | with safe_open(path, framework="pt", device="cpu") as f:
63 | for key in f.keys():
64 | tensors[key] = f.get_tensor(key)
65 | return tensors
66 |
67 |
68 | def get_lora_rank(checkpoint):
69 | for k in checkpoint.keys():
70 | if k.endswith(".down.weight"):
71 | return checkpoint[k].shape[0]
72 |
73 |
74 | def load_checkpoint(local_path, repo_id, name):
75 | if local_path is not None:
76 | if ".safetensors" in local_path:
77 | print(f"Loading .safetensors checkpoint from {local_path}")
78 | checkpoint = load_safetensors(local_path)
79 | else:
80 | print(f"Loading checkpoint from {local_path}")
81 | checkpoint = torch.load(local_path, map_location="cpu")
82 | elif repo_id is not None and name is not None:
83 | print(f"Loading checkpoint {name} from repo id {repo_id}")
84 | checkpoint = load_from_repo_id(repo_id, name)
85 | else:
86 | raise ValueError(
87 | "LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
88 | )
89 | return checkpoint
90 |
91 |
92 | def c_crop(image):
93 | width, height = image.size
94 | new_size = min(width, height)
95 | left = (width - new_size) / 2
96 | top = (height - new_size) / 2
97 | right = (width + new_size) / 2
98 | bottom = (height + new_size) / 2
99 | return image.crop((left, top, right, bottom))
100 |
101 |
102 | def pad64(x):
103 | return int(np.ceil(float(x) / 64.0) * 64 - x)
104 |
105 |
106 | def HWC3(x):
107 | assert x.dtype == np.uint8
108 | if x.ndim == 2:
109 | x = x[:, :, None]
110 | assert x.ndim == 3
111 | H, W, C = x.shape
112 | assert C == 1 or C == 3 or C == 4
113 | if C == 3:
114 | return x
115 | if C == 1:
116 | return np.concatenate([x, x, x], axis=2)
117 | if C == 4:
118 | color = x[:, :, 0:3].astype(np.float32)
119 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0
120 | y = color * alpha + 255.0 * (1.0 - alpha)
121 | y = y.clip(0, 255).astype(np.uint8)
122 | return y
123 |
124 |
125 | @dataclass
126 | class ModelSpec:
127 | params: FluxParams
128 | ae_params: AutoEncoderParams
129 | ckpt_path: str | None
130 | ae_path: str | None
131 | repo_id: str | None
132 | repo_flow: str | None
133 | repo_ae: str | None
134 | repo_id_ae: str | None
135 |
136 |
137 | configs = {
138 | "flux-dev": ModelSpec(
139 | repo_id="black-forest-labs/FLUX.1-dev",
140 | repo_id_ae="black-forest-labs/FLUX.1-dev",
141 | repo_flow="flux1-dev.safetensors",
142 | repo_ae="ae.safetensors",
143 | ckpt_path=os.getenv("FLUX_DEV"),
144 | params=FluxParams(
145 | in_channels=64,
146 | vec_in_dim=768,
147 | context_in_dim=4096,
148 | hidden_size=3072,
149 | mlp_ratio=4.0,
150 | num_heads=24,
151 | depth=19,
152 | depth_single_blocks=38,
153 | axes_dim=[16, 56, 56],
154 | theta=10_000,
155 | qkv_bias=True,
156 | guidance_embed=True,
157 | ),
158 | ae_path=os.getenv("AE"),
159 | ae_params=AutoEncoderParams(
160 | resolution=256,
161 | in_channels=3,
162 | ch=128,
163 | out_ch=3,
164 | ch_mult=[1, 2, 4, 4],
165 | num_res_blocks=2,
166 | z_channels=16,
167 | scale_factor=0.3611,
168 | shift_factor=0.1159,
169 | ),
170 | ),
171 | "flux-dev-fp8": ModelSpec(
172 | repo_id="black-forest-labs/FLUX.1-dev",
173 | repo_id_ae="black-forest-labs/FLUX.1-dev",
174 | repo_flow="flux1-dev.safetensors",
175 | repo_ae="ae.safetensors",
176 | ckpt_path=os.getenv("FLUX_DEV_FP8"),
177 | params=FluxParams(
178 | in_channels=64,
179 | vec_in_dim=768,
180 | context_in_dim=4096,
181 | hidden_size=3072,
182 | mlp_ratio=4.0,
183 | num_heads=24,
184 | depth=19,
185 | depth_single_blocks=38,
186 | axes_dim=[16, 56, 56],
187 | theta=10_000,
188 | qkv_bias=True,
189 | guidance_embed=True,
190 | ),
191 | ae_path=os.getenv("AE"),
192 | ae_params=AutoEncoderParams(
193 | resolution=256,
194 | in_channels=3,
195 | ch=128,
196 | out_ch=3,
197 | ch_mult=[1, 2, 4, 4],
198 | num_res_blocks=2,
199 | z_channels=16,
200 | scale_factor=0.3611,
201 | shift_factor=0.1159,
202 | ),
203 | ),
204 | "flux-krea-dev": ModelSpec(
205 | repo_id="black-forest-labs/FLUX.1-Krea-dev",
206 | repo_id_ae="black-forest-labs/FLUX.1-Krea-dev",
207 | repo_flow="flux1-krea-dev.safetensors",
208 | repo_ae="ae.safetensors",
209 | ckpt_path=os.getenv("FLUX_KREA_DEV"),
210 | params=FluxParams(
211 | in_channels=64,
212 | vec_in_dim=768,
213 | context_in_dim=4096,
214 | hidden_size=3072,
215 | mlp_ratio=4.0,
216 | num_heads=24,
217 | depth=19,
218 | depth_single_blocks=38,
219 | axes_dim=[16, 56, 56],
220 | theta=10_000,
221 | qkv_bias=True,
222 | guidance_embed=True,
223 | ),
224 | ae_path=os.getenv("AE"),
225 | ae_params=AutoEncoderParams(
226 | resolution=256,
227 | in_channels=3,
228 | ch=128,
229 | out_ch=3,
230 | ch_mult=[1, 2, 4, 4],
231 | num_res_blocks=2,
232 | z_channels=16,
233 | scale_factor=0.3611,
234 | shift_factor=0.1159,
235 | ),
236 | ),
237 | "flux-schnell": ModelSpec(
238 | repo_id="black-forest-labs/FLUX.1-schnell",
239 | repo_id_ae="black-forest-labs/FLUX.1-dev",
240 | repo_flow="flux1-schnell.safetensors",
241 | repo_ae="ae.safetensors",
242 | ckpt_path=os.getenv("FLUX_SCHNELL"),
243 | params=FluxParams(
244 | in_channels=64,
245 | vec_in_dim=768,
246 | context_in_dim=4096,
247 | hidden_size=3072,
248 | mlp_ratio=4.0,
249 | num_heads=24,
250 | depth=19,
251 | depth_single_blocks=38,
252 | axes_dim=[16, 56, 56],
253 | theta=10_000,
254 | qkv_bias=True,
255 | guidance_embed=False,
256 | ),
257 | ae_path=os.getenv("AE"),
258 | ae_params=AutoEncoderParams(
259 | resolution=256,
260 | in_channels=3,
261 | ch=128,
262 | out_ch=3,
263 | ch_mult=[1, 2, 4, 4],
264 | num_res_blocks=2,
265 | z_channels=16,
266 | scale_factor=0.3611,
267 | shift_factor=0.1159,
268 | ),
269 | ),
270 | }
271 |
272 |
273 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
274 | if len(missing) > 0 and len(unexpected) > 0:
275 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
276 | print("\n" + "-" * 79 + "\n")
277 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
278 | elif len(missing) > 0:
279 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
280 | elif len(unexpected) > 0:
281 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
282 |
283 |
284 | def load_from_repo_id(repo_id, checkpoint_name):
285 | ckpt_path = hf_hub_download(repo_id, checkpoint_name)
286 | sd = load_sft(ckpt_path, device="cpu")
287 | return sd
288 |
289 |
290 | def load_flow_model(
291 | name: str, device: str | torch.device = "cuda", hf_download: bool = True
292 | ):
293 | # Loading Flux
294 | print("Init model")
295 | ckpt_path = configs[name].ckpt_path
296 | if (
297 | ckpt_path is None
298 | and configs[name].repo_id is not None
299 | and configs[name].repo_flow is not None
300 | ):
301 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
302 |
303 | # with torch.device("meta" if ckpt_path is not None else device):
304 | with torch.device(device):
305 | model = Flux(configs[name].params).to(torch.bfloat16)
306 |
307 | if ckpt_path is not None:
308 | print("Loading main checkpoint")
309 | # load_sft doesn't support torch.device
310 | sd = load_model(ckpt_path, device="cpu")
311 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
312 | print_load_warning(missing, unexpected)
313 | return model.to(str(device))
314 |
315 |
316 | def load_flow_model_only_lora(
317 | name: str,
318 | device: str | torch.device = "cuda",
319 | hf_download: bool = True,
320 | lora_rank: int = 16,
321 | use_fp8: bool = False,
322 | ):
323 | # Loading Flux
324 | # ckpt_path = configs[name].ckpt_path
325 | # kiki
326 | ckpt_path = os.path.join(folder_paths.models_dir, 'diffusers', 'FLUX.1-dev', 'flux1-dev.safetensors')
327 | if (
328 | ckpt_path is None
329 | and configs[name].repo_id is not None
330 | and configs[name].repo_flow is not None
331 | ):
332 | ckpt_path = hf_hub_download(
333 | configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")
334 | )
335 |
336 | # if hf_download:
337 | # try:
338 | # lora_ckpt_path = hf_hub_download(
339 | # "bytedance-research/USO", "uso_flux_v1.0/dit_lora.safetensors"
340 | # )
341 | # except Exception as e:
342 | # print(f"Failed to download lora checkpoint: {e}")
343 | # print("Trying to load lora from local")
344 | # lora_ckpt_path = os.environ.get("LORA", None)
345 | # try:
346 | # proj_ckpt_path = hf_hub_download(
347 | # "bytedance-research/USO", "uso_flux_v1.0/projector.safetensors"
348 | # )
349 | # except Exception as e:
350 | # print(f"Failed to download projection_model checkpoint: {e}")
351 | # print("Trying to load projection_model from local")
352 | # proj_ckpt_path = os.environ.get("PROJECTION_MODEL", None)
353 | # else:
354 | # lora_ckpt_path = os.environ.get("LORA", None)
355 | # proj_ckpt_path = os.environ.get("PROJECTION_MODEL", None)
356 | # print(lora_ckpt_path)
357 | # print(proj_ckpt_path)
358 | base_ckpt_path = os.path.join(folder_paths.models_dir, 'uso', 'uso_flux_v1.0')
359 | lora_ckpt_path = os.path.join(base_ckpt_path, 'dit_lora.safetensors')
360 | proj_ckpt_path = os.path.join(base_ckpt_path, 'projector.safetensors')
361 | with torch.device("meta" if ckpt_path is not None else device):
362 | model = Flux(configs[name].params)
363 |
364 | model = set_lora(
365 | model, lora_rank, device="meta" if lora_ckpt_path is not None else device
366 | )
367 |
368 | if ckpt_path is not None:
369 | print(f"Loading lora from {lora_ckpt_path}")
370 | lora_sd = (
371 | load_sft(lora_ckpt_path, device=str(device))
372 | if lora_ckpt_path.endswith("safetensors")
373 | else torch.load(lora_ckpt_path, map_location="cpu")
374 | )
375 | proj_sd = (
376 | load_sft(proj_ckpt_path, device=str(device))
377 | if proj_ckpt_path.endswith("safetensors")
378 | else torch.load(proj_ckpt_path, map_location="cpu")
379 | )
380 | lora_sd.update(proj_sd)
381 |
382 | print("Loading main checkpoint")
383 | # load_sft doesn't support torch.device
384 |
385 | if ckpt_path.endswith("safetensors"):
386 | if use_fp8:
387 | print(
388 | "####\n"
389 | "We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n"
390 | "we convert the fp8 checkpoint on flight from bf16 checkpoint\n"
391 | "If your storage is constrained"
392 | "you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n"
393 | )
394 | sd = load_sft(ckpt_path, device="cpu")
395 | sd = {
396 | k: v.to(dtype=torch.float8_e4m3fn, device=device)
397 | for k, v in sd.items()
398 | }
399 | else:
400 | sd = load_sft(ckpt_path, device=str(device))
401 |
402 | sd.update(lora_sd)
403 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
404 | else:
405 | dit_state = torch.load(ckpt_path, map_location="cpu")
406 | sd = {}
407 | for k in dit_state.keys():
408 | sd[k.replace("module.", "")] = dit_state[k]
409 | sd.update(lora_sd)
410 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
411 | model.to(str(device))
412 | print_load_warning(missing, unexpected)
413 | return model
414 |
415 |
416 | def set_lora(
417 | model: Flux,
418 | lora_rank: int,
419 | double_blocks_indices: list[int] | None = None,
420 | single_blocks_indices: list[int] | None = None,
421 | device: str | torch.device = "cpu",
422 | ) -> Flux:
423 | double_blocks_indices = (
424 | list(range(model.params.depth))
425 | if double_blocks_indices is None
426 | else double_blocks_indices
427 | )
428 | single_blocks_indices = (
429 | list(range(model.params.depth_single_blocks))
430 | if single_blocks_indices is None
431 | else single_blocks_indices
432 | )
433 |
434 | lora_attn_procs = {}
435 | with torch.device(device):
436 | for name, attn_processor in model.attn_processors.items():
437 | match = re.search(r"\.(\d+)\.", name)
438 | if match:
439 | layer_index = int(match.group(1))
440 |
441 | if (
442 | name.startswith("double_blocks")
443 | and layer_index in double_blocks_indices
444 | ):
445 | lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(
446 | dim=model.params.hidden_size, rank=lora_rank
447 | )
448 | elif (
449 | name.startswith("single_blocks")
450 | and layer_index in single_blocks_indices
451 | ):
452 | lora_attn_procs[name] = SingleStreamBlockLoraProcessor(
453 | dim=model.params.hidden_size, rank=lora_rank
454 | )
455 | else:
456 | lora_attn_procs[name] = attn_processor
457 | model.set_attn_processor(lora_attn_procs)
458 | return model
459 |
460 |
461 | def load_flow_model_quintized(
462 | name: str, device: str | torch.device = "cuda", hf_download: bool = True
463 | ):
464 | # Loading Flux
465 | from optimum.quanto import requantize
466 |
467 | print("Init model")
468 | ckpt_path = configs[name].ckpt_path
469 | if (
470 | ckpt_path is None
471 | and configs[name].repo_id is not None
472 | and configs[name].repo_flow is not None
473 | and hf_download
474 | ):
475 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
476 | json_path = hf_hub_download(configs[name].repo_id, "flux_dev_quantization_map.json")
477 |
478 | model = Flux(configs[name].params).to(torch.bfloat16)
479 |
480 | print("Loading checkpoint")
481 | # load_sft doesn't support torch.device
482 | sd = load_sft(ckpt_path, device="cpu")
483 | sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()}
484 | model.load_state_dict(sd, assign=True)
485 | return model
486 | with open(json_path, "r") as f:
487 | quantization_map = json.load(f)
488 | print("Start a quantization process...")
489 | requantize(model, sd, quantization_map, device=device)
490 | print("Model is quantized!")
491 | return model
492 |
493 |
494 | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
495 | # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
496 | #version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders")
497 | # version = '/workspace/comfyui/models/clip/xflux_text_encoders'
498 | version = os.path.join(folder_paths.models_dir, 'clip', 'xflux_text_encoders')
499 | return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(
500 | device
501 | )
502 |
503 |
504 | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
505 | # version = os.environ.get("CLIP", "openai/clip-vit-large-patch14")
506 | #kiki
507 | # version = '/workspace/comfyui/models/clip_vision/clip-vit-large-patch14'
508 | version = os.path.join(folder_paths.models_dir, 'clip_vision', 'clip-vit-large-patch14')
509 | return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16, is_clip=True).to(device)
510 |
511 |
512 | def load_ae(
513 | name: str, device: str | torch.device = "cuda", hf_download: bool = True
514 | ) -> AutoEncoder:
515 | # ckpt_path = configs[name].ae_path
516 | # if (
517 | # ckpt_path is None
518 | # and configs[name].repo_id is not None
519 | # and configs[name].repo_ae is not None
520 | # and hf_download
521 | # ):
522 | # ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
523 | #kiki
524 | ckpt_path = os.path.join(folder_paths.models_dir, 'diffusers', 'FLUX.1-dev', 'ae.safetensors')
525 |
526 | # Loading the autoencoder
527 | print("Init AE")
528 | with torch.device("meta" if ckpt_path is not None else device):
529 | ae = AutoEncoder(configs[name].ae_params)
530 |
531 | if ckpt_path is not None:
532 | sd = load_sft(ckpt_path, device=str(device))
533 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
534 | print_load_warning(missing, unexpected)
535 | return ae
536 |
--------------------------------------------------------------------------------
/uso/flux/modules/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import math
17 | from dataclasses import dataclass
18 |
19 | import torch
20 | from einops import rearrange, repeat
21 | from torch import Tensor, nn
22 |
23 | from ..math import attention, rope
24 |
25 |
26 | class EmbedND(nn.Module):
27 | def __init__(self, dim: int, theta: int, axes_dim: list[int]):
28 | super().__init__()
29 | self.dim = dim
30 | self.theta = theta
31 | self.axes_dim = axes_dim
32 |
33 | def forward(self, ids: Tensor) -> Tensor:
34 | n_axes = ids.shape[-1]
35 | emb = torch.cat(
36 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
37 | dim=-3,
38 | )
39 |
40 | return emb.unsqueeze(1)
41 |
42 |
43 | def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
44 | """
45 | Create sinusoidal timestep embeddings.
46 | :param t: a 1-D Tensor of N indices, one per batch element.
47 | These may be fractional.
48 | :param dim: the dimension of the output.
49 | :param max_period: controls the minimum frequency of the embeddings.
50 | :return: an (N, D) Tensor of positional embeddings.
51 | """
52 | t = time_factor * t
53 | half = dim // 2
54 | freqs = torch.exp(
55 | -math.log(max_period)
56 | * torch.arange(start=0, end=half, dtype=torch.float32)
57 | / half
58 | ).to(t.device)
59 |
60 | args = t[:, None].float() * freqs[None]
61 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
62 | if dim % 2:
63 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
64 | if torch.is_floating_point(t):
65 | embedding = embedding.to(t)
66 | return embedding
67 |
68 |
69 | class MLPEmbedder(nn.Module):
70 | def __init__(self, in_dim: int, hidden_dim: int):
71 | super().__init__()
72 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
73 | self.silu = nn.SiLU()
74 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
75 |
76 | def forward(self, x: Tensor) -> Tensor:
77 | return self.out_layer(self.silu(self.in_layer(x)))
78 |
79 |
80 | class RMSNorm(torch.nn.Module):
81 | def __init__(self, dim: int):
82 | super().__init__()
83 | self.scale = nn.Parameter(torch.ones(dim))
84 |
85 | def forward(self, x: Tensor):
86 | x_dtype = x.dtype
87 | x = x.float()
88 | rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
89 | return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
90 |
91 |
92 | class QKNorm(torch.nn.Module):
93 | def __init__(self, dim: int):
94 | super().__init__()
95 | self.query_norm = RMSNorm(dim)
96 | self.key_norm = RMSNorm(dim)
97 |
98 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
99 | q = self.query_norm(q)
100 | k = self.key_norm(k)
101 | return q.to(v), k.to(v)
102 |
103 |
104 | class LoRALinearLayer(nn.Module):
105 | def __init__(
106 | self,
107 | in_features,
108 | out_features,
109 | rank=4,
110 | network_alpha=None,
111 | device=None,
112 | dtype=None,
113 | ):
114 | super().__init__()
115 |
116 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
117 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
118 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
119 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
120 | self.network_alpha = network_alpha
121 | self.rank = rank
122 |
123 | nn.init.normal_(self.down.weight, std=1 / rank)
124 | nn.init.zeros_(self.up.weight)
125 |
126 | def forward(self, hidden_states):
127 | orig_dtype = hidden_states.dtype
128 | dtype = self.down.weight.dtype
129 |
130 | down_hidden_states = self.down(hidden_states.to(dtype))
131 | up_hidden_states = self.up(down_hidden_states)
132 |
133 | if self.network_alpha is not None:
134 | up_hidden_states *= self.network_alpha / self.rank
135 |
136 | return up_hidden_states.to(orig_dtype)
137 |
138 |
139 | class FLuxSelfAttnProcessor:
140 | def __call__(self, attn, x, pe, **attention_kwargs):
141 | qkv = attn.qkv(x)
142 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
143 | q, k = attn.norm(q, k, v)
144 | x = attention(q, k, v, pe=pe)
145 | x = attn.proj(x)
146 | return x
147 |
148 |
149 | class LoraFluxAttnProcessor(nn.Module):
150 |
151 | def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
152 | super().__init__()
153 | self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
154 | self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
155 | self.lora_weight = lora_weight
156 |
157 | def __call__(self, attn, x, pe, **attention_kwargs):
158 | qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
159 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
160 | q, k = attn.norm(q, k, v)
161 | x = attention(q, k, v, pe=pe)
162 | x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
163 | return x
164 |
165 |
166 | class SelfAttention(nn.Module):
167 | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
168 | super().__init__()
169 | self.num_heads = num_heads
170 | head_dim = dim // num_heads
171 |
172 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
173 | self.norm = QKNorm(head_dim)
174 | self.proj = nn.Linear(dim, dim)
175 |
176 | def forward():
177 | pass
178 |
179 |
180 | @dataclass
181 | class ModulationOut:
182 | shift: Tensor
183 | scale: Tensor
184 | gate: Tensor
185 |
186 |
187 | class Modulation(nn.Module):
188 | def __init__(self, dim: int, double: bool):
189 | super().__init__()
190 | self.is_double = double
191 | self.multiplier = 6 if double else 3
192 | self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
193 |
194 | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
195 | out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
196 | self.multiplier, dim=-1
197 | )
198 |
199 | return (
200 | ModulationOut(*out[:3]),
201 | ModulationOut(*out[3:]) if self.is_double else None,
202 | )
203 |
204 |
205 | class DoubleStreamBlockLoraProcessor(nn.Module):
206 | def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
207 | super().__init__()
208 | self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
209 | self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
210 | self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
211 | self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
212 | self.lora_weight = lora_weight
213 |
214 | def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
215 | img_mod1, img_mod2 = attn.img_mod(vec)
216 | txt_mod1, txt_mod2 = attn.txt_mod(vec)
217 |
218 | # prepare image for attention
219 | img_modulated = attn.img_norm1(img)
220 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
221 | img_qkv = (
222 | attn.img_attn.qkv(img_modulated)
223 | + self.qkv_lora1(img_modulated) * self.lora_weight
224 | )
225 | img_q, img_k, img_v = rearrange(
226 | img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads
227 | )
228 | img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
229 |
230 | # prepare txt for attention
231 | txt_modulated = attn.txt_norm1(txt)
232 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
233 | txt_qkv = (
234 | attn.txt_attn.qkv(txt_modulated)
235 | + self.qkv_lora2(txt_modulated) * self.lora_weight
236 | )
237 | txt_q, txt_k, txt_v = rearrange(
238 | txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads
239 | )
240 | txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
241 |
242 | # run actual attention
243 | q = torch.cat((txt_q, img_q), dim=2)
244 | k = torch.cat((txt_k, img_k), dim=2)
245 | v = torch.cat((txt_v, img_v), dim=2)
246 |
247 | attn1 = attention(q, k, v, pe=pe)
248 | txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
249 |
250 | # calculate the img bloks
251 | img = img + img_mod1.gate * (
252 | attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight
253 | )
254 | img = img + img_mod2.gate * attn.img_mlp(
255 | (1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift
256 | )
257 |
258 | # calculate the txt bloks
259 | txt = txt + txt_mod1.gate * (
260 | attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight
261 | )
262 | txt = txt + txt_mod2.gate * attn.txt_mlp(
263 | (1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift
264 | )
265 | return img, txt
266 |
267 |
268 | class DoubleStreamBlockProcessor:
269 | def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
270 | img_mod1, img_mod2 = attn.img_mod(vec)
271 | txt_mod1, txt_mod2 = attn.txt_mod(vec)
272 |
273 | # prepare image for attention
274 | img_modulated = attn.img_norm1(img)
275 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
276 | img_qkv = attn.img_attn.qkv(img_modulated)
277 | img_q, img_k, img_v = rearrange(
278 | img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim
279 | )
280 | img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
281 |
282 | # prepare txt for attention
283 | txt_modulated = attn.txt_norm1(txt)
284 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
285 | txt_qkv = attn.txt_attn.qkv(txt_modulated)
286 | txt_q, txt_k, txt_v = rearrange(
287 | txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim
288 | )
289 | txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
290 |
291 | # run actual attention
292 | q = torch.cat((txt_q, img_q), dim=2)
293 | k = torch.cat((txt_k, img_k), dim=2)
294 | v = torch.cat((txt_v, img_v), dim=2)
295 |
296 | attn1 = attention(q, k, v, pe=pe)
297 | txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
298 |
299 | # calculate the img bloks
300 | img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
301 | img = img + img_mod2.gate * attn.img_mlp(
302 | (1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift
303 | )
304 |
305 | # calculate the txt bloks
306 | txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
307 | txt = txt + txt_mod2.gate * attn.txt_mlp(
308 | (1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift
309 | )
310 | return img, txt
311 |
312 |
313 | class DoubleStreamBlock(nn.Module):
314 | def __init__(
315 | self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
316 | ):
317 | super().__init__()
318 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
319 | self.num_heads = num_heads
320 | self.hidden_size = hidden_size
321 | self.head_dim = hidden_size // num_heads
322 |
323 | self.img_mod = Modulation(hidden_size, double=True)
324 | self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
325 | self.img_attn = SelfAttention(
326 | dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
327 | )
328 |
329 | self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
330 | self.img_mlp = nn.Sequential(
331 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
332 | nn.GELU(approximate="tanh"),
333 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
334 | )
335 |
336 | self.txt_mod = Modulation(hidden_size, double=True)
337 | self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
338 | self.txt_attn = SelfAttention(
339 | dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
340 | )
341 |
342 | self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
343 | self.txt_mlp = nn.Sequential(
344 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
345 | nn.GELU(approximate="tanh"),
346 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
347 | )
348 | processor = DoubleStreamBlockProcessor()
349 | self.set_processor(processor)
350 |
351 | def set_processor(self, processor) -> None:
352 | self.processor = processor
353 |
354 | def get_processor(self):
355 | return self.processor
356 |
357 | def forward(
358 | self,
359 | img: Tensor,
360 | txt: Tensor,
361 | vec: Tensor,
362 | pe: Tensor,
363 | image_proj: Tensor = None,
364 | ip_scale: float = 1.0,
365 | ) -> tuple[Tensor, Tensor]:
366 | if image_proj is None:
367 | return self.processor(self, img, txt, vec, pe)
368 | else:
369 | return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
370 |
371 |
372 | class SingleStreamBlockLoraProcessor(nn.Module):
373 | def __init__(
374 | self, dim: int, rank: int = 4, network_alpha=None, lora_weight: float = 1
375 | ):
376 | super().__init__()
377 | self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
378 | self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha)
379 | self.lora_weight = lora_weight
380 |
381 | def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
382 |
383 | mod, _ = attn.modulation(vec)
384 | x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
385 | qkv, mlp = torch.split(
386 | attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1
387 | )
388 | qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
389 |
390 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
391 | q, k = attn.norm(q, k, v)
392 |
393 | # compute attention
394 | attn_1 = attention(q, k, v, pe=pe)
395 |
396 | # compute activation in mlp stream, cat again and run second linear layer
397 | output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
398 | output = (
399 | output
400 | + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
401 | * self.lora_weight
402 | )
403 | output = x + mod.gate * output
404 | return output
405 |
406 |
407 | class SingleStreamBlockProcessor:
408 | def __call__(
409 | self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs
410 | ) -> Tensor:
411 |
412 | mod, _ = attn.modulation(vec)
413 | x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
414 | qkv, mlp = torch.split(
415 | attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1
416 | )
417 |
418 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
419 | q, k = attn.norm(q, k, v)
420 |
421 | # compute attention
422 | attn_1 = attention(q, k, v, pe=pe)
423 |
424 | # compute activation in mlp stream, cat again and run second linear layer
425 | output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
426 | output = x + mod.gate * output
427 | return output
428 |
429 |
430 | class SingleStreamBlock(nn.Module):
431 | """
432 | A DiT block with parallel linear layers as described in
433 | https://arxiv.org/abs/2302.05442 and adapted modulation interface.
434 | """
435 |
436 | def __init__(
437 | self,
438 | hidden_size: int,
439 | num_heads: int,
440 | mlp_ratio: float = 4.0,
441 | qk_scale: float | None = None,
442 | ):
443 | super().__init__()
444 | self.hidden_dim = hidden_size
445 | self.num_heads = num_heads
446 | self.head_dim = hidden_size // num_heads
447 | self.scale = qk_scale or self.head_dim**-0.5
448 |
449 | self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
450 | # qkv and mlp_in
451 | self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
452 | # proj and mlp_out
453 | self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
454 |
455 | self.norm = QKNorm(self.head_dim)
456 |
457 | self.hidden_size = hidden_size
458 | self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
459 |
460 | self.mlp_act = nn.GELU(approximate="tanh")
461 | self.modulation = Modulation(hidden_size, double=False)
462 |
463 | processor = SingleStreamBlockProcessor()
464 | self.set_processor(processor)
465 |
466 | def set_processor(self, processor) -> None:
467 | self.processor = processor
468 |
469 | def get_processor(self):
470 | return self.processor
471 |
472 | def forward(
473 | self,
474 | x: Tensor,
475 | vec: Tensor,
476 | pe: Tensor,
477 | image_proj: Tensor | None = None,
478 | ip_scale: float = 1.0,
479 | ) -> Tensor:
480 | if image_proj is None:
481 | return self.processor(self, x, vec, pe)
482 | else:
483 | return self.processor(self, x, vec, pe, image_proj, ip_scale)
484 |
485 |
486 | class LastLayer(nn.Module):
487 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
488 | super().__init__()
489 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
490 | self.linear = nn.Linear(
491 | hidden_size, patch_size * patch_size * out_channels, bias=True
492 | )
493 | self.adaLN_modulation = nn.Sequential(
494 | nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
495 | )
496 |
497 | def forward(self, x: Tensor, vec: Tensor) -> Tensor:
498 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
499 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
500 | x = self.linear(x)
501 | return x
502 |
503 |
504 | class SigLIPMultiFeatProjModel(torch.nn.Module):
505 | """
506 | SigLIP Multi-Feature Projection Model for processing style features from different layers
507 | and projecting them into a unified hidden space.
508 |
509 | Args:
510 | siglip_token_nums (int): Number of SigLIP tokens, default 257
511 | style_token_nums (int): Number of style tokens, default 256
512 | siglip_token_dims (int): Dimension of SigLIP tokens, default 1536
513 | hidden_size (int): Hidden layer size, default 3072
514 | context_layer_norm (bool): Whether to use context layer normalization, default False
515 | """
516 |
517 | def __init__(
518 | self,
519 | siglip_token_nums: int = 257,
520 | style_token_nums: int = 256,
521 | siglip_token_dims: int = 1536,
522 | hidden_size: int = 3072,
523 | context_layer_norm: bool = False,
524 | ):
525 | super().__init__()
526 |
527 | # High-level feature processing (layer -2)
528 | self.high_embedding_linear = nn.Sequential(
529 | nn.Linear(siglip_token_nums, style_token_nums),
530 | nn.SiLU()
531 | )
532 | self.high_layer_norm = (
533 | nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
534 | )
535 | self.high_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True)
536 |
537 | # Mid-level feature processing (layer -11)
538 | self.mid_embedding_linear = nn.Sequential(
539 | nn.Linear(siglip_token_nums, style_token_nums),
540 | nn.SiLU()
541 | )
542 | self.mid_layer_norm = (
543 | nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
544 | )
545 | self.mid_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True)
546 |
547 | # Low-level feature processing (layer -20)
548 | self.low_embedding_linear = nn.Sequential(
549 | nn.Linear(siglip_token_nums, style_token_nums),
550 | nn.SiLU()
551 | )
552 | self.low_layer_norm = (
553 | nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
554 | )
555 | self.low_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True)
556 |
557 | def forward(self, siglip_outputs):
558 | """
559 | Forward pass function
560 |
561 | Args:
562 | siglip_outputs: Output from SigLIP model, containing hidden_states
563 |
564 | Returns:
565 | torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size]
566 | """
567 | dtype = next(self.high_embedding_linear.parameters()).dtype
568 |
569 | # Process high-level features (layer -2)
570 | high_embedding = self._process_layer_features(
571 | siglip_outputs.hidden_states[-2],
572 | self.high_embedding_linear,
573 | self.high_layer_norm,
574 | self.high_projection,
575 | dtype
576 | )
577 |
578 | # Process mid-level features (layer -11)
579 | mid_embedding = self._process_layer_features(
580 | siglip_outputs.hidden_states[-11],
581 | self.mid_embedding_linear,
582 | self.mid_layer_norm,
583 | self.mid_projection,
584 | dtype
585 | )
586 |
587 | # Process low-level features (layer -20)
588 | low_embedding = self._process_layer_features(
589 | siglip_outputs.hidden_states[-20],
590 | self.low_embedding_linear,
591 | self.low_layer_norm,
592 | self.low_projection,
593 | dtype
594 | )
595 |
596 | # Concatenate features from all layers
597 | return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1)
598 |
599 | def _process_layer_features(
600 | self,
601 | hidden_states: torch.Tensor,
602 | embedding_linear: nn.Module,
603 | layer_norm: nn.Module,
604 | projection: nn.Module,
605 | dtype: torch.dtype
606 | ) -> torch.Tensor:
607 | """
608 | Helper function to process features from a single layer
609 |
610 | Args:
611 | hidden_states: Input hidden states [bs, seq_len, dim]
612 | embedding_linear: Embedding linear layer
613 | layer_norm: Layer normalization
614 | projection: Projection layer
615 | dtype: Target data type
616 |
617 | Returns:
618 | torch.Tensor: Processed features [bs, style_token_nums, hidden_size]
619 | """
620 | # Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim]
621 | embedding = embedding_linear(
622 | hidden_states.to(dtype).transpose(1, 2)
623 | ).transpose(1, 2)
624 |
625 | # Apply layer normalization
626 | embedding = layer_norm(embedding)
627 |
628 | # Project to target hidden space
629 | embedding = projection(embedding)
630 |
631 | return embedding
632 |
--------------------------------------------------------------------------------