├── src ├── flux │ ├── __main__.py │ ├── __init__.py │ ├── annotator │ │ ├── util.py │ │ └── dwpose │ │ │ ├── wholebody.py │ │ │ ├── __init__.py │ │ │ ├── onnxdet.py │ │ │ ├── util.py │ │ │ └── onnxpose.py │ ├── math.py │ ├── modules │ │ ├── conditioner.py │ │ ├── autoencoder.py │ │ └── layers.py │ ├── api.py │ ├── sampling.py │ ├── model.py │ ├── cli.py │ ├── train_pipeline.py │ └── util.py └── train.py ├── train_configs └── test_lora.yaml ├── README.md ├── inference.py └── pipeline.py /src/flux/__main__.py: -------------------------------------------------------------------------------- 1 | from .cli import app 2 | 3 | if __name__ == "__main__": 4 | app() 5 | -------------------------------------------------------------------------------- /src/flux/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._version import version as __version__ # type: ignore 3 | from ._version import version_tuple 4 | except ImportError: 5 | __version__ = "unknown (no version information available)" 6 | version_tuple = (0, 0, "unknown", "noinfo") 7 | 8 | from pathlib import Path 9 | 10 | PACKAGE = __package__.replace("_", "-") 11 | PACKAGE_ROOT = Path(__file__).parent 12 | -------------------------------------------------------------------------------- /train_configs/test_lora.yaml: -------------------------------------------------------------------------------- 1 | model_name: "flux-dev" 2 | data_config: 3 | train_batch_size: 1 4 | num_workers: 4 5 | img_size: 1024 6 | img_dir: images/ 7 | random_ratio: true # support multi crop preprocessing 8 | report_to: wandb 9 | train_batch_size: 1 10 | output_dir: lora/ 11 | max_train_steps: 100000 12 | learning_rate: 1e-5 13 | lr_scheduler: constant 14 | lr_warmup_steps: 10 15 | adam_beta1: 0.9 16 | adam_beta2: 0.999 17 | adam_weight_decay: 0.01 18 | adam_epsilon: 1e-8 19 | max_grad_norm: 1.0 20 | logging_dir: logs 21 | mixed_precision: "bf16" 22 | checkpointing_steps: 2500 23 | checkpoints_total_limit: 10 24 | tracker_project_name: lora_test 25 | resume_from_checkpoint: latest 26 | gradient_accumulation_steps: 2 27 | rank: 16 28 | single_blocks: "1,2,3,4" 29 | double_blocks: null 30 | disable_sampling: false 31 | sample_every: 250 # sample every this many steps 32 | sample_width: 1024 33 | sample_height: 1024 34 | sample_steps: 20 35 | -------------------------------------------------------------------------------- /src/flux/annotator/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | 5 | 6 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') 7 | 8 | 9 | def HWC3(x): 10 | assert x.dtype == np.uint8 11 | if x.ndim == 2: 12 | x = x[:, :, None] 13 | assert x.ndim == 3 14 | H, W, C = x.shape 15 | assert C == 1 or C == 3 or C == 4 16 | if C == 3: 17 | return x 18 | if C == 1: 19 | return np.concatenate([x, x, x], axis=2) 20 | if C == 4: 21 | color = x[:, :, 0:3].astype(np.float32) 22 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 23 | y = color * alpha + 255.0 * (1.0 - alpha) 24 | y = y.clip(0, 255).astype(np.uint8) 25 | return y 26 | 27 | 28 | def resize_image(input_image, resolution): 29 | H, W, C = input_image.shape 30 | H = float(H) 31 | W = float(W) 32 | k = float(resolution) / min(H, W) 33 | H *= k 34 | W *= k 35 | H = int(np.round(H / 64.0)) * 64 36 | W = int(np.round(W / 64.0)) * 64 37 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 38 | return img 39 | -------------------------------------------------------------------------------- /src/flux/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import Tensor 4 | 5 | 6 | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: 7 | q, k = apply_rope(q, k, pe) 8 | 9 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 10 | x = rearrange(x, "B H L D -> B L (H D)") 11 | 12 | return x 13 | 14 | 15 | def rope(pos: Tensor, dim: int, theta: int) -> Tensor: 16 | assert dim % 2 == 0 17 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim 18 | omega = 1.0 / (theta**scale) 19 | out = torch.einsum("...n,d->...nd", pos, omega) 20 | out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) 21 | out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) 22 | return out.float() 23 | 24 | 25 | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: 26 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 27 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 28 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 29 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 30 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HOComp: Interaction-Aware Human-Object Composition 2 | ### *NeurIPS 2025* 3 | 4 | This is the official repository for our paper: 5 | 6 | 📄 **HOComp: Interaction-Aware Human-Object Composition** 7 | 📚 *Preprint available on arXiv* 8 | 9 | --- 10 | 11 | ## 📝 Abstract 12 | **HOComp** is a novel framework for harmonizing foreground objects into human-centric backgrounds. 13 | By leveraging a **Flux.1 Kontext** base model and a novel **Sequence Concatenation** strategy, the method achieves precise control over **human–object interactions** with high fidelity. 14 | 15 | --- 16 | 17 | ## 🛠️ Custom Inference 18 | 19 | To generate a specific interaction, provide background / foreground images, the interaction prompt, and the foreground bounding box: 20 | 21 | ```bash 22 | python run_inference.py \ 23 | --prompt "A young man holding a vintage camera" \ 24 | --bg_path "examples/background.jpg" \ 25 | --fg_path "examples/camera.png" \ 26 | --box "[300 300 700 700]" 27 | ``` 28 | 29 | 30 | ## 📌 Citation 31 | 32 | If you find our work helpful, please consider citing: 33 | 34 | ```bibtex 35 | @article{liang2025hocomp, 36 | title={HOComp: Interaction-Aware Human-Object Composition}, 37 | author={Dong Liang and Jinyuan Jia and Yuhao Liu and Rynson W. H. Lau}, 38 | journal={arXiv preprint arXiv:2507.16813}, 39 | year={2025} 40 | } -------------------------------------------------------------------------------- /src/flux/modules/conditioner.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, 3 | T5Tokenizer) 4 | 5 | 6 | class HFEmbedder(nn.Module): 7 | def __init__(self, version: str, max_length: int, **hf_kwargs): 8 | super().__init__() 9 | self.is_clip = version.startswith("openai") 10 | self.max_length = max_length 11 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" 12 | 13 | if self.is_clip: 14 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) 15 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) 16 | else: 17 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) 18 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) 19 | 20 | self.hf_module = self.hf_module.eval().requires_grad_(False) 21 | 22 | def forward(self, text: list[str]) -> Tensor: 23 | batch_encoding = self.tokenizer( 24 | text, 25 | truncation=True, 26 | max_length=self.max_length, 27 | return_length=False, 28 | return_overflowing_tokens=False, 29 | padding="max_length", 30 | return_tensors="pt", 31 | ) 32 | 33 | outputs = self.hf_module( 34 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device), 35 | attention_mask=None, 36 | output_hidden_states=False, 37 | ) 38 | return outputs[self.output_key] 39 | -------------------------------------------------------------------------------- /src/flux/annotator/dwpose/wholebody.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import onnxruntime as ort 5 | from huggingface_hub import hf_hub_download 6 | from .onnxdet import inference_detector 7 | from .onnxpose import inference_pose 8 | 9 | 10 | class Wholebody: 11 | def __init__(self, device="cuda:0"): 12 | providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider'] 13 | onnx_det = hf_hub_download("yzd-v/DWPose", "yolox_l.onnx") 14 | onnx_pose = hf_hub_download("yzd-v/DWPose", "dw-ll_ucoco_384.onnx") 15 | 16 | self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) 17 | self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) 18 | 19 | def __call__(self, oriImg): 20 | det_result = inference_detector(self.session_det, oriImg) 21 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) 22 | 23 | keypoints_info = np.concatenate( 24 | (keypoints, scores[..., None]), axis=-1) 25 | # compute neck joint 26 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1) 27 | # neck score when visualizing pred 28 | neck[:, 2:4] = np.logical_and( 29 | keypoints_info[:, 5, 2:4] > 0.3, 30 | keypoints_info[:, 6, 2:4] > 0.3).astype(int) 31 | new_keypoints_info = np.insert( 32 | keypoints_info, 17, neck, axis=1) 33 | mmpose_idx = [ 34 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 35 | ] 36 | openpose_idx = [ 37 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 38 | ] 39 | new_keypoints_info[:, openpose_idx] = \ 40 | new_keypoints_info[:, mmpose_idx] 41 | keypoints_info = new_keypoints_info 42 | 43 | keypoints, scores = keypoints_info[ 44 | ..., :2], keypoints_info[..., 2] 45 | 46 | return keypoints, scores 47 | 48 | 49 | -------------------------------------------------------------------------------- /src/flux/annotator/dwpose/__init__.py: -------------------------------------------------------------------------------- 1 | # Openpose 2 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose 3 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose 4 | # 3rd Edited by ControlNet 5 | # 4th Edited by ControlNet (added face and correct hands) 6 | 7 | import os 8 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 9 | 10 | import torch 11 | import numpy as np 12 | from . import util 13 | from .wholebody import Wholebody 14 | 15 | def draw_pose(pose, H, W): 16 | bodies = pose['bodies'] 17 | faces = pose['faces'] 18 | hands = pose['hands'] 19 | candidate = bodies['candidate'] 20 | subset = bodies['subset'] 21 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) 22 | 23 | canvas = util.draw_bodypose(canvas, candidate, subset) 24 | 25 | canvas = util.draw_handpose(canvas, hands) 26 | 27 | canvas = util.draw_facepose(canvas, faces) 28 | 29 | return canvas 30 | 31 | 32 | class DWposeDetector: 33 | def __init__(self, device): 34 | 35 | self.pose_estimation = Wholebody(device) 36 | 37 | def __call__(self, oriImg): 38 | oriImg = oriImg.copy() 39 | H, W, C = oriImg.shape 40 | with torch.no_grad(): 41 | candidate, subset = self.pose_estimation(oriImg) 42 | nums, keys, locs = candidate.shape 43 | candidate[..., 0] /= float(W) 44 | candidate[..., 1] /= float(H) 45 | body = candidate[:,:18].copy() 46 | body = body.reshape(nums*18, locs) 47 | score = subset[:,:18] 48 | for i in range(len(score)): 49 | for j in range(len(score[i])): 50 | if score[i][j] > 0.3: 51 | score[i][j] = int(18*i+j) 52 | else: 53 | score[i][j] = -1 54 | 55 | un_visible = subset<0.3 56 | candidate[un_visible] = -1 57 | 58 | foot = candidate[:,18:24] 59 | 60 | faces = candidate[:,24:92] 61 | 62 | hands = candidate[:,92:113] 63 | hands = np.vstack([hands, candidate[:,113:]]) 64 | 65 | bodies = dict(candidate=body, subset=score) 66 | pose = dict(bodies=bodies, hands=hands, faces=faces) 67 | 68 | return draw_pose(pose, H, W) 69 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from PIL import Image 5 | from hocomp_pipeline import HOCompPipeline 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description="HOComp Inference Script") 9 | 10 | # --- Model Paths --- 11 | parser.add_argument("--base_model", type=str, default="black-forest-labs/FLUX.1-Kontext-dev") 12 | parser.add_argument("--lora_path", type=str, default="./checkpoints/hocomp_v1.safetensors") 13 | 14 | # 1. Background Image Path 15 | parser.add_argument("--bg_path", type=str, default="path/to/your/background.jpg", 16 | help="Path to the background image containing a human subject") 17 | 18 | # 2. Foreground Image Path 19 | parser.add_argument("--fg_path", type=str, default="path/to/your/object.png", 20 | help="Path to the foreground object image") 21 | 22 | parser.add_argument("--prompt", type=str, 23 | default="A young man holding a vintage camera", 24 | ) 25 | parser.add_argument("--box", type=int, nargs=4, default=[300, 300, 700, 700]) 26 | 27 | return parser.parse_args() 28 | 29 | def load_image_data(path, label, size=(1024, 1024)): 30 | 31 | 32 | if os.path.exists(path): 33 | return Image.open(path).convert("RGB") 34 | else: 35 | # Fallback for code demonstration so it doesn't crash 36 | return Image.new("RGB", size, (200, 200, 200)) 37 | 38 | def main(): 39 | args = parse_args() 40 | 41 | # 1. Initialize Pipeline 42 | pipeline = HOCompPipeline( 43 | base_model_id=args.base_model, 44 | local_lora_path=args.lora_path 45 | ) 46 | 47 | # 2. Load Visual Inputs 48 | bg_img = load_image_data(args.bg_path, "Background") 49 | fg_img = load_image_data(args.fg_path, "Foreground", size=(512, 512)) 50 | 51 | # 3. Load Semantic & Spatial Inputs (Prompt & Box) 52 | 53 | # 4. Run Sequence Concatenation Pipeline 54 | result = pipeline( 55 | prompt=args.prompt, 56 | bg_img=bg_img, 57 | fg_img=fg_img, 58 | box=args.box 59 | ) 60 | 61 | # 5. Save Output 62 | output_path = "output_hocomp.png" 63 | result.save(output_path) 64 | 65 | if __name__ == "__main__": 66 | main() -------------------------------------------------------------------------------- /src/flux/annotator/dwpose/onnxdet.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import onnxruntime 5 | 6 | def nms(boxes, scores, nms_thr): 7 | """Single class NMS implemented in Numpy.""" 8 | x1 = boxes[:, 0] 9 | y1 = boxes[:, 1] 10 | x2 = boxes[:, 2] 11 | y2 = boxes[:, 3] 12 | 13 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 14 | order = scores.argsort()[::-1] 15 | 16 | keep = [] 17 | while order.size > 0: 18 | i = order[0] 19 | keep.append(i) 20 | xx1 = np.maximum(x1[i], x1[order[1:]]) 21 | yy1 = np.maximum(y1[i], y1[order[1:]]) 22 | xx2 = np.minimum(x2[i], x2[order[1:]]) 23 | yy2 = np.minimum(y2[i], y2[order[1:]]) 24 | 25 | w = np.maximum(0.0, xx2 - xx1 + 1) 26 | h = np.maximum(0.0, yy2 - yy1 + 1) 27 | inter = w * h 28 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 29 | 30 | inds = np.where(ovr <= nms_thr)[0] 31 | order = order[inds + 1] 32 | 33 | return keep 34 | 35 | def multiclass_nms(boxes, scores, nms_thr, score_thr): 36 | """Multiclass NMS implemented in Numpy. Class-aware version.""" 37 | final_dets = [] 38 | num_classes = scores.shape[1] 39 | for cls_ind in range(num_classes): 40 | cls_scores = scores[:, cls_ind] 41 | valid_score_mask = cls_scores > score_thr 42 | if valid_score_mask.sum() == 0: 43 | continue 44 | else: 45 | valid_scores = cls_scores[valid_score_mask] 46 | valid_boxes = boxes[valid_score_mask] 47 | keep = nms(valid_boxes, valid_scores, nms_thr) 48 | if len(keep) > 0: 49 | cls_inds = np.ones((len(keep), 1)) * cls_ind 50 | dets = np.concatenate( 51 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 52 | ) 53 | final_dets.append(dets) 54 | if len(final_dets) == 0: 55 | return None 56 | return np.concatenate(final_dets, 0) 57 | 58 | def demo_postprocess(outputs, img_size, p6=False): 59 | grids = [] 60 | expanded_strides = [] 61 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] 62 | 63 | hsizes = [img_size[0] // stride for stride in strides] 64 | wsizes = [img_size[1] // stride for stride in strides] 65 | 66 | for hsize, wsize, stride in zip(hsizes, wsizes, strides): 67 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) 68 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2) 69 | grids.append(grid) 70 | shape = grid.shape[:2] 71 | expanded_strides.append(np.full((*shape, 1), stride)) 72 | 73 | grids = np.concatenate(grids, 1) 74 | expanded_strides = np.concatenate(expanded_strides, 1) 75 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides 76 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides 77 | 78 | return outputs 79 | 80 | def preprocess(img, input_size, swap=(2, 0, 1)): 81 | if len(img.shape) == 3: 82 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 83 | else: 84 | padded_img = np.ones(input_size, dtype=np.uint8) * 114 85 | 86 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 87 | resized_img = cv2.resize( 88 | img, 89 | (int(img.shape[1] * r), int(img.shape[0] * r)), 90 | interpolation=cv2.INTER_LINEAR, 91 | ).astype(np.uint8) 92 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 93 | 94 | padded_img = padded_img.transpose(swap) 95 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 96 | return padded_img, r 97 | 98 | def inference_detector(session, oriImg): 99 | input_shape = (640,640) 100 | img, ratio = preprocess(oriImg, input_shape) 101 | 102 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} 103 | output = session.run(None, ort_inputs) 104 | predictions = demo_postprocess(output[0], input_shape)[0] 105 | 106 | boxes = predictions[:, :4] 107 | scores = predictions[:, 4:5] * predictions[:, 5:] 108 | 109 | boxes_xyxy = np.ones_like(boxes) 110 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. 111 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. 112 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. 113 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. 114 | boxes_xyxy /= ratio 115 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) 116 | if dets is not None: 117 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] 118 | isscore = final_scores>0.3 119 | iscat = final_cls_inds == 0 120 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)] 121 | final_boxes = final_boxes[isbbox] 122 | else: 123 | final_boxes = np.array([]) 124 | 125 | return final_boxes 126 | -------------------------------------------------------------------------------- /pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from diffusers import FluxTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler 5 | from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast 6 | from diffusers.image_processor import VaeImageProcessor 7 | import os 8 | 9 | 10 | class LocalEncoderWrapper(nn.Module): 11 | def __init__(self, name, device, dtype): 12 | super().__init__() 13 | self.name = name 14 | self.projector = nn.Linear(768, 4).to(device, dtype=dtype) 15 | 16 | def forward(self, x): 17 | 18 | return torch.randn(x.shape[0], 256, 4, device=x.device, dtype=x.dtype) 19 | 20 | 21 | class HOCompPipeline: 22 | def __init__(self, base_model_id, local_lora_path, device="cuda", dtype=torch.bfloat16): 23 | self.device = device 24 | self.dtype = dtype 25 | 26 | 27 | 28 | self.transformer = FluxTransformer2DModel.from_pretrained( 29 | base_model_id, subfolder="transformer", torch_dtype=dtype 30 | ).to(device) 31 | self.transformer.requires_grad_(False) 32 | 33 | if os.path.exists(local_lora_path): 34 | self.transformer.load_lora_weights(local_lora_path, adapter_name="hocomp") 35 | self.transformer.fuse_lora(lora_scale=1.0) 36 | 37 | self.vae = AutoencoderKL.from_pretrained(base_model_id, subfolder="vae", torch_dtype=dtype).to(device) 38 | self.text_encoder_2 = T5EncoderModel.from_pretrained(base_model_id, subfolder="text_encoder_2", torch_dtype=dtype).to(device) 39 | self.tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_id, subfolder="tokenizer_2") 40 | self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_id, subfolder="scheduler") 41 | self.img_proc = VaeImageProcessor(vae_scale_factor=8) 42 | 43 | self.id_encoder = LocalEncoderWrapper("ID_Net", device, dtype) 44 | self.detail_encoder = LocalEncoderWrapper("Detail_Net", device, dtype) 45 | 46 | self.mask_proj = nn.Linear(1, 4).to(device, dtype=dtype) 47 | 48 | def encode_text(self, prompt): 49 | """Standard T5 Encoding.""" 50 | txt_input = self.tokenizer_2(prompt, padding="max_length", max_length=512, truncation=True, return_tensors="pt").to(self.device) 51 | return self.text_encoder_2(txt_input.input_ids)[0].to(self.dtype) 52 | 53 | def image_to_tokens(self, image): 54 | """Encodes Image to 4-channel Tokens via VAE.""" 55 | img_t = self.img_proc.preprocess(image, height=1024, width=1024).to(self.device, dtype=self.dtype) 56 | latents = self.vae.encode(img_t).latent_dist.sample() * self.vae.config.scaling_factor 57 | # [B, 4, H, W] -> [B, H*W, 4] 58 | b, c, h, w = latents.shape 59 | return latents.permute(0, 2, 3, 1).reshape(b, -1, c) 60 | 61 | @torch.no_grad() 62 | def __call__(self, prompt, bg_img, fg_img, box, num_inference_steps=25): 63 | 64 | 65 | 66 | 67 | text_emb = self.encode_text(prompt) # [1, 512, 4096] 68 | 69 | 70 | bg_tokens = self.image_to_tokens(bg_img) 71 | fg_id_tokens = self.id_encoder(fg_img) 72 | fg_det_tokens = self.detail_encoder(fg_img) 73 | 74 | H, W = 1024, 1024 75 | mask = torch.zeros((1, 1024, 1024, 1), device=self.device, dtype=self.dtype) 76 | mask[:, box[1]:box[3], box[0]:box[2], :] = 1.0 77 | 78 | mask_tokens = F.interpolate(mask.permute(0,3,1,2), size=(128,128), mode="nearest").permute(0,2,3,1).reshape(1, -1, 1) 79 | mask_tokens = self.mask_proj(mask_tokens) # [1, 16384, 4] 80 | 81 | 82 | noise = torch.randn(1, 4, 128, 128, device=self.device, dtype=self.dtype) 83 | noise_tokens = noise.permute(0, 2, 3, 1).reshape(1, -1, 4) 84 | 85 | combined_tokens = torch.cat([ 86 | noise_tokens, 87 | bg_tokens, 88 | fg_id_tokens, 89 | fg_det_tokens, 90 | mask_tokens 91 | ], dim=1) 92 | 93 | self.scheduler.set_timesteps(num_inference_steps) 94 | 95 | latents_seq = combined_tokens 96 | 97 | for t in self.scheduler.timesteps: 98 | 99 | 100 | output = self.transformer( 101 | hidden_states=latents_seq, # <--- 4 Channel Sequence 102 | encoder_hidden_states=text_emb, # <--- Text 103 | timestep=t, 104 | return_dict=False 105 | )[0] 106 | 107 | noise_pred = output[:, :noise_tokens.shape[1], :] 108 | 109 | noise_pred_map = noise_pred.reshape(1, 128, 128, 4).permute(0, 3, 1, 2) 110 | noise_map = latents_seq[:, :noise_tokens.shape[1], :].reshape(1, 128, 128, 4).permute(0, 3, 1, 2) 111 | 112 | updated_noise = self.scheduler.step(noise_pred_map, t, noise_map).prev_sample 113 | 114 | updated_noise_tokens = updated_noise.permute(0, 2, 3, 1).reshape(1, -1, 4) 115 | latents_seq = torch.cat([updated_noise_tokens, bg_tokens, fg_id_tokens, fg_det_tokens, mask_tokens], dim=1) 116 | 117 | image = self.vae.decode(updated_noise / self.vae.config.scaling_factor).sample 118 | return image[0] -------------------------------------------------------------------------------- /src/flux/api.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import time 4 | from pathlib import Path 5 | 6 | import requests 7 | from PIL import Image 8 | 9 | API_ENDPOINT = "https://api.bfl.ml" 10 | 11 | 12 | class ApiException(Exception): 13 | def __init__(self, status_code: int, detail: str | list[dict] | None = None): 14 | super().__init__() 15 | self.detail = detail 16 | self.status_code = status_code 17 | 18 | def __str__(self) -> str: 19 | return self.__repr__() 20 | 21 | def __repr__(self) -> str: 22 | if self.detail is None: 23 | message = None 24 | elif isinstance(self.detail, str): 25 | message = self.detail 26 | else: 27 | message = "[" + ",".join(d["msg"] for d in self.detail) + "]" 28 | return f"ApiException({self.status_code=}, {message=}, detail={self.detail})" 29 | 30 | 31 | class ImageRequest: 32 | def __init__( 33 | self, 34 | prompt: str, 35 | width: int = 1024, 36 | height: int = 1024, 37 | name: str = "flux.1-pro", 38 | num_steps: int = 50, 39 | prompt_upsampling: bool = False, 40 | seed: int | None = None, 41 | validate: bool = True, 42 | launch: bool = True, 43 | api_key: str | None = None, 44 | ): 45 | """ 46 | Manages an image generation request to the API. 47 | 48 | Args: 49 | prompt: Prompt to sample 50 | width: Width of the image in pixel 51 | height: Height of the image in pixel 52 | name: Name of the model 53 | num_steps: Number of network evaluations 54 | prompt_upsampling: Use prompt upsampling 55 | seed: Fix the generation seed 56 | validate: Run input validation 57 | launch: Directly launches request 58 | api_key: Your API key if not provided by the environment 59 | 60 | Raises: 61 | ValueError: For invalid input 62 | ApiException: For errors raised from the API 63 | """ 64 | if validate: 65 | if name not in ["flux.1-pro"]: 66 | raise ValueError(f"Invalid model {name}") 67 | elif width % 32 != 0: 68 | raise ValueError(f"width must be divisible by 32, got {width}") 69 | elif not (256 <= width <= 1440): 70 | raise ValueError(f"width must be between 256 and 1440, got {width}") 71 | elif height % 32 != 0: 72 | raise ValueError(f"height must be divisible by 32, got {height}") 73 | elif not (256 <= height <= 1440): 74 | raise ValueError(f"height must be between 256 and 1440, got {height}") 75 | elif not (1 <= num_steps <= 50): 76 | raise ValueError(f"steps must be between 1 and 50, got {num_steps}") 77 | 78 | self.request_json = { 79 | "prompt": prompt, 80 | "width": width, 81 | "height": height, 82 | "variant": name, 83 | "steps": num_steps, 84 | "prompt_upsampling": prompt_upsampling, 85 | } 86 | if seed is not None: 87 | self.request_json["seed"] = seed 88 | 89 | self.request_id: str | None = None 90 | self.result: dict | None = None 91 | self._image_bytes: bytes | None = None 92 | self._url: str | None = None 93 | if api_key is None: 94 | self.api_key = os.environ.get("BFL_API_KEY") 95 | else: 96 | self.api_key = api_key 97 | 98 | if launch: 99 | self.request() 100 | 101 | def request(self): 102 | """ 103 | Request to generate the image. 104 | """ 105 | if self.request_id is not None: 106 | return 107 | response = requests.post( 108 | f"{API_ENDPOINT}/v1/image", 109 | headers={ 110 | "accept": "application/json", 111 | "x-key": self.api_key, 112 | "Content-Type": "application/json", 113 | }, 114 | json=self.request_json, 115 | ) 116 | result = response.json() 117 | if response.status_code != 200: 118 | raise ApiException(status_code=response.status_code, detail=result.get("detail")) 119 | self.request_id = response.json()["id"] 120 | 121 | def retrieve(self) -> dict: 122 | """ 123 | Wait for the generation to finish and retrieve response. 124 | """ 125 | if self.request_id is None: 126 | self.request() 127 | while self.result is None: 128 | response = requests.get( 129 | f"{API_ENDPOINT}/v1/get_result", 130 | headers={ 131 | "accept": "application/json", 132 | "x-key": self.api_key, 133 | }, 134 | params={ 135 | "id": self.request_id, 136 | }, 137 | ) 138 | result = response.json() 139 | if "status" not in result: 140 | raise ApiException(status_code=response.status_code, detail=result.get("detail")) 141 | elif result["status"] == "Ready": 142 | self.result = result["result"] 143 | elif result["status"] == "Pending": 144 | time.sleep(0.5) 145 | else: 146 | raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") 147 | return self.result 148 | 149 | @property 150 | def bytes(self) -> bytes: 151 | """ 152 | Generated image as bytes. 153 | """ 154 | if self._image_bytes is None: 155 | response = requests.get(self.url) 156 | if response.status_code == 200: 157 | self._image_bytes = response.content 158 | else: 159 | raise ApiException(status_code=response.status_code) 160 | return self._image_bytes 161 | 162 | @property 163 | def url(self) -> str: 164 | """ 165 | Public url to retrieve the image from 166 | """ 167 | if self._url is None: 168 | result = self.retrieve() 169 | self._url = result["sample"] 170 | return self._url 171 | 172 | @property 173 | def image(self) -> Image.Image: 174 | """ 175 | Load the image as a PIL Image 176 | """ 177 | return Image.open(io.BytesIO(self.bytes)) 178 | 179 | def save(self, path: str): 180 | """ 181 | Save the generated image to a local path 182 | """ 183 | suffix = Path(self.url).suffix 184 | if not path.endswith(suffix): 185 | path = path + suffix 186 | Path(path).resolve().parent.mkdir(parents=True, exist_ok=True) 187 | with open(path, "wb") as file: 188 | file.write(self.bytes) 189 | 190 | 191 | if __name__ == "__main__": 192 | from fire import Fire 193 | 194 | Fire(ImageRequest) 195 | -------------------------------------------------------------------------------- /src/flux/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable 3 | 4 | import torch 5 | from einops import rearrange, repeat 6 | from torch import Tensor 7 | 8 | from .model import Flux 9 | from .modules.conditioner import HFEmbedder 10 | 11 | 12 | def get_noise( 13 | num_samples: int, 14 | height: int, 15 | width: int, 16 | device: torch.device, 17 | dtype: torch.dtype, 18 | seed: int, 19 | ): 20 | return torch.randn( 21 | num_samples, 22 | 16, 23 | # allow for packing 24 | 2 * math.ceil(height / 16), 25 | 2 * math.ceil(width / 16), 26 | device=device, 27 | dtype=dtype, 28 | generator=torch.Generator(device=device).manual_seed(seed), 29 | ) 30 | 31 | 32 | def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: 33 | bs, c, h, w = img.shape 34 | if bs == 1 and not isinstance(prompt, str): 35 | bs = len(prompt) 36 | 37 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 38 | if img.shape[0] == 1 and bs > 1: 39 | img = repeat(img, "1 ... -> bs ...", bs=bs) 40 | 41 | img_ids = torch.zeros(h // 2, w // 2, 3) 42 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 43 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 44 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 45 | 46 | if isinstance(prompt, str): 47 | prompt = [prompt] 48 | txt = t5(prompt) 49 | if txt.shape[0] == 1 and bs > 1: 50 | txt = repeat(txt, "1 ... -> bs ...", bs=bs) 51 | txt_ids = torch.zeros(bs, txt.shape[1], 3) 52 | 53 | vec = clip(prompt) 54 | if vec.shape[0] == 1 and bs > 1: 55 | vec = repeat(vec, "1 ... -> bs ...", bs=bs) 56 | 57 | return { 58 | "img": img, 59 | "img_ids": img_ids.to(img.device), 60 | "txt": txt.to(img.device), 61 | "txt_ids": txt_ids.to(img.device), 62 | "vec": vec.to(img.device), 63 | } 64 | 65 | 66 | def time_shift(mu: float, sigma: float, t: Tensor): 67 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 68 | 69 | 70 | def get_lin_function( 71 | x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 72 | ) -> Callable[[float], float]: 73 | m = (y2 - y1) / (x2 - x1) 74 | b = y1 - m * x1 75 | return lambda x: m * x + b 76 | 77 | 78 | def get_schedule( 79 | num_steps: int, 80 | image_seq_len: int, 81 | base_shift: float = 0.5, 82 | max_shift: float = 1.15, 83 | shift: bool = True, 84 | ) -> list[float]: 85 | # extra step for zero 86 | timesteps = torch.linspace(1, 0, num_steps + 1) 87 | 88 | # shifting the schedule to favor high timesteps for higher signal images 89 | if shift: 90 | # eastimate mu based on linear estimation between two points 91 | mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) 92 | timesteps = time_shift(mu, 1.0, timesteps) 93 | 94 | return timesteps.tolist() 95 | 96 | 97 | def denoise( 98 | model: Flux, 99 | # model input 100 | img: Tensor, 101 | img_ids: Tensor, 102 | txt: Tensor, 103 | txt_ids: Tensor, 104 | vec: Tensor, 105 | neg_txt: Tensor, 106 | neg_txt_ids: Tensor, 107 | neg_vec: Tensor, 108 | # sampling parameters 109 | timesteps: list[float], 110 | guidance: float = 4.0, 111 | true_gs = 1, 112 | timestep_to_start_cfg=0, 113 | # ip-adapter parameters 114 | image_proj: Tensor=None, 115 | neg_image_proj: Tensor=None, 116 | ip_scale: Tensor | float = 1.0, 117 | neg_ip_scale: Tensor | float = 1.0 118 | ): 119 | i = 0 120 | # this is ignored for schnell 121 | guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) 122 | for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): 123 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) 124 | pred = model( 125 | img=img, 126 | img_ids=img_ids, 127 | txt=txt, 128 | txt_ids=txt_ids, 129 | y=vec, 130 | timesteps=t_vec, 131 | guidance=guidance_vec, 132 | image_proj=image_proj, 133 | ip_scale=ip_scale, 134 | ) 135 | if i >= timestep_to_start_cfg: 136 | neg_pred = model( 137 | img=img, 138 | img_ids=img_ids, 139 | txt=neg_txt, 140 | txt_ids=neg_txt_ids, 141 | y=neg_vec, 142 | timesteps=t_vec, 143 | guidance=guidance_vec, 144 | image_proj=neg_image_proj, 145 | ip_scale=neg_ip_scale, 146 | ) 147 | pred = neg_pred + true_gs * (pred - neg_pred) 148 | img = img + (t_prev - t_curr) * pred 149 | i += 1 150 | return img 151 | 152 | def denoise_controlnet( 153 | model: Flux, 154 | controlnet:None, 155 | # model input 156 | img: Tensor, 157 | img_ids: Tensor, 158 | txt: Tensor, 159 | txt_ids: Tensor, 160 | vec: Tensor, 161 | neg_txt: Tensor, 162 | neg_txt_ids: Tensor, 163 | neg_vec: Tensor, 164 | controlnet_cond, 165 | # sampling parameters 166 | timesteps: list[float], 167 | guidance: float = 4.0, 168 | true_gs = 1, 169 | controlnet_gs=0.7, 170 | timestep_to_start_cfg=0, 171 | # ip-adapter parameters 172 | image_proj: Tensor=None, 173 | neg_image_proj: Tensor=None, 174 | ip_scale: Tensor | float = 1, 175 | neg_ip_scale: Tensor | float = 1, 176 | ): 177 | # this is ignored for schnell 178 | i = 0 179 | guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) 180 | for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): 181 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) 182 | block_res_samples = controlnet( 183 | img=img, 184 | img_ids=img_ids, 185 | controlnet_cond=controlnet_cond, 186 | txt=txt, 187 | txt_ids=txt_ids, 188 | y=vec, 189 | timesteps=t_vec, 190 | guidance=guidance_vec, 191 | ) 192 | pred = model( 193 | img=img, 194 | img_ids=img_ids, 195 | txt=txt, 196 | txt_ids=txt_ids, 197 | y=vec, 198 | timesteps=t_vec, 199 | guidance=guidance_vec, 200 | block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples], 201 | image_proj=image_proj, 202 | ip_scale=ip_scale, 203 | ) 204 | if i >= timestep_to_start_cfg: 205 | neg_block_res_samples = controlnet( 206 | img=img, 207 | img_ids=img_ids, 208 | controlnet_cond=controlnet_cond, 209 | txt=neg_txt, 210 | txt_ids=neg_txt_ids, 211 | y=neg_vec, 212 | timesteps=t_vec, 213 | guidance=guidance_vec, 214 | ) 215 | neg_pred = model( 216 | img=img, 217 | img_ids=img_ids, 218 | txt=neg_txt, 219 | txt_ids=neg_txt_ids, 220 | y=neg_vec, 221 | timesteps=t_vec, 222 | guidance=guidance_vec, 223 | block_controlnet_hidden_states=[i * controlnet_gs for i in neg_block_res_samples], 224 | image_proj=neg_image_proj, 225 | ip_scale=neg_ip_scale, 226 | ) 227 | pred = neg_pred + true_gs * (pred - neg_pred) 228 | 229 | img = img + (t_prev - t_curr) * pred 230 | 231 | i += 1 232 | return img 233 | 234 | def unpack(x: Tensor, height: int, width: int) -> Tensor: 235 | return rearrange( 236 | x, 237 | "b (h w) (c ph pw) -> b c (h ph) (w pw)", 238 | h=math.ceil(height / 16), 239 | w=math.ceil(width / 16), 240 | ph=2, 241 | pw=2, 242 | ) 243 | -------------------------------------------------------------------------------- /src/flux/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from einops import rearrange 6 | 7 | from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, 8 | MLPEmbedder, SingleStreamBlock, 9 | timestep_embedding) 10 | 11 | 12 | @dataclass 13 | class FluxParams: 14 | in_channels: int 15 | vec_in_dim: int 16 | context_in_dim: int 17 | hidden_size: int 18 | mlp_ratio: float 19 | num_heads: int 20 | depth: int 21 | depth_single_blocks: int 22 | axes_dim: list[int] 23 | theta: int 24 | qkv_bias: bool 25 | guidance_embed: bool 26 | 27 | 28 | class Flux(nn.Module): 29 | """ 30 | Transformer model for flow matching on sequences. 31 | """ 32 | _supports_gradient_checkpointing = True 33 | 34 | def __init__(self, params: FluxParams): 35 | super().__init__() 36 | 37 | self.params = params 38 | self.in_channels = params.in_channels 39 | self.out_channels = self.in_channels 40 | if params.hidden_size % params.num_heads != 0: 41 | raise ValueError( 42 | f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" 43 | ) 44 | pe_dim = params.hidden_size // params.num_heads 45 | if sum(params.axes_dim) != pe_dim: 46 | raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 47 | self.hidden_size = params.hidden_size 48 | self.num_heads = params.num_heads 49 | self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 50 | self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 51 | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 52 | self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 53 | self.guidance_in = ( 54 | MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() 55 | ) 56 | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 57 | 58 | self.double_blocks = nn.ModuleList( 59 | [ 60 | DoubleStreamBlock( 61 | self.hidden_size, 62 | self.num_heads, 63 | mlp_ratio=params.mlp_ratio, 64 | qkv_bias=params.qkv_bias, 65 | ) 66 | for _ in range(params.depth) 67 | ] 68 | ) 69 | 70 | self.single_blocks = nn.ModuleList( 71 | [ 72 | SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) 73 | for _ in range(params.depth_single_blocks) 74 | ] 75 | ) 76 | 77 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 78 | self.gradient_checkpointing = False 79 | 80 | def _set_gradient_checkpointing(self, module, value=False): 81 | if hasattr(module, "gradient_checkpointing"): 82 | module.gradient_checkpointing = value 83 | 84 | @property 85 | def attn_processors(self): 86 | # set recursively 87 | processors = {} 88 | 89 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): 90 | if hasattr(module, "set_processor"): 91 | processors[f"{name}.processor"] = module.processor 92 | 93 | for sub_name, child in module.named_children(): 94 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 95 | 96 | return processors 97 | 98 | for name, module in self.named_children(): 99 | fn_recursive_add_processors(name, module, processors) 100 | 101 | return processors 102 | 103 | def set_attn_processor(self, processor): 104 | r""" 105 | Sets the attention processor to use to compute attention. 106 | 107 | Parameters: 108 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 109 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 110 | for **all** `Attention` layers. 111 | 112 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 113 | processor. This is strongly recommended when setting trainable attention processors. 114 | 115 | """ 116 | count = len(self.attn_processors.keys()) 117 | 118 | if isinstance(processor, dict) and len(processor) != count: 119 | raise ValueError( 120 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 121 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 122 | ) 123 | 124 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 125 | if hasattr(module, "set_processor"): 126 | if not isinstance(processor, dict): 127 | module.set_processor(processor) 128 | else: 129 | module.set_processor(processor.pop(f"{name}.processor")) 130 | 131 | for sub_name, child in module.named_children(): 132 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 133 | 134 | for name, module in self.named_children(): 135 | fn_recursive_attn_processor(name, module, processor) 136 | 137 | def forward( 138 | self, 139 | img: Tensor, 140 | img_ids: Tensor, 141 | txt: Tensor, 142 | txt_ids: Tensor, 143 | timesteps: Tensor, 144 | y: Tensor, 145 | block_controlnet_hidden_states=None, 146 | guidance: Tensor | None = None, 147 | image_proj: Tensor | None = None, 148 | ip_scale: Tensor | float = 1.0, 149 | ) -> Tensor: 150 | if img.ndim != 3 or txt.ndim != 3: 151 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 152 | 153 | # running on sequences img 154 | img = self.img_in(img) 155 | vec = self.time_in(timestep_embedding(timesteps, 256)) 156 | if self.params.guidance_embed: 157 | if guidance is None: 158 | raise ValueError("Didn't get guidance strength for guidance distilled model.") 159 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 160 | vec = vec + self.vector_in(y) 161 | txt = self.txt_in(txt) 162 | 163 | ids = torch.cat((txt_ids, img_ids), dim=1) 164 | pe = self.pe_embedder(ids) 165 | if block_controlnet_hidden_states is not None: 166 | controlnet_depth = len(block_controlnet_hidden_states) 167 | for index_block, block in enumerate(self.double_blocks): 168 | if self.training and self.gradient_checkpointing: 169 | 170 | def create_custom_forward(module, return_dict=None): 171 | def custom_forward(*inputs): 172 | if return_dict is not None: 173 | return module(*inputs, return_dict=return_dict) 174 | else: 175 | return module(*inputs) 176 | 177 | return custom_forward 178 | 179 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 180 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 181 | create_custom_forward(block), 182 | img, 183 | txt, 184 | vec, 185 | pe, 186 | image_proj, 187 | ip_scale, 188 | ) 189 | else: 190 | img, txt = block( 191 | img=img, 192 | txt=txt, 193 | vec=vec, 194 | pe=pe, 195 | image_proj=image_proj, 196 | ip_scale=ip_scale, 197 | ) 198 | # controlnet residual 199 | if block_controlnet_hidden_states is not None: 200 | img = img + block_controlnet_hidden_states[index_block % 2] 201 | 202 | 203 | img = torch.cat((txt, img), 1) 204 | for block in self.single_blocks: 205 | if self.training and self.gradient_checkpointing: 206 | 207 | def create_custom_forward(module, return_dict=None): 208 | def custom_forward(*inputs): 209 | if return_dict is not None: 210 | return module(*inputs, return_dict=return_dict) 211 | else: 212 | return module(*inputs) 213 | 214 | return custom_forward 215 | 216 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 217 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 218 | create_custom_forward(block), 219 | img, 220 | vec, 221 | pe, 222 | ) 223 | else: 224 | img = block(img, vec=vec, pe=pe) 225 | img = img[:, txt.shape[1] :, ...] 226 | 227 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 228 | return img 229 | -------------------------------------------------------------------------------- /src/flux/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from dataclasses import dataclass 5 | from glob import iglob 6 | 7 | import torch 8 | from einops import rearrange 9 | from fire import Fire 10 | from PIL import ExifTags, Image 11 | 12 | from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack 13 | from flux.util import (configs, embed_watermark, load_ae, load_clip, 14 | load_flow_model, load_t5) 15 | from transformers import pipeline 16 | 17 | NSFW_THRESHOLD = 0.85 18 | 19 | @dataclass 20 | class SamplingOptions: 21 | prompt: str 22 | width: int 23 | height: int 24 | num_steps: int 25 | guidance: float 26 | seed: int | None 27 | 28 | 29 | def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: 30 | user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" 31 | usage = ( 32 | "Usage: Either write your prompt directly, leave this field empty " 33 | "to repeat the prompt or write a command starting with a slash:\n" 34 | "- '/w ' will set the width of the generated image\n" 35 | "- '/h ' will set the height of the generated image\n" 36 | "- '/s ' sets the next seed\n" 37 | "- '/g ' sets the guidance (flux-dev only)\n" 38 | "- '/n ' sets the number of steps\n" 39 | "- '/q' to quit" 40 | ) 41 | 42 | while (prompt := input(user_question)).startswith("/"): 43 | if prompt.startswith("/w"): 44 | if prompt.count(" ") != 1: 45 | print(f"Got invalid command '{prompt}'\n{usage}") 46 | continue 47 | _, width = prompt.split() 48 | options.width = 16 * (int(width) // 16) 49 | print( 50 | f"Setting resolution to {options.width} x {options.height} " 51 | f"({options.height *options.width/1e6:.2f}MP)" 52 | ) 53 | elif prompt.startswith("/h"): 54 | if prompt.count(" ") != 1: 55 | print(f"Got invalid command '{prompt}'\n{usage}") 56 | continue 57 | _, height = prompt.split() 58 | options.height = 16 * (int(height) // 16) 59 | print( 60 | f"Setting resolution to {options.width} x {options.height} " 61 | f"({options.height *options.width/1e6:.2f}MP)" 62 | ) 63 | elif prompt.startswith("/g"): 64 | if prompt.count(" ") != 1: 65 | print(f"Got invalid command '{prompt}'\n{usage}") 66 | continue 67 | _, guidance = prompt.split() 68 | options.guidance = float(guidance) 69 | print(f"Setting guidance to {options.guidance}") 70 | elif prompt.startswith("/s"): 71 | if prompt.count(" ") != 1: 72 | print(f"Got invalid command '{prompt}'\n{usage}") 73 | continue 74 | _, seed = prompt.split() 75 | options.seed = int(seed) 76 | print(f"Setting seed to {options.seed}") 77 | elif prompt.startswith("/n"): 78 | if prompt.count(" ") != 1: 79 | print(f"Got invalid command '{prompt}'\n{usage}") 80 | continue 81 | _, steps = prompt.split() 82 | options.num_steps = int(steps) 83 | print(f"Setting seed to {options.num_steps}") 84 | elif prompt.startswith("/q"): 85 | print("Quitting") 86 | return None 87 | else: 88 | if not prompt.startswith("/h"): 89 | print(f"Got invalid command '{prompt}'\n{usage}") 90 | print(usage) 91 | if prompt != "": 92 | options.prompt = prompt 93 | return options 94 | 95 | 96 | @torch.inference_mode() 97 | def main( 98 | name: str = "flux-schnell", 99 | width: int = 1360, 100 | height: int = 768, 101 | seed: int | None = None, 102 | prompt: str = ( 103 | "a photo of a forest with mist swirling around the tree trunks. The word " 104 | '"FLUX" is painted over it in big, red brush strokes with visible texture' 105 | ), 106 | device: str = "cuda" if torch.cuda.is_available() else "cpu", 107 | num_steps: int | None = None, 108 | loop: bool = False, 109 | guidance: float = 3.5, 110 | offload: bool = False, 111 | output_dir: str = "output", 112 | add_sampling_metadata: bool = True, 113 | ): 114 | """ 115 | Sample the flux model. Either interactively (set `--loop`) or run for a 116 | single image. 117 | 118 | Args: 119 | name: Name of the model to load 120 | height: height of the sample in pixels (should be a multiple of 16) 121 | width: width of the sample in pixels (should be a multiple of 16) 122 | seed: Set a seed for sampling 123 | output_name: where to save the output image, `{idx}` will be replaced 124 | by the index of the sample 125 | prompt: Prompt used for sampling 126 | device: Pytorch device 127 | num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) 128 | loop: start an interactive session and sample multiple times 129 | guidance: guidance value used for guidance distillation 130 | add_sampling_metadata: Add the prompt to the image Exif metadata 131 | """ 132 | nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection") 133 | 134 | if name not in configs: 135 | available = ", ".join(configs.keys()) 136 | raise ValueError(f"Got unknown model name: {name}, chose from {available}") 137 | 138 | torch_device = torch.device(device) 139 | if num_steps is None: 140 | num_steps = 4 if name == "flux-schnell" else 50 141 | 142 | # allow for packing and conversion to latent space 143 | height = 16 * (height // 16) 144 | width = 16 * (width // 16) 145 | 146 | output_name = os.path.join(output_dir, "img_{idx}.jpg") 147 | if not os.path.exists(output_dir): 148 | os.makedirs(output_dir) 149 | idx = 0 150 | else: 151 | fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)] 152 | if len(fns) > 0: 153 | idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 154 | else: 155 | idx = 0 156 | 157 | # init all components 158 | t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) 159 | clip = load_clip(torch_device) 160 | model = load_flow_model(name, device="cpu" if offload else torch_device) 161 | ae = load_ae(name, device="cpu" if offload else torch_device) 162 | 163 | rng = torch.Generator(device="cpu") 164 | opts = SamplingOptions( 165 | prompt=prompt, 166 | width=width, 167 | height=height, 168 | num_steps=num_steps, 169 | guidance=guidance, 170 | seed=seed, 171 | ) 172 | 173 | if loop: 174 | opts = parse_prompt(opts) 175 | 176 | while opts is not None: 177 | if opts.seed is None: 178 | opts.seed = rng.seed() 179 | print(f"Generating with seed {opts.seed}:\n{opts.prompt}") 180 | t0 = time.perf_counter() 181 | 182 | # prepare input 183 | x = get_noise( 184 | 1, 185 | opts.height, 186 | opts.width, 187 | device=torch_device, 188 | dtype=torch.bfloat16, 189 | seed=opts.seed, 190 | ) 191 | opts.seed = None 192 | if offload: 193 | ae = ae.cpu() 194 | torch.cuda.empty_cache() 195 | t5, clip = t5.to(torch_device), clip.to(torch_device) 196 | inp = prepare(t5, clip, x, prompt=opts.prompt) 197 | timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) 198 | 199 | # offload TEs to CPU, load model to gpu 200 | if offload: 201 | t5, clip = t5.cpu(), clip.cpu() 202 | torch.cuda.empty_cache() 203 | model = model.to(torch_device) 204 | 205 | # denoise initial noise 206 | x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) 207 | 208 | # offload model, load autoencoder to gpu 209 | if offload: 210 | model.cpu() 211 | torch.cuda.empty_cache() 212 | ae.decoder.to(x.device) 213 | 214 | # decode latents to pixel space 215 | x = unpack(x.float(), opts.height, opts.width) 216 | with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): 217 | x = ae.decode(x) 218 | t1 = time.perf_counter() 219 | 220 | fn = output_name.format(idx=idx) 221 | print(f"Done in {t1 - t0:.1f}s. Saving {fn}") 222 | # bring into PIL format and save 223 | x = x.clamp(-1, 1) 224 | x = embed_watermark(x.float()) 225 | x = rearrange(x[0], "c h w -> h w c") 226 | 227 | img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) 228 | nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] 229 | 230 | if nsfw_score < NSFW_THRESHOLD: 231 | exif_data = Image.Exif() 232 | exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" 233 | exif_data[ExifTags.Base.Make] = "Black Forest Labs" 234 | exif_data[ExifTags.Base.Model] = name 235 | if add_sampling_metadata: 236 | exif_data[ExifTags.Base.ImageDescription] = prompt 237 | img.save(fn, exif=exif_data, quality=95, subsampling=0) 238 | idx += 1 239 | else: 240 | print("Your generated image may contain NSFW content.") 241 | 242 | if loop: 243 | print("-" * 80) 244 | opts = parse_prompt(opts) 245 | else: 246 | opts = None 247 | 248 | 249 | def app(): 250 | Fire(main) 251 | 252 | 253 | if __name__ == "__main__": 254 | app() 255 | -------------------------------------------------------------------------------- /src/flux/modules/autoencoder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import rearrange 5 | from torch import Tensor, nn 6 | 7 | 8 | @dataclass 9 | class AutoEncoderParams: 10 | resolution: int 11 | in_channels: int 12 | ch: int 13 | out_ch: int 14 | ch_mult: list[int] 15 | num_res_blocks: int 16 | z_channels: int 17 | scale_factor: float 18 | shift_factor: float 19 | 20 | 21 | def swish(x: Tensor) -> Tensor: 22 | return x * torch.sigmoid(x) 23 | 24 | 25 | class AttnBlock(nn.Module): 26 | def __init__(self, in_channels: int): 27 | super().__init__() 28 | self.in_channels = in_channels 29 | 30 | self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 31 | 32 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 33 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 34 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 35 | self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 36 | 37 | def attention(self, h_: Tensor) -> Tensor: 38 | h_ = self.norm(h_) 39 | q = self.q(h_) 40 | k = self.k(h_) 41 | v = self.v(h_) 42 | 43 | b, c, h, w = q.shape 44 | q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 45 | k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 46 | v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 47 | h_ = nn.functional.scaled_dot_product_attention(q, k, v) 48 | 49 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 50 | 51 | def forward(self, x: Tensor) -> Tensor: 52 | return x + self.proj_out(self.attention(x)) 53 | 54 | 55 | class ResnetBlock(nn.Module): 56 | def __init__(self, in_channels: int, out_channels: int): 57 | super().__init__() 58 | self.in_channels = in_channels 59 | out_channels = in_channels if out_channels is None else out_channels 60 | self.out_channels = out_channels 61 | 62 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 63 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 64 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 65 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 66 | if self.in_channels != self.out_channels: 67 | self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 68 | 69 | def forward(self, x): 70 | h = x 71 | h = self.norm1(h) 72 | h = swish(h) 73 | h = self.conv1(h) 74 | 75 | h = self.norm2(h) 76 | h = swish(h) 77 | h = self.conv2(h) 78 | 79 | if self.in_channels != self.out_channels: 80 | x = self.nin_shortcut(x) 81 | 82 | return x + h 83 | 84 | 85 | class Downsample(nn.Module): 86 | def __init__(self, in_channels: int): 87 | super().__init__() 88 | # no asymmetric padding in torch conv, must do it ourselves 89 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) 90 | 91 | def forward(self, x: Tensor): 92 | pad = (0, 1, 0, 1) 93 | x = nn.functional.pad(x, pad, mode="constant", value=0) 94 | x = self.conv(x) 95 | return x 96 | 97 | 98 | class Upsample(nn.Module): 99 | def __init__(self, in_channels: int): 100 | super().__init__() 101 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 102 | 103 | def forward(self, x: Tensor): 104 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 105 | x = self.conv(x) 106 | return x 107 | 108 | 109 | class Encoder(nn.Module): 110 | def __init__( 111 | self, 112 | resolution: int, 113 | in_channels: int, 114 | ch: int, 115 | ch_mult: list[int], 116 | num_res_blocks: int, 117 | z_channels: int, 118 | ): 119 | super().__init__() 120 | self.ch = ch 121 | self.num_resolutions = len(ch_mult) 122 | self.num_res_blocks = num_res_blocks 123 | self.resolution = resolution 124 | self.in_channels = in_channels 125 | # downsampling 126 | self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 127 | 128 | curr_res = resolution 129 | in_ch_mult = (1,) + tuple(ch_mult) 130 | self.in_ch_mult = in_ch_mult 131 | self.down = nn.ModuleList() 132 | block_in = self.ch 133 | for i_level in range(self.num_resolutions): 134 | block = nn.ModuleList() 135 | attn = nn.ModuleList() 136 | block_in = ch * in_ch_mult[i_level] 137 | block_out = ch * ch_mult[i_level] 138 | for _ in range(self.num_res_blocks): 139 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 140 | block_in = block_out 141 | down = nn.Module() 142 | down.block = block 143 | down.attn = attn 144 | if i_level != self.num_resolutions - 1: 145 | down.downsample = Downsample(block_in) 146 | curr_res = curr_res // 2 147 | self.down.append(down) 148 | 149 | # middle 150 | self.mid = nn.Module() 151 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 152 | self.mid.attn_1 = AttnBlock(block_in) 153 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 154 | 155 | # end 156 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 157 | self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) 158 | 159 | def forward(self, x: Tensor) -> Tensor: 160 | # downsampling 161 | hs = [self.conv_in(x)] 162 | for i_level in range(self.num_resolutions): 163 | for i_block in range(self.num_res_blocks): 164 | h = self.down[i_level].block[i_block](hs[-1]) 165 | if len(self.down[i_level].attn) > 0: 166 | h = self.down[i_level].attn[i_block](h) 167 | hs.append(h) 168 | if i_level != self.num_resolutions - 1: 169 | hs.append(self.down[i_level].downsample(hs[-1])) 170 | 171 | # middle 172 | h = hs[-1] 173 | h = self.mid.block_1(h) 174 | h = self.mid.attn_1(h) 175 | h = self.mid.block_2(h) 176 | # end 177 | h = self.norm_out(h) 178 | h = swish(h) 179 | h = self.conv_out(h) 180 | return h 181 | 182 | 183 | class Decoder(nn.Module): 184 | def __init__( 185 | self, 186 | ch: int, 187 | out_ch: int, 188 | ch_mult: list[int], 189 | num_res_blocks: int, 190 | in_channels: int, 191 | resolution: int, 192 | z_channels: int, 193 | ): 194 | super().__init__() 195 | self.ch = ch 196 | self.num_resolutions = len(ch_mult) 197 | self.num_res_blocks = num_res_blocks 198 | self.resolution = resolution 199 | self.in_channels = in_channels 200 | self.ffactor = 2 ** (self.num_resolutions - 1) 201 | 202 | # compute in_ch_mult, block_in and curr_res at lowest res 203 | block_in = ch * ch_mult[self.num_resolutions - 1] 204 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 205 | self.z_shape = (1, z_channels, curr_res, curr_res) 206 | 207 | # z to block_in 208 | self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 209 | 210 | # middle 211 | self.mid = nn.Module() 212 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 213 | self.mid.attn_1 = AttnBlock(block_in) 214 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 215 | 216 | # upsampling 217 | self.up = nn.ModuleList() 218 | for i_level in reversed(range(self.num_resolutions)): 219 | block = nn.ModuleList() 220 | attn = nn.ModuleList() 221 | block_out = ch * ch_mult[i_level] 222 | for _ in range(self.num_res_blocks + 1): 223 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 224 | block_in = block_out 225 | up = nn.Module() 226 | up.block = block 227 | up.attn = attn 228 | if i_level != 0: 229 | up.upsample = Upsample(block_in) 230 | curr_res = curr_res * 2 231 | self.up.insert(0, up) # prepend to get consistent order 232 | 233 | # end 234 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 235 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 236 | 237 | def forward(self, z: Tensor) -> Tensor: 238 | # z to block_in 239 | h = self.conv_in(z) 240 | 241 | # middle 242 | h = self.mid.block_1(h) 243 | h = self.mid.attn_1(h) 244 | h = self.mid.block_2(h) 245 | 246 | # upsampling 247 | for i_level in reversed(range(self.num_resolutions)): 248 | for i_block in range(self.num_res_blocks + 1): 249 | h = self.up[i_level].block[i_block](h) 250 | if len(self.up[i_level].attn) > 0: 251 | h = self.up[i_level].attn[i_block](h) 252 | if i_level != 0: 253 | h = self.up[i_level].upsample(h) 254 | 255 | # end 256 | h = self.norm_out(h) 257 | h = swish(h) 258 | h = self.conv_out(h) 259 | return h 260 | 261 | 262 | class DiagonalGaussian(nn.Module): 263 | def __init__(self, sample: bool = True, chunk_dim: int = 1): 264 | super().__init__() 265 | self.sample = sample 266 | self.chunk_dim = chunk_dim 267 | 268 | def forward(self, z: Tensor) -> Tensor: 269 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 270 | if self.sample: 271 | std = torch.exp(0.5 * logvar) 272 | return mean + std * torch.randn_like(mean) 273 | else: 274 | return mean 275 | 276 | 277 | class AutoEncoder(nn.Module): 278 | def __init__(self, params: AutoEncoderParams): 279 | super().__init__() 280 | self.encoder = Encoder( 281 | resolution=params.resolution, 282 | in_channels=params.in_channels, 283 | ch=params.ch, 284 | ch_mult=params.ch_mult, 285 | num_res_blocks=params.num_res_blocks, 286 | z_channels=params.z_channels, 287 | ) 288 | self.decoder = Decoder( 289 | resolution=params.resolution, 290 | in_channels=params.in_channels, 291 | ch=params.ch, 292 | out_ch=params.out_ch, 293 | ch_mult=params.ch_mult, 294 | num_res_blocks=params.num_res_blocks, 295 | z_channels=params.z_channels, 296 | ) 297 | self.reg = DiagonalGaussian() 298 | 299 | self.scale_factor = params.scale_factor 300 | self.shift_factor = params.shift_factor 301 | 302 | def encode(self, x: Tensor) -> Tensor: 303 | z = self.reg(self.encoder(x)) 304 | z = self.scale_factor * (z - self.shift_factor) 305 | return z 306 | 307 | def decode(self, z: Tensor) -> Tensor: 308 | z = z / self.scale_factor + self.shift_factor 309 | return self.decoder(z) 310 | 311 | def forward(self, x: Tensor) -> Tensor: 312 | return self.decode(self.encode(x)) 313 | -------------------------------------------------------------------------------- /src/flux/annotator/dwpose/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import matplotlib 4 | import cv2 5 | 6 | 7 | eps = 0.01 8 | 9 | 10 | def smart_resize(x, s): 11 | Ht, Wt = s 12 | if x.ndim == 2: 13 | Ho, Wo = x.shape 14 | Co = 1 15 | else: 16 | Ho, Wo, Co = x.shape 17 | if Co == 3 or Co == 1: 18 | k = float(Ht + Wt) / float(Ho + Wo) 19 | return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) 20 | else: 21 | return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) 22 | 23 | 24 | def smart_resize_k(x, fx, fy): 25 | if x.ndim == 2: 26 | Ho, Wo = x.shape 27 | Co = 1 28 | else: 29 | Ho, Wo, Co = x.shape 30 | Ht, Wt = Ho * fy, Wo * fx 31 | if Co == 3 or Co == 1: 32 | k = float(Ht + Wt) / float(Ho + Wo) 33 | return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) 34 | else: 35 | return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) 36 | 37 | 38 | def padRightDownCorner(img, stride, padValue): 39 | h = img.shape[0] 40 | w = img.shape[1] 41 | 42 | pad = 4 * [None] 43 | pad[0] = 0 # up 44 | pad[1] = 0 # left 45 | pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down 46 | pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right 47 | 48 | img_padded = img 49 | pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) 50 | img_padded = np.concatenate((pad_up, img_padded), axis=0) 51 | pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) 52 | img_padded = np.concatenate((pad_left, img_padded), axis=1) 53 | pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) 54 | img_padded = np.concatenate((img_padded, pad_down), axis=0) 55 | pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) 56 | img_padded = np.concatenate((img_padded, pad_right), axis=1) 57 | 58 | return img_padded, pad 59 | 60 | 61 | def transfer(model, model_weights): 62 | transfered_model_weights = {} 63 | for weights_name in model.state_dict().keys(): 64 | transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] 65 | return transfered_model_weights 66 | 67 | 68 | def draw_bodypose(canvas, candidate, subset): 69 | H, W, C = canvas.shape 70 | candidate = np.array(candidate) 71 | subset = np.array(subset) 72 | 73 | stickwidth = 4 74 | 75 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ 76 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ 77 | [1, 16], [16, 18], [3, 17], [6, 18]] 78 | 79 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 80 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 81 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 82 | 83 | for i in range(17): 84 | for n in range(len(subset)): 85 | index = subset[n][np.array(limbSeq[i]) - 1] 86 | if -1 in index: 87 | continue 88 | Y = candidate[index.astype(int), 0] * float(W) 89 | X = candidate[index.astype(int), 1] * float(H) 90 | mX = np.mean(X) 91 | mY = np.mean(Y) 92 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 93 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 94 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 95 | cv2.fillConvexPoly(canvas, polygon, colors[i]) 96 | 97 | canvas = (canvas * 0.6).astype(np.uint8) 98 | 99 | for i in range(18): 100 | for n in range(len(subset)): 101 | index = int(subset[n][i]) 102 | if index == -1: 103 | continue 104 | x, y = candidate[index][0:2] 105 | x = int(x * W) 106 | y = int(y * H) 107 | cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) 108 | 109 | return canvas 110 | 111 | 112 | def draw_handpose(canvas, all_hand_peaks): 113 | H, W, C = canvas.shape 114 | 115 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ 116 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] 117 | 118 | for peaks in all_hand_peaks: 119 | peaks = np.array(peaks) 120 | 121 | for ie, e in enumerate(edges): 122 | x1, y1 = peaks[e[0]] 123 | x2, y2 = peaks[e[1]] 124 | x1 = int(x1 * W) 125 | y1 = int(y1 * H) 126 | x2 = int(x2 * W) 127 | y2 = int(y2 * H) 128 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps: 129 | cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) 130 | 131 | for i, keyponit in enumerate(peaks): 132 | x, y = keyponit 133 | x = int(x * W) 134 | y = int(y * H) 135 | if x > eps and y > eps: 136 | cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) 137 | return canvas 138 | 139 | 140 | def draw_facepose(canvas, all_lmks): 141 | H, W, C = canvas.shape 142 | for lmks in all_lmks: 143 | lmks = np.array(lmks) 144 | for lmk in lmks: 145 | x, y = lmk 146 | x = int(x * W) 147 | y = int(y * H) 148 | if x > eps and y > eps: 149 | cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) 150 | return canvas 151 | 152 | 153 | # detect hand according to body pose keypoints 154 | # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp 155 | def handDetect(candidate, subset, oriImg): 156 | # right hand: wrist 4, elbow 3, shoulder 2 157 | # left hand: wrist 7, elbow 6, shoulder 5 158 | ratioWristElbow = 0.33 159 | detect_result = [] 160 | image_height, image_width = oriImg.shape[0:2] 161 | for person in subset.astype(int): 162 | # if any of three not detected 163 | has_left = np.sum(person[[5, 6, 7]] == -1) == 0 164 | has_right = np.sum(person[[2, 3, 4]] == -1) == 0 165 | if not (has_left or has_right): 166 | continue 167 | hands = [] 168 | #left hand 169 | if has_left: 170 | left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] 171 | x1, y1 = candidate[left_shoulder_index][:2] 172 | x2, y2 = candidate[left_elbow_index][:2] 173 | x3, y3 = candidate[left_wrist_index][:2] 174 | hands.append([x1, y1, x2, y2, x3, y3, True]) 175 | # right hand 176 | if has_right: 177 | right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] 178 | x1, y1 = candidate[right_shoulder_index][:2] 179 | x2, y2 = candidate[right_elbow_index][:2] 180 | x3, y3 = candidate[right_wrist_index][:2] 181 | hands.append([x1, y1, x2, y2, x3, y3, False]) 182 | 183 | for x1, y1, x2, y2, x3, y3, is_left in hands: 184 | # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox 185 | # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); 186 | # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); 187 | # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); 188 | # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); 189 | # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); 190 | x = x3 + ratioWristElbow * (x3 - x2) 191 | y = y3 + ratioWristElbow * (y3 - y2) 192 | distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) 193 | distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) 194 | width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) 195 | # x-y refers to the center --> offset to topLeft point 196 | # handRectangle.x -= handRectangle.width / 2.f; 197 | # handRectangle.y -= handRectangle.height / 2.f; 198 | x -= width / 2 199 | y -= width / 2 # width = height 200 | # overflow the image 201 | if x < 0: x = 0 202 | if y < 0: y = 0 203 | width1 = width 204 | width2 = width 205 | if x + width > image_width: width1 = image_width - x 206 | if y + width > image_height: width2 = image_height - y 207 | width = min(width1, width2) 208 | # the max hand box value is 20 pixels 209 | if width >= 20: 210 | detect_result.append([int(x), int(y), int(width), is_left]) 211 | 212 | ''' 213 | return value: [[x, y, w, True if left hand else False]]. 214 | width=height since the network require squared input. 215 | x, y is the coordinate of top left 216 | ''' 217 | return detect_result 218 | 219 | 220 | # Written by Lvmin 221 | def faceDetect(candidate, subset, oriImg): 222 | # left right eye ear 14 15 16 17 223 | detect_result = [] 224 | image_height, image_width = oriImg.shape[0:2] 225 | for person in subset.astype(int): 226 | has_head = person[0] > -1 227 | if not has_head: 228 | continue 229 | 230 | has_left_eye = person[14] > -1 231 | has_right_eye = person[15] > -1 232 | has_left_ear = person[16] > -1 233 | has_right_ear = person[17] > -1 234 | 235 | if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): 236 | continue 237 | 238 | head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] 239 | 240 | width = 0.0 241 | x0, y0 = candidate[head][:2] 242 | 243 | if has_left_eye: 244 | x1, y1 = candidate[left_eye][:2] 245 | d = max(abs(x0 - x1), abs(y0 - y1)) 246 | width = max(width, d * 3.0) 247 | 248 | if has_right_eye: 249 | x1, y1 = candidate[right_eye][:2] 250 | d = max(abs(x0 - x1), abs(y0 - y1)) 251 | width = max(width, d * 3.0) 252 | 253 | if has_left_ear: 254 | x1, y1 = candidate[left_ear][:2] 255 | d = max(abs(x0 - x1), abs(y0 - y1)) 256 | width = max(width, d * 1.5) 257 | 258 | if has_right_ear: 259 | x1, y1 = candidate[right_ear][:2] 260 | d = max(abs(x0 - x1), abs(y0 - y1)) 261 | width = max(width, d * 1.5) 262 | 263 | x, y = x0, y0 264 | 265 | x -= width 266 | y -= width 267 | 268 | if x < 0: 269 | x = 0 270 | 271 | if y < 0: 272 | y = 0 273 | 274 | width1 = width * 2 275 | width2 = width * 2 276 | 277 | if x + width > image_width: 278 | width1 = image_width - x 279 | 280 | if y + width > image_height: 281 | width2 = image_height - y 282 | 283 | width = min(width1, width2) 284 | 285 | if width >= 20: 286 | detect_result.append([int(x), int(y), int(width)]) 287 | 288 | return detect_result 289 | 290 | 291 | # get max index of 2d array 292 | def npmax(array): 293 | arrayindex = array.argmax(1) 294 | arrayvalue = array.max(1) 295 | i = arrayvalue.argmax() 296 | j = arrayindex[i] 297 | return i, j 298 | -------------------------------------------------------------------------------- /src/flux/annotator/dwpose/onnxpose.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import cv2 4 | import numpy as np 5 | import onnxruntime as ort 6 | 7 | def preprocess( 8 | img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) 9 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 10 | """Do preprocessing for RTMPose model inference. 11 | 12 | Args: 13 | img (np.ndarray): Input image in shape. 14 | input_size (tuple): Input image size in shape (w, h). 15 | 16 | Returns: 17 | tuple: 18 | - resized_img (np.ndarray): Preprocessed image. 19 | - center (np.ndarray): Center of image. 20 | - scale (np.ndarray): Scale of image. 21 | """ 22 | # get shape of image 23 | img_shape = img.shape[:2] 24 | out_img, out_center, out_scale = [], [], [] 25 | if len(out_bbox) == 0: 26 | out_bbox = [[0, 0, img_shape[1], img_shape[0]]] 27 | for i in range(len(out_bbox)): 28 | x0 = out_bbox[i][0] 29 | y0 = out_bbox[i][1] 30 | x1 = out_bbox[i][2] 31 | y1 = out_bbox[i][3] 32 | bbox = np.array([x0, y0, x1, y1]) 33 | 34 | # get center and scale 35 | center, scale = bbox_xyxy2cs(bbox, padding=1.25) 36 | 37 | # do affine transformation 38 | resized_img, scale = top_down_affine(input_size, scale, center, img) 39 | 40 | # normalize image 41 | mean = np.array([123.675, 116.28, 103.53]) 42 | std = np.array([58.395, 57.12, 57.375]) 43 | resized_img = (resized_img - mean) / std 44 | 45 | out_img.append(resized_img) 46 | out_center.append(center) 47 | out_scale.append(scale) 48 | 49 | return out_img, out_center, out_scale 50 | 51 | 52 | def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: 53 | """Inference RTMPose model. 54 | 55 | Args: 56 | sess (ort.InferenceSession): ONNXRuntime session. 57 | img (np.ndarray): Input image in shape. 58 | 59 | Returns: 60 | outputs (np.ndarray): Output of RTMPose model. 61 | """ 62 | all_out = [] 63 | # build input 64 | for i in range(len(img)): 65 | input = [img[i].transpose(2, 0, 1)] 66 | 67 | # build output 68 | sess_input = {sess.get_inputs()[0].name: input} 69 | sess_output = [] 70 | for out in sess.get_outputs(): 71 | sess_output.append(out.name) 72 | 73 | # run model 74 | outputs = sess.run(sess_output, sess_input) 75 | all_out.append(outputs) 76 | 77 | return all_out 78 | 79 | 80 | def postprocess(outputs: List[np.ndarray], 81 | model_input_size: Tuple[int, int], 82 | center: Tuple[int, int], 83 | scale: Tuple[int, int], 84 | simcc_split_ratio: float = 2.0 85 | ) -> Tuple[np.ndarray, np.ndarray]: 86 | """Postprocess for RTMPose model output. 87 | 88 | Args: 89 | outputs (np.ndarray): Output of RTMPose model. 90 | model_input_size (tuple): RTMPose model Input image size. 91 | center (tuple): Center of bbox in shape (x, y). 92 | scale (tuple): Scale of bbox in shape (w, h). 93 | simcc_split_ratio (float): Split ratio of simcc. 94 | 95 | Returns: 96 | tuple: 97 | - keypoints (np.ndarray): Rescaled keypoints. 98 | - scores (np.ndarray): Model predict scores. 99 | """ 100 | all_key = [] 101 | all_score = [] 102 | for i in range(len(outputs)): 103 | # use simcc to decode 104 | simcc_x, simcc_y = outputs[i] 105 | keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) 106 | 107 | # rescale keypoints 108 | keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 109 | all_key.append(keypoints[0]) 110 | all_score.append(scores[0]) 111 | 112 | return np.array(all_key), np.array(all_score) 113 | 114 | 115 | def bbox_xyxy2cs(bbox: np.ndarray, 116 | padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: 117 | """Transform the bbox format from (x,y,w,h) into (center, scale) 118 | 119 | Args: 120 | bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted 121 | as (left, top, right, bottom) 122 | padding (float): BBox padding factor that will be multilied to scale. 123 | Default: 1.0 124 | 125 | Returns: 126 | tuple: A tuple containing center and scale. 127 | - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or 128 | (n, 2) 129 | - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or 130 | (n, 2) 131 | """ 132 | # convert single bbox from (4, ) to (1, 4) 133 | dim = bbox.ndim 134 | if dim == 1: 135 | bbox = bbox[None, :] 136 | 137 | # get bbox center and scale 138 | x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) 139 | center = np.hstack([x1 + x2, y1 + y2]) * 0.5 140 | scale = np.hstack([x2 - x1, y2 - y1]) * padding 141 | 142 | if dim == 1: 143 | center = center[0] 144 | scale = scale[0] 145 | 146 | return center, scale 147 | 148 | 149 | def _fix_aspect_ratio(bbox_scale: np.ndarray, 150 | aspect_ratio: float) -> np.ndarray: 151 | """Extend the scale to match the given aspect ratio. 152 | 153 | Args: 154 | scale (np.ndarray): The image scale (w, h) in shape (2, ) 155 | aspect_ratio (float): The ratio of ``w/h`` 156 | 157 | Returns: 158 | np.ndarray: The reshaped image scale in (2, ) 159 | """ 160 | w, h = np.hsplit(bbox_scale, [1]) 161 | bbox_scale = np.where(w > h * aspect_ratio, 162 | np.hstack([w, w / aspect_ratio]), 163 | np.hstack([h * aspect_ratio, h])) 164 | return bbox_scale 165 | 166 | 167 | def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: 168 | """Rotate a point by an angle. 169 | 170 | Args: 171 | pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) 172 | angle_rad (float): rotation angle in radian 173 | 174 | Returns: 175 | np.ndarray: Rotated point in shape (2, ) 176 | """ 177 | sn, cs = np.sin(angle_rad), np.cos(angle_rad) 178 | rot_mat = np.array([[cs, -sn], [sn, cs]]) 179 | return rot_mat @ pt 180 | 181 | 182 | def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: 183 | """To calculate the affine matrix, three pairs of points are required. This 184 | function is used to get the 3rd point, given 2D points a & b. 185 | 186 | The 3rd point is defined by rotating vector `a - b` by 90 degrees 187 | anticlockwise, using b as the rotation center. 188 | 189 | Args: 190 | a (np.ndarray): The 1st point (x,y) in shape (2, ) 191 | b (np.ndarray): The 2nd point (x,y) in shape (2, ) 192 | 193 | Returns: 194 | np.ndarray: The 3rd point. 195 | """ 196 | direction = a - b 197 | c = b + np.r_[-direction[1], direction[0]] 198 | return c 199 | 200 | 201 | def get_warp_matrix(center: np.ndarray, 202 | scale: np.ndarray, 203 | rot: float, 204 | output_size: Tuple[int, int], 205 | shift: Tuple[float, float] = (0., 0.), 206 | inv: bool = False) -> np.ndarray: 207 | """Calculate the affine transformation matrix that can warp the bbox area 208 | in the input image to the output size. 209 | 210 | Args: 211 | center (np.ndarray[2, ]): Center of the bounding box (x, y). 212 | scale (np.ndarray[2, ]): Scale of the bounding box 213 | wrt [width, height]. 214 | rot (float): Rotation angle (degree). 215 | output_size (np.ndarray[2, ] | list(2,)): Size of the 216 | destination heatmaps. 217 | shift (0-100%): Shift translation ratio wrt the width/height. 218 | Default (0., 0.). 219 | inv (bool): Option to inverse the affine transform direction. 220 | (inv=False: src->dst or inv=True: dst->src) 221 | 222 | Returns: 223 | np.ndarray: A 2x3 transformation matrix 224 | """ 225 | shift = np.array(shift) 226 | src_w = scale[0] 227 | dst_w = output_size[0] 228 | dst_h = output_size[1] 229 | 230 | # compute transformation matrix 231 | rot_rad = np.deg2rad(rot) 232 | src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) 233 | dst_dir = np.array([0., dst_w * -0.5]) 234 | 235 | # get four corners of the src rectangle in the original image 236 | src = np.zeros((3, 2), dtype=np.float32) 237 | src[0, :] = center + scale * shift 238 | src[1, :] = center + src_dir + scale * shift 239 | src[2, :] = _get_3rd_point(src[0, :], src[1, :]) 240 | 241 | # get four corners of the dst rectangle in the input image 242 | dst = np.zeros((3, 2), dtype=np.float32) 243 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 244 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir 245 | dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) 246 | 247 | if inv: 248 | warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 249 | else: 250 | warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 251 | 252 | return warp_mat 253 | 254 | 255 | def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, 256 | img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 257 | """Get the bbox image as the model input by affine transform. 258 | 259 | Args: 260 | input_size (dict): The input size of the model. 261 | bbox_scale (dict): The bbox scale of the img. 262 | bbox_center (dict): The bbox center of the img. 263 | img (np.ndarray): The original image. 264 | 265 | Returns: 266 | tuple: A tuple containing center and scale. 267 | - np.ndarray[float32]: img after affine transform. 268 | - np.ndarray[float32]: bbox scale after affine transform. 269 | """ 270 | w, h = input_size 271 | warp_size = (int(w), int(h)) 272 | 273 | # reshape bbox to fixed aspect ratio 274 | bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) 275 | 276 | # get the affine matrix 277 | center = bbox_center 278 | scale = bbox_scale 279 | rot = 0 280 | warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) 281 | 282 | # do affine transform 283 | img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) 284 | 285 | return img, bbox_scale 286 | 287 | 288 | def get_simcc_maximum(simcc_x: np.ndarray, 289 | simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 290 | """Get maximum response location and value from simcc representations. 291 | 292 | Note: 293 | instance number: N 294 | num_keypoints: K 295 | heatmap height: H 296 | heatmap width: W 297 | 298 | Args: 299 | simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) 300 | simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) 301 | 302 | Returns: 303 | tuple: 304 | - locs (np.ndarray): locations of maximum heatmap responses in shape 305 | (K, 2) or (N, K, 2) 306 | - vals (np.ndarray): values of maximum heatmap responses in shape 307 | (K,) or (N, K) 308 | """ 309 | N, K, Wx = simcc_x.shape 310 | simcc_x = simcc_x.reshape(N * K, -1) 311 | simcc_y = simcc_y.reshape(N * K, -1) 312 | 313 | # get maximum value locations 314 | x_locs = np.argmax(simcc_x, axis=1) 315 | y_locs = np.argmax(simcc_y, axis=1) 316 | locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) 317 | max_val_x = np.amax(simcc_x, axis=1) 318 | max_val_y = np.amax(simcc_y, axis=1) 319 | 320 | # get maximum value across x and y axis 321 | mask = max_val_x > max_val_y 322 | max_val_x[mask] = max_val_y[mask] 323 | vals = max_val_x 324 | locs[vals <= 0.] = -1 325 | 326 | # reshape 327 | locs = locs.reshape(N, K, 2) 328 | vals = vals.reshape(N, K) 329 | 330 | return locs, vals 331 | 332 | 333 | def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, 334 | simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: 335 | """Modulate simcc distribution with Gaussian. 336 | 337 | Args: 338 | simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. 339 | simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. 340 | simcc_split_ratio (int): The split ratio of simcc. 341 | 342 | Returns: 343 | tuple: A tuple containing center and scale. 344 | - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) 345 | - np.ndarray[float32]: scores in shape (K,) or (n, K) 346 | """ 347 | keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) 348 | keypoints /= simcc_split_ratio 349 | 350 | return keypoints, scores 351 | 352 | 353 | def inference_pose(session, out_bbox, oriImg): 354 | h, w = session.get_inputs()[0].shape[2:] 355 | model_input_size = (w, h) 356 | resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) 357 | outputs = inference(session, resized_img) 358 | keypoints, scores = postprocess(outputs, model_input_size, center, scale) 359 | 360 | return keypoints, scores -------------------------------------------------------------------------------- /src/flux/train_pipeline.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ExifTags 2 | import numpy as np 3 | import torch 4 | from torch import Tensor 5 | 6 | from einops import rearrange 7 | import uuid 8 | import os 9 | 10 | from src.flux.modules.layers import ( 11 | SingleStreamBlockProcessor, 12 | DoubleStreamBlockProcessor, 13 | SingleStreamBlockLoraProcessor, 14 | DoubleStreamBlockLoraProcessor, 15 | IPDoubleStreamBlockProcessor, 16 | ImageProjModel, 17 | ) 18 | from src.flux.sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack 19 | from src.flux.util import ( 20 | load_ae, 21 | load_clip, 22 | load_flow_model, 23 | load_t5, 24 | load_controlnet, 25 | load_flow_model_quintized, 26 | Annotator, 27 | get_lora_rank, 28 | load_checkpoint 29 | ) 30 | 31 | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor 32 | 33 | class XFluxPipeline: 34 | def __init__(self, model_type, device, offload: bool = False): 35 | self.device = torch.device(device) 36 | self.offload = offload 37 | self.model_type = model_type 38 | 39 | self.clip = load_clip(self.device) 40 | self.t5 = load_t5(self.device, max_length=512) 41 | self.ae = load_ae(model_type, device="cpu" if offload else self.device) 42 | if "fp8" in model_type: 43 | self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) 44 | else: 45 | self.model = load_flow_model(model_type, device="cpu" if offload else self.device) 46 | 47 | self.image_encoder_path = "openai/clip-vit-large-patch14" 48 | self.hf_lora_collection = "XLabs-AI/flux-lora-collection" 49 | self.lora_types_to_names = { 50 | "realism": "lora.safetensors", 51 | } 52 | self.controlnet_loaded = False 53 | self.ip_loaded = False 54 | 55 | def set_ip(self, local_path: str = None, repo_id = None, name: str = None): 56 | self.model.to(self.device) 57 | 58 | # unpack checkpoint 59 | checkpoint = load_checkpoint(local_path, repo_id, name) 60 | prefix = "double_blocks." 61 | blocks = {} 62 | proj = {} 63 | 64 | for key, value in checkpoint.items(): 65 | if key.startswith(prefix): 66 | blocks[key[len(prefix):].replace('.processor.', '.')] = value 67 | if key.startswith("ip_adapter_proj_model"): 68 | proj[key[len("ip_adapter_proj_model."):]] = value 69 | 70 | # load image encoder 71 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( 72 | self.device, dtype=torch.float16 73 | ) 74 | self.clip_image_processor = CLIPImageProcessor() 75 | 76 | # setup image embedding projection model 77 | self.improj = ImageProjModel(4096, 768, 4) 78 | self.improj.load_state_dict(proj) 79 | self.improj = self.improj.to(self.device, dtype=torch.bfloat16) 80 | 81 | ip_attn_procs = {} 82 | 83 | for name, _ in self.model.attn_processors.items(): 84 | ip_state_dict = {} 85 | for k in checkpoint.keys(): 86 | if name in k: 87 | ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k] 88 | if ip_state_dict: 89 | ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) 90 | ip_attn_procs[name].load_state_dict(ip_state_dict) 91 | ip_attn_procs[name].to(self.device, dtype=torch.bfloat16) 92 | else: 93 | ip_attn_procs[name] = self.model.attn_processors[name] 94 | 95 | self.model.set_attn_processor(ip_attn_procs) 96 | self.ip_loaded = True 97 | 98 | def set_lora(self, local_path: str = None, repo_id: str = None, 99 | name: str = None, lora_weight: int = 0.7): 100 | checkpoint = load_checkpoint(local_path, repo_id, name) 101 | self.update_model_with_lora(checkpoint, lora_weight) 102 | 103 | def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7): 104 | checkpoint = load_checkpoint( 105 | None, self.hf_lora_collection, self.lora_types_to_names[lora_type] 106 | ) 107 | self.update_model_with_lora(checkpoint, lora_weight) 108 | 109 | def update_model_with_lora(self, checkpoint, lora_weight): 110 | rank = get_lora_rank(checkpoint) 111 | lora_attn_procs = {} 112 | 113 | for name, _ in self.model.attn_processors.items(): 114 | lora_state_dict = {} 115 | for k in checkpoint.keys(): 116 | if name in k: 117 | lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight 118 | 119 | if len(lora_state_dict): 120 | if name.startswith("single_blocks"): 121 | lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank) 122 | else: 123 | lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) 124 | lora_attn_procs[name].load_state_dict(lora_state_dict) 125 | lora_attn_procs[name].to(self.device) 126 | else: 127 | if name.startswith("single_blocks"): 128 | lora_attn_procs[name] = SingleStreamBlockProcessor() 129 | else: 130 | lora_attn_procs[name] = DoubleStreamBlockProcessor() 131 | 132 | self.model.set_attn_processor(lora_attn_procs) 133 | 134 | def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): 135 | self.model.to(self.device) 136 | self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16) 137 | 138 | checkpoint = load_checkpoint(local_path, repo_id, name) 139 | self.controlnet.load_state_dict(checkpoint, strict=False) 140 | self.annotator = Annotator(control_type, self.device) 141 | self.controlnet_loaded = True 142 | self.control_type = control_type 143 | 144 | def get_image_proj( 145 | self, 146 | image_prompt: Tensor, 147 | ): 148 | # encode image-prompt embeds 149 | image_prompt = self.clip_image_processor( 150 | images=image_prompt, 151 | return_tensors="pt" 152 | ).pixel_values 153 | image_prompt = image_prompt.to(self.image_encoder.device) 154 | image_prompt_embeds = self.image_encoder( 155 | image_prompt 156 | ).image_embeds.to( 157 | device=self.device, dtype=torch.bfloat16, 158 | ) 159 | # encode image 160 | image_proj = self.improj(image_prompt_embeds) 161 | return image_proj 162 | 163 | def __call__(self, 164 | prompt: str, 165 | image_prompt: Image = None, 166 | controlnet_image: Image = None, 167 | width: int = 512, 168 | height: int = 512, 169 | guidance: float = 4, 170 | num_steps: int = 50, 171 | seed: int = 123456789, 172 | true_gs: float = 3, 173 | control_weight: float = 0.9, 174 | ip_scale: float = 1.0, 175 | neg_ip_scale: float = 1.0, 176 | neg_prompt: str = '', 177 | neg_image_prompt: Image = None, 178 | timestep_to_start_cfg: int = 0, 179 | ): 180 | width = 16 * (width // 16) 181 | height = 16 * (height // 16) 182 | image_proj = None 183 | neg_image_proj = None 184 | if not (image_prompt is None and neg_image_prompt is None) : 185 | assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input' 186 | 187 | if image_prompt is None: 188 | image_prompt = np.zeros((width, height, 3), dtype=np.uint8) 189 | if neg_image_prompt is None: 190 | neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8) 191 | 192 | image_proj = self.get_image_proj(image_prompt) 193 | neg_image_proj = self.get_image_proj(neg_image_prompt) 194 | 195 | if self.controlnet_loaded: 196 | controlnet_image = self.annotator(controlnet_image, width, height) 197 | controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) 198 | controlnet_image = controlnet_image.permute( 199 | 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) 200 | 201 | return self.forward( 202 | prompt, 203 | width, 204 | height, 205 | guidance, 206 | num_steps, 207 | seed, 208 | controlnet_image, 209 | timestep_to_start_cfg=timestep_to_start_cfg, 210 | true_gs=true_gs, 211 | control_weight=control_weight, 212 | neg_prompt=neg_prompt, 213 | image_proj=image_proj, 214 | neg_image_proj=neg_image_proj, 215 | ip_scale=ip_scale, 216 | neg_ip_scale=neg_ip_scale, 217 | ) 218 | 219 | @torch.inference_mode() 220 | def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance, 221 | num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, 222 | neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, 223 | lora_weight, local_path, lora_local_path, ip_local_path): 224 | if controlnet_image is not None: 225 | controlnet_image = Image.fromarray(controlnet_image) 226 | if ((self.controlnet_loaded and control_type != self.control_type) 227 | or not self.controlnet_loaded): 228 | if local_path is not None: 229 | self.set_controlnet(control_type, local_path=local_path) 230 | else: 231 | self.set_controlnet(control_type, local_path=None, 232 | repo_id=f"xlabs-ai/flux-controlnet-{control_type}-v3", 233 | name=f"flux-{control_type}-controlnet-v3.safetensors") 234 | if lora_local_path is not None: 235 | self.set_lora(local_path=lora_local_path, lora_weight=lora_weight) 236 | if image_prompt is not None: 237 | image_prompt = Image.fromarray(image_prompt) 238 | if neg_image_prompt is not None: 239 | neg_image_prompt = Image.fromarray(neg_image_prompt) 240 | if not self.ip_loaded: 241 | if ip_local_path is not None: 242 | self.set_ip(local_path=ip_local_path) 243 | else: 244 | self.set_ip(repo_id="xlabs-ai/flux-ip-adapter", 245 | name="flux-ip-adapter.safetensors") 246 | seed = int(seed) 247 | if seed == -1: 248 | seed = torch.Generator(device="cpu").seed() 249 | 250 | img = self(prompt, image_prompt, controlnet_image, width, height, guidance, 251 | num_steps, seed, true_gs, control_weight, ip_scale, neg_ip_scale, neg_prompt, 252 | neg_image_prompt, timestep_to_start_cfg) 253 | 254 | filename = f"output/gradio/{uuid.uuid4()}.jpg" 255 | os.makedirs(os.path.dirname(filename), exist_ok=True) 256 | exif_data = Image.Exif() 257 | exif_data[ExifTags.Base.Make] = "XLabs AI" 258 | exif_data[ExifTags.Base.Model] = self.model_type 259 | img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) 260 | return img, filename 261 | 262 | def forward( 263 | self, 264 | prompt, 265 | width, 266 | height, 267 | guidance, 268 | num_steps, 269 | seed, 270 | controlnet_image = None, 271 | timestep_to_start_cfg = 0, 272 | true_gs = 3.5, 273 | control_weight = 0.9, 274 | neg_prompt="", 275 | image_proj=None, 276 | neg_image_proj=None, 277 | ip_scale=1.0, 278 | neg_ip_scale=1.0, 279 | ): 280 | x = get_noise( 281 | 1, height, width, device=self.device, 282 | dtype=torch.bfloat16, seed=seed 283 | ) 284 | timesteps = get_schedule( 285 | num_steps, 286 | (width // 8) * (height // 8) // (16 * 16), 287 | shift=True, 288 | ) 289 | torch.manual_seed(seed) 290 | with torch.no_grad(): 291 | if self.offload: 292 | self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) 293 | inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt) 294 | neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) 295 | 296 | if self.offload: 297 | self.offload_model_to_cpu(self.t5, self.clip) 298 | self.model = self.model.to(self.device) 299 | if self.controlnet_loaded: 300 | x = denoise_controlnet( 301 | self.model, 302 | **inp_cond, 303 | controlnet=self.controlnet, 304 | timesteps=timesteps, 305 | guidance=guidance, 306 | controlnet_cond=controlnet_image, 307 | timestep_to_start_cfg=timestep_to_start_cfg, 308 | neg_txt=neg_inp_cond['txt'], 309 | neg_txt_ids=neg_inp_cond['txt_ids'], 310 | neg_vec=neg_inp_cond['vec'], 311 | true_gs=true_gs, 312 | controlnet_gs=control_weight, 313 | image_proj=image_proj, 314 | neg_image_proj=neg_image_proj, 315 | ip_scale=ip_scale, 316 | neg_ip_scale=neg_ip_scale, 317 | ) 318 | else: 319 | x = denoise( 320 | self.model, 321 | **inp_cond, 322 | timesteps=timesteps, 323 | guidance=guidance, 324 | timestep_to_start_cfg=timestep_to_start_cfg, 325 | neg_txt=neg_inp_cond['txt'], 326 | neg_txt_ids=neg_inp_cond['txt_ids'], 327 | neg_vec=neg_inp_cond['vec'], 328 | true_gs=true_gs, 329 | image_proj=image_proj, 330 | neg_image_proj=neg_image_proj, 331 | ip_scale=ip_scale, 332 | neg_ip_scale=neg_ip_scale, 333 | ) 334 | 335 | if self.offload: 336 | self.offload_model_to_cpu(self.model) 337 | self.ae.decoder.to(x.device) 338 | x = unpack(x.float(), height, width) 339 | x = self.ae.decode(x) 340 | self.offload_model_to_cpu(self.ae.decoder) 341 | 342 | x1 = x.clamp(-1, 1) 343 | x1 = rearrange(x1[-1], "c h w -> h w c") 344 | output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) 345 | return output_img 346 | 347 | def offload_model_to_cpu(self, *models): 348 | if not self.offload: return 349 | for model in models: 350 | model.cpu() 351 | torch.cuda.empty_cache() 352 | 353 | 354 | class XFluxSampler(XFluxPipeline): 355 | def __init__(self, clip, t5, ae, model, device): 356 | self.clip = clip 357 | self.t5 = t5 358 | self.ae = ae 359 | self.model = model 360 | self.model.eval() 361 | self.device = device 362 | self.controlnet_loaded = False 363 | self.ip_loaded = False 364 | self.offload = False 365 | -------------------------------------------------------------------------------- /src/flux/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | import json 6 | import cv2 7 | import numpy as np 8 | from PIL import Image 9 | from huggingface_hub import hf_hub_download 10 | from safetensors import safe_open 11 | from safetensors.torch import load_file as load_sft 12 | 13 | from optimum.quanto import requantize 14 | 15 | from .model import Flux, FluxParams 16 | from .controlnet import ControlNetFlux 17 | from .modules.autoencoder import AutoEncoder, AutoEncoderParams 18 | from .modules.conditioner import HFEmbedder 19 | from .annotator.dwpose import DWposeDetector 20 | from .annotator.mlsd import MLSDdetector 21 | from .annotator.canny import CannyDetector 22 | from .annotator.midas import MidasDetector 23 | from .annotator.hed import HEDdetector 24 | from .annotator.tile import TileDetector 25 | from .annotator.zoe import ZoeDetector 26 | 27 | 28 | def load_safetensors(path): 29 | tensors = {} 30 | with safe_open(path, framework="pt", device="cpu") as f: 31 | for key in f.keys(): 32 | tensors[key] = f.get_tensor(key) 33 | return tensors 34 | 35 | def get_lora_rank(checkpoint): 36 | for k in checkpoint.keys(): 37 | if k.endswith(".down.weight"): 38 | return checkpoint[k].shape[0] 39 | 40 | def load_checkpoint(local_path, repo_id, name): 41 | if local_path is not None: 42 | if '.safetensors' in local_path: 43 | print(f"Loading .safetensors checkpoint from {local_path}") 44 | checkpoint = load_safetensors(local_path) 45 | else: 46 | print(f"Loading checkpoint from {local_path}") 47 | checkpoint = torch.load(local_path, map_location='cpu') 48 | elif repo_id is not None and name is not None: 49 | print(f"Loading checkpoint {name} from repo id {repo_id}") 50 | checkpoint = load_from_repo_id(repo_id, name) 51 | else: 52 | raise ValueError( 53 | "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" 54 | ) 55 | return checkpoint 56 | 57 | 58 | def c_crop(image): 59 | width, height = image.size 60 | new_size = min(width, height) 61 | left = (width - new_size) / 2 62 | top = (height - new_size) / 2 63 | right = (width + new_size) / 2 64 | bottom = (height + new_size) / 2 65 | return image.crop((left, top, right, bottom)) 66 | 67 | def pad64(x): 68 | return int(np.ceil(float(x) / 64.0) * 64 - x) 69 | 70 | def HWC3(x): 71 | assert x.dtype == np.uint8 72 | if x.ndim == 2: 73 | x = x[:, :, None] 74 | assert x.ndim == 3 75 | H, W, C = x.shape 76 | assert C == 1 or C == 3 or C == 4 77 | if C == 3: 78 | return x 79 | if C == 1: 80 | return np.concatenate([x, x, x], axis=2) 81 | if C == 4: 82 | color = x[:, :, 0:3].astype(np.float32) 83 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 84 | y = color * alpha + 255.0 * (1.0 - alpha) 85 | y = y.clip(0, 255).astype(np.uint8) 86 | return y 87 | 88 | def safer_memory(x): 89 | # Fix many MAC/AMD problems 90 | return np.ascontiguousarray(x.copy()).copy() 91 | 92 | #https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/processor.py#L17 93 | #Added upscale_method, mode params 94 | def resize_image_with_pad(input_image, resolution, skip_hwc3=False, mode='edge'): 95 | if skip_hwc3: 96 | img = input_image 97 | else: 98 | img = HWC3(input_image) 99 | H_raw, W_raw, _ = img.shape 100 | if resolution == 0: 101 | return img, lambda x: x 102 | k = float(resolution) / float(min(H_raw, W_raw)) 103 | H_target = int(np.round(float(H_raw) * k)) 104 | W_target = int(np.round(float(W_raw) * k)) 105 | img = cv2.resize(img, (W_target, H_target), interpolation=cv2.INTER_AREA) 106 | H_pad, W_pad = pad64(H_target), pad64(W_target) 107 | img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode) 108 | 109 | def remove_pad(x): 110 | return safer_memory(x[:H_target, :W_target, ...]) 111 | 112 | return safer_memory(img_padded), remove_pad 113 | 114 | class Annotator: 115 | def __init__(self, name: str, device: str): 116 | if name == "canny": 117 | processor = CannyDetector() 118 | elif name == "openpose": 119 | processor = DWposeDetector(device) 120 | elif name == "depth": 121 | processor = MidasDetector() 122 | elif name == "hed": 123 | processor = HEDdetector() 124 | elif name == "hough": 125 | processor = MLSDdetector() 126 | elif name == "tile": 127 | processor = TileDetector() 128 | elif name == "zoe": 129 | processor = ZoeDetector() 130 | self.name = name 131 | self.processor = processor 132 | 133 | def __call__(self, image: Image, width: int, height: int): 134 | image = np.array(image) 135 | detect_resolution = max(width, height) 136 | image, remove_pad = resize_image_with_pad(image, detect_resolution) 137 | 138 | image = np.array(image) 139 | if self.name == "canny": 140 | result = self.processor(image, low_threshold=100, high_threshold=200) 141 | elif self.name == "hough": 142 | result = self.processor(image, thr_v=0.05, thr_d=5) 143 | elif self.name == "depth": 144 | result = self.processor(image) 145 | result, _ = result 146 | else: 147 | result = self.processor(image) 148 | 149 | result = HWC3(remove_pad(result)) 150 | result = cv2.resize(result, (width, height)) 151 | return result 152 | 153 | 154 | @dataclass 155 | class ModelSpec: 156 | params: FluxParams 157 | ae_params: AutoEncoderParams 158 | ckpt_path: str | None 159 | ae_path: str | None 160 | repo_id: str | None 161 | repo_flow: str | None 162 | repo_ae: str | None 163 | repo_id_ae: str | None 164 | 165 | 166 | configs = { 167 | "flux-dev": ModelSpec( 168 | repo_id="black-forest-labs/FLUX.1-dev", 169 | repo_id_ae="black-forest-labs/FLUX.1-dev", 170 | repo_flow="flux1-dev.safetensors", 171 | repo_ae="ae.safetensors", 172 | ckpt_path=os.getenv("FLUX_DEV"), 173 | params=FluxParams( 174 | in_channels=64, 175 | vec_in_dim=768, 176 | context_in_dim=4096, 177 | hidden_size=3072, 178 | mlp_ratio=4.0, 179 | num_heads=24, 180 | depth=19, 181 | depth_single_blocks=38, 182 | axes_dim=[16, 56, 56], 183 | theta=10_000, 184 | qkv_bias=True, 185 | guidance_embed=True, 186 | ), 187 | ae_path=os.getenv("AE"), 188 | ae_params=AutoEncoderParams( 189 | resolution=256, 190 | in_channels=3, 191 | ch=128, 192 | out_ch=3, 193 | ch_mult=[1, 2, 4, 4], 194 | num_res_blocks=2, 195 | z_channels=16, 196 | scale_factor=0.3611, 197 | shift_factor=0.1159, 198 | ), 199 | ), 200 | "flux-dev-fp8": ModelSpec( 201 | repo_id="XLabs-AI/flux-dev-fp8", 202 | repo_id_ae="black-forest-labs/FLUX.1-dev", 203 | repo_flow="flux-dev-fp8.safetensors", 204 | repo_ae="ae.safetensors", 205 | ckpt_path=os.getenv("FLUX_DEV_FP8"), 206 | params=FluxParams( 207 | in_channels=64, 208 | vec_in_dim=768, 209 | context_in_dim=4096, 210 | hidden_size=3072, 211 | mlp_ratio=4.0, 212 | num_heads=24, 213 | depth=19, 214 | depth_single_blocks=38, 215 | axes_dim=[16, 56, 56], 216 | theta=10_000, 217 | qkv_bias=True, 218 | guidance_embed=True, 219 | ), 220 | ae_path=os.getenv("AE"), 221 | ae_params=AutoEncoderParams( 222 | resolution=256, 223 | in_channels=3, 224 | ch=128, 225 | out_ch=3, 226 | ch_mult=[1, 2, 4, 4], 227 | num_res_blocks=2, 228 | z_channels=16, 229 | scale_factor=0.3611, 230 | shift_factor=0.1159, 231 | ), 232 | ), 233 | "flux-schnell": ModelSpec( 234 | repo_id="black-forest-labs/FLUX.1-schnell", 235 | repo_id_ae="black-forest-labs/FLUX.1-dev", 236 | repo_flow="flux1-schnell.safetensors", 237 | repo_ae="ae.safetensors", 238 | ckpt_path=os.getenv("FLUX_SCHNELL"), 239 | params=FluxParams( 240 | in_channels=64, 241 | vec_in_dim=768, 242 | context_in_dim=4096, 243 | hidden_size=3072, 244 | mlp_ratio=4.0, 245 | num_heads=24, 246 | depth=19, 247 | depth_single_blocks=38, 248 | axes_dim=[16, 56, 56], 249 | theta=10_000, 250 | qkv_bias=True, 251 | guidance_embed=False, 252 | ), 253 | ae_path=os.getenv("AE"), 254 | ae_params=AutoEncoderParams( 255 | resolution=256, 256 | in_channels=3, 257 | ch=128, 258 | out_ch=3, 259 | ch_mult=[1, 2, 4, 4], 260 | num_res_blocks=2, 261 | z_channels=16, 262 | scale_factor=0.3611, 263 | shift_factor=0.1159, 264 | ), 265 | ), 266 | } 267 | 268 | 269 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None: 270 | if len(missing) > 0 and len(unexpected) > 0: 271 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 272 | print("\n" + "-" * 79 + "\n") 273 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 274 | elif len(missing) > 0: 275 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 276 | elif len(unexpected) > 0: 277 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 278 | 279 | def load_from_repo_id(repo_id, checkpoint_name): 280 | ckpt_path = hf_hub_download(repo_id, checkpoint_name) 281 | sd = load_sft(ckpt_path, device='cpu') 282 | return sd 283 | 284 | def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): 285 | # Loading Flux 286 | print("Init model") 287 | ckpt_path = configs[name].ckpt_path 288 | if ( 289 | ckpt_path is None 290 | and configs[name].repo_id is not None 291 | and configs[name].repo_flow is not None 292 | and hf_download 293 | ): 294 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) 295 | 296 | with torch.device("meta" if ckpt_path is not None else device): 297 | model = Flux(configs[name].params).to(torch.bfloat16) 298 | 299 | if ckpt_path is not None: 300 | print("Loading checkpoint") 301 | # load_sft doesn't support torch.device 302 | sd = load_sft(ckpt_path, device=str(device)) 303 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 304 | print_load_warning(missing, unexpected) 305 | return model 306 | 307 | def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True): 308 | # Loading Flux 309 | print("Init model") 310 | ckpt_path = configs[name].ckpt_path 311 | if ( 312 | ckpt_path is None 313 | and configs[name].repo_id is not None 314 | and configs[name].repo_flow is not None 315 | and hf_download 316 | ): 317 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) 318 | 319 | with torch.device("meta" if ckpt_path is not None else device): 320 | model = Flux(configs[name].params) 321 | 322 | if ckpt_path is not None: 323 | print("Loading checkpoint") 324 | # load_sft doesn't support torch.device 325 | sd = load_sft(ckpt_path, device=str(device)) 326 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 327 | print_load_warning(missing, unexpected) 328 | return model 329 | 330 | def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True): 331 | # Loading Flux 332 | print("Init model") 333 | ckpt_path = configs[name].ckpt_path 334 | if ( 335 | ckpt_path is None 336 | and configs[name].repo_id is not None 337 | and configs[name].repo_flow is not None 338 | and hf_download 339 | ): 340 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) 341 | json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json') 342 | 343 | 344 | model = Flux(configs[name].params).to(torch.bfloat16) 345 | 346 | print("Loading checkpoint") 347 | # load_sft doesn't support torch.device 348 | sd = load_sft(ckpt_path, device='cpu') 349 | with open(json_path, "r") as f: 350 | quantization_map = json.load(f) 351 | print("Start a quantization process...") 352 | requantize(model, sd, quantization_map, device=device) 353 | print("Model is quantized!") 354 | return model 355 | 356 | def load_controlnet(name, device, transformer=None): 357 | with torch.device(device): 358 | controlnet = ControlNetFlux(configs[name].params) 359 | if transformer is not None: 360 | controlnet.load_state_dict(transformer.state_dict(), strict=False) 361 | return controlnet 362 | 363 | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: 364 | # max length 64, 128, 256 and 512 should work (if your sequence is short enough) 365 | return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device) 366 | 367 | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: 368 | return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) 369 | 370 | 371 | def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: 372 | ckpt_path = configs[name].ae_path 373 | if ( 374 | ckpt_path is None 375 | and configs[name].repo_id is not None 376 | and configs[name].repo_ae is not None 377 | and hf_download 378 | ): 379 | ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) 380 | 381 | # Loading the autoencoder 382 | print("Init AE") 383 | with torch.device("meta" if ckpt_path is not None else device): 384 | ae = AutoEncoder(configs[name].ae_params) 385 | 386 | if ckpt_path is not None: 387 | sd = load_sft(ckpt_path, device=str(device)) 388 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) 389 | print_load_warning(missing, unexpected) 390 | return ae 391 | 392 | 393 | class WatermarkEmbedder: 394 | def __init__(self, watermark): 395 | self.watermark = watermark 396 | self.num_bits = len(WATERMARK_BITS) 397 | self.encoder = WatermarkEncoder() 398 | self.encoder.set_watermark("bits", self.watermark) 399 | 400 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 401 | """ 402 | Adds a predefined watermark to the input image 403 | 404 | Args: 405 | image: ([N,] B, RGB, H, W) in range [-1, 1] 406 | 407 | Returns: 408 | same as input but watermarked 409 | """ 410 | image = 0.5 * image + 0.5 411 | squeeze = len(image.shape) == 4 412 | if squeeze: 413 | image = image[None, ...] 414 | n = image.shape[0] 415 | image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1] 416 | # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] 417 | # watermarking libary expects input as cv2 BGR format 418 | for k in range(image_np.shape[0]): 419 | image_np[k] = self.encoder.encode(image_np[k], "dwtDct") 420 | image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to( 421 | image.device 422 | ) 423 | image = torch.clamp(image / 255, min=0.0, max=1.0) 424 | if squeeze: 425 | image = image[0] 426 | image = 2 * image - 1 427 | return image 428 | 429 | 430 | # A fixed 48-bit message that was choosen at random 431 | WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110 432 | # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 433 | WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] 434 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import re 6 | import random 7 | import shutil 8 | from contextlib import nullcontext 9 | from pathlib import Path 10 | from safetensors.torch import save_file 11 | 12 | import accelerate 13 | import datasets 14 | import numpy as np 15 | import torch 16 | import torch.nn.functional as F 17 | import torch.utils.checkpoint 18 | import transformers 19 | from accelerate import Accelerator 20 | from accelerate.logging import get_logger 21 | from accelerate.state import AcceleratorState 22 | from accelerate.utils import ProjectConfiguration, set_seed 23 | from huggingface_hub import create_repo, upload_folder 24 | from packaging import version 25 | from tqdm.auto import tqdm 26 | from transformers import CLIPTextModel, CLIPTokenizer 27 | from transformers.utils import ContextManagers 28 | from omegaconf import OmegaConf 29 | from copy import deepcopy 30 | import diffusers 31 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline 32 | from diffusers.optimization import get_scheduler 33 | from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr 34 | from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid 35 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card 36 | from diffusers.utils.import_utils import is_xformers_available 37 | from diffusers.utils.torch_utils import is_compiled_module 38 | from einops import rearrange 39 | from src.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack 40 | from src.flux.util import (configs, load_ae, load_clip, 41 | load_flow_model2, load_t5) 42 | from src.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor 43 | from src.flux.xflux_pipeline import XFluxSampler 44 | 45 | from image_datasets.dataset import loader 46 | if is_wandb_available(): 47 | import wandb 48 | logger = get_logger(__name__, log_level="INFO") 49 | 50 | def get_models(name: str, device, offload: bool, is_schnell: bool): 51 | t5 = load_t5(device, max_length=256 if is_schnell else 512) 52 | clip = load_clip(device) 53 | clip.requires_grad_(False) 54 | model = load_flow_model2(name, device="cpu") 55 | vae = load_ae(name, device="cpu" if offload else device) 56 | return model, vae, t5, clip 57 | 58 | def parse_args(): 59 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 60 | parser.add_argument( 61 | "--config", 62 | type=str, 63 | default=None, 64 | required=True, 65 | help="path to config", 66 | ) 67 | args = parser.parse_args() 68 | 69 | 70 | return args.config 71 | 72 | 73 | def main(): 74 | args = OmegaConf.load(parse_args()) 75 | is_schnell = args.model_name == "flux-schnell" 76 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 77 | 78 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 79 | 80 | accelerator = Accelerator( 81 | gradient_accumulation_steps=args.gradient_accumulation_steps, 82 | mixed_precision=args.mixed_precision, 83 | log_with=args.report_to, 84 | project_config=accelerator_project_config, 85 | ) 86 | 87 | # Make one log on every process with the configuration for debugging. 88 | logging.basicConfig( 89 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 90 | datefmt="%m/%d/%Y %H:%M:%S", 91 | level=logging.INFO, 92 | ) 93 | logger.info(accelerator.state, main_process_only=False) 94 | if accelerator.is_local_main_process: 95 | datasets.utils.logging.set_verbosity_warning() 96 | transformers.utils.logging.set_verbosity_warning() 97 | diffusers.utils.logging.set_verbosity_info() 98 | else: 99 | datasets.utils.logging.set_verbosity_error() 100 | transformers.utils.logging.set_verbosity_error() 101 | diffusers.utils.logging.set_verbosity_error() 102 | 103 | 104 | if accelerator.is_main_process: 105 | if args.output_dir is not None: 106 | os.makedirs(args.output_dir, exist_ok=True) 107 | 108 | dit, vae, t5, clip = get_models(name=args.model_name, device=accelerator.device, offload=False, is_schnell=is_schnell) 109 | lora_attn_procs = {} 110 | 111 | if args.double_blocks is None: 112 | double_blocks_idx = list(range(19)) 113 | else: 114 | double_blocks_idx = [int(idx) for idx in args.double_blocks.split(",")] 115 | 116 | if args.single_blocks is None: 117 | single_blocks_idx = list(range(38)) 118 | elif args.single_blocks is not None: 119 | single_blocks_idx = [int(idx) for idx in args.single_blocks.split(",")] 120 | 121 | for name, attn_processor in dit.attn_processors.items(): 122 | match = re.search(r'\.(\d+)\.', name) 123 | if match: 124 | layer_index = int(match.group(1)) 125 | 126 | if name.startswith("double_blocks") and layer_index in double_blocks_idx: 127 | print("setting LoRA Processor for", name) 128 | lora_attn_procs[name] = DoubleStreamBlockLoraProcessor( 129 | dim=3072, rank=args.rank 130 | ) 131 | elif name.startswith("single_blocks") and layer_index in single_blocks_idx: 132 | print("setting LoRA Processor for", name) 133 | lora_attn_procs[name] = SingleStreamBlockLoraProcessor( 134 | dim=3072, rank=args.rank 135 | ) 136 | else: 137 | lora_attn_procs[name] = attn_processor 138 | 139 | dit.set_attn_processor(lora_attn_procs) 140 | 141 | vae.requires_grad_(False) 142 | t5.requires_grad_(False) 143 | clip.requires_grad_(False) 144 | dit = dit.to(torch.float32) 145 | dit.train() 146 | optimizer_cls = torch.optim.AdamW 147 | for n, param in dit.named_parameters(): 148 | if '_lora' not in n: 149 | param.requires_grad = False 150 | else: 151 | print(n) 152 | print(sum([p.numel() for p in dit.parameters() if p.requires_grad]) / 1000000, 'parameters') 153 | optimizer = optimizer_cls( 154 | [p for p in dit.parameters() if p.requires_grad], 155 | lr=args.learning_rate, 156 | betas=(args.adam_beta1, args.adam_beta2), 157 | weight_decay=args.adam_weight_decay, 158 | eps=args.adam_epsilon, 159 | ) 160 | 161 | train_dataloader = loader(**args.data_config) 162 | # Scheduler and math around the number of training steps. 163 | overrode_max_train_steps = False 164 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 165 | if args.max_train_steps is None: 166 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 167 | overrode_max_train_steps = True 168 | 169 | lr_scheduler = get_scheduler( 170 | args.lr_scheduler, 171 | optimizer=optimizer, 172 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 173 | num_training_steps=args.max_train_steps * accelerator.num_processes, 174 | ) 175 | global_step = 0 176 | first_epoch = 0 177 | 178 | dit, optimizer, _, lr_scheduler = accelerator.prepare( 179 | dit, optimizer, deepcopy(train_dataloader), lr_scheduler 180 | ) 181 | 182 | weight_dtype = torch.float32 183 | if accelerator.mixed_precision == "fp16": 184 | weight_dtype = torch.float16 185 | args.mixed_precision = accelerator.mixed_precision 186 | elif accelerator.mixed_precision == "bf16": 187 | weight_dtype = torch.bfloat16 188 | args.mixed_precision = accelerator.mixed_precision 189 | 190 | 191 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 192 | if overrode_max_train_steps: 193 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 194 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 195 | 196 | if accelerator.is_main_process: 197 | accelerator.init_trackers(args.tracker_project_name, {"test": None}) 198 | 199 | timesteps = get_schedule( 200 | 999, 201 | (1024 // 8) * (1024 // 8) // 4, 202 | shift=True, 203 | ) 204 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 205 | 206 | logger.info("***** Running training *****") 207 | logger.info(f" Num Epochs = {args.num_train_epochs}") 208 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 209 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 210 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 211 | logger.info(f" Total optimization steps = {args.max_train_steps}") 212 | if args.resume_from_checkpoint: 213 | if args.resume_from_checkpoint != "latest": 214 | path = os.path.basename(args.resume_from_checkpoint) 215 | else: 216 | # Get the most recent checkpoint 217 | dirs = os.listdir(args.output_dir) 218 | dirs = [d for d in dirs if d.startswith("checkpoint")] 219 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 220 | path = dirs[-1] if len(dirs) > 0 else None 221 | 222 | if path is None: 223 | accelerator.print( 224 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 225 | ) 226 | args.resume_from_checkpoint = None 227 | initial_global_step = 0 228 | else: 229 | accelerator.print(f"Resuming from checkpoint {path}") 230 | accelerator.load_state(os.path.join(args.output_dir, path)) 231 | global_step = int(path.split("-")[1]) 232 | 233 | initial_global_step = global_step 234 | first_epoch = global_step // num_update_steps_per_epoch 235 | 236 | else: 237 | initial_global_step = 0 238 | progress_bar = tqdm( 239 | range(0, args.max_train_steps), 240 | initial=initial_global_step, 241 | desc="Steps", 242 | disable=not accelerator.is_local_main_process, 243 | ) 244 | 245 | for epoch in range(first_epoch, args.num_train_epochs): 246 | train_loss = 0.0 247 | for step, batch in enumerate(train_dataloader): 248 | with accelerator.accumulate(dit): 249 | img, prompts = batch 250 | with torch.no_grad(): 251 | x_1 = vae.encode(img.to(accelerator.device).to(torch.float32)) 252 | inp = prepare(t5=t5, clip=clip, img=x_1, prompt=prompts) 253 | x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 254 | 255 | bs = img.shape[0] 256 | t = torch.tensor([timesteps[random.randint(0, 999)]]).to(accelerator.device) 257 | x_0 = torch.randn_like(x_1).to(accelerator.device) 258 | x_t = (1 - t) * x_1 + t * x_0 259 | bsz = x_1.shape[0] 260 | guidance_vec = torch.full((x_t.shape[0],), 1, device=x_t.device, dtype=x_t.dtype) 261 | 262 | # Predict the noise residual and compute loss 263 | model_pred = dit(img=x_t.to(weight_dtype), 264 | img_ids=inp['img_ids'].to(weight_dtype), 265 | txt=inp['txt'].to(weight_dtype), 266 | txt_ids=inp['txt_ids'].to(weight_dtype), 267 | y=inp['vec'].to(weight_dtype), 268 | timesteps=t.to(weight_dtype), 269 | guidance=guidance_vec.to(weight_dtype),) 270 | 271 | loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean") 272 | 273 | # Gather the losses across all processes for logging (if we use distributed training). 274 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 275 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 276 | 277 | # Backpropagate 278 | accelerator.backward(loss) 279 | if accelerator.sync_gradients: 280 | accelerator.clip_grad_norm_(dit.parameters(), args.max_grad_norm) 281 | optimizer.step() 282 | lr_scheduler.step() 283 | optimizer.zero_grad() 284 | 285 | # Checks if the accelerator has performed an optimization step behind the scenes 286 | if accelerator.sync_gradients: 287 | progress_bar.update(1) 288 | global_step += 1 289 | accelerator.log({"train_loss": train_loss}, step=global_step) 290 | train_loss = 0.0 291 | 292 | if not args.disable_sampling and global_step % args.sample_every == 0: 293 | if accelerator.is_main_process: 294 | print(f"Sampling images for step {global_step}...") 295 | sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device) 296 | images = [] 297 | for i, prompt in enumerate(args.sample_prompts): 298 | result = sampler(prompt=prompt, 299 | width=args.sample_width, 300 | height=args.sample_height, 301 | num_steps=args.sample_steps 302 | ) 303 | images.append(wandb.Image(result)) 304 | print(f"Result for prompt #{i} is generated") 305 | # result.save(f"{global_step}_prompt_{i}_res.png") 306 | wandb.log({f"Results, step {global_step}": images}) 307 | 308 | if global_step % args.checkpointing_steps == 0: 309 | if accelerator.is_main_process: 310 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 311 | if args.checkpoints_total_limit is not None: 312 | checkpoints = os.listdir(args.output_dir) 313 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 314 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 315 | 316 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 317 | if len(checkpoints) >= args.checkpoints_total_limit: 318 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 319 | removing_checkpoints = checkpoints[0:num_to_remove] 320 | 321 | logger.info( 322 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 323 | ) 324 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 325 | 326 | for removing_checkpoint in removing_checkpoints: 327 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 328 | shutil.rmtree(removing_checkpoint) 329 | 330 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 331 | 332 | accelerator.save_state(save_path) 333 | unwrapped_model_state = accelerator.unwrap_model(dit).state_dict() 334 | 335 | # save checkpoint in safetensors format 336 | lora_state_dict = {k:unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k} 337 | save_file( 338 | lora_state_dict, 339 | os.path.join(save_path, "lora.safetensors") 340 | ) 341 | 342 | logger.info(f"Saved state to {save_path}") 343 | 344 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 345 | progress_bar.set_postfix(**logs) 346 | 347 | if global_step >= args.max_train_steps: 348 | break 349 | 350 | accelerator.wait_for_everyone() 351 | accelerator.end_training() 352 | 353 | 354 | if __name__ == "__main__": 355 | main() 356 | -------------------------------------------------------------------------------- /src/flux/modules/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | from einops import rearrange 6 | from torch import Tensor, nn 7 | 8 | from ..math import attention, rope 9 | import torch.nn.functional as F 10 | 11 | class EmbedND(nn.Module): 12 | def __init__(self, dim: int, theta: int, axes_dim: list[int]): 13 | super().__init__() 14 | self.dim = dim 15 | self.theta = theta 16 | self.axes_dim = axes_dim 17 | 18 | def forward(self, ids: Tensor) -> Tensor: 19 | n_axes = ids.shape[-1] 20 | emb = torch.cat( 21 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 22 | dim=-3, 23 | ) 24 | 25 | return emb.unsqueeze(1) 26 | 27 | 28 | def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): 29 | """ 30 | Create sinusoidal timestep embeddings. 31 | :param t: a 1-D Tensor of N indices, one per batch element. 32 | These may be fractional. 33 | :param dim: the dimension of the output. 34 | :param max_period: controls the minimum frequency of the embeddings. 35 | :return: an (N, D) Tensor of positional embeddings. 36 | """ 37 | t = time_factor * t 38 | half = dim // 2 39 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( 40 | t.device 41 | ) 42 | 43 | args = t[:, None].float() * freqs[None] 44 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 45 | if dim % 2: 46 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 47 | if torch.is_floating_point(t): 48 | embedding = embedding.to(t) 49 | return embedding 50 | 51 | 52 | class MLPEmbedder(nn.Module): 53 | def __init__(self, in_dim: int, hidden_dim: int): 54 | super().__init__() 55 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) 56 | self.silu = nn.SiLU() 57 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) 58 | 59 | def forward(self, x: Tensor) -> Tensor: 60 | return self.out_layer(self.silu(self.in_layer(x))) 61 | 62 | 63 | class RMSNorm(torch.nn.Module): 64 | def __init__(self, dim: int): 65 | super().__init__() 66 | self.scale = nn.Parameter(torch.ones(dim)) 67 | 68 | def forward(self, x: Tensor): 69 | x_dtype = x.dtype 70 | x = x.float() 71 | rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) 72 | return (x * rrms).to(dtype=x_dtype) * self.scale 73 | 74 | 75 | class QKNorm(torch.nn.Module): 76 | def __init__(self, dim: int): 77 | super().__init__() 78 | self.query_norm = RMSNorm(dim) 79 | self.key_norm = RMSNorm(dim) 80 | 81 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 82 | q = self.query_norm(q) 83 | k = self.key_norm(k) 84 | return q.to(v), k.to(v) 85 | 86 | class LoRALinearLayer(nn.Module): 87 | def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): 88 | super().__init__() 89 | 90 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 91 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 92 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 93 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 94 | self.network_alpha = network_alpha 95 | self.rank = rank 96 | 97 | nn.init.normal_(self.down.weight, std=1 / rank) 98 | nn.init.zeros_(self.up.weight) 99 | 100 | def forward(self, hidden_states): 101 | orig_dtype = hidden_states.dtype 102 | dtype = self.down.weight.dtype 103 | 104 | down_hidden_states = self.down(hidden_states.to(dtype)) 105 | up_hidden_states = self.up(down_hidden_states) 106 | 107 | if self.network_alpha is not None: 108 | up_hidden_states *= self.network_alpha / self.rank 109 | 110 | return up_hidden_states.to(orig_dtype) 111 | 112 | class FLuxSelfAttnProcessor: 113 | def __call__(self, attn, x, pe, **attention_kwargs): 114 | print('2' * 30) 115 | 116 | qkv = attn.qkv(x) 117 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 118 | q, k = attn.norm(q, k, v) 119 | x = attention(q, k, v, pe=pe) 120 | x = attn.proj(x) 121 | return x 122 | 123 | class LoraFluxAttnProcessor(nn.Module): 124 | 125 | def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 126 | super().__init__() 127 | self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 128 | self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) 129 | self.lora_weight = lora_weight 130 | 131 | 132 | def __call__(self, attn, x, pe, **attention_kwargs): 133 | qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight 134 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 135 | q, k = attn.norm(q, k, v) 136 | x = attention(q, k, v, pe=pe) 137 | x = attn.proj(x) + self.proj_lora(x) * self.lora_weight 138 | print('1' * 30) 139 | print(x.norm(), (self.proj_lora(x) * self.lora_weight).norm(), 'norm') 140 | return x 141 | 142 | class SelfAttention(nn.Module): 143 | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 144 | super().__init__() 145 | self.num_heads = num_heads 146 | head_dim = dim // num_heads 147 | 148 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 149 | self.norm = QKNorm(head_dim) 150 | self.proj = nn.Linear(dim, dim) 151 | def forward(): 152 | pass 153 | 154 | 155 | @dataclass 156 | class ModulationOut: 157 | shift: Tensor 158 | scale: Tensor 159 | gate: Tensor 160 | 161 | 162 | class Modulation(nn.Module): 163 | def __init__(self, dim: int, double: bool): 164 | super().__init__() 165 | self.is_double = double 166 | self.multiplier = 6 if double else 3 167 | self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) 168 | 169 | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: 170 | out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) 171 | 172 | return ( 173 | ModulationOut(*out[:3]), 174 | ModulationOut(*out[3:]) if self.is_double else None, 175 | ) 176 | 177 | class DoubleStreamBlockLoraProcessor(nn.Module): 178 | def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 179 | super().__init__() 180 | self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 181 | self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) 182 | self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 183 | self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) 184 | self.lora_weight = lora_weight 185 | 186 | def forward(self, attn, img, txt, vec, pe, **attention_kwargs): 187 | img_mod1, img_mod2 = attn.img_mod(vec) 188 | txt_mod1, txt_mod2 = attn.txt_mod(vec) 189 | 190 | # prepare image for attention 191 | img_modulated = attn.img_norm1(img) 192 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 193 | img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight 194 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 195 | img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 196 | 197 | # prepare txt for attention 198 | txt_modulated = attn.txt_norm1(txt) 199 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 200 | txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight 201 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 202 | txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 203 | 204 | # run actual attention 205 | q = torch.cat((txt_q, img_q), dim=2) 206 | k = torch.cat((txt_k, img_k), dim=2) 207 | v = torch.cat((txt_v, img_v), dim=2) 208 | 209 | attn1 = attention(q, k, v, pe=pe) 210 | txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 211 | 212 | # calculate the img bloks 213 | img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight 214 | img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) 215 | 216 | # calculate the txt bloks 217 | txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight 218 | txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) 219 | return img, txt 220 | 221 | class IPDoubleStreamBlockProcessor(nn.Module): 222 | """Attention processor for handling IP-adapter with double stream block.""" 223 | 224 | def __init__(self, context_dim, hidden_dim): 225 | super().__init__() 226 | if not hasattr(F, "scaled_dot_product_attention"): 227 | raise ImportError( 228 | "IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch." 229 | ) 230 | 231 | # Ensure context_dim matches the dimension of image_proj 232 | self.context_dim = context_dim 233 | self.hidden_dim = hidden_dim 234 | 235 | # Initialize projections for IP-adapter 236 | self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True) 237 | self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True) 238 | 239 | nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight) 240 | nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias) 241 | 242 | nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight) 243 | nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias) 244 | 245 | def __call__(self, attn, img, txt, vec, pe, image_proj, ip_scale=1.0, **attention_kwargs): 246 | 247 | # Prepare image for attention 248 | img_mod1, img_mod2 = attn.img_mod(vec) 249 | txt_mod1, txt_mod2 = attn.txt_mod(vec) 250 | 251 | img_modulated = attn.img_norm1(img) 252 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 253 | img_qkv = attn.img_attn.qkv(img_modulated) 254 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 255 | img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 256 | 257 | txt_modulated = attn.txt_norm1(txt) 258 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 259 | txt_qkv = attn.txt_attn.qkv(txt_modulated) 260 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 261 | txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 262 | 263 | q = torch.cat((txt_q, img_q), dim=2) 264 | k = torch.cat((txt_k, img_k), dim=2) 265 | v = torch.cat((txt_v, img_v), dim=2) 266 | 267 | attn1 = attention(q, k, v, pe=pe) 268 | txt_attn, img_attn = attn1[:, :txt.shape[1]], attn1[:, txt.shape[1]:] 269 | 270 | # print(f"txt_attn shape: {txt_attn.size()}") 271 | # print(f"img_attn shape: {img_attn.size()}") 272 | 273 | img = img + img_mod1.gate * attn.img_attn.proj(img_attn) 274 | img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) 275 | 276 | txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) 277 | txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) 278 | 279 | 280 | # IP-adapter processing 281 | ip_query = img_q # latent sample query 282 | ip_key = self.ip_adapter_double_stream_k_proj(image_proj) 283 | ip_value = self.ip_adapter_double_stream_v_proj(image_proj) 284 | 285 | # Reshape projections for multi-head attention 286 | ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) 287 | ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) 288 | 289 | # Compute attention between IP projections and the latent query 290 | ip_attention = F.scaled_dot_product_attention( 291 | ip_query, 292 | ip_key, 293 | ip_value, 294 | dropout_p=0.0, 295 | is_causal=False 296 | ) 297 | ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim) 298 | 299 | img = img + ip_scale * ip_attention 300 | 301 | return img, txt 302 | 303 | class DoubleStreamBlockProcessor: 304 | def __call__(self, attn, img, txt, vec, pe, **attention_kwargs): 305 | img_mod1, img_mod2 = attn.img_mod(vec) 306 | txt_mod1, txt_mod2 = attn.txt_mod(vec) 307 | 308 | # prepare image for attention 309 | img_modulated = attn.img_norm1(img) 310 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 311 | img_qkv = attn.img_attn.qkv(img_modulated) 312 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 313 | img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 314 | 315 | # prepare txt for attention 316 | txt_modulated = attn.txt_norm1(txt) 317 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 318 | txt_qkv = attn.txt_attn.qkv(txt_modulated) 319 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 320 | txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 321 | 322 | # run actual attention 323 | q = torch.cat((txt_q, img_q), dim=2) 324 | k = torch.cat((txt_k, img_k), dim=2) 325 | v = torch.cat((txt_v, img_v), dim=2) 326 | 327 | attn1 = attention(q, k, v, pe=pe) 328 | txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 329 | 330 | # calculate the img bloks 331 | img = img + img_mod1.gate * attn.img_attn.proj(img_attn) 332 | img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) 333 | 334 | # calculate the txt bloks 335 | txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) 336 | txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) 337 | return img, txt 338 | 339 | class DoubleStreamBlock(nn.Module): 340 | def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): 341 | super().__init__() 342 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 343 | self.num_heads = num_heads 344 | self.hidden_size = hidden_size 345 | self.head_dim = hidden_size // num_heads 346 | 347 | self.img_mod = Modulation(hidden_size, double=True) 348 | self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 349 | self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 350 | 351 | self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 352 | self.img_mlp = nn.Sequential( 353 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 354 | nn.GELU(approximate="tanh"), 355 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 356 | ) 357 | 358 | self.txt_mod = Modulation(hidden_size, double=True) 359 | self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 360 | self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 361 | 362 | self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 363 | self.txt_mlp = nn.Sequential( 364 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 365 | nn.GELU(approximate="tanh"), 366 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 367 | ) 368 | processor = DoubleStreamBlockProcessor() 369 | self.set_processor(processor) 370 | 371 | def set_processor(self, processor) -> None: 372 | self.processor = processor 373 | 374 | def get_processor(self): 375 | return self.processor 376 | 377 | def forward( 378 | self, 379 | img: Tensor, 380 | txt: Tensor, 381 | vec: Tensor, 382 | pe: Tensor, 383 | image_proj: Tensor = None, 384 | ip_scale: float =1.0, 385 | ) -> tuple[Tensor, Tensor]: 386 | if image_proj is None: 387 | return self.processor(self, img, txt, vec, pe) 388 | else: 389 | return self.processor(self, img, txt, vec, pe, image_proj, ip_scale) 390 | 391 | class IPSingleStreamBlockProcessor(nn.Module): 392 | """Attention processor for handling IP-adapter with single stream block.""" 393 | def __init__(self, context_dim, hidden_dim): 394 | super().__init__() 395 | if not hasattr(F, "scaled_dot_product_attention"): 396 | raise ImportError( 397 | "IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch." 398 | ) 399 | 400 | # Ensure context_dim matches the dimension of image_proj 401 | self.context_dim = context_dim 402 | self.hidden_dim = hidden_dim 403 | 404 | # Initialize projections for IP-adapter 405 | self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False) 406 | self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False) 407 | 408 | nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight) 409 | nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight) 410 | 411 | def __call__( 412 | self, 413 | attn: nn.Module, 414 | x: Tensor, 415 | vec: Tensor, 416 | pe: Tensor, 417 | image_proj: Tensor | None = None, 418 | ip_scale: float = 1.0 419 | ) -> Tensor: 420 | 421 | mod, _ = attn.modulation(vec) 422 | x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 423 | qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) 424 | 425 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 426 | q, k = attn.norm(q, k, v) 427 | 428 | # compute attention 429 | attn_1 = attention(q, k, v, pe=pe) 430 | 431 | # IP-adapter processing 432 | ip_query = q 433 | ip_key = self.ip_adapter_single_stream_k_proj(image_proj) 434 | ip_value = self.ip_adapter_single_stream_v_proj(image_proj) 435 | 436 | # Reshape projections for multi-head attention 437 | ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) 438 | ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) 439 | 440 | 441 | # Compute attention between IP projections and the latent query 442 | ip_attention = F.scaled_dot_product_attention( 443 | ip_query, 444 | ip_key, 445 | ip_value 446 | ) 447 | ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)") 448 | 449 | attn_out = attn_1 + ip_scale * ip_attention 450 | 451 | # compute activation in mlp stream, cat again and run second linear layer 452 | output = attn.linear2(torch.cat((attn_out, attn.mlp_act(mlp)), 2)) 453 | out = x + mod.gate * output 454 | 455 | return out 456 | 457 | 458 | class SingleStreamBlockLoraProcessor(nn.Module): 459 | def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1): 460 | super().__init__() 461 | self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 462 | self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha) 463 | self.lora_weight = lora_weight 464 | 465 | def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: 466 | 467 | mod, _ = attn.modulation(vec) 468 | x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 469 | qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) 470 | qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight 471 | 472 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 473 | q, k = attn.norm(q, k, v) 474 | 475 | # compute attention 476 | attn_1 = attention(q, k, v, pe=pe) 477 | 478 | # compute activation in mlp stream, cat again and run second linear layer 479 | output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) 480 | output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight 481 | output = x + mod.gate * output 482 | return output 483 | 484 | 485 | class SingleStreamBlockProcessor: 486 | def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: 487 | 488 | mod, _ = attn.modulation(vec) 489 | x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 490 | qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) 491 | 492 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 493 | q, k = attn.norm(q, k, v) 494 | 495 | # compute attention 496 | attn_1 = attention(q, k, v, pe=pe) 497 | 498 | # compute activation in mlp stream, cat again and run second linear layer 499 | output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) 500 | output = x + mod.gate * output 501 | return output 502 | 503 | class SingleStreamBlock(nn.Module): 504 | """ 505 | A DiT block with parallel linear layers as described in 506 | https://arxiv.org/abs/2302.05442 and adapted modulation interface. 507 | """ 508 | 509 | def __init__( 510 | self, 511 | hidden_size: int, 512 | num_heads: int, 513 | mlp_ratio: float = 4.0, 514 | qk_scale: float | None = None, 515 | ): 516 | super().__init__() 517 | self.hidden_dim = hidden_size 518 | self.num_heads = num_heads 519 | self.head_dim = hidden_size // num_heads 520 | self.scale = qk_scale or self.head_dim**-0.5 521 | 522 | self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 523 | # qkv and mlp_in 524 | self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 525 | # proj and mlp_out 526 | self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 527 | 528 | self.norm = QKNorm(self.head_dim) 529 | 530 | self.hidden_size = hidden_size 531 | self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 532 | 533 | self.mlp_act = nn.GELU(approximate="tanh") 534 | self.modulation = Modulation(hidden_size, double=False) 535 | 536 | processor = SingleStreamBlockProcessor() 537 | self.set_processor(processor) 538 | 539 | 540 | def set_processor(self, processor) -> None: 541 | self.processor = processor 542 | 543 | def get_processor(self): 544 | return self.processor 545 | 546 | def forward( 547 | self, 548 | x: Tensor, 549 | vec: Tensor, 550 | pe: Tensor, 551 | image_proj: Tensor | None = None, 552 | ip_scale: float = 1.0 553 | ) -> Tensor: 554 | if image_proj is None: 555 | return self.processor(self, x, vec, pe) 556 | else: 557 | return self.processor(self, x, vec, pe, image_proj, ip_scale) 558 | 559 | 560 | 561 | class LastLayer(nn.Module): 562 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 563 | super().__init__() 564 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 565 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 566 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) 567 | 568 | def forward(self, x: Tensor, vec: Tensor) -> Tensor: 569 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) 570 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] 571 | x = self.linear(x) 572 | return x 573 | 574 | class ImageProjModel(torch.nn.Module): 575 | """Projection Model 576 | https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28 577 | """ 578 | 579 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 580 | super().__init__() 581 | 582 | self.generator = None 583 | self.cross_attention_dim = cross_attention_dim 584 | self.clip_extra_context_tokens = clip_extra_context_tokens 585 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 586 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 587 | 588 | def forward(self, image_embeds): 589 | embeds = image_embeds 590 | clip_extra_context_tokens = self.proj(embeds).reshape( 591 | -1, self.clip_extra_context_tokens, self.cross_attention_dim 592 | ) 593 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 594 | return clip_extra_context_tokens 595 | 596 | --------------------------------------------------------------------------------