├── rh_config.json ├── __init__.py ├── requirements.txt ├── .gitignore ├── uso └── flux │ ├── math.py │ ├── modules │ ├── conditioner.py │ ├── autoencoder.py │ └── layers.py │ ├── sampling.py │ ├── model.py │ ├── pipeline.py │ └── util.py ├── README_CN.md ├── README.md ├── rh_uso_nodes.py ├── inference.py ├── app.py └── LICENSE /rh_config.json: -------------------------------------------------------------------------------- 1 | {"enable": true, "untracked_paths": []} -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .rh_uso_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate>=1.1.1 2 | deepspeed>=0.14.4 3 | einops>=0.8.0 4 | transformers>=4.43.3 5 | huggingface-hub 6 | diffusers>=0.30.1 7 | sentencepiece>=0.2.0 8 | gradio>=5.22.0 9 | opencv-python 10 | matplotlib 11 | safetensors>=0.4.5 12 | scipy>=1.10.1 13 | numpy>=1.24.4 14 | onnxruntime-gpu 15 | # httpx==0.23.3 16 | git+https://github.com/openai/CLIP.git 17 | --extra-index-url https://download.pytorch.org/whl/cu124 18 | torch>=2.4.0 19 | torchvision>=0.19.0 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # PyInstaller 24 | *.manifest 25 | *.spec 26 | 27 | # Virtual environments 28 | .env 29 | .venv 30 | env/ 31 | venv/ 32 | ENV/ 33 | env.bak/ 34 | venv.bak/ 35 | 36 | # IDE 37 | .vscode/ 38 | .idea/ 39 | *.swp 40 | *.swo 41 | *~ 42 | 43 | # OS 44 | .DS_Store 45 | Thumbs.db 46 | 47 | # Temporary files 48 | temp/ 49 | tmp/ 50 | *.tmp 51 | *.temp 52 | 53 | # Logs 54 | *.log -------------------------------------------------------------------------------- /uso/flux/math.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from einops import rearrange 18 | from torch import Tensor 19 | 20 | 21 | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: 22 | q, k = apply_rope(q, k, pe) 23 | 24 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 25 | x = rearrange(x, "B H L D -> B L (H D)") 26 | 27 | return x 28 | 29 | 30 | def rope(pos: Tensor, dim: int, theta: int) -> Tensor: 31 | assert dim % 2 == 0 32 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim 33 | omega = 1.0 / (theta**scale) 34 | out = torch.einsum("...n,d->...nd", pos, omega) 35 | out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) 36 | out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) 37 | return out.float() 38 | 39 | 40 | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: 41 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 42 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 43 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 44 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 45 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) 46 | -------------------------------------------------------------------------------- /uso/flux/modules/conditioner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from torch import Tensor, nn 17 | from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, 18 | T5Tokenizer) 19 | 20 | 21 | class HFEmbedder(nn.Module): 22 | def __init__(self, version: str, max_length: int, **hf_kwargs): 23 | super().__init__() 24 | # self.is_clip = "clip" in version.lower() 25 | #kiki 26 | self.is_clip = hf_kwargs.pop('is_clip', False) 27 | self.max_length = max_length 28 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" 29 | 30 | if self.is_clip: 31 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) 32 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) 33 | else: 34 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) 35 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) 36 | 37 | self.hf_module = self.hf_module.eval().requires_grad_(False) 38 | 39 | def forward(self, text: list[str]) -> Tensor: 40 | batch_encoding = self.tokenizer( 41 | text, 42 | truncation=True, 43 | max_length=self.max_length, 44 | return_length=False, 45 | return_overflowing_tokens=False, 46 | padding="max_length", 47 | return_tensors="pt", 48 | ) 49 | 50 | outputs = self.hf_module( 51 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device), 52 | attention_mask=None, 53 | output_hidden_states=False, 54 | ) 55 | return outputs[self.output_key] 56 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | # ComfyUI USO 节点 2 | 3 | 一个用于ComfyUI的自定义节点,集成USO(统一风格和主题驱动生成)模型,实现高质量的风格和主题控制图像生成。 4 | 5 | ## ✨ 特性 6 | 7 | - 🎨 **统一风格与主题生成**: 基于FLUX架构的USO模型 8 | - 🎯 **风格驱动生成**: 根据特定艺术风格生成图像 9 | - 👤 **主题驱动生成**: 在生成过程中保持主题一致性 10 | - 🔄 **多风格支持**: 在单次生成中结合多种风格 11 | - ⚙️ **内存优化**: 支持FP8精度,适用于消费级GPU(约16GB显存) 12 | - 🚀 **灵活控制**: 高级参数控制,精细调节生成结果 13 | 14 | ## 🔧 节点列表 15 | 16 | ### 核心节点 17 | - **RH_USO_Loader**: 加载和初始化USO模型,包含优化选项 18 | - **RH_USO_Generator**: 具有风格和主题控制的图像生成器 19 | 20 | ## 🚀 快速安装 21 | 22 | ### 步骤1: 安装节点 23 | ```bash 24 | # 进入ComfyUI自定义节点目录 25 | cd ComfyUI/custom_nodes 26 | 27 | # 克隆仓库 28 | git clone https://github.com/HM-RunningHub/ComfyUI_RH_USO 29 | 30 | # 安装依赖 31 | cd ComfyUI_RH_USO 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ### 步骤2: 下载所需模型 36 | ```bash 37 | # 下载FLUX.1-dev模型(必需的基础模型) 38 | huggingface-cli download black-forest-labs/FLUX.1-dev flux1-dev.safetensors --local-dir models/diffusers/FLUX.1-dev 39 | huggingface-cli download black-forest-labs/FLUX.1-dev ae.safetensors --local-dir models/diffusers/FLUX.1-dev 40 | 41 | # 下载USO模型 42 | huggingface-cli download bytedance-research/USO --local-dir models/uso 43 | 44 | # 下载SigLIP模型 45 | huggingface-cli download google/siglip-so400m-patch14-384 --local-dir models/clip/siglip-so400m-patch14-384 46 | 47 | # 最终模型结构应该如下: 48 | models/ 49 | ├── diffusers/ 50 | │ └── FLUX.1-dev/ 51 | │ ├── flux1-dev.safetensors 52 | │ └── ae.safetensors 53 | ├── uso/ 54 | │ ├── assets/ 55 | │ │ └── uso.webp 56 | │ ├── config.json 57 | │ ├── download_repo_enhanced.py 58 | │ ├── README.md 59 | │ └── uso_flux_v1.0/ 60 | │ ├── dit_lora.safetensors 61 | │ └── projector.safetensors 62 | └── clip/ 63 | └── siglip-so400m-patch14-384/ 64 | 65 | # 重启ComfyUI 66 | ``` 67 | 68 | ## 📖 使用方法 69 | 70 | ### 基础工作流 71 | ``` 72 | [RH_USO_Loader] → [RH_USO_Generator] → [Save Image] 73 | ``` 74 | 75 | ### 生成类型 76 | 77 | #### 风格驱动生成 78 | - 加载风格参考图像 79 | - 输入描述内容的文本提示 80 | - 生成指定风格的图像 81 | 82 | #### 主题驱动生成 83 | - 加载主题参考图像 84 | - 输入包含场景描述的文本提示 85 | - 生成保持主题身份的图像 86 | 87 | #### 风格+主题生成 88 | - 同时加载风格和主题参考图像 89 | - 结合风格转换与主题一致性 90 | - 生成具有统一风格且保持主题的图像 91 | 92 | ## 🛠️ 技术要求 93 | 94 | - **GPU**: 16GB+显存(使用FP8优化) 95 | - **内存**: 推荐32GB+ 96 | - **存储**: 约35GB用于所有模型 97 | - FLUX.1-dev: ~24GB (flux1-dev.safetensors + ae.safetensors) 98 | - USO模型: ~6GB 99 | - SigLIP: ~1.5GB 100 | - **CUDA**: 优化性能需要CUDA支持 101 | 102 | ## ⚠️ 重要提示 103 | 104 | - **模型路径**: 模型必须放置在特定目录: 105 | - FLUX.1-dev → `models/diffusers/FLUX.1-dev/` 106 | - USO模型 → `models/uso/` 107 | - SigLIP → `models/clip/siglip-so400m-patch14-384/` 108 | - 推荐消费级GPU使用FP8模式(减少显存占用) 109 | - 所有模型文件必须在首次使用前下载完成 110 | 111 | ## 📄 许可证 112 | 113 | 本项目采用Apache 2.0许可证。 114 | 115 | ## 🔗 参考链接 116 | 117 | - [USO项目页面](https://bytedance.github.io/USO/) 118 | - [USO论文](https://arxiv.org/abs/2508.18966) 119 | - [USO HuggingFace](https://huggingface.co/bytedance-research/USO) 120 | - [ComfyUI](https://github.com/comfyanonymous/ComfyUI) 121 | 122 | ## 🤝 贡献 123 | 124 | 欢迎贡献!请随时提交问题和拉取请求。 125 | 126 | ## ⭐ 引用 127 | 128 | 如果您觉得这个项目有用,请考虑引用原始USO论文: 129 | 130 | ```bibtex 131 | @article{wu2025uso, 132 | title={USO: Unified Style and Subject-Driven Generation via Disentangled and Reward Learning}, 133 | author={Shaojin Wu and Mengqi Huang and Yufeng Cheng and Wenxu Wu and Jiahe Tian and Yiming Luo and Fei Ding and Qian He}, 134 | year={2025}, 135 | eprint={2508.18966}, 136 | archivePrefix={arXiv}, 137 | primaryClass={cs.CV}, 138 | } 139 | ``` 140 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI USO Node 2 | 3 | A custom node for ComfyUI that integrates USO (Unified Style and Subject-Driven Generation) for high-quality image generation with style and subject control. 4 | 5 | ## ✨ Features 6 | 7 | - 🎨 **Unified Style & Subject Generation**: Powered by USO model based on FLUX architecture 8 | - 🎯 **Style-Driven Generation**: Generate images with specific artistic styles 9 | - 👤 **Subject-Driven Generation**: Maintain subject consistency across generations 10 | - 🔄 **Multi-Style Support**: Combine multiple styles in a single generation 11 | - ⚙️ **Memory Optimization**: FP8 precision support for consumer-grade GPUs (~16GB VRAM) 12 | - 🚀 **Flexible Control**: Advanced parameter control for fine-tuning results 13 | 14 | ## 🔧 Node List 15 | 16 | ### Core Nodes 17 | - **RH_USO_Loader**: Load and initialize USO models with optimization options 18 | - **RH_USO_Generator**: Generate images with style and subject control 19 | 20 | ## 🚀 Quick Installation 21 | 22 | ### Step 1: Install the Node 23 | ```bash 24 | # Navigate to ComfyUI custom_nodes directory 25 | cd ComfyUI/custom_nodes 26 | 27 | # Clone the repository 28 | git clone https://github.com/HM-RunningHub/ComfyUI_RH_USO 29 | 30 | # Install dependencies 31 | cd ComfyUI_RH_USO 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ### Step 2: Download Required Models 36 | ```bash 37 | # Download FLUX.1-dev model (Required base model) 38 | huggingface-cli download black-forest-labs/FLUX.1-dev flux1-dev.safetensors --local-dir models/diffusers/FLUX.1-dev 39 | huggingface-cli download black-forest-labs/FLUX.1-dev ae.safetensors --local-dir models/diffusers/FLUX.1-dev 40 | 41 | # Download USO model 42 | huggingface-cli download bytedance-research/USO --local-dir models/uso 43 | 44 | # Download SigLIP model 45 | huggingface-cli download google/siglip-so400m-patch14-384 --local-dir models/clip/siglip-so400m-patch14-384 46 | 47 | # Final model structure should look like: 48 | models/ 49 | ├── diffusers/ 50 | │ └── FLUX.1-dev/ 51 | │ ├── flux1-dev.safetensors 52 | │ └── ae.safetensors 53 | │ └── .... 54 | ├── uso/ 55 | │ ├── assets/ 56 | │ │ └── uso.webp 57 | │ ├── config.json 58 | │ ├── download_repo_enhanced.py 59 | │ ├── README.md 60 | │ └── uso_flux_v1.0/ 61 | │ ├── dit_lora.safetensors 62 | │ └── projector.safetensors 63 | └── clip/ 64 | └── siglip-so400m-patch14-384/ 65 | 66 | # Restart ComfyUI 67 | ``` 68 | 69 | ## 📖 Usage 70 | 71 | ### Basic Workflow 72 | ``` 73 | [RH_USO_Loader] → [RH_USO_Generator] → [Save Image] 74 | ``` 75 | 76 | ### Generation Types 77 | 78 | #### Style-Driven Generation 79 | - Load style reference images 80 | - Input text prompt describing the content 81 | - Generate images in the specified style 82 | 83 | #### Subject-Driven Generation 84 | - Load subject reference image 85 | - Input text prompt with scene description 86 | - Generate images maintaining subject identity 87 | 88 | #### Style + Subject Generation 89 | - Load both style and subject reference images 90 | - Combine style transfer with subject consistency 91 | - Generate images with unified style and preserved subjects 92 | 93 | ## 🛠️ Technical Requirements 94 | 95 | - **GPU**: 16GB+ VRAM (with FP8 optimization) 96 | - **RAM**: 32GB+ recommended 97 | - **Storage**: ~35GB for all models 98 | - FLUX.1-dev: ~24GB (flux1-dev.safetensors + ae.safetensors) 99 | - USO models: ~6GB 100 | - SigLIP: ~1.5GB 101 | - **CUDA**: Required for optimal performance 102 | 103 | ## ⚠️ Important Notes 104 | 105 | - **Model Paths**: Models must be placed in specific directories: 106 | - FLUX.1-dev → `models/diffusers/FLUX.1-dev/` 107 | - USO models → `models/uso/` 108 | - SigLIP → `models/clip/siglip-so400m-patch14-384/` 109 | - FP8 mode recommended for consumer GPUs (reduces VRAM usage) 110 | - All model files must be downloaded before first use 111 | 112 | ## 📄 License 113 | 114 | This project is licensed under Apache 2.0 License. 115 | 116 | ## 🔗 References 117 | 118 | - [USO Project Page](https://bytedance.github.io/USO/) 119 | - [USO Paper](https://arxiv.org/abs/2508.18966) 120 | - [USO HuggingFace](https://huggingface.co/bytedance-research/USO) 121 | - [ComfyUI](https://github.com/comfyanonymous/ComfyUI) 122 | 123 | ## 🔗 Example 124 | image 125 | image 126 | image 127 | 128 | 129 | ## 🤝 Contributing 130 | 131 | Contributions are welcome! Please feel free to submit issues and pull requests. 132 | 133 | ## ⭐ Citation 134 | 135 | If you find this project useful, please consider citing the original USO paper: 136 | 137 | ```bibtex 138 | @article{wu2025uso, 139 | title={USO: Unified Style and Subject-Driven Generation via Disentangled and Reward Learning}, 140 | author={Shaojin Wu and Mengqi Huang and Yufeng Cheng and Wenxu Wu and Jiahe Tian and Yiming Luo and Fei Ding and Qian He}, 141 | year={2025}, 142 | eprint={2508.18966}, 143 | archivePrefix={arXiv}, 144 | primaryClass={cs.CV}, 145 | } 146 | ``` 147 | -------------------------------------------------------------------------------- /rh_uso_nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import dataclasses 3 | from typing import Literal 4 | 5 | from accelerate import Accelerator 6 | from transformers import HfArgumentParser 7 | from PIL import Image 8 | import json 9 | import itertools 10 | import torch 11 | 12 | from .uso.flux.pipeline import USOPipeline, preprocess_ref 13 | from transformers import SiglipVisionModel, SiglipImageProcessor 14 | from tqdm import tqdm 15 | import folder_paths 16 | import numpy as np 17 | import comfy.utils 18 | 19 | class RH_USO_Loader: 20 | @classmethod 21 | def INPUT_TYPES(s): 22 | return { 23 | "required": { 24 | 25 | }, 26 | } 27 | 28 | RETURN_TYPES = ("RHUSOMudules",) 29 | RETURN_NAMES = ("USO Modules",) 30 | FUNCTION = "load" 31 | 32 | CATEGORY = "Runninghub/USO" 33 | 34 | def load(self, **kwargs): 35 | # accelerator = Accelerator() 36 | device = 'cuda' 37 | siglip_path = os.path.join(folder_paths.models_dir, 'clip', 'siglip-so400m-patch14-384') 38 | siglip_processor = SiglipImageProcessor.from_pretrained( 39 | siglip_path 40 | ) 41 | siglip_model = SiglipVisionModel.from_pretrained( 42 | siglip_path 43 | ) 44 | siglip_model.eval() 45 | siglip_model.to(device) 46 | print("SigLIP model loaded successfully") 47 | 48 | # hardcode hyperparamters -kiki 49 | model_type = 'flux-dev-fp8' 50 | lora_rank = 128 51 | 52 | pipeline = USOPipeline( 53 | model_type, 54 | device, 55 | True, #args.offload, 56 | only_lora=True, 57 | lora_rank=lora_rank, 58 | hf_download=False, 59 | ) 60 | if siglip_model is not None: 61 | pipeline.model.vision_encoder = siglip_model 62 | print('-----> hook siglip encoder') 63 | return ({'siglip_processor':siglip_processor, 'pipeline':pipeline}, ) 64 | 65 | class RH_USO_Sampler: 66 | 67 | @classmethod 68 | def INPUT_TYPES(s): 69 | return { 70 | "required": { 71 | "uso": ("RHUSOMudules", ), 72 | "prompt": ("STRING", {"multiline": True, 73 | 'default': ''}), 74 | "width": ("INT", {"default": 1024}), 75 | "height": ("INT", {"default": 1024}), 76 | "num_inference_steps": ("INT", {"default": 25}), 77 | "guidance": ("FLOAT", {"default": 4.0}), 78 | "seed": ("INT", {"default": 20, "min": 0, "max": 0xffffffffffffffff, 79 | "tooltip": "The random seed used for creating the noise."}), 80 | }, 81 | "optional": { 82 | "content_image": ("IMAGE", ), 83 | "style_image": ("IMAGE", ), 84 | "style2_image": ("IMAGE", ), 85 | } 86 | } 87 | 88 | RETURN_TYPES = ("IMAGE",) 89 | RETURN_NAMES = ("image",) 90 | FUNCTION = "sample" 91 | 92 | CATEGORY = "Runninghub/USO" 93 | 94 | def tensor_2_pil(self, img_tensor): 95 | if img_tensor is not None: 96 | i = 255. * img_tensor.squeeze().cpu().numpy() 97 | img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) 98 | return img 99 | else: 100 | return None 101 | 102 | def preprocess_ref(self, raw_image: Image.Image, long_size: int = 512, scale_ratio: int = 1): 103 | # 获取原始图像的宽度和高度 104 | image_w, image_h = raw_image.size 105 | if image_w == image_h and image_w == 16: 106 | return raw_image 107 | 108 | # 计算长边和短边 109 | if image_w >= image_h: 110 | new_w = long_size 111 | new_h = int((long_size / image_w) * image_h) 112 | else: 113 | new_h = long_size 114 | new_w = int((long_size / image_h) * image_w) 115 | 116 | # 按新的宽高进行等比例缩放 117 | raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS) 118 | 119 | # 为了能让canny img进行scale 120 | scale_ratio = int(scale_ratio) 121 | target_w = new_w // (16 * scale_ratio) * (16 * scale_ratio) 122 | target_h = new_h // (16 * scale_ratio) * (16 * scale_ratio) 123 | 124 | # 计算裁剪的起始坐标以实现中心裁剪 125 | left = (new_w - target_w) // 2 126 | top = (new_h - target_h) // 2 127 | right = left + target_w 128 | bottom = top + target_h 129 | 130 | # 进行中心裁剪 131 | raw_image = raw_image.crop((left, top, right, bottom)) 132 | 133 | # 转换为 RGB 模式 134 | raw_image = raw_image.convert("RGB") 135 | return raw_image 136 | 137 | def sample(self, **kwargs): 138 | ref_imgs = [] 139 | content_image = self.tensor_2_pil(kwargs.get('content_image', None)) 140 | style_image = self.tensor_2_pil(kwargs.get('style_image', None)) 141 | style2_image = self.tensor_2_pil(kwargs.get('style2_image', None)) 142 | print(f'conds-c/s1/s2:{content_image is not None} {style_image is not None} {style2_image is not None}') 143 | ref_imgs.append(content_image) 144 | if style_image is not None: 145 | ref_imgs.append(style_image) 146 | if style2_image is not None: 147 | ref_imgs.append(style2_image) 148 | siglip_inputs = None 149 | 150 | width = kwargs.get('width') 151 | height = kwargs.get('height') 152 | prompt = kwargs.get('prompt') 153 | guidance = kwargs.get('guidance') 154 | num_steps = kwargs.get('num_inference_steps') 155 | seed = kwargs.get('seed') % (2 ** 32) 156 | 157 | # hardcode hyperparameters -kiki 158 | content_ref = 512 159 | pe = 'd' 160 | 161 | uso = kwargs.get('uso') 162 | siglip_processor = uso['siglip_processor'] 163 | pipeline = uso['pipeline'] 164 | with torch.no_grad(): 165 | siglip_inputs = [ 166 | siglip_processor(img, return_tensors="pt").to(pipeline.device) 167 | for img in ref_imgs[1:] if isinstance(img, Image.Image) 168 | ] 169 | 170 | ref_imgs_pil = [ 171 | self.preprocess_ref(img, content_ref) for img in ref_imgs[:1] if isinstance(img, Image.Image) 172 | ] 173 | self.pbar = comfy.utils.ProgressBar(num_steps) 174 | 175 | image_gen = pipeline( 176 | prompt=prompt, 177 | width=width, 178 | height=height, 179 | guidance=guidance, 180 | num_steps=num_steps, 181 | seed=seed, 182 | ref_imgs=ref_imgs_pil, 183 | pe=pe, 184 | siglip_inputs=siglip_inputs, 185 | update_func=self.update, 186 | ) 187 | 188 | image = np.array(image_gen).astype(np.float32) / 255.0 189 | image = torch.from_numpy(image)[None,] 190 | 191 | return (image, ) 192 | 193 | def update(self): 194 | self.pbar.update(1) 195 | 196 | 197 | NODE_CLASS_MAPPINGS = { 198 | "RunningHub USO Loader": RH_USO_Loader, 199 | "RunningHub USO Sampler":RH_USO_Sampler, 200 | } 201 | 202 | NODE_DISPLAY_NAME_MAPPINGS = { 203 | "RunningHub USO Loader": "RunningHub USO Loader", 204 | "RunningHub USO Sampler": "RunningHub USO Sampler", 205 | } -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import dataclasses 17 | from typing import Literal 18 | 19 | from accelerate import Accelerator 20 | from transformers import HfArgumentParser 21 | from PIL import Image 22 | import json 23 | import itertools 24 | import torch 25 | 26 | from uso.flux.pipeline import USOPipeline, preprocess_ref 27 | from transformers import SiglipVisionModel, SiglipImageProcessor 28 | from tqdm import tqdm 29 | 30 | 31 | def horizontal_concat(images): 32 | widths, heights = zip(*(img.size for img in images)) 33 | 34 | total_width = sum(widths) 35 | max_height = max(heights) 36 | 37 | new_im = Image.new("RGB", (total_width, max_height)) 38 | 39 | x_offset = 0 40 | for img in images: 41 | new_im.paste(img, (x_offset, 0)) 42 | x_offset += img.size[0] 43 | 44 | return new_im 45 | 46 | 47 | @dataclasses.dataclass 48 | class InferenceArgs: 49 | prompt: str | None = None 50 | image_paths: list[str] | None = None 51 | eval_json_path: str | None = None 52 | # offload: bool = False 53 | offload: bool = True 54 | num_images_per_prompt: int = 1 55 | model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev-fp8" 56 | width: int = 1024 57 | height: int = 1024 58 | num_steps: int = 25 59 | guidance: float = 4 60 | seed: int = 3407 61 | save_path: str = "output/inference" 62 | only_lora: bool = True 63 | concat_refs: bool = False 64 | lora_rank: int = 128 65 | pe: Literal["d", "h", "w", "o"] = "d" 66 | content_ref: int = 512 67 | ckpt_path: str | None = None 68 | use_siglip: bool = True 69 | instruct_edit: bool = False 70 | hf_download: bool = True 71 | 72 | 73 | def main(args: InferenceArgs): 74 | accelerator = Accelerator() 75 | 76 | # init SigLIP model 77 | siglip_processor = None 78 | siglip_model = None 79 | 80 | siglip_path = '/workspace/comfyui/models/clip/siglip-so400m-patch14-384' 81 | if args.use_siglip: 82 | siglip_processor = SiglipImageProcessor.from_pretrained( 83 | # "google/siglip-so400m-patch14-384" 84 | siglip_path 85 | ) 86 | siglip_model = SiglipVisionModel.from_pretrained( 87 | # "google/siglip-so400m-patch14-384" 88 | siglip_path 89 | ) 90 | siglip_model.eval() 91 | siglip_model.to(accelerator.device) 92 | print("SigLIP model loaded successfully") 93 | 94 | pipeline = USOPipeline( 95 | args.model_type, 96 | accelerator.device, 97 | args.offload, 98 | only_lora=args.only_lora, 99 | lora_rank=args.lora_rank, 100 | hf_download=args.hf_download, 101 | ) 102 | if args.use_siglip and siglip_model is not None: 103 | pipeline.model.vision_encoder = siglip_model 104 | print('-----> hook siglip encoder') 105 | 106 | assert ( 107 | args.prompt is not None or args.eval_json_path is not None 108 | ), "Please provide either prompt or eval_json_path" 109 | 110 | if args.eval_json_path is not None: 111 | with open(args.eval_json_path, "rt") as f: 112 | data_dicts = json.load(f) 113 | data_root = os.path.dirname(args.eval_json_path) 114 | else: 115 | data_root = "" 116 | data_dicts = [{"prompt": args.prompt, "image_paths": args.image_paths}] 117 | 118 | print( 119 | f"process: {accelerator.num_processes}/{accelerator.process_index}, \ 120 | process images: {len(data_dicts)}/{len(data_dicts[accelerator.process_index::accelerator.num_processes])}" 121 | ) 122 | 123 | data_dicts = data_dicts[accelerator.process_index :: accelerator.num_processes] 124 | 125 | accelerator.wait_for_everyone() 126 | local_task_count = len(data_dicts) * args.num_images_per_prompt 127 | if accelerator.is_main_process: 128 | progress_bar = tqdm(total=local_task_count, desc="Generating Images") 129 | 130 | for (i, data_dict), j in itertools.product( 131 | enumerate(data_dicts), range(args.num_images_per_prompt) 132 | ): 133 | ref_imgs = [] 134 | for _, img_path in enumerate(data_dict["image_paths"]): 135 | if img_path != "": 136 | img = Image.open(os.path.join(data_root, img_path)).convert("RGB") 137 | ref_imgs.append(img) 138 | else: 139 | ref_imgs.append(None) 140 | siglip_inputs = None 141 | if args.use_siglip and siglip_processor is not None: 142 | with torch.no_grad(): 143 | siglip_inputs = [ 144 | siglip_processor(img, return_tensors="pt").to(pipeline.device) 145 | for img in ref_imgs[1:] if isinstance(img, Image.Image) 146 | ] 147 | 148 | ref_imgs_pil = [ 149 | preprocess_ref(img, args.content_ref) for img in ref_imgs[:1] if isinstance(img, Image.Image) 150 | ] 151 | 152 | if args.instruct_edit: 153 | args.width, args.height = ref_imgs_pil[0].size 154 | args.width, args.height = args.width * (1024 / args.content_ref), args.height * (1024 / args.content_ref) 155 | image_gen = pipeline( 156 | prompt=data_dict["prompt"], 157 | width=args.width, 158 | height=args.height, 159 | guidance=args.guidance, 160 | num_steps=args.num_steps, 161 | seed=args.seed + j, 162 | ref_imgs=ref_imgs_pil, 163 | pe=args.pe, 164 | siglip_inputs=siglip_inputs, 165 | ) 166 | if args.concat_refs: 167 | image_gen = horizontal_concat([image_gen, *ref_imgs]) 168 | 169 | if "save_dir" in data_dict: 170 | config_save_path = os.path.join(args.save_path, data_dict["save_dir"] + f"_{j}.json") 171 | image_save_path = os.path.join(args.save_path, data_dict["save_dir"] + f"_{j}.png") 172 | else: 173 | os.makedirs(args.save_path, exist_ok=True) 174 | config_save_path = os.path.join(args.save_path, f"{i}_{j}.json") 175 | image_save_path = os.path.join(args.save_path, f"{i}_{j}.png") 176 | 177 | # save config and image 178 | os.makedirs(os.path.dirname(image_save_path), exist_ok=True) 179 | image_gen.save(image_save_path) 180 | # ensure the prompt and image_paths are saved in the config file 181 | args.prompt = data_dict["prompt"] 182 | args.image_paths = data_dict["image_paths"] 183 | args_dict = vars(args) 184 | with open(config_save_path, "w") as f: 185 | json.dump(args_dict, f, indent=4) 186 | 187 | if accelerator.is_main_process: 188 | progress_bar.update(1) 189 | if accelerator.is_main_process: 190 | progress_bar.close() 191 | 192 | 193 | if __name__ == "__main__": 194 | parser = HfArgumentParser([InferenceArgs]) 195 | args = parser.parse_args_into_dataclasses()[0] 196 | main(args) 197 | -------------------------------------------------------------------------------- /uso/flux/sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import math 17 | from typing import Literal 18 | 19 | import torch 20 | from einops import rearrange, repeat 21 | from torch import Tensor 22 | from tqdm import tqdm 23 | 24 | from .model import Flux 25 | from .modules.conditioner import HFEmbedder 26 | 27 | 28 | def get_noise( 29 | num_samples: int, 30 | height: int, 31 | width: int, 32 | device: torch.device, 33 | dtype: torch.dtype, 34 | seed: int, 35 | ): 36 | return torch.randn( 37 | num_samples, 38 | 16, 39 | # allow for packing 40 | 2 * math.ceil(height / 16), 41 | 2 * math.ceil(width / 16), 42 | device=device, 43 | dtype=dtype, 44 | generator=torch.Generator(device=device).manual_seed(seed), 45 | ) 46 | 47 | 48 | def prepare( 49 | t5: HFEmbedder, 50 | clip: HFEmbedder, 51 | img: Tensor, 52 | prompt: str | list[str], 53 | ref_img: None | Tensor = None, 54 | pe: Literal["d", "h", "w", "o"] = "d", 55 | ) -> dict[str, Tensor]: 56 | assert pe in ["d", "h", "w", "o"] 57 | bs, c, h, w = img.shape 58 | if bs == 1 and not isinstance(prompt, str): 59 | bs = len(prompt) 60 | 61 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 62 | if img.shape[0] == 1 and bs > 1: 63 | img = repeat(img, "1 ... -> bs ...", bs=bs) 64 | 65 | img_ids = torch.zeros(h // 2, w // 2, 3) 66 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 67 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 68 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 69 | 70 | if ref_img is not None: 71 | _, _, ref_h, ref_w = ref_img.shape 72 | ref_img = rearrange( 73 | ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2 74 | ) 75 | if ref_img.shape[0] == 1 and bs > 1: 76 | ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) 77 | ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3) 78 | # img id分别在宽高偏移各自最大值 79 | h_offset = h // 2 if pe in {"d", "h"} else 0 80 | w_offset = w // 2 if pe in {"d", "w"} else 0 81 | ref_img_ids[..., 1] = ( 82 | ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + h_offset 83 | ) 84 | ref_img_ids[..., 2] = ( 85 | ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + w_offset 86 | ) 87 | ref_img_ids = repeat(ref_img_ids, "h w c -> b (h w) c", b=bs) 88 | 89 | if isinstance(prompt, str): 90 | prompt = [prompt] 91 | txt = t5(prompt) 92 | if txt.shape[0] == 1 and bs > 1: 93 | txt = repeat(txt, "1 ... -> bs ...", bs=bs) 94 | txt_ids = torch.zeros(bs, txt.shape[1], 3) 95 | 96 | vec = clip(prompt) 97 | if vec.shape[0] == 1 and bs > 1: 98 | vec = repeat(vec, "1 ... -> bs ...", bs=bs) 99 | 100 | if ref_img is not None: 101 | return { 102 | "img": img, 103 | "img_ids": img_ids.to(img.device), 104 | "ref_img": ref_img, 105 | "ref_img_ids": ref_img_ids.to(img.device), 106 | "txt": txt.to(img.device), 107 | "txt_ids": txt_ids.to(img.device), 108 | "vec": vec.to(img.device), 109 | } 110 | else: 111 | return { 112 | "img": img, 113 | "img_ids": img_ids.to(img.device), 114 | "txt": txt.to(img.device), 115 | "txt_ids": txt_ids.to(img.device), 116 | "vec": vec.to(img.device), 117 | } 118 | 119 | 120 | def prepare_multi_ip( 121 | t5: HFEmbedder, 122 | clip: HFEmbedder, 123 | img: Tensor, 124 | prompt: str | list[str], 125 | ref_imgs: list[Tensor] | None = None, 126 | pe: Literal["d", "h", "w", "o"] = "d", 127 | ) -> dict[str, Tensor]: 128 | assert pe in ["d", "h", "w", "o"] 129 | bs, c, h, w = img.shape 130 | if bs == 1 and not isinstance(prompt, str): 131 | bs = len(prompt) 132 | 133 | # tgt img 134 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 135 | if img.shape[0] == 1 and bs > 1: 136 | img = repeat(img, "1 ... -> bs ...", bs=bs) 137 | 138 | img_ids = torch.zeros(h // 2, w // 2, 3) 139 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 140 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 141 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 142 | 143 | ref_img_ids = [] 144 | ref_imgs_list = [] 145 | 146 | pe_shift_w, pe_shift_h = w // 2, h // 2 147 | for ref_img in ref_imgs: 148 | _, _, ref_h1, ref_w1 = ref_img.shape 149 | ref_img = rearrange( 150 | ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2 151 | ) 152 | if ref_img.shape[0] == 1 and bs > 1: 153 | ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) 154 | ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3) 155 | # img id分别在宽高偏移各自最大值 156 | h_offset = pe_shift_h if pe in {"d", "h"} else 0 157 | w_offset = pe_shift_w if pe in {"d", "w"} else 0 158 | ref_img_ids1[..., 1] = ( 159 | ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset 160 | ) 161 | ref_img_ids1[..., 2] = ( 162 | ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset 163 | ) 164 | ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs) 165 | ref_img_ids.append(ref_img_ids1) 166 | ref_imgs_list.append(ref_img) 167 | 168 | # 更新pe shift 169 | pe_shift_h += ref_h1 // 2 170 | pe_shift_w += ref_w1 // 2 171 | 172 | if isinstance(prompt, str): 173 | prompt = [prompt] 174 | txt = t5(prompt) 175 | if txt.shape[0] == 1 and bs > 1: 176 | txt = repeat(txt, "1 ... -> bs ...", bs=bs) 177 | txt_ids = torch.zeros(bs, txt.shape[1], 3) 178 | 179 | vec = clip(prompt) 180 | if vec.shape[0] == 1 and bs > 1: 181 | vec = repeat(vec, "1 ... -> bs ...", bs=bs) 182 | 183 | return { 184 | "img": img, 185 | "img_ids": img_ids.to(img.device), 186 | "ref_img": tuple(ref_imgs_list), 187 | "ref_img_ids": [ref_img_id.to(img.device) for ref_img_id in ref_img_ids], 188 | "txt": txt.to(img.device), 189 | "txt_ids": txt_ids.to(img.device), 190 | "vec": vec.to(img.device), 191 | } 192 | 193 | 194 | def time_shift(mu: float, sigma: float, t: Tensor): 195 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 196 | 197 | 198 | def get_lin_function( 199 | x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 200 | ): 201 | m = (y2 - y1) / (x2 - x1) 202 | b = y1 - m * x1 203 | return lambda x: m * x + b 204 | 205 | 206 | def get_schedule( 207 | num_steps: int, 208 | image_seq_len: int, 209 | base_shift: float = 0.5, 210 | max_shift: float = 1.15, 211 | shift: bool = True, 212 | ) -> list[float]: 213 | # extra step for zero 214 | timesteps = torch.linspace(1, 0, num_steps + 1) 215 | 216 | # shifting the schedule to favor high timesteps for higher signal images 217 | if shift: 218 | # eastimate mu based on linear estimation between two points 219 | mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) 220 | timesteps = time_shift(mu, 1.0, timesteps) 221 | 222 | return timesteps.tolist() 223 | 224 | 225 | def denoise( 226 | model: Flux, 227 | # model input 228 | img: Tensor, 229 | img_ids: Tensor, 230 | txt: Tensor, 231 | txt_ids: Tensor, 232 | vec: Tensor, 233 | # sampling parameters 234 | timesteps: list[float], 235 | guidance: float = 4.0, 236 | ref_img: Tensor = None, 237 | ref_img_ids: Tensor = None, 238 | siglip_inputs: list[Tensor] | None = None, 239 | #kiki 240 | update_func = None, 241 | ): 242 | i = 0 243 | guidance_vec = torch.full( 244 | (img.shape[0],), guidance, device=img.device, dtype=img.dtype 245 | ) 246 | for t_curr, t_prev in tqdm( 247 | zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1 248 | ): 249 | if update_func is not None: 250 | update_func() 251 | # for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): 252 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) 253 | pred = model( 254 | img=img, 255 | img_ids=img_ids, 256 | ref_img=ref_img, 257 | ref_img_ids=ref_img_ids, 258 | txt=txt, 259 | txt_ids=txt_ids, 260 | y=vec, 261 | timesteps=t_vec, 262 | guidance=guidance_vec, 263 | siglip_inputs=siglip_inputs, 264 | ) 265 | img = img + (t_prev - t_curr) * pred 266 | i += 1 267 | return img 268 | 269 | 270 | def unpack(x: Tensor, height: int, width: int) -> Tensor: 271 | return rearrange( 272 | x, 273 | "b (h w) (c ph pw) -> b c (h ph) (w pw)", 274 | h=math.ceil(height / 16), 275 | w=math.ceil(width / 16), 276 | ph=2, 277 | pw=2, 278 | ) 279 | -------------------------------------------------------------------------------- /uso/flux/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass 17 | 18 | import torch 19 | from torch import Tensor, nn 20 | 21 | from .modules.layers import ( 22 | DoubleStreamBlock, 23 | EmbedND, 24 | LastLayer, 25 | MLPEmbedder, 26 | SingleStreamBlock, 27 | timestep_embedding, 28 | SigLIPMultiFeatProjModel, 29 | ) 30 | import os 31 | 32 | 33 | @dataclass 34 | class FluxParams: 35 | in_channels: int 36 | vec_in_dim: int 37 | context_in_dim: int 38 | hidden_size: int 39 | mlp_ratio: float 40 | num_heads: int 41 | depth: int 42 | depth_single_blocks: int 43 | axes_dim: list[int] 44 | theta: int 45 | qkv_bias: bool 46 | guidance_embed: bool 47 | 48 | 49 | class Flux(nn.Module): 50 | """ 51 | Transformer model for flow matching on sequences. 52 | """ 53 | 54 | _supports_gradient_checkpointing = True 55 | 56 | def __init__(self, params: FluxParams): 57 | super().__init__() 58 | 59 | self.params = params 60 | self.in_channels = params.in_channels 61 | self.out_channels = self.in_channels 62 | if params.hidden_size % params.num_heads != 0: 63 | raise ValueError( 64 | f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" 65 | ) 66 | pe_dim = params.hidden_size // params.num_heads 67 | if sum(params.axes_dim) != pe_dim: 68 | raise ValueError( 69 | f"Got {params.axes_dim} but expected positional dim {pe_dim}" 70 | ) 71 | self.hidden_size = params.hidden_size 72 | self.num_heads = params.num_heads 73 | self.pe_embedder = EmbedND( 74 | dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim 75 | ) 76 | self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 77 | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 78 | self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 79 | self.guidance_in = ( 80 | MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 81 | if params.guidance_embed 82 | else nn.Identity() 83 | ) 84 | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 85 | 86 | self.double_blocks = nn.ModuleList( 87 | [ 88 | DoubleStreamBlock( 89 | self.hidden_size, 90 | self.num_heads, 91 | mlp_ratio=params.mlp_ratio, 92 | qkv_bias=params.qkv_bias, 93 | ) 94 | for _ in range(params.depth) 95 | ] 96 | ) 97 | 98 | self.single_blocks = nn.ModuleList( 99 | [ 100 | SingleStreamBlock( 101 | self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio 102 | ) 103 | for _ in range(params.depth_single_blocks) 104 | ] 105 | ) 106 | 107 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 108 | self.gradient_checkpointing = False 109 | 110 | # feature embedder for siglip multi-feat inputs 111 | self.feature_embedder = SigLIPMultiFeatProjModel( 112 | siglip_token_nums=729, 113 | style_token_nums=64, 114 | siglip_token_dims=1152, 115 | hidden_size=self.hidden_size, 116 | context_layer_norm=True, 117 | ) 118 | print("use semantic encoder siglip multi-feat to encode style image") 119 | 120 | self.vision_encoder = None 121 | 122 | def _set_gradient_checkpointing(self, module, value=False): 123 | if hasattr(module, "gradient_checkpointing"): 124 | module.gradient_checkpointing = value 125 | 126 | @property 127 | def attn_processors(self): 128 | # set recursively 129 | processors = {} # type: dict[str, nn.Module] 130 | 131 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): 132 | if hasattr(module, "set_processor"): 133 | processors[f"{name}.processor"] = module.processor 134 | 135 | for sub_name, child in module.named_children(): 136 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 137 | 138 | return processors 139 | 140 | for name, module in self.named_children(): 141 | fn_recursive_add_processors(name, module, processors) 142 | 143 | return processors 144 | 145 | def set_attn_processor(self, processor): 146 | r""" 147 | Sets the attention processor to use to compute attention. 148 | 149 | Parameters: 150 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 151 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 152 | for **all** `Attention` layers. 153 | 154 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 155 | processor. This is strongly recommended when setting trainable attention processors. 156 | 157 | """ 158 | count = len(self.attn_processors.keys()) 159 | 160 | if isinstance(processor, dict) and len(processor) != count: 161 | raise ValueError( 162 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 163 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 164 | ) 165 | 166 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 167 | if hasattr(module, "set_processor"): 168 | if not isinstance(processor, dict): 169 | module.set_processor(processor) 170 | else: 171 | module.set_processor(processor.pop(f"{name}.processor")) 172 | 173 | for sub_name, child in module.named_children(): 174 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 175 | 176 | for name, module in self.named_children(): 177 | fn_recursive_attn_processor(name, module, processor) 178 | 179 | def forward( 180 | self, 181 | img: Tensor, 182 | img_ids: Tensor, 183 | txt: Tensor, 184 | txt_ids: Tensor, 185 | timesteps: Tensor, 186 | y: Tensor, 187 | guidance: Tensor | None = None, 188 | ref_img: Tensor | None = None, 189 | ref_img_ids: Tensor | None = None, 190 | siglip_inputs: list[Tensor] | None = None, 191 | ) -> Tensor: 192 | if img.ndim != 3 or txt.ndim != 3: 193 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 194 | 195 | # running on sequences img 196 | img = self.img_in(img) 197 | vec = self.time_in(timestep_embedding(timesteps, 256)) 198 | if self.params.guidance_embed: 199 | if guidance is None: 200 | raise ValueError( 201 | "Didn't get guidance strength for guidance distilled model." 202 | ) 203 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 204 | vec = vec + self.vector_in(y) 205 | txt = self.txt_in(txt) 206 | if self.feature_embedder is not None and siglip_inputs is not None and len(siglip_inputs) > 0 and self.vision_encoder is not None: 207 | # processing style feat into textural hidden space 208 | siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in siglip_inputs] 209 | # siglip_embedding = [self.vision_encoder(**(emb.to(torch.bfloat16)), output_hidden_states=True) for emb in siglip_inputs] 210 | siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1) 211 | txt = torch.cat((siglip_embedding, txt), dim=1) 212 | siglip_embedding_ids = torch.zeros( 213 | siglip_embedding.shape[0], siglip_embedding.shape[1], 3 214 | ).to(txt_ids.device) 215 | txt_ids = torch.cat((siglip_embedding_ids, txt_ids), dim=1) 216 | 217 | ids = torch.cat((txt_ids, img_ids), dim=1) 218 | 219 | # concat ref_img/img 220 | img_end = img.shape[1] 221 | if ref_img is not None: 222 | if isinstance(ref_img, tuple) or isinstance(ref_img, list): 223 | img_in = [img] + [self.img_in(ref) for ref in ref_img] 224 | img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids] 225 | img = torch.cat(img_in, dim=1) 226 | ids = torch.cat(img_ids, dim=1) 227 | else: 228 | img = torch.cat((img, self.img_in(ref_img)), dim=1) 229 | ids = torch.cat((ids, ref_img_ids), dim=1) 230 | pe = self.pe_embedder(ids) 231 | 232 | for index_block, block in enumerate(self.double_blocks): 233 | if self.training and self.gradient_checkpointing: 234 | img, txt = torch.utils.checkpoint.checkpoint( 235 | block, 236 | img=img, 237 | txt=txt, 238 | vec=vec, 239 | pe=pe, 240 | use_reentrant=False, 241 | ) 242 | else: 243 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe) 244 | 245 | img = torch.cat((txt, img), 1) 246 | for block in self.single_blocks: 247 | if self.training and self.gradient_checkpointing: 248 | img = torch.utils.checkpoint.checkpoint( 249 | block, img, vec=vec, pe=pe, use_reentrant=False 250 | ) 251 | else: 252 | img = block(img, vec=vec, pe=pe) 253 | img = img[:, txt.shape[1] :, ...] 254 | # index img 255 | img = img[:, :img_end, ...] 256 | 257 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 258 | return img 259 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import dataclasses 16 | import json 17 | import os 18 | from pathlib import Path 19 | 20 | import gradio as gr 21 | import torch 22 | 23 | from uso.flux.pipeline import USOPipeline 24 | from transformers import SiglipVisionModel, SiglipImageProcessor 25 | 26 | 27 | with open("assets/uso_text.svg", "r", encoding="utf-8") as svg_file: 28 | text_content = svg_file.read() 29 | 30 | with open("assets/uso_logo.svg", "r", encoding="utf-8") as svg_file: 31 | logo_content = svg_file.read() 32 | 33 | title = f""" 34 |
35 | {text_content} 36 | by UXO Team 37 | {logo_content} 38 |
39 | """.strip() 40 | 41 | badges_text = r""" 42 |
43 | 44 | Build 45 | Build 46 | 47 |
48 | """.strip() 49 | 50 | tips = """ 51 | **What is USO?** 🎨 52 | USO is a unified style-subject optimized customization model and the latest addition to the UXO family ( USO and UNO). 53 | It can freely combine any subjects with any styles in any scenarios. 54 | 55 | **How to use?** 💡 56 | We provide step-by-step instructions in our Github Repo. 57 | Additionally, try the examples provided below the demo to quickly get familiar with USO and spark your creativity! 58 | 59 |
60 | The model is trained on 1024x1024 resolution and supports 3 types of usage. 📌 Tips: 61 | 62 | * **Only content img**: support following types: 63 | * Subject/Identity-driven (supports natural prompt, e.g., *A clock on the table.* *The woman near the sea.*, excels in producing **photorealistic portraits**) 64 | * Style edit (layout-preserved): *Transform the image into Ghibli style/Pixel style/Retro comic style/Watercolor painting style...*. 65 | * Style edit (layout-shift): *Ghibli style, the man on the beach.*. 66 | * **Only style img**: Reference input style and generate anything following prompt. Excelling in this and further support multiple style references (in beta). 67 | * **Content img + style img**: Place the content into the desired style. 68 | * Layout-preserved: set prompt to **empty**. 69 | * Layout-shift: using natural prompt.
""" 70 | 71 | star = r""" 72 | If USO is helpful, please help to ⭐ our Github Repo. Thanks a lot!""" 73 | 74 | def get_examples(examples_dir: str = "assets/examples") -> list: 75 | examples = Path(examples_dir) 76 | ans = [] 77 | for example in examples.iterdir(): 78 | if not example.is_dir() or len(os.listdir(example)) == 0: 79 | continue 80 | with open(example / "config.json") as f: 81 | example_dict = json.load(f) 82 | 83 | 84 | example_list = [] 85 | example_list.append(example_dict["prompt"]) # prompt 86 | 87 | for key in ["image_ref1", "image_ref2", "image_ref3"]: 88 | if key in example_dict: 89 | example_list.append(str(example / example_dict[key])) 90 | else: 91 | example_list.append(None) 92 | 93 | example_list.append(example_dict["seed"]) 94 | ans.append(example_list) 95 | return ans 96 | 97 | 98 | def create_demo( 99 | model_type: str, 100 | device: str = "cuda" if torch.cuda.is_available() else "cpu", 101 | offload: bool = False, 102 | ): 103 | pipeline = USOPipeline( 104 | model_type, device, offload, only_lora=True, lora_rank=128, hf_download=True 105 | ) 106 | print("USOPipeline loaded successfully") 107 | 108 | siglip_processor = SiglipImageProcessor.from_pretrained( 109 | "google/siglip-so400m-patch14-384" 110 | ) 111 | siglip_model = SiglipVisionModel.from_pretrained( 112 | "google/siglip-so400m-patch14-384" 113 | ) 114 | siglip_model.eval() 115 | siglip_model.to(device) 116 | pipeline.model.vision_encoder = siglip_model 117 | pipeline.model.vision_encoder_processor = siglip_processor 118 | print("SigLIP model loaded successfully") 119 | 120 | with gr.Blocks() as demo: 121 | gr.Markdown(title) 122 | gr.Markdown(badges_text) 123 | gr.Markdown(tips) 124 | with gr.Row(): 125 | with gr.Column(): 126 | prompt = gr.Textbox(label="Prompt", value="A beautiful woman.") 127 | with gr.Row(): 128 | image_prompt1 = gr.Image( 129 | label="Content Reference Img", visible=True, interactive=True, type="pil" 130 | ) 131 | image_prompt2 = gr.Image( 132 | label="Style Reference Img", visible=True, interactive=True, type="pil" 133 | ) 134 | image_prompt3 = gr.Image( 135 | label="Extra Style Reference Img (Beta)", visible=True, interactive=True, type="pil" 136 | ) 137 | 138 | with gr.Row(): 139 | with gr.Row(): 140 | width = gr.Slider( 141 | 512, 1536, 1024, step=16, label="Generation Width" 142 | ) 143 | height = gr.Slider( 144 | 512, 1536, 1024, step=16, label="Generation Height" 145 | ) 146 | with gr.Row(): 147 | with gr.Row(): 148 | keep_size = gr.Checkbox( 149 | label="Keep input size", 150 | value=False, 151 | interactive=True 152 | ) 153 | with gr.Column(): 154 | gr.Markdown("Set it to True if you only need style editing or want to keep the layout.") 155 | 156 | with gr.Accordion("Advanced Options", open=True): 157 | with gr.Row(): 158 | num_steps = gr.Slider( 159 | 1, 50, 25, step=1, label="Number of steps" 160 | ) 161 | guidance = gr.Slider( 162 | 1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True 163 | ) 164 | content_long_size = gr.Slider( 165 | 0, 1024, 512, step=16, label="Content reference size" 166 | ) 167 | seed = gr.Number(-1, label="Seed (-1 for random)") 168 | 169 | generate_btn = gr.Button("Generate") 170 | gr.Markdown(star) 171 | 172 | with gr.Column(): 173 | output_image = gr.Image(label="Generated Image") 174 | download_btn = gr.File( 175 | label="Download full-resolution", type="filepath", interactive=False 176 | ) 177 | 178 | inputs = [ 179 | prompt, 180 | image_prompt1, 181 | image_prompt2, 182 | image_prompt3, 183 | seed, 184 | width, 185 | height, 186 | guidance, 187 | num_steps, 188 | keep_size, 189 | content_long_size, 190 | ] 191 | generate_btn.click( 192 | fn=pipeline.gradio_generate, 193 | inputs=inputs, 194 | outputs=[output_image, download_btn], 195 | ) 196 | 197 | # example_text = gr.Text("", visible=False, label="Case For:") 198 | examples = get_examples("./assets/gradio_examples") 199 | 200 | gr.Examples( 201 | examples=examples, 202 | inputs=[ 203 | prompt, 204 | image_prompt1, 205 | image_prompt2, 206 | image_prompt3, 207 | seed, 208 | ], 209 | # cache_examples='lazy', 210 | outputs=[output_image, download_btn], 211 | fn=pipeline.gradio_generate, 212 | label='row 1-4: identity/subject-driven; row 5-7: style-subject-driven; row 8-9: style-driven; row 10-12: multi-style-driven task; row 13: txt2img', 213 | examples_per_page=15 214 | ) 215 | 216 | return demo 217 | 218 | 219 | if __name__ == "__main__": 220 | from typing import Literal 221 | 222 | from transformers import HfArgumentParser 223 | 224 | @dataclasses.dataclass 225 | class AppArgs: 226 | name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell", "flux-krea-dev"] = "flux-dev" 227 | device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu" 228 | offload: bool = dataclasses.field( 229 | default=False, 230 | metadata={ 231 | "help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used." 232 | }, 233 | ) 234 | port: int = 7860 235 | 236 | parser = HfArgumentParser([AppArgs]) 237 | args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs] 238 | args = args_tuple[0] 239 | 240 | demo = create_demo(args.name, args.device, args.offload) 241 | demo.launch(server_port=args.port) 242 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /uso/flux/modules/autoencoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass 17 | 18 | import torch 19 | from einops import rearrange 20 | from torch import Tensor, nn 21 | 22 | 23 | @dataclass 24 | class AutoEncoderParams: 25 | resolution: int 26 | in_channels: int 27 | ch: int 28 | out_ch: int 29 | ch_mult: list[int] 30 | num_res_blocks: int 31 | z_channels: int 32 | scale_factor: float 33 | shift_factor: float 34 | 35 | 36 | def swish(x: Tensor) -> Tensor: 37 | return x * torch.sigmoid(x) 38 | 39 | 40 | class AttnBlock(nn.Module): 41 | def __init__(self, in_channels: int): 42 | super().__init__() 43 | self.in_channels = in_channels 44 | 45 | self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 46 | 47 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 48 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 49 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 50 | self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 51 | 52 | def attention(self, h_: Tensor) -> Tensor: 53 | h_ = self.norm(h_) 54 | q = self.q(h_) 55 | k = self.k(h_) 56 | v = self.v(h_) 57 | 58 | b, c, h, w = q.shape 59 | q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 60 | k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 61 | v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 62 | h_ = nn.functional.scaled_dot_product_attention(q, k, v) 63 | 64 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 65 | 66 | def forward(self, x: Tensor) -> Tensor: 67 | return x + self.proj_out(self.attention(x)) 68 | 69 | 70 | class ResnetBlock(nn.Module): 71 | def __init__(self, in_channels: int, out_channels: int): 72 | super().__init__() 73 | self.in_channels = in_channels 74 | out_channels = in_channels if out_channels is None else out_channels 75 | self.out_channels = out_channels 76 | 77 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 79 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 80 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 81 | if self.in_channels != self.out_channels: 82 | self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 83 | 84 | def forward(self, x): 85 | h = x 86 | h = self.norm1(h) 87 | h = swish(h) 88 | h = self.conv1(h) 89 | 90 | h = self.norm2(h) 91 | h = swish(h) 92 | h = self.conv2(h) 93 | 94 | if self.in_channels != self.out_channels: 95 | x = self.nin_shortcut(x) 96 | 97 | return x + h 98 | 99 | 100 | class Downsample(nn.Module): 101 | def __init__(self, in_channels: int): 102 | super().__init__() 103 | # no asymmetric padding in torch conv, must do it ourselves 104 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) 105 | 106 | def forward(self, x: Tensor): 107 | pad = (0, 1, 0, 1) 108 | x = nn.functional.pad(x, pad, mode="constant", value=0) 109 | x = self.conv(x) 110 | return x 111 | 112 | 113 | class Upsample(nn.Module): 114 | def __init__(self, in_channels: int): 115 | super().__init__() 116 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 117 | 118 | def forward(self, x: Tensor): 119 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 120 | x = self.conv(x) 121 | return x 122 | 123 | 124 | class Encoder(nn.Module): 125 | def __init__( 126 | self, 127 | resolution: int, 128 | in_channels: int, 129 | ch: int, 130 | ch_mult: list[int], 131 | num_res_blocks: int, 132 | z_channels: int, 133 | ): 134 | super().__init__() 135 | self.ch = ch 136 | self.num_resolutions = len(ch_mult) 137 | self.num_res_blocks = num_res_blocks 138 | self.resolution = resolution 139 | self.in_channels = in_channels 140 | # downsampling 141 | self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 142 | 143 | curr_res = resolution 144 | in_ch_mult = (1,) + tuple(ch_mult) 145 | self.in_ch_mult = in_ch_mult 146 | self.down = nn.ModuleList() 147 | block_in = self.ch 148 | for i_level in range(self.num_resolutions): 149 | block = nn.ModuleList() 150 | attn = nn.ModuleList() 151 | block_in = ch * in_ch_mult[i_level] 152 | block_out = ch * ch_mult[i_level] 153 | for _ in range(self.num_res_blocks): 154 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 155 | block_in = block_out 156 | down = nn.Module() 157 | down.block = block 158 | down.attn = attn 159 | if i_level != self.num_resolutions - 1: 160 | down.downsample = Downsample(block_in) 161 | curr_res = curr_res // 2 162 | self.down.append(down) 163 | 164 | # middle 165 | self.mid = nn.Module() 166 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 167 | self.mid.attn_1 = AttnBlock(block_in) 168 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 169 | 170 | # end 171 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 172 | self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) 173 | 174 | def forward(self, x: Tensor) -> Tensor: 175 | # downsampling 176 | hs = [self.conv_in(x)] 177 | for i_level in range(self.num_resolutions): 178 | for i_block in range(self.num_res_blocks): 179 | h = self.down[i_level].block[i_block](hs[-1]) 180 | if len(self.down[i_level].attn) > 0: 181 | h = self.down[i_level].attn[i_block](h) 182 | hs.append(h) 183 | if i_level != self.num_resolutions - 1: 184 | hs.append(self.down[i_level].downsample(hs[-1])) 185 | 186 | # middle 187 | h = hs[-1] 188 | h = self.mid.block_1(h) 189 | h = self.mid.attn_1(h) 190 | h = self.mid.block_2(h) 191 | # end 192 | h = self.norm_out(h) 193 | h = swish(h) 194 | h = self.conv_out(h) 195 | return h 196 | 197 | 198 | class Decoder(nn.Module): 199 | def __init__( 200 | self, 201 | ch: int, 202 | out_ch: int, 203 | ch_mult: list[int], 204 | num_res_blocks: int, 205 | in_channels: int, 206 | resolution: int, 207 | z_channels: int, 208 | ): 209 | super().__init__() 210 | self.ch = ch 211 | self.num_resolutions = len(ch_mult) 212 | self.num_res_blocks = num_res_blocks 213 | self.resolution = resolution 214 | self.in_channels = in_channels 215 | self.ffactor = 2 ** (self.num_resolutions - 1) 216 | 217 | # compute in_ch_mult, block_in and curr_res at lowest res 218 | block_in = ch * ch_mult[self.num_resolutions - 1] 219 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 220 | self.z_shape = (1, z_channels, curr_res, curr_res) 221 | 222 | # z to block_in 223 | self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 224 | 225 | # middle 226 | self.mid = nn.Module() 227 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 228 | self.mid.attn_1 = AttnBlock(block_in) 229 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 230 | 231 | # upsampling 232 | self.up = nn.ModuleList() 233 | for i_level in reversed(range(self.num_resolutions)): 234 | block = nn.ModuleList() 235 | attn = nn.ModuleList() 236 | block_out = ch * ch_mult[i_level] 237 | for _ in range(self.num_res_blocks + 1): 238 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 239 | block_in = block_out 240 | up = nn.Module() 241 | up.block = block 242 | up.attn = attn 243 | if i_level != 0: 244 | up.upsample = Upsample(block_in) 245 | curr_res = curr_res * 2 246 | self.up.insert(0, up) # prepend to get consistent order 247 | 248 | # end 249 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 250 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 251 | 252 | def forward(self, z: Tensor) -> Tensor: 253 | # z to block_in 254 | h = self.conv_in(z) 255 | 256 | # middle 257 | h = self.mid.block_1(h) 258 | h = self.mid.attn_1(h) 259 | h = self.mid.block_2(h) 260 | 261 | # upsampling 262 | for i_level in reversed(range(self.num_resolutions)): 263 | for i_block in range(self.num_res_blocks + 1): 264 | h = self.up[i_level].block[i_block](h) 265 | if len(self.up[i_level].attn) > 0: 266 | h = self.up[i_level].attn[i_block](h) 267 | if i_level != 0: 268 | h = self.up[i_level].upsample(h) 269 | 270 | # end 271 | h = self.norm_out(h) 272 | h = swish(h) 273 | h = self.conv_out(h) 274 | return h 275 | 276 | 277 | class DiagonalGaussian(nn.Module): 278 | def __init__(self, sample: bool = True, chunk_dim: int = 1): 279 | super().__init__() 280 | self.sample = sample 281 | self.chunk_dim = chunk_dim 282 | 283 | def forward(self, z: Tensor) -> Tensor: 284 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 285 | if self.sample: 286 | std = torch.exp(0.5 * logvar) 287 | return mean + std * torch.randn_like(mean) 288 | else: 289 | return mean 290 | 291 | 292 | class AutoEncoder(nn.Module): 293 | def __init__(self, params: AutoEncoderParams): 294 | super().__init__() 295 | self.encoder = Encoder( 296 | resolution=params.resolution, 297 | in_channels=params.in_channels, 298 | ch=params.ch, 299 | ch_mult=params.ch_mult, 300 | num_res_blocks=params.num_res_blocks, 301 | z_channels=params.z_channels, 302 | ) 303 | self.decoder = Decoder( 304 | resolution=params.resolution, 305 | in_channels=params.in_channels, 306 | ch=params.ch, 307 | out_ch=params.out_ch, 308 | ch_mult=params.ch_mult, 309 | num_res_blocks=params.num_res_blocks, 310 | z_channels=params.z_channels, 311 | ) 312 | self.reg = DiagonalGaussian() 313 | 314 | self.scale_factor = params.scale_factor 315 | self.shift_factor = params.shift_factor 316 | 317 | def encode(self, x: Tensor) -> Tensor: 318 | z = self.reg(self.encoder(x)) 319 | z = self.scale_factor * (z - self.shift_factor) 320 | return z 321 | 322 | def decode(self, z: Tensor) -> Tensor: 323 | z = z / self.scale_factor + self.shift_factor 324 | return self.decoder(z) 325 | 326 | def forward(self, x: Tensor) -> Tensor: 327 | return self.decode(self.encode(x)) 328 | -------------------------------------------------------------------------------- /uso/flux/pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import math 18 | from typing import Literal, Optional 19 | from torch import Tensor 20 | 21 | import torch 22 | from einops import rearrange 23 | from PIL import ExifTags, Image 24 | import torchvision.transforms.functional as TVF 25 | 26 | from .modules.layers import ( 27 | DoubleStreamBlockLoraProcessor, 28 | DoubleStreamBlockProcessor, 29 | SingleStreamBlockLoraProcessor, 30 | SingleStreamBlockProcessor, 31 | ) 32 | from .sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack 33 | from .util import ( 34 | get_lora_rank, 35 | load_ae, 36 | load_checkpoint, 37 | load_clip, 38 | load_flow_model, 39 | load_flow_model_only_lora, 40 | load_t5, 41 | ) 42 | 43 | 44 | def find_nearest_scale(image_h, image_w, predefined_scales): 45 | """ 46 | 根据图片的高度和宽度,找到最近的预定义尺度。 47 | 48 | :param image_h: 图片的高度 49 | :param image_w: 图片的宽度 50 | :param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...] 51 | :return: 最近的预定义尺度 (h, w) 52 | """ 53 | # 计算输入图片的长宽比 54 | image_ratio = image_h / image_w 55 | 56 | # 初始化变量以存储最小差异和最近的尺度 57 | min_diff = float("inf") 58 | nearest_scale = None 59 | 60 | # 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度 61 | for scale_h, scale_w in predefined_scales: 62 | predefined_ratio = scale_h / scale_w 63 | diff = abs(predefined_ratio - image_ratio) 64 | 65 | if diff < min_diff: 66 | min_diff = diff 67 | nearest_scale = (scale_h, scale_w) 68 | 69 | return nearest_scale 70 | 71 | 72 | def preprocess_ref(raw_image: Image.Image, long_size: int = 512, scale_ratio: int = 1): 73 | # 获取原始图像的宽度和高度 74 | image_w, image_h = raw_image.size 75 | if image_w == image_h and image_w == 16: 76 | return raw_image 77 | 78 | # 计算长边和短边 79 | if image_w >= image_h: 80 | new_w = long_size 81 | new_h = int((long_size / image_w) * image_h) 82 | else: 83 | new_h = long_size 84 | new_w = int((long_size / image_h) * image_w) 85 | 86 | # 按新的宽高进行等比例缩放 87 | raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS) 88 | 89 | # 为了能让canny img进行scale 90 | scale_ratio = int(scale_ratio) 91 | target_w = new_w // (16 * scale_ratio) * (16 * scale_ratio) 92 | target_h = new_h // (16 * scale_ratio) * (16 * scale_ratio) 93 | 94 | # 计算裁剪的起始坐标以实现中心裁剪 95 | left = (new_w - target_w) // 2 96 | top = (new_h - target_h) // 2 97 | right = left + target_w 98 | bottom = top + target_h 99 | 100 | # 进行中心裁剪 101 | raw_image = raw_image.crop((left, top, right, bottom)) 102 | 103 | # 转换为 RGB 模式 104 | raw_image = raw_image.convert("RGB") 105 | return raw_image 106 | 107 | 108 | def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1): 109 | target_height_ref1 = int(target_height_ref1 // 64 * 64) 110 | target_width_ref1 = int(target_width_ref1 // 64 * 64) 111 | h, w = image.shape[-2:] 112 | if h < target_height_ref1 or w < target_width_ref1: 113 | # 计算长宽比 114 | aspect_ratio = w / h 115 | if h < target_height_ref1: 116 | new_h = target_height_ref1 117 | new_w = new_h * aspect_ratio 118 | if new_w < target_width_ref1: 119 | new_w = target_width_ref1 120 | new_h = new_w / aspect_ratio 121 | else: 122 | new_w = target_width_ref1 123 | new_h = new_w / aspect_ratio 124 | if new_h < target_height_ref1: 125 | new_h = target_height_ref1 126 | new_w = new_h * aspect_ratio 127 | else: 128 | aspect_ratio = w / h 129 | tgt_aspect_ratio = target_width_ref1 / target_height_ref1 130 | if aspect_ratio > tgt_aspect_ratio: 131 | new_h = target_height_ref1 132 | new_w = new_h * aspect_ratio 133 | else: 134 | new_w = target_width_ref1 135 | new_h = new_w / aspect_ratio 136 | # 使用 TVF.resize 进行图像缩放 137 | image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w))) 138 | # 计算中心裁剪的参数 139 | top = (image.shape[-2] - target_height_ref1) // 2 140 | left = (image.shape[-1] - target_width_ref1) // 2 141 | # 使用 TVF.crop 进行中心裁剪 142 | image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1) 143 | return image 144 | 145 | 146 | class USOPipeline: 147 | def __init__( 148 | self, 149 | model_type: str, 150 | device: torch.device, 151 | offload: bool = False, 152 | only_lora: bool = False, 153 | lora_rank: int = 16, 154 | hf_download: bool = True, 155 | ): 156 | self.device = device 157 | self.offload = offload 158 | self.model_type = model_type 159 | 160 | print(f'----> model type is {model_type}({only_lora})') 161 | self.clip = load_clip(self.device) 162 | print('----> load clip completely') 163 | self.t5 = load_t5(self.device, max_length=512) 164 | print('----> load t5 completely') 165 | self.ae = load_ae(model_type, device="cpu" if offload else self.device) 166 | print('----> load ae completely') 167 | self.use_fp8 = "fp8" in model_type 168 | if only_lora: 169 | self.model = load_flow_model_only_lora( 170 | model_type, 171 | device="cpu" if offload else self.device, 172 | lora_rank=lora_rank, 173 | use_fp8=self.use_fp8, 174 | hf_download=hf_download, 175 | ) 176 | else: 177 | self.model = load_flow_model( 178 | model_type, device="cpu" if offload else self.device 179 | ) 180 | 181 | def load_ckpt(self, ckpt_path): 182 | if ckpt_path is not None: 183 | from safetensors.torch import load_file as load_sft 184 | 185 | print("Loading checkpoint to replace old keys") 186 | # load_sft doesn't support torch.device 187 | if ckpt_path.endswith("safetensors"): 188 | sd = load_sft(ckpt_path, device="cpu") 189 | missing, unexpected = self.model.load_state_dict( 190 | sd, strict=False, assign=True 191 | ) 192 | else: 193 | dit_state = torch.load(ckpt_path, map_location="cpu") 194 | sd = {} 195 | for k in dit_state.keys(): 196 | sd[k.replace("module.", "")] = dit_state[k] 197 | missing, unexpected = self.model.load_state_dict( 198 | sd, strict=False, assign=True 199 | ) 200 | self.model.to(str(self.device)) 201 | print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}") 202 | 203 | def set_lora( 204 | self, 205 | local_path: str = None, 206 | repo_id: str = None, 207 | name: str = None, 208 | lora_weight: int = 0.7, 209 | ): 210 | checkpoint = load_checkpoint(local_path, repo_id, name) 211 | self.update_model_with_lora(checkpoint, lora_weight) 212 | 213 | def set_lora_from_collection( 214 | self, lora_type: str = "realism", lora_weight: int = 0.7 215 | ): 216 | checkpoint = load_checkpoint( 217 | None, self.hf_lora_collection, self.lora_types_to_names[lora_type] 218 | ) 219 | self.update_model_with_lora(checkpoint, lora_weight) 220 | 221 | def update_model_with_lora(self, checkpoint, lora_weight): 222 | rank = get_lora_rank(checkpoint) 223 | lora_attn_procs = {} 224 | 225 | for name, _ in self.model.attn_processors.items(): 226 | lora_state_dict = {} 227 | for k in checkpoint.keys(): 228 | if name in k: 229 | lora_state_dict[k[len(name) + 1 :]] = checkpoint[k] * lora_weight 230 | 231 | if len(lora_state_dict): 232 | if name.startswith("single_blocks"): 233 | lora_attn_procs[name] = SingleStreamBlockLoraProcessor( 234 | dim=3072, rank=rank 235 | ) 236 | else: 237 | lora_attn_procs[name] = DoubleStreamBlockLoraProcessor( 238 | dim=3072, rank=rank 239 | ) 240 | lora_attn_procs[name].load_state_dict(lora_state_dict) 241 | lora_attn_procs[name].to(self.device) 242 | else: 243 | if name.startswith("single_blocks"): 244 | lora_attn_procs[name] = SingleStreamBlockProcessor() 245 | else: 246 | lora_attn_procs[name] = DoubleStreamBlockProcessor() 247 | 248 | self.model.set_attn_processor(lora_attn_procs) 249 | 250 | def __call__( 251 | self, 252 | prompt: str, 253 | width: int = 512, 254 | height: int = 512, 255 | guidance: float = 4, 256 | num_steps: int = 50, 257 | seed: int = 123456789, 258 | **kwargs, 259 | ): 260 | width = 16 * (width // 16) 261 | height = 16 * (height // 16) 262 | 263 | device_type = self.device if isinstance(self.device, str) else self.device.type 264 | with torch.autocast( 265 | enabled=self.use_fp8, device_type=device_type, dtype=torch.bfloat16 266 | ): 267 | return self.forward( 268 | prompt, width, height, guidance, num_steps, seed, **kwargs 269 | ) 270 | 271 | @torch.inference_mode() 272 | def gradio_generate( 273 | self, 274 | prompt: str, 275 | image_prompt1: Image.Image, 276 | image_prompt2: Image.Image, 277 | image_prompt3: Image.Image, 278 | seed: int, 279 | width: int = 1024, 280 | height: int = 1024, 281 | guidance: float = 4, 282 | num_steps: int = 25, 283 | keep_size: bool = False, 284 | content_long_size: int = 512, 285 | ): 286 | ref_content_imgs = [image_prompt1] 287 | ref_content_imgs = [img for img in ref_content_imgs if isinstance(img, Image.Image)] 288 | ref_content_imgs = [preprocess_ref(img, content_long_size) for img in ref_content_imgs] 289 | 290 | ref_style_imgs = [image_prompt2, image_prompt3] 291 | ref_style_imgs = [img for img in ref_style_imgs if isinstance(img, Image.Image)] 292 | ref_style_imgs = [self.model.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs] 293 | 294 | seed = seed if seed != -1 else torch.randint(0, 10**8, (1,)).item() 295 | 296 | # whether keep input image size 297 | if keep_size and len(ref_content_imgs)>0: 298 | width, height = ref_content_imgs[0].size 299 | width, height = int(width * (1024 / content_long_size)), int(height * (1024 / content_long_size)) 300 | img = self( 301 | prompt=prompt, 302 | width=width, 303 | height=height, 304 | guidance=guidance, 305 | num_steps=num_steps, 306 | seed=seed, 307 | ref_imgs=ref_content_imgs, 308 | siglip_inputs=ref_style_imgs, 309 | ) 310 | 311 | filename = f"output/gradio/{seed}_{prompt[:20]}.png" 312 | os.makedirs(os.path.dirname(filename), exist_ok=True) 313 | exif_data = Image.Exif() 314 | exif_data[ExifTags.Base.Make] = "USO" 315 | exif_data[ExifTags.Base.Model] = self.model_type 316 | info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}" 317 | exif_data[ExifTags.Base.ImageDescription] = info 318 | img.save(filename, format="png", exif=exif_data) 319 | return img, filename 320 | 321 | @torch.inference_mode 322 | def forward( 323 | self, 324 | prompt: str, 325 | width: int, 326 | height: int, 327 | guidance: float, 328 | num_steps: int, 329 | seed: int, 330 | ref_imgs: list[Image.Image] | None = None, 331 | pe: Literal["d", "h", "w", "o"] = "d", 332 | siglip_inputs: list[Tensor] | None = None, 333 | **kwargs 334 | ): 335 | 336 | update_func = kwargs.get('update_func', lambda *args, **kwargs: None) 337 | x = get_noise( 338 | 1, height, width, device=self.device, dtype=torch.bfloat16, seed=seed 339 | ) 340 | timesteps = get_schedule( 341 | num_steps, 342 | (width // 8) * (height // 8) // (16 * 16), 343 | shift=True, 344 | ) 345 | if self.offload: 346 | self.ae.encoder = self.ae.encoder.to(self.device) 347 | x_1_refs = [ 348 | self.ae.encode( 349 | (TVF.to_tensor(ref_img) * 2.0 - 1.0) 350 | .unsqueeze(0) 351 | .to(self.device, torch.float32) 352 | ).to(torch.bfloat16) 353 | for ref_img in ref_imgs 354 | ] 355 | 356 | if self.offload: 357 | self.offload_model_to_cpu(self.ae.encoder) 358 | self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) 359 | inp_cond = prepare_multi_ip( 360 | t5=self.t5, 361 | clip=self.clip, 362 | img=x, 363 | prompt=prompt, 364 | ref_imgs=x_1_refs, 365 | pe=pe, 366 | ) 367 | 368 | if self.offload: 369 | self.offload_model_to_cpu(self.t5, self.clip) 370 | self.model = self.model.to(self.device) 371 | 372 | x = denoise( 373 | self.model, 374 | **inp_cond, 375 | timesteps=timesteps, 376 | guidance=guidance, 377 | siglip_inputs=siglip_inputs, 378 | update_func=update_func, 379 | ) 380 | 381 | if self.offload: 382 | self.offload_model_to_cpu(self.model) 383 | self.ae.decoder.to(x.device) 384 | x = unpack(x.float(), height, width) 385 | x = self.ae.decode(x) 386 | self.offload_model_to_cpu(self.ae.decoder) 387 | 388 | x1 = x.clamp(-1, 1) 389 | x1 = rearrange(x1[-1], "c h w -> h w c") 390 | output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) 391 | return output_img 392 | 393 | def offload_model_to_cpu(self, *models): 394 | if not self.offload: 395 | return 396 | for model in models: 397 | model.cpu() 398 | torch.cuda.empty_cache() 399 | -------------------------------------------------------------------------------- /uso/flux/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from dataclasses import dataclass 18 | 19 | import torch 20 | import json 21 | import numpy as np 22 | from huggingface_hub import hf_hub_download 23 | from safetensors import safe_open 24 | from safetensors.torch import load_file as load_sft 25 | 26 | from .model import Flux, FluxParams 27 | from .modules.autoencoder import AutoEncoder, AutoEncoderParams 28 | from .modules.conditioner import HFEmbedder 29 | 30 | import re 31 | from .modules.layers import ( 32 | DoubleStreamBlockLoraProcessor, 33 | SingleStreamBlockLoraProcessor, 34 | ) 35 | 36 | import os 37 | try: 38 | import folder_paths 39 | print('run in comfyui') 40 | except: 41 | print('not run in comfyui') 42 | from types import SimpleNamespace 43 | folder_paths = SimpleNamespace() 44 | folder_paths.models_dir = '/workspace/comfyui/models/' 45 | 46 | 47 | def load_model(ckpt, device="cpu"): 48 | if ckpt.endswith("safetensors"): 49 | from safetensors import safe_open 50 | 51 | pl_sd = {} 52 | with safe_open(ckpt, framework="pt", device=device) as f: 53 | for k in f.keys(): 54 | pl_sd[k] = f.get_tensor(k) 55 | else: 56 | pl_sd = torch.load(ckpt, map_location=device) 57 | return pl_sd 58 | 59 | 60 | def load_safetensors(path): 61 | tensors = {} 62 | with safe_open(path, framework="pt", device="cpu") as f: 63 | for key in f.keys(): 64 | tensors[key] = f.get_tensor(key) 65 | return tensors 66 | 67 | 68 | def get_lora_rank(checkpoint): 69 | for k in checkpoint.keys(): 70 | if k.endswith(".down.weight"): 71 | return checkpoint[k].shape[0] 72 | 73 | 74 | def load_checkpoint(local_path, repo_id, name): 75 | if local_path is not None: 76 | if ".safetensors" in local_path: 77 | print(f"Loading .safetensors checkpoint from {local_path}") 78 | checkpoint = load_safetensors(local_path) 79 | else: 80 | print(f"Loading checkpoint from {local_path}") 81 | checkpoint = torch.load(local_path, map_location="cpu") 82 | elif repo_id is not None and name is not None: 83 | print(f"Loading checkpoint {name} from repo id {repo_id}") 84 | checkpoint = load_from_repo_id(repo_id, name) 85 | else: 86 | raise ValueError( 87 | "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" 88 | ) 89 | return checkpoint 90 | 91 | 92 | def c_crop(image): 93 | width, height = image.size 94 | new_size = min(width, height) 95 | left = (width - new_size) / 2 96 | top = (height - new_size) / 2 97 | right = (width + new_size) / 2 98 | bottom = (height + new_size) / 2 99 | return image.crop((left, top, right, bottom)) 100 | 101 | 102 | def pad64(x): 103 | return int(np.ceil(float(x) / 64.0) * 64 - x) 104 | 105 | 106 | def HWC3(x): 107 | assert x.dtype == np.uint8 108 | if x.ndim == 2: 109 | x = x[:, :, None] 110 | assert x.ndim == 3 111 | H, W, C = x.shape 112 | assert C == 1 or C == 3 or C == 4 113 | if C == 3: 114 | return x 115 | if C == 1: 116 | return np.concatenate([x, x, x], axis=2) 117 | if C == 4: 118 | color = x[:, :, 0:3].astype(np.float32) 119 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 120 | y = color * alpha + 255.0 * (1.0 - alpha) 121 | y = y.clip(0, 255).astype(np.uint8) 122 | return y 123 | 124 | 125 | @dataclass 126 | class ModelSpec: 127 | params: FluxParams 128 | ae_params: AutoEncoderParams 129 | ckpt_path: str | None 130 | ae_path: str | None 131 | repo_id: str | None 132 | repo_flow: str | None 133 | repo_ae: str | None 134 | repo_id_ae: str | None 135 | 136 | 137 | configs = { 138 | "flux-dev": ModelSpec( 139 | repo_id="black-forest-labs/FLUX.1-dev", 140 | repo_id_ae="black-forest-labs/FLUX.1-dev", 141 | repo_flow="flux1-dev.safetensors", 142 | repo_ae="ae.safetensors", 143 | ckpt_path=os.getenv("FLUX_DEV"), 144 | params=FluxParams( 145 | in_channels=64, 146 | vec_in_dim=768, 147 | context_in_dim=4096, 148 | hidden_size=3072, 149 | mlp_ratio=4.0, 150 | num_heads=24, 151 | depth=19, 152 | depth_single_blocks=38, 153 | axes_dim=[16, 56, 56], 154 | theta=10_000, 155 | qkv_bias=True, 156 | guidance_embed=True, 157 | ), 158 | ae_path=os.getenv("AE"), 159 | ae_params=AutoEncoderParams( 160 | resolution=256, 161 | in_channels=3, 162 | ch=128, 163 | out_ch=3, 164 | ch_mult=[1, 2, 4, 4], 165 | num_res_blocks=2, 166 | z_channels=16, 167 | scale_factor=0.3611, 168 | shift_factor=0.1159, 169 | ), 170 | ), 171 | "flux-dev-fp8": ModelSpec( 172 | repo_id="black-forest-labs/FLUX.1-dev", 173 | repo_id_ae="black-forest-labs/FLUX.1-dev", 174 | repo_flow="flux1-dev.safetensors", 175 | repo_ae="ae.safetensors", 176 | ckpt_path=os.getenv("FLUX_DEV_FP8"), 177 | params=FluxParams( 178 | in_channels=64, 179 | vec_in_dim=768, 180 | context_in_dim=4096, 181 | hidden_size=3072, 182 | mlp_ratio=4.0, 183 | num_heads=24, 184 | depth=19, 185 | depth_single_blocks=38, 186 | axes_dim=[16, 56, 56], 187 | theta=10_000, 188 | qkv_bias=True, 189 | guidance_embed=True, 190 | ), 191 | ae_path=os.getenv("AE"), 192 | ae_params=AutoEncoderParams( 193 | resolution=256, 194 | in_channels=3, 195 | ch=128, 196 | out_ch=3, 197 | ch_mult=[1, 2, 4, 4], 198 | num_res_blocks=2, 199 | z_channels=16, 200 | scale_factor=0.3611, 201 | shift_factor=0.1159, 202 | ), 203 | ), 204 | "flux-krea-dev": ModelSpec( 205 | repo_id="black-forest-labs/FLUX.1-Krea-dev", 206 | repo_id_ae="black-forest-labs/FLUX.1-Krea-dev", 207 | repo_flow="flux1-krea-dev.safetensors", 208 | repo_ae="ae.safetensors", 209 | ckpt_path=os.getenv("FLUX_KREA_DEV"), 210 | params=FluxParams( 211 | in_channels=64, 212 | vec_in_dim=768, 213 | context_in_dim=4096, 214 | hidden_size=3072, 215 | mlp_ratio=4.0, 216 | num_heads=24, 217 | depth=19, 218 | depth_single_blocks=38, 219 | axes_dim=[16, 56, 56], 220 | theta=10_000, 221 | qkv_bias=True, 222 | guidance_embed=True, 223 | ), 224 | ae_path=os.getenv("AE"), 225 | ae_params=AutoEncoderParams( 226 | resolution=256, 227 | in_channels=3, 228 | ch=128, 229 | out_ch=3, 230 | ch_mult=[1, 2, 4, 4], 231 | num_res_blocks=2, 232 | z_channels=16, 233 | scale_factor=0.3611, 234 | shift_factor=0.1159, 235 | ), 236 | ), 237 | "flux-schnell": ModelSpec( 238 | repo_id="black-forest-labs/FLUX.1-schnell", 239 | repo_id_ae="black-forest-labs/FLUX.1-dev", 240 | repo_flow="flux1-schnell.safetensors", 241 | repo_ae="ae.safetensors", 242 | ckpt_path=os.getenv("FLUX_SCHNELL"), 243 | params=FluxParams( 244 | in_channels=64, 245 | vec_in_dim=768, 246 | context_in_dim=4096, 247 | hidden_size=3072, 248 | mlp_ratio=4.0, 249 | num_heads=24, 250 | depth=19, 251 | depth_single_blocks=38, 252 | axes_dim=[16, 56, 56], 253 | theta=10_000, 254 | qkv_bias=True, 255 | guidance_embed=False, 256 | ), 257 | ae_path=os.getenv("AE"), 258 | ae_params=AutoEncoderParams( 259 | resolution=256, 260 | in_channels=3, 261 | ch=128, 262 | out_ch=3, 263 | ch_mult=[1, 2, 4, 4], 264 | num_res_blocks=2, 265 | z_channels=16, 266 | scale_factor=0.3611, 267 | shift_factor=0.1159, 268 | ), 269 | ), 270 | } 271 | 272 | 273 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None: 274 | if len(missing) > 0 and len(unexpected) > 0: 275 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 276 | print("\n" + "-" * 79 + "\n") 277 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 278 | elif len(missing) > 0: 279 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 280 | elif len(unexpected) > 0: 281 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 282 | 283 | 284 | def load_from_repo_id(repo_id, checkpoint_name): 285 | ckpt_path = hf_hub_download(repo_id, checkpoint_name) 286 | sd = load_sft(ckpt_path, device="cpu") 287 | return sd 288 | 289 | 290 | def load_flow_model( 291 | name: str, device: str | torch.device = "cuda", hf_download: bool = True 292 | ): 293 | # Loading Flux 294 | print("Init model") 295 | ckpt_path = configs[name].ckpt_path 296 | if ( 297 | ckpt_path is None 298 | and configs[name].repo_id is not None 299 | and configs[name].repo_flow is not None 300 | ): 301 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) 302 | 303 | # with torch.device("meta" if ckpt_path is not None else device): 304 | with torch.device(device): 305 | model = Flux(configs[name].params).to(torch.bfloat16) 306 | 307 | if ckpt_path is not None: 308 | print("Loading main checkpoint") 309 | # load_sft doesn't support torch.device 310 | sd = load_model(ckpt_path, device="cpu") 311 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 312 | print_load_warning(missing, unexpected) 313 | return model.to(str(device)) 314 | 315 | 316 | def load_flow_model_only_lora( 317 | name: str, 318 | device: str | torch.device = "cuda", 319 | hf_download: bool = True, 320 | lora_rank: int = 16, 321 | use_fp8: bool = False, 322 | ): 323 | # Loading Flux 324 | # ckpt_path = configs[name].ckpt_path 325 | # kiki 326 | ckpt_path = os.path.join(folder_paths.models_dir, 'diffusers', 'FLUX.1-dev', 'flux1-dev.safetensors') 327 | if ( 328 | ckpt_path is None 329 | and configs[name].repo_id is not None 330 | and configs[name].repo_flow is not None 331 | ): 332 | ckpt_path = hf_hub_download( 333 | configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors") 334 | ) 335 | 336 | # if hf_download: 337 | # try: 338 | # lora_ckpt_path = hf_hub_download( 339 | # "bytedance-research/USO", "uso_flux_v1.0/dit_lora.safetensors" 340 | # ) 341 | # except Exception as e: 342 | # print(f"Failed to download lora checkpoint: {e}") 343 | # print("Trying to load lora from local") 344 | # lora_ckpt_path = os.environ.get("LORA", None) 345 | # try: 346 | # proj_ckpt_path = hf_hub_download( 347 | # "bytedance-research/USO", "uso_flux_v1.0/projector.safetensors" 348 | # ) 349 | # except Exception as e: 350 | # print(f"Failed to download projection_model checkpoint: {e}") 351 | # print("Trying to load projection_model from local") 352 | # proj_ckpt_path = os.environ.get("PROJECTION_MODEL", None) 353 | # else: 354 | # lora_ckpt_path = os.environ.get("LORA", None) 355 | # proj_ckpt_path = os.environ.get("PROJECTION_MODEL", None) 356 | # print(lora_ckpt_path) 357 | # print(proj_ckpt_path) 358 | base_ckpt_path = os.path.join(folder_paths.models_dir, 'uso', 'uso_flux_v1.0') 359 | lora_ckpt_path = os.path.join(base_ckpt_path, 'dit_lora.safetensors') 360 | proj_ckpt_path = os.path.join(base_ckpt_path, 'projector.safetensors') 361 | with torch.device("meta" if ckpt_path is not None else device): 362 | model = Flux(configs[name].params) 363 | 364 | model = set_lora( 365 | model, lora_rank, device="meta" if lora_ckpt_path is not None else device 366 | ) 367 | 368 | if ckpt_path is not None: 369 | print(f"Loading lora from {lora_ckpt_path}") 370 | lora_sd = ( 371 | load_sft(lora_ckpt_path, device=str(device)) 372 | if lora_ckpt_path.endswith("safetensors") 373 | else torch.load(lora_ckpt_path, map_location="cpu") 374 | ) 375 | proj_sd = ( 376 | load_sft(proj_ckpt_path, device=str(device)) 377 | if proj_ckpt_path.endswith("safetensors") 378 | else torch.load(proj_ckpt_path, map_location="cpu") 379 | ) 380 | lora_sd.update(proj_sd) 381 | 382 | print("Loading main checkpoint") 383 | # load_sft doesn't support torch.device 384 | 385 | if ckpt_path.endswith("safetensors"): 386 | if use_fp8: 387 | print( 388 | "####\n" 389 | "We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n" 390 | "we convert the fp8 checkpoint on flight from bf16 checkpoint\n" 391 | "If your storage is constrained" 392 | "you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n" 393 | ) 394 | sd = load_sft(ckpt_path, device="cpu") 395 | sd = { 396 | k: v.to(dtype=torch.float8_e4m3fn, device=device) 397 | for k, v in sd.items() 398 | } 399 | else: 400 | sd = load_sft(ckpt_path, device=str(device)) 401 | 402 | sd.update(lora_sd) 403 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 404 | else: 405 | dit_state = torch.load(ckpt_path, map_location="cpu") 406 | sd = {} 407 | for k in dit_state.keys(): 408 | sd[k.replace("module.", "")] = dit_state[k] 409 | sd.update(lora_sd) 410 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 411 | model.to(str(device)) 412 | print_load_warning(missing, unexpected) 413 | return model 414 | 415 | 416 | def set_lora( 417 | model: Flux, 418 | lora_rank: int, 419 | double_blocks_indices: list[int] | None = None, 420 | single_blocks_indices: list[int] | None = None, 421 | device: str | torch.device = "cpu", 422 | ) -> Flux: 423 | double_blocks_indices = ( 424 | list(range(model.params.depth)) 425 | if double_blocks_indices is None 426 | else double_blocks_indices 427 | ) 428 | single_blocks_indices = ( 429 | list(range(model.params.depth_single_blocks)) 430 | if single_blocks_indices is None 431 | else single_blocks_indices 432 | ) 433 | 434 | lora_attn_procs = {} 435 | with torch.device(device): 436 | for name, attn_processor in model.attn_processors.items(): 437 | match = re.search(r"\.(\d+)\.", name) 438 | if match: 439 | layer_index = int(match.group(1)) 440 | 441 | if ( 442 | name.startswith("double_blocks") 443 | and layer_index in double_blocks_indices 444 | ): 445 | lora_attn_procs[name] = DoubleStreamBlockLoraProcessor( 446 | dim=model.params.hidden_size, rank=lora_rank 447 | ) 448 | elif ( 449 | name.startswith("single_blocks") 450 | and layer_index in single_blocks_indices 451 | ): 452 | lora_attn_procs[name] = SingleStreamBlockLoraProcessor( 453 | dim=model.params.hidden_size, rank=lora_rank 454 | ) 455 | else: 456 | lora_attn_procs[name] = attn_processor 457 | model.set_attn_processor(lora_attn_procs) 458 | return model 459 | 460 | 461 | def load_flow_model_quintized( 462 | name: str, device: str | torch.device = "cuda", hf_download: bool = True 463 | ): 464 | # Loading Flux 465 | from optimum.quanto import requantize 466 | 467 | print("Init model") 468 | ckpt_path = configs[name].ckpt_path 469 | if ( 470 | ckpt_path is None 471 | and configs[name].repo_id is not None 472 | and configs[name].repo_flow is not None 473 | and hf_download 474 | ): 475 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) 476 | json_path = hf_hub_download(configs[name].repo_id, "flux_dev_quantization_map.json") 477 | 478 | model = Flux(configs[name].params).to(torch.bfloat16) 479 | 480 | print("Loading checkpoint") 481 | # load_sft doesn't support torch.device 482 | sd = load_sft(ckpt_path, device="cpu") 483 | sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()} 484 | model.load_state_dict(sd, assign=True) 485 | return model 486 | with open(json_path, "r") as f: 487 | quantization_map = json.load(f) 488 | print("Start a quantization process...") 489 | requantize(model, sd, quantization_map, device=device) 490 | print("Model is quantized!") 491 | return model 492 | 493 | 494 | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: 495 | # max length 64, 128, 256 and 512 should work (if your sequence is short enough) 496 | #version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders") 497 | # version = '/workspace/comfyui/models/clip/xflux_text_encoders' 498 | version = os.path.join(folder_paths.models_dir, 'clip', 'xflux_text_encoders') 499 | return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to( 500 | device 501 | ) 502 | 503 | 504 | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: 505 | # version = os.environ.get("CLIP", "openai/clip-vit-large-patch14") 506 | #kiki 507 | # version = '/workspace/comfyui/models/clip_vision/clip-vit-large-patch14' 508 | version = os.path.join(folder_paths.models_dir, 'clip_vision', 'clip-vit-large-patch14') 509 | return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16, is_clip=True).to(device) 510 | 511 | 512 | def load_ae( 513 | name: str, device: str | torch.device = "cuda", hf_download: bool = True 514 | ) -> AutoEncoder: 515 | # ckpt_path = configs[name].ae_path 516 | # if ( 517 | # ckpt_path is None 518 | # and configs[name].repo_id is not None 519 | # and configs[name].repo_ae is not None 520 | # and hf_download 521 | # ): 522 | # ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) 523 | #kiki 524 | ckpt_path = os.path.join(folder_paths.models_dir, 'diffusers', 'FLUX.1-dev', 'ae.safetensors') 525 | 526 | # Loading the autoencoder 527 | print("Init AE") 528 | with torch.device("meta" if ckpt_path is not None else device): 529 | ae = AutoEncoder(configs[name].ae_params) 530 | 531 | if ckpt_path is not None: 532 | sd = load_sft(ckpt_path, device=str(device)) 533 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) 534 | print_load_warning(missing, unexpected) 535 | return ae 536 | -------------------------------------------------------------------------------- /uso/flux/modules/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import math 17 | from dataclasses import dataclass 18 | 19 | import torch 20 | from einops import rearrange, repeat 21 | from torch import Tensor, nn 22 | 23 | from ..math import attention, rope 24 | 25 | 26 | class EmbedND(nn.Module): 27 | def __init__(self, dim: int, theta: int, axes_dim: list[int]): 28 | super().__init__() 29 | self.dim = dim 30 | self.theta = theta 31 | self.axes_dim = axes_dim 32 | 33 | def forward(self, ids: Tensor) -> Tensor: 34 | n_axes = ids.shape[-1] 35 | emb = torch.cat( 36 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 37 | dim=-3, 38 | ) 39 | 40 | return emb.unsqueeze(1) 41 | 42 | 43 | def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): 44 | """ 45 | Create sinusoidal timestep embeddings. 46 | :param t: a 1-D Tensor of N indices, one per batch element. 47 | These may be fractional. 48 | :param dim: the dimension of the output. 49 | :param max_period: controls the minimum frequency of the embeddings. 50 | :return: an (N, D) Tensor of positional embeddings. 51 | """ 52 | t = time_factor * t 53 | half = dim // 2 54 | freqs = torch.exp( 55 | -math.log(max_period) 56 | * torch.arange(start=0, end=half, dtype=torch.float32) 57 | / half 58 | ).to(t.device) 59 | 60 | args = t[:, None].float() * freqs[None] 61 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 62 | if dim % 2: 63 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 64 | if torch.is_floating_point(t): 65 | embedding = embedding.to(t) 66 | return embedding 67 | 68 | 69 | class MLPEmbedder(nn.Module): 70 | def __init__(self, in_dim: int, hidden_dim: int): 71 | super().__init__() 72 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) 73 | self.silu = nn.SiLU() 74 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) 75 | 76 | def forward(self, x: Tensor) -> Tensor: 77 | return self.out_layer(self.silu(self.in_layer(x))) 78 | 79 | 80 | class RMSNorm(torch.nn.Module): 81 | def __init__(self, dim: int): 82 | super().__init__() 83 | self.scale = nn.Parameter(torch.ones(dim)) 84 | 85 | def forward(self, x: Tensor): 86 | x_dtype = x.dtype 87 | x = x.float() 88 | rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) 89 | return ((x * rrms) * self.scale.float()).to(dtype=x_dtype) 90 | 91 | 92 | class QKNorm(torch.nn.Module): 93 | def __init__(self, dim: int): 94 | super().__init__() 95 | self.query_norm = RMSNorm(dim) 96 | self.key_norm = RMSNorm(dim) 97 | 98 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 99 | q = self.query_norm(q) 100 | k = self.key_norm(k) 101 | return q.to(v), k.to(v) 102 | 103 | 104 | class LoRALinearLayer(nn.Module): 105 | def __init__( 106 | self, 107 | in_features, 108 | out_features, 109 | rank=4, 110 | network_alpha=None, 111 | device=None, 112 | dtype=None, 113 | ): 114 | super().__init__() 115 | 116 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 117 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 118 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 119 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 120 | self.network_alpha = network_alpha 121 | self.rank = rank 122 | 123 | nn.init.normal_(self.down.weight, std=1 / rank) 124 | nn.init.zeros_(self.up.weight) 125 | 126 | def forward(self, hidden_states): 127 | orig_dtype = hidden_states.dtype 128 | dtype = self.down.weight.dtype 129 | 130 | down_hidden_states = self.down(hidden_states.to(dtype)) 131 | up_hidden_states = self.up(down_hidden_states) 132 | 133 | if self.network_alpha is not None: 134 | up_hidden_states *= self.network_alpha / self.rank 135 | 136 | return up_hidden_states.to(orig_dtype) 137 | 138 | 139 | class FLuxSelfAttnProcessor: 140 | def __call__(self, attn, x, pe, **attention_kwargs): 141 | qkv = attn.qkv(x) 142 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 143 | q, k = attn.norm(q, k, v) 144 | x = attention(q, k, v, pe=pe) 145 | x = attn.proj(x) 146 | return x 147 | 148 | 149 | class LoraFluxAttnProcessor(nn.Module): 150 | 151 | def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 152 | super().__init__() 153 | self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 154 | self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) 155 | self.lora_weight = lora_weight 156 | 157 | def __call__(self, attn, x, pe, **attention_kwargs): 158 | qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight 159 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 160 | q, k = attn.norm(q, k, v) 161 | x = attention(q, k, v, pe=pe) 162 | x = attn.proj(x) + self.proj_lora(x) * self.lora_weight 163 | return x 164 | 165 | 166 | class SelfAttention(nn.Module): 167 | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 168 | super().__init__() 169 | self.num_heads = num_heads 170 | head_dim = dim // num_heads 171 | 172 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 173 | self.norm = QKNorm(head_dim) 174 | self.proj = nn.Linear(dim, dim) 175 | 176 | def forward(): 177 | pass 178 | 179 | 180 | @dataclass 181 | class ModulationOut: 182 | shift: Tensor 183 | scale: Tensor 184 | gate: Tensor 185 | 186 | 187 | class Modulation(nn.Module): 188 | def __init__(self, dim: int, double: bool): 189 | super().__init__() 190 | self.is_double = double 191 | self.multiplier = 6 if double else 3 192 | self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) 193 | 194 | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: 195 | out = self.lin(nn.functional.silu(vec))[:, None, :].chunk( 196 | self.multiplier, dim=-1 197 | ) 198 | 199 | return ( 200 | ModulationOut(*out[:3]), 201 | ModulationOut(*out[3:]) if self.is_double else None, 202 | ) 203 | 204 | 205 | class DoubleStreamBlockLoraProcessor(nn.Module): 206 | def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 207 | super().__init__() 208 | self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 209 | self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) 210 | self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 211 | self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) 212 | self.lora_weight = lora_weight 213 | 214 | def forward(self, attn, img, txt, vec, pe, **attention_kwargs): 215 | img_mod1, img_mod2 = attn.img_mod(vec) 216 | txt_mod1, txt_mod2 = attn.txt_mod(vec) 217 | 218 | # prepare image for attention 219 | img_modulated = attn.img_norm1(img) 220 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 221 | img_qkv = ( 222 | attn.img_attn.qkv(img_modulated) 223 | + self.qkv_lora1(img_modulated) * self.lora_weight 224 | ) 225 | img_q, img_k, img_v = rearrange( 226 | img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads 227 | ) 228 | img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 229 | 230 | # prepare txt for attention 231 | txt_modulated = attn.txt_norm1(txt) 232 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 233 | txt_qkv = ( 234 | attn.txt_attn.qkv(txt_modulated) 235 | + self.qkv_lora2(txt_modulated) * self.lora_weight 236 | ) 237 | txt_q, txt_k, txt_v = rearrange( 238 | txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads 239 | ) 240 | txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 241 | 242 | # run actual attention 243 | q = torch.cat((txt_q, img_q), dim=2) 244 | k = torch.cat((txt_k, img_k), dim=2) 245 | v = torch.cat((txt_v, img_v), dim=2) 246 | 247 | attn1 = attention(q, k, v, pe=pe) 248 | txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 249 | 250 | # calculate the img bloks 251 | img = img + img_mod1.gate * ( 252 | attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight 253 | ) 254 | img = img + img_mod2.gate * attn.img_mlp( 255 | (1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift 256 | ) 257 | 258 | # calculate the txt bloks 259 | txt = txt + txt_mod1.gate * ( 260 | attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight 261 | ) 262 | txt = txt + txt_mod2.gate * attn.txt_mlp( 263 | (1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift 264 | ) 265 | return img, txt 266 | 267 | 268 | class DoubleStreamBlockProcessor: 269 | def __call__(self, attn, img, txt, vec, pe, **attention_kwargs): 270 | img_mod1, img_mod2 = attn.img_mod(vec) 271 | txt_mod1, txt_mod2 = attn.txt_mod(vec) 272 | 273 | # prepare image for attention 274 | img_modulated = attn.img_norm1(img) 275 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 276 | img_qkv = attn.img_attn.qkv(img_modulated) 277 | img_q, img_k, img_v = rearrange( 278 | img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim 279 | ) 280 | img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 281 | 282 | # prepare txt for attention 283 | txt_modulated = attn.txt_norm1(txt) 284 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 285 | txt_qkv = attn.txt_attn.qkv(txt_modulated) 286 | txt_q, txt_k, txt_v = rearrange( 287 | txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim 288 | ) 289 | txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 290 | 291 | # run actual attention 292 | q = torch.cat((txt_q, img_q), dim=2) 293 | k = torch.cat((txt_k, img_k), dim=2) 294 | v = torch.cat((txt_v, img_v), dim=2) 295 | 296 | attn1 = attention(q, k, v, pe=pe) 297 | txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 298 | 299 | # calculate the img bloks 300 | img = img + img_mod1.gate * attn.img_attn.proj(img_attn) 301 | img = img + img_mod2.gate * attn.img_mlp( 302 | (1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift 303 | ) 304 | 305 | # calculate the txt bloks 306 | txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) 307 | txt = txt + txt_mod2.gate * attn.txt_mlp( 308 | (1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift 309 | ) 310 | return img, txt 311 | 312 | 313 | class DoubleStreamBlock(nn.Module): 314 | def __init__( 315 | self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False 316 | ): 317 | super().__init__() 318 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 319 | self.num_heads = num_heads 320 | self.hidden_size = hidden_size 321 | self.head_dim = hidden_size // num_heads 322 | 323 | self.img_mod = Modulation(hidden_size, double=True) 324 | self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 325 | self.img_attn = SelfAttention( 326 | dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias 327 | ) 328 | 329 | self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 330 | self.img_mlp = nn.Sequential( 331 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 332 | nn.GELU(approximate="tanh"), 333 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 334 | ) 335 | 336 | self.txt_mod = Modulation(hidden_size, double=True) 337 | self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 338 | self.txt_attn = SelfAttention( 339 | dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias 340 | ) 341 | 342 | self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 343 | self.txt_mlp = nn.Sequential( 344 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 345 | nn.GELU(approximate="tanh"), 346 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 347 | ) 348 | processor = DoubleStreamBlockProcessor() 349 | self.set_processor(processor) 350 | 351 | def set_processor(self, processor) -> None: 352 | self.processor = processor 353 | 354 | def get_processor(self): 355 | return self.processor 356 | 357 | def forward( 358 | self, 359 | img: Tensor, 360 | txt: Tensor, 361 | vec: Tensor, 362 | pe: Tensor, 363 | image_proj: Tensor = None, 364 | ip_scale: float = 1.0, 365 | ) -> tuple[Tensor, Tensor]: 366 | if image_proj is None: 367 | return self.processor(self, img, txt, vec, pe) 368 | else: 369 | return self.processor(self, img, txt, vec, pe, image_proj, ip_scale) 370 | 371 | 372 | class SingleStreamBlockLoraProcessor(nn.Module): 373 | def __init__( 374 | self, dim: int, rank: int = 4, network_alpha=None, lora_weight: float = 1 375 | ): 376 | super().__init__() 377 | self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 378 | self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha) 379 | self.lora_weight = lora_weight 380 | 381 | def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: 382 | 383 | mod, _ = attn.modulation(vec) 384 | x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 385 | qkv, mlp = torch.split( 386 | attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1 387 | ) 388 | qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight 389 | 390 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 391 | q, k = attn.norm(q, k, v) 392 | 393 | # compute attention 394 | attn_1 = attention(q, k, v, pe=pe) 395 | 396 | # compute activation in mlp stream, cat again and run second linear layer 397 | output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) 398 | output = ( 399 | output 400 | + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) 401 | * self.lora_weight 402 | ) 403 | output = x + mod.gate * output 404 | return output 405 | 406 | 407 | class SingleStreamBlockProcessor: 408 | def __call__( 409 | self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs 410 | ) -> Tensor: 411 | 412 | mod, _ = attn.modulation(vec) 413 | x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 414 | qkv, mlp = torch.split( 415 | attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1 416 | ) 417 | 418 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 419 | q, k = attn.norm(q, k, v) 420 | 421 | # compute attention 422 | attn_1 = attention(q, k, v, pe=pe) 423 | 424 | # compute activation in mlp stream, cat again and run second linear layer 425 | output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) 426 | output = x + mod.gate * output 427 | return output 428 | 429 | 430 | class SingleStreamBlock(nn.Module): 431 | """ 432 | A DiT block with parallel linear layers as described in 433 | https://arxiv.org/abs/2302.05442 and adapted modulation interface. 434 | """ 435 | 436 | def __init__( 437 | self, 438 | hidden_size: int, 439 | num_heads: int, 440 | mlp_ratio: float = 4.0, 441 | qk_scale: float | None = None, 442 | ): 443 | super().__init__() 444 | self.hidden_dim = hidden_size 445 | self.num_heads = num_heads 446 | self.head_dim = hidden_size // num_heads 447 | self.scale = qk_scale or self.head_dim**-0.5 448 | 449 | self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 450 | # qkv and mlp_in 451 | self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 452 | # proj and mlp_out 453 | self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 454 | 455 | self.norm = QKNorm(self.head_dim) 456 | 457 | self.hidden_size = hidden_size 458 | self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 459 | 460 | self.mlp_act = nn.GELU(approximate="tanh") 461 | self.modulation = Modulation(hidden_size, double=False) 462 | 463 | processor = SingleStreamBlockProcessor() 464 | self.set_processor(processor) 465 | 466 | def set_processor(self, processor) -> None: 467 | self.processor = processor 468 | 469 | def get_processor(self): 470 | return self.processor 471 | 472 | def forward( 473 | self, 474 | x: Tensor, 475 | vec: Tensor, 476 | pe: Tensor, 477 | image_proj: Tensor | None = None, 478 | ip_scale: float = 1.0, 479 | ) -> Tensor: 480 | if image_proj is None: 481 | return self.processor(self, x, vec, pe) 482 | else: 483 | return self.processor(self, x, vec, pe, image_proj, ip_scale) 484 | 485 | 486 | class LastLayer(nn.Module): 487 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 488 | super().__init__() 489 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 490 | self.linear = nn.Linear( 491 | hidden_size, patch_size * patch_size * out_channels, bias=True 492 | ) 493 | self.adaLN_modulation = nn.Sequential( 494 | nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) 495 | ) 496 | 497 | def forward(self, x: Tensor, vec: Tensor) -> Tensor: 498 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) 499 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] 500 | x = self.linear(x) 501 | return x 502 | 503 | 504 | class SigLIPMultiFeatProjModel(torch.nn.Module): 505 | """ 506 | SigLIP Multi-Feature Projection Model for processing style features from different layers 507 | and projecting them into a unified hidden space. 508 | 509 | Args: 510 | siglip_token_nums (int): Number of SigLIP tokens, default 257 511 | style_token_nums (int): Number of style tokens, default 256 512 | siglip_token_dims (int): Dimension of SigLIP tokens, default 1536 513 | hidden_size (int): Hidden layer size, default 3072 514 | context_layer_norm (bool): Whether to use context layer normalization, default False 515 | """ 516 | 517 | def __init__( 518 | self, 519 | siglip_token_nums: int = 257, 520 | style_token_nums: int = 256, 521 | siglip_token_dims: int = 1536, 522 | hidden_size: int = 3072, 523 | context_layer_norm: bool = False, 524 | ): 525 | super().__init__() 526 | 527 | # High-level feature processing (layer -2) 528 | self.high_embedding_linear = nn.Sequential( 529 | nn.Linear(siglip_token_nums, style_token_nums), 530 | nn.SiLU() 531 | ) 532 | self.high_layer_norm = ( 533 | nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() 534 | ) 535 | self.high_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True) 536 | 537 | # Mid-level feature processing (layer -11) 538 | self.mid_embedding_linear = nn.Sequential( 539 | nn.Linear(siglip_token_nums, style_token_nums), 540 | nn.SiLU() 541 | ) 542 | self.mid_layer_norm = ( 543 | nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() 544 | ) 545 | self.mid_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True) 546 | 547 | # Low-level feature processing (layer -20) 548 | self.low_embedding_linear = nn.Sequential( 549 | nn.Linear(siglip_token_nums, style_token_nums), 550 | nn.SiLU() 551 | ) 552 | self.low_layer_norm = ( 553 | nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() 554 | ) 555 | self.low_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True) 556 | 557 | def forward(self, siglip_outputs): 558 | """ 559 | Forward pass function 560 | 561 | Args: 562 | siglip_outputs: Output from SigLIP model, containing hidden_states 563 | 564 | Returns: 565 | torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size] 566 | """ 567 | dtype = next(self.high_embedding_linear.parameters()).dtype 568 | 569 | # Process high-level features (layer -2) 570 | high_embedding = self._process_layer_features( 571 | siglip_outputs.hidden_states[-2], 572 | self.high_embedding_linear, 573 | self.high_layer_norm, 574 | self.high_projection, 575 | dtype 576 | ) 577 | 578 | # Process mid-level features (layer -11) 579 | mid_embedding = self._process_layer_features( 580 | siglip_outputs.hidden_states[-11], 581 | self.mid_embedding_linear, 582 | self.mid_layer_norm, 583 | self.mid_projection, 584 | dtype 585 | ) 586 | 587 | # Process low-level features (layer -20) 588 | low_embedding = self._process_layer_features( 589 | siglip_outputs.hidden_states[-20], 590 | self.low_embedding_linear, 591 | self.low_layer_norm, 592 | self.low_projection, 593 | dtype 594 | ) 595 | 596 | # Concatenate features from all layers 597 | return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1) 598 | 599 | def _process_layer_features( 600 | self, 601 | hidden_states: torch.Tensor, 602 | embedding_linear: nn.Module, 603 | layer_norm: nn.Module, 604 | projection: nn.Module, 605 | dtype: torch.dtype 606 | ) -> torch.Tensor: 607 | """ 608 | Helper function to process features from a single layer 609 | 610 | Args: 611 | hidden_states: Input hidden states [bs, seq_len, dim] 612 | embedding_linear: Embedding linear layer 613 | layer_norm: Layer normalization 614 | projection: Projection layer 615 | dtype: Target data type 616 | 617 | Returns: 618 | torch.Tensor: Processed features [bs, style_token_nums, hidden_size] 619 | """ 620 | # Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim] 621 | embedding = embedding_linear( 622 | hidden_states.to(dtype).transpose(1, 2) 623 | ).transpose(1, 2) 624 | 625 | # Apply layer normalization 626 | embedding = layer_norm(embedding) 627 | 628 | # Project to target hidden space 629 | embedding = projection(embedding) 630 | 631 | return embedding 632 | --------------------------------------------------------------------------------