252 |
253 | Here is also a Chinese tutorial [Youtube video](https://www.youtube.com/watch?v=rRMc5DE4qMo) on how to install and use ICEdit, created by [softicelee2](https://github.com/softicelee2). It's definitely worth a watch!
254 |
255 | ## 💼 Windows one-click package
256 |
257 | Great thanks to [gluttony-10](https://github.com/River-Zhang/ICEdit/issues/23#issue-3050804566), a famous [Bilibili Up](https://space.bilibili.com/893892)! He made a tutorial ([Youtube](https://youtu.be/C-OpWlJi424) and [Bilibili](https://www.bilibili.com/video/BV1oT5uzzEbs)) on how to install our project on windows and a one-click package for Windows! **Just unzip it and it's ready to use**. It has undergone quantization processing. It only takes up 14GB of space and supports graphics cards of the 50 series.
258 |
259 | Download link: [Google Drive](https://drive.google.com/drive/folders/16j3wQvWjuzCRKnVolszLmhCtc_yOCqcx?usp=sharing) or [Baidu Wangpan](https://www.bilibili.com/video/BV1oT5uzzEbs/?vd_source=2a911c0bc75f6d9b9d056bf0e7410d45)(refer to the comment section of the video)
260 |
261 |
262 |
263 | # 🔧 Training
264 |
265 | Found more details in here: [Training Code](./train/)
266 |
267 | # 💪 To Do List
268 |
269 | - [x] Inference Code
270 | - [ ] Inference-time Scaling with VLM
271 | - [x] Pretrained Weights
272 | - [x] More Inference Demos
273 | - [x] Gradio demo
274 | - [x] Comfy UI demo (by @[judian17](https://github.com/River-Zhang/ICEdit/issues/1#issuecomment-2846568411), compatible with [nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku), support high-res refinement and FLUX Redux. Only 4GB VRAM GPU is enough to run!)
275 | - [x] Comfy UI demo with normal lora (by @[Datou](https://openart.ai/workflows/datou/icedit-moe-lora-flux-fill/QFmaWNKsQo3P5liYz4RB) in openart)
276 | - [x] Official ComfyUI workflow
277 | - [x] Training Code
278 | - [ ] LoRA for higher image resolution (768, 1024)
279 |
280 |
281 |
282 | # 💪 Comparison with Commercial Models
283 |
284 |
285 |
286 |

287 |
Compared with commercial models such as Gemini and GPT-4o, our methods are comparable to and even superior to these commercial models in terms of character ID preservation and instruction following. We are more open-source than them, with lower costs, faster speed (it takes about 9 seconds to process one image), and powerful performance.
288 |
289 |
290 |
291 |
292 |
293 |
294 | # 🌟 Star History
295 |
296 | [](https://www.star-history.com/#River-Zhang/ICEdit&Date)
297 |
298 | # Bibtex
299 | If this work is helpful for your research, please consider citing the following BibTeX entry.
300 |
301 | ```
302 | @misc{zhang2025ICEdit,
303 | title={In-Context Edit: Enabling Instructional Image Editing with In-Context Generation in Large Scale Diffusion Transformer},
304 | author={Zechuan Zhang and Ji Xie and Yu Lu and Zongxin Yang and Yi Yang},
305 | year={2025},
306 | eprint={2504.20690},
307 | archivePrefix={arXiv},
308 | primaryClass={cs.CV},
309 | url={https://arxiv.org/abs/2504.20690},
310 | }
311 | ```
312 |
--------------------------------------------------------------------------------
/assets/boy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/assets/boy.png
--------------------------------------------------------------------------------
/assets/girl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/assets/girl.png
--------------------------------------------------------------------------------
/assets/hybrid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/assets/hybrid.png
--------------------------------------------------------------------------------
/assets/kaori.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/assets/kaori.jpg
--------------------------------------------------------------------------------
/docs/images/comfyuiexample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/comfyuiexample.png
--------------------------------------------------------------------------------
/docs/images/gpt4o_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/gpt4o_comparison.png
--------------------------------------------------------------------------------
/docs/images/gradio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/gradio.png
--------------------------------------------------------------------------------
/docs/images/lora_scale.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/lora_scale.png
--------------------------------------------------------------------------------
/docs/images/official_workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/official_workflow.png
--------------------------------------------------------------------------------
/docs/images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/teaser.png
--------------------------------------------------------------------------------
/docs/images/windows_install.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/windows_install.png
--------------------------------------------------------------------------------
/docs/images/workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/workflow.png
--------------------------------------------------------------------------------
/docs/images/workflow_t8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/workflow_t8.png
--------------------------------------------------------------------------------
/docs/images/workflow_tutorial.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/docs/images/workflow_tutorial.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | diffusers==0.33.0
3 | gradio
4 | numpy
5 | peft
6 | protobuf
7 | sentencepiece
8 | spaces
9 | torch==2.7.0
10 | torchvision
11 | transformers==4.51.3
12 | gguf
13 |
--------------------------------------------------------------------------------
/scripts/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "FluxTransformer2DModel",
3 | "_diffusers_version": "0.32.0.dev0",
4 | "attention_head_dim": 128,
5 | "axes_dims_rope": [
6 | 16,
7 | 56,
8 | 56
9 | ],
10 | "guidance_embeds": true,
11 | "in_channels": 384,
12 | "joint_attention_dim": 4096,
13 | "num_attention_heads": 24,
14 | "num_layers": 19,
15 | "num_single_layers": 38,
16 | "out_channels": 64,
17 | "patch_size": 1,
18 | "pooled_projection_dim": 768
19 | }
20 |
--------------------------------------------------------------------------------
/scripts/gradio_demo.py:
--------------------------------------------------------------------------------
1 | '''
2 | python scripts/gradio_demo.py
3 | '''
4 |
5 | import sys
6 | import os
7 | # workspace_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../icedit"))
8 |
9 | # if workspace_dir not in sys.path:
10 | # sys.path.insert(0, workspace_dir)
11 |
12 | from diffusers import FluxFillPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
13 | import gradio as gr
14 | import numpy as np
15 | import torch
16 | import spaces
17 | import argparse
18 | import random
19 | from diffusers import FluxFillPipeline
20 | from PIL import Image
21 |
22 | from transformers import T5EncoderModel
23 |
24 | MAX_SEED = np.iinfo(np.int32).max
25 | MAX_IMAGE_SIZE = 1024
26 |
27 | current_lora_scale = 1.0
28 |
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument("--server_name", type=str, default="127.0.0.1")
31 | parser.add_argument("--port", type=int, default=7860, help="Port for the Gradio app")
32 | parser.add_argument("--share", action="store_true")
33 | parser.add_argument("--output-dir", type=str, default="gradio_results", help="Directory to save the output image")
34 | parser.add_argument("--flux-path", type=str, default='black-forest-labs/flux.1-fill-dev', help="Path to the model")
35 | parser.add_argument("--lora-path", type=str, default='RiverZ/normal-lora', help="Path to the LoRA weights")
36 | parser.add_argument("--transformer", type=str, default=None, help="The gguf model of FluxTransformer2DModel")
37 | parser.add_argument("--text_encoder_2", type=str, default=None, help="The gguf model of T5EncoderModel")
38 | parser.add_argument("--enable-model-cpu-offload", action="store_true", help="Enable CPU offloading for the model")
39 | args = parser.parse_args()
40 |
41 | if args.transformer:
42 | args.transformer = os.path.abspath(args.transformer)
43 | transformer = FluxTransformer2DModel.from_single_file(
44 | args.transformer,
45 | config="scripts/config.json",
46 | quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
47 | torch_dtype=torch.bfloat16,
48 | )
49 | else:
50 | transformer = FluxTransformer2DModel.from_pretrained(
51 | args.flux_path,
52 | subfolder="transformer",
53 | torch_dtype=torch.bfloat16,
54 | )
55 |
56 | if args.text_encoder_2:
57 | args.text_encoder_2 = os.path.abspath(args.text_encoder_2)
58 | text_encoder_2 = T5EncoderModel.from_pretrained(
59 | args.flux_path,
60 | subfolder="text_encoder_2",
61 | gguf_file=f"{args.text_encoder_2}",
62 | torch_dtype=torch.bfloat16,
63 | )
64 | else:
65 | text_encoder_2 = T5EncoderModel.from_pretrained(
66 | args.flux_path,
67 | subfolder="text_encoder_2",
68 | torch_dtype=torch.bfloat16,
69 | )
70 |
71 | pipe = FluxFillPipeline.from_pretrained(
72 | args.flux_path,
73 | transformer=transformer,
74 | text_encoder_2=text_encoder_2,
75 | torch_dtype=torch.bfloat16
76 | )
77 | pipe.load_lora_weights(args.lora_path, adapter_name="icedit")
78 | pipe.set_adapters("icedit", 1.0)
79 |
80 | if args.enable_model_cpu_offload:
81 | pipe.enable_model_cpu_offload()
82 | else:
83 | pipe = pipe.to("cuda")
84 |
85 |
86 | @spaces.GPU
87 | def infer(edit_images,
88 | prompt,
89 | seed=666,
90 | randomize_seed=False,
91 | width=1024,
92 | height=1024,
93 | guidance_scale=50,
94 | num_inference_steps=28,
95 | lora_scale=1.0,
96 | progress=gr.Progress(track_tqdm=True)):
97 |
98 | global current_lora_scale
99 |
100 | if lora_scale != current_lora_scale:
101 | print(f"\033[93m[INFO] LoRA scale changed from {current_lora_scale} to {lora_scale}, reloading LoRA weights\033[0m")
102 | pipe.set_adapters("icedit", lora_scale)
103 | current_lora_scale = lora_scale
104 |
105 | image = edit_images
106 |
107 | if image.size[0] != 512:
108 | print("\033[93m[WARNING] We can only deal with the case where the image's width is 512.\033[0m")
109 | new_width = 512
110 | scale = new_width / image.size[0]
111 | new_height = int(image.size[1] * scale)
112 | new_height = (new_height // 8) * 8
113 | image = image.resize((new_width, new_height))
114 | print(f"\033[93m[WARNING] Resizing the image to {new_width} x {new_height}\033[0m")
115 |
116 | image = image.convert("RGB")
117 | width, height = image.size
118 | image = image.resize((512, int(512 * height / width)))
119 | combined_image = Image.new("RGB", (width * 2, height))
120 | combined_image.paste(image, (0, 0))
121 | mask_array = np.zeros((height, width * 2), dtype=np.uint8)
122 | mask_array[:, width:] = 255
123 | mask = Image.fromarray(mask_array)
124 | instruction = f'A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left but {prompt}'
125 |
126 | if randomize_seed:
127 | seed = random.randint(0, MAX_SEED)
128 |
129 | image = pipe(
130 | prompt=instruction,
131 | image=combined_image,
132 | mask_image=mask,
133 | height=height,
134 | width=width * 2,
135 | guidance_scale=guidance_scale,
136 | num_inference_steps=num_inference_steps,
137 | generator=torch.Generator().manual_seed(seed),
138 | ).images[0]
139 |
140 | w, h = image.size
141 | image = image.crop((w // 2, 0, w, h))
142 |
143 | os.makedirs(args.output_dir, exist_ok=True)
144 |
145 | index = len(os.listdir(args.output_dir))
146 | image.save(f"{args.output_dir}/result_{index}.png")
147 |
148 | return image, seed
149 |
150 | original_examples = [
151 | "a tiny astronaut hatching from an egg on the moon",
152 | "a cat holding a sign that says hello world",
153 | "an anime illustration of a wiener schnitzel",
154 | ]
155 |
156 | new_examples = [
157 | ['assets/girl.png', 'Make her hair dark green and her clothes checked.', 304897401],
158 | ['assets/boy.png', 'Change the sunglasses to a Christmas hat.', 748891420],
159 | ['assets/kaori.jpg', 'Make it a sketch.', 484817364]
160 | ]
161 |
162 | css = """
163 | #col-container {
164 | margin: 0 auto;
165 | max-width: 1000px;
166 | }
167 | """
168 |
169 | with gr.Blocks(css=css) as demo:
170 |
171 | with gr.Column(elem_id="col-container"):
172 | gr.Markdown(f"""# IC-Edit
173 | A demo for [IC-Edit](https://arxiv.org/pdf/2504.20690).
174 | More **open-source**, with **lower costs**, **faster speed** (it takes about 9 seconds to process one image), and **powerful performance**.
175 | For more details, check out our [Github Repository](https://github.com/River-Zhang/ICEdit) and [website](https://river-zhang.github.io/ICEdit-gh-pages/). If our project resonates with you or proves useful, we'd be truly grateful if you could spare a moment to give it a star.
176 | """)
177 | with gr.Row():
178 | with gr.Column():
179 | edit_image = gr.Image(
180 | label='Upload image for editing',
181 | type='pil',
182 | sources=["upload", "webcam"],
183 | image_mode='RGB',
184 | height=600
185 | )
186 | prompt = gr.Text(
187 | label="Prompt",
188 | show_label=False,
189 | max_lines=1,
190 | placeholder="Enter your prompt",
191 | container=False,
192 | )
193 | run_button = gr.Button("Run")
194 |
195 | result = gr.Image(label="Result", show_label=False)
196 |
197 | with gr.Accordion("Advanced Settings", open=True):
198 |
199 | seed = gr.Slider(
200 | label="Seed",
201 | minimum=0,
202 | maximum=MAX_SEED,
203 | step=1,
204 | value=0,
205 | )
206 |
207 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
208 |
209 | with gr.Row():
210 |
211 | width = gr.Slider(
212 | label="Width",
213 | minimum=512,
214 | maximum=MAX_IMAGE_SIZE,
215 | step=32,
216 | value=1024,
217 | visible=False
218 | )
219 |
220 | height = gr.Slider(
221 | label="Height",
222 | minimum=512,
223 | maximum=MAX_IMAGE_SIZE,
224 | step=32,
225 | value=1024,
226 | visible=False
227 | )
228 |
229 | with gr.Row():
230 |
231 | guidance_scale = gr.Slider(
232 | label="Guidance Scale",
233 | minimum=1,
234 | maximum=100,
235 | step=0.5,
236 | value=50,
237 | )
238 |
239 | num_inference_steps = gr.Slider(
240 | label="Number of inference steps",
241 | minimum=1,
242 | maximum=50,
243 | step=1,
244 | value=28,
245 | )
246 |
247 | lora_scale = gr.Slider(
248 | label="LoRA Scale",
249 | minimum=0,
250 | maximum=1.0,
251 | step=0.01,
252 | value=1.0,
253 | )
254 |
255 | def process_example(edit_image, prompt, seed, randomize_seed):
256 | result, seed_out = infer(edit_image, prompt, seed, False, 1024, 1024, 50, 28, 1.0)
257 | return result, seed_out, False
258 |
259 | gr.Examples(
260 | examples=new_examples,
261 | inputs=[edit_image, prompt, seed, randomize_seed],
262 | outputs=[result, seed, randomize_seed],
263 | fn=process_example,
264 | cache_examples=False
265 | )
266 |
267 | gr.on(
268 | triggers=[run_button.click, prompt.submit],
269 | fn=infer,
270 | inputs=[edit_image, prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, lora_scale],
271 | outputs=[result, seed]
272 | )
273 |
274 | if __name__ == "__main__":
275 | demo.launch(
276 | server_name=args.server_name,
277 | server_port=args.port,
278 | share=args.share,
279 | inbrowser=True,
280 | )
281 |
--------------------------------------------------------------------------------
/scripts/inference.py:
--------------------------------------------------------------------------------
1 | # Use the modified diffusers & peft library
2 | import sys
3 | import os
4 | # workspace_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../icedit"))
5 |
6 | # if workspace_dir not in sys.path:
7 | # sys.path.insert(0, workspace_dir)
8 |
9 | from diffusers import FluxFillPipeline
10 |
11 | # Below is the original library
12 | import torch
13 | from PIL import Image
14 | import numpy as np
15 | import argparse
16 | import random
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--image", type=str, help="Name of the image to be edited", required=True)
20 | parser.add_argument("--instruction", type=str, help="Instruction for editing the image", required=True)
21 | parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
22 | parser.add_argument("--output-dir", type=str, default=".", help="Directory to save the output image")
23 | parser.add_argument("--flux-path", type=str, default='black-forest-labs/flux.1-fill-dev', help="Path to the model")
24 | parser.add_argument("--lora-path", type=str, default='RiverZ/normal-lora', help="Path to the LoRA weights")
25 | parser.add_argument("--enable-model-cpu-offload", action="store_true", help="Enable CPU offloading for the model")
26 |
27 |
28 | args = parser.parse_args()
29 | pipe = FluxFillPipeline.from_pretrained(args.flux_path, torch_dtype=torch.bfloat16)
30 | pipe.load_lora_weights(args.lora_path)
31 |
32 | if args.enable_model_cpu_offload:
33 | pipe.enable_model_cpu_offload()
34 | else:
35 | pipe = pipe.to("cuda")
36 |
37 | image = Image.open(args.image)
38 | image = image.convert("RGB")
39 |
40 | if image.size[0] != 512:
41 | print("\033[93m[WARNING] We can only deal with the case where the image's width is 512.\033[0m")
42 | new_width = 512
43 | scale = new_width / image.size[0]
44 | new_height = int(image.size[1] * scale)
45 | new_height = (new_height // 8) * 8
46 | image = image.resize((new_width, new_height))
47 | print(f"\033[93m[WARNING] Resizing the image to {new_width} x {new_height}\033[0m")
48 |
49 | instruction = args.instruction
50 |
51 | print(f"Instruction: {instruction}")
52 | instruction = f'A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left but {instruction}'
53 |
54 | width, height = image.size
55 | combined_image = Image.new("RGB", (width * 2, height))
56 | combined_image.paste(image, (0, 0))
57 | combined_image.paste(image, (width, 0))
58 | mask_array = np.zeros((height, width * 2), dtype=np.uint8)
59 | mask_array[:, width:] = 255
60 | mask = Image.fromarray(mask_array)
61 |
62 | result_image = pipe(
63 | prompt=instruction,
64 | image=combined_image,
65 | mask_image=mask,
66 | height=height,
67 | width=width * 2,
68 | guidance_scale=50,
69 | num_inference_steps=28,
70 | generator=torch.Generator("cpu").manual_seed(args.seed) if args.seed is not None else None,
71 | ).images[0]
72 |
73 | result_image = result_image.crop((width,0,width*2,height))
74 |
75 | os.makedirs(args.output_dir, exist_ok=True)
76 |
77 | image_name = args.image.split("/")[-1]
78 | result_image.save(os.path.join(args.output_dir, f"{image_name}"))
79 | print(f"\033[92mResult saved as {os.path.abspath(os.path.join(args.output_dir, image_name))}\033[0m")
80 |
--------------------------------------------------------------------------------
/train/README.md:
--------------------------------------------------------------------------------
1 | # ICEdit Training Repository
2 |
3 | This repository contains the training code for ICEdit, a model for image editing based on text instructions. It utilizes conditional generation to perform instructional image edits.
4 |
5 | This codebase is based heavily on the [OminiControl](https://github.com/Yuanshi9815/OminiControl) repository. We thank the authors for their work and contributions to the field!
6 |
7 | ## Setup and Installation
8 |
9 | ```bash
10 | # Create a new conda environment
11 | conda create -n train python=3.10
12 | conda activate train
13 |
14 | # Install requirements
15 | pip install -r requirements.txt
16 | ```
17 |
18 | ## Project Structure
19 |
20 | - `src/`: Source code directory
21 | - `train/`: Training modules
22 | - `train.py`: Main training script
23 | - `data.py`: Dataset classes for handling different data formats
24 | - `model.py`: Model definition using Flux pipeline
25 | - `callbacks.py`: Training callbacks for logging and checkpointing
26 | - `flux/`: Flux model implementation
27 | - `assets/`: Asset files
28 | - `parquet/`: Parquet data files
29 | - `requirements.txt`: Dependency list
30 |
31 | ## Datasets
32 |
33 | Download training datasets (part of OmniEdit) to the `parquet/` directory. You can use the provided scripts `parquet/prepare.sh`.
34 |
35 | ```bash
36 | cd parquet
37 | bash prepare.sh
38 | ```
39 |
40 | ## Training
41 |
42 | ```bash
43 | bash train/script/train.sh
44 | ```
45 |
46 | You can modify the training configuration in `train/config/normal_lora.yaml`.
47 |
--------------------------------------------------------------------------------
/train/assets/book.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/book.jpg
--------------------------------------------------------------------------------
/train/assets/clock.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/clock.jpg
--------------------------------------------------------------------------------
/train/assets/coffee.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/coffee.png
--------------------------------------------------------------------------------
/train/assets/monalisa.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/monalisa.jpg
--------------------------------------------------------------------------------
/train/assets/oranges.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/oranges.jpg
--------------------------------------------------------------------------------
/train/assets/penguin.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/penguin.jpg
--------------------------------------------------------------------------------
/train/assets/room_corner.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/room_corner.jpg
--------------------------------------------------------------------------------
/train/assets/vase.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/River-Zhang/ICEdit/a9472d6e71cd15d347759bcddb80441c8cc30da7/train/assets/vase.jpg
--------------------------------------------------------------------------------
/train/parquet/prepare.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mkdir -p $(dirname "$0")
4 |
5 | BASE_URL_105="https://huggingface.co/datasets/sayakpaul/OmniEdit-mini/resolve/main/data"
6 | BASE_URL_571="https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M/resolve/main/data"
7 |
8 | FILES=(
9 | "train-00053-of-00105.parquet"
10 | "train-00008-of-00105.parquet"
11 | "train-00093-of-00105.parquet"
12 | "train-00097-of-00105.parquet"
13 | "train-00009-of-00105.parquet"
14 | "train-00069-of-00105.parquet"
15 | "train-00029-of-00105.parquet"
16 | "train-00083-of-00105.parquet"
17 | "train-00037-of-00105.parquet"
18 | "train-00079-of-00105.parquet"
19 | "train-00085-of-00105.parquet"
20 | "train-00087-of-00105.parquet"
21 | "train-00038-of-00105.parquet"
22 | "train-00041-of-00105.parquet"
23 | "train-00047-of-00105.parquet"
24 | "train-00145-of-00571.parquet"
25 | "train-00091-of-00105.parquet"
26 | "train-00004-of-00105.parquet"
27 | "train-00014-of-00105.parquet"
28 | "train-00016-of-00105.parquet"
29 | "train-00035-of-00105.parquet"
30 | "train-00017-of-00105.parquet"
31 | "train-00066-of-00105.parquet"
32 | "train-00071-of-00105.parquet"
33 | "train-00043-of-00105.parquet"
34 | "train-00067-of-00105.parquet"
35 | "train-00074-of-00105.parquet"
36 | "train-00001-of-00105.parquet"
37 | "train-00115-of-00571.parquet"
38 | "train-00048-of-00105.parquet"
39 | "train-00064-of-00105.parquet"
40 | "train-00010-of-00105.parquet"
41 | "train-00011-of-00105.parquet"
42 | "train-00062-of-00105.parquet"
43 | "train-00567-of-00571.parquet"
44 | "train-00032-of-00105.parquet"
45 | "train-00070-of-00105.parquet"
46 | "train-00160-of-00571.parquet"
47 | "train-00046-of-00105.parquet"
48 | "train-00073-of-00105.parquet"
49 | "train-00006-of-00105.parquet"
50 | "train-00061-of-00105.parquet"
51 | "train-00050-of-00105.parquet"
52 | "train-00056-of-00105.parquet"
53 | "train-00003-of-00105.parquet"
54 | "train-00012-of-00105.parquet"
55 | "train-00089-of-00105.parquet"
56 | "train-00028-of-00105.parquet"
57 | "train-00015-of-00105.parquet"
58 | "train-00103-of-00105.parquet"
59 | "train-00099-of-00105.parquet"
60 | "train-00020-of-00105.parquet"
61 | "train-00033-of-00105.parquet"
62 | "train-00078-of-00105.parquet"
63 | "train-00000-of-00105.parquet"
64 | "train-00566-of-00571.parquet"
65 | "train-00054-of-00105.parquet"
66 | "train-00044-of-00105.parquet"
67 | "train-00100-of-00571.parquet"
68 | "train-00049-of-00105.parquet"
69 | "train-00019-of-00105.parquet"
70 | "train-00076-of-00105.parquet"
71 | "train-00025-of-00105.parquet"
72 | "train-00081-of-00105.parquet"
73 | "train-00045-of-00105.parquet"
74 | "train-00036-of-00105.parquet"
75 | "train-00080-of-00105.parquet"
76 | "train-00034-of-00105.parquet"
77 | "train-00057-of-00105.parquet"
78 | "train-00082-of-00105.parquet"
79 | "train-00059-of-00105.parquet"
80 | "train-00058-of-00105.parquet"
81 | "train-00013-of-00105.parquet"
82 | "train-00084-of-00105.parquet"
83 | "train-00100-of-00105.parquet"
84 | "train-00090-of-00105.parquet"
85 | "train-00094-of-00105.parquet"
86 | "train-00060-of-00105.parquet"
87 | "train-00175-of-00571.parquet"
88 | "train-00065-of-00105.parquet"
89 | "train-00040-of-00105.parquet"
90 | "train-00023-of-00105.parquet"
91 | "train-00088-of-00105.parquet"
92 | "train-00068-of-00105.parquet"
93 | "train-00027-of-00105.parquet"
94 | "train-00568-of-00571.parquet"
95 | "train-00098-of-00105.parquet"
96 | "train-00031-of-00105.parquet"
97 | "train-00063-of-00105.parquet"
98 | "train-00002-of-00105.parquet"
99 | "train-00007-of-00105.parquet"
100 | "train-00569-of-00571.parquet"
101 | "train-00052-of-00105.parquet"
102 | "train-00102-of-00105.parquet"
103 | "train-00104-of-00105.parquet"
104 | "train-00072-of-00105.parquet"
105 | "train-00051-of-00105.parquet"
106 | "train-00101-of-00105.parquet"
107 | "train-00570-of-00571.parquet"
108 | "train-00095-of-00105.parquet"
109 | "train-00092-of-00105.parquet"
110 | "train-00030-of-00105.parquet"
111 | "train-00055-of-00105.parquet"
112 | "train-00042-of-00105.parquet"
113 | "train-00018-of-00105.parquet"
114 | "train-00096-of-00105.parquet"
115 | "train-00005-of-00105.parquet"
116 | "train-00022-of-00105.parquet"
117 | "train-00086-of-00105.parquet"
118 | "train-00024-of-00105.parquet"
119 | "train-00077-of-00105.parquet"
120 | "train-00075-of-00105.parquet"
121 | "train-00039-of-00105.parquet"
122 | "train-00021-of-00105.parquet"
123 | "train-00130-of-00571.parquet"
124 | "train-00026-of-00105.parquet"
125 | "train-00000-of-00571.parquet"
126 | )
127 |
128 | TOTAL=${#FILES[@]}
129 | CURRENT=0
130 |
131 | for file in "${FILES[@]}"; do
132 | CURRENT=$((CURRENT+1))
133 | echo "[$CURRENT/$TOTAL] $file"
134 |
135 | if [[ $file == *"-of-00105.parquet" ]]; then
136 | BASE_URL=$BASE_URL_105
137 | else
138 | BASE_URL=$BASE_URL_571
139 | fi
140 |
141 | wget -c "$BASE_URL/$file" -O "$file" || {
142 | echo "Download $file failed, trying to continue..."
143 | }
144 |
145 | if [ -f "$file" ]; then
146 | filesize=$(du -h "$file" | cut -f1)
147 | echo "Downloaded : $file ($filesize)"
148 | else
149 | echo "Warning: $file failed!"
150 | fi
151 |
152 | echo "------------------------------------"
153 | done
154 |
155 | echo "All files downloaded!"
--------------------------------------------------------------------------------
/train/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers==0.32.0
2 | datasets==3.6.0
3 | transformers
4 | peft
5 | opencv-python
6 | protobuf
7 | sentencepiece
8 | gradio
9 | jupyter
10 | torchao
11 | lightning
12 | torchvision
13 | prodigyopt
14 | wandb
--------------------------------------------------------------------------------
/train/runs/20250513-085800/config.yaml:
--------------------------------------------------------------------------------
1 | dtype: bfloat16
2 | flux_path: black-forest-labs/flux.1-fill-dev
3 | model:
4 | add_cond_attn: false
5 | latent_lora: false
6 | union_cond_attn: true
7 | use_sep: false
8 | train:
9 | accumulate_grad_batches: 1
10 | batch_size: 2
11 | condition_type: edit
12 | dataloader_workers: 5
13 | dataset:
14 | condition_size: 512
15 | drop_image_prob: 0.1
16 | drop_text_prob: 0.1
17 | image_size: 512
18 | padding: 8
19 | path: parquet/*.parquet
20 | target_size: 512
21 | type: edit_with_omini
22 | gradient_checkpointing: false
23 | lora_config:
24 | init_lora_weights: gaussian
25 | lora_alpha: 32
26 | r: 32
27 | target_modules: (.*x_embedder|.*(? torch.FloatTensor:
17 | batch_size, _, _ = (
18 | hidden_states.shape
19 | if encoder_hidden_states is None
20 | else encoder_hidden_states.shape
21 | )
22 |
23 | with enable_lora(
24 | (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
25 | ):
26 | # `sample` projections.
27 | query = attn.to_q(hidden_states)
28 | key = attn.to_k(hidden_states)
29 | value = attn.to_v(hidden_states)
30 |
31 | inner_dim = key.shape[-1]
32 | head_dim = inner_dim // attn.heads
33 |
34 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
35 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
36 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
37 |
38 | if attn.norm_q is not None:
39 | query = attn.norm_q(query)
40 | if attn.norm_k is not None:
41 | key = attn.norm_k(key)
42 |
43 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
44 | if encoder_hidden_states is not None:
45 | # `context` projections.
46 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
47 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
48 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
49 |
50 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
51 | batch_size, -1, attn.heads, head_dim
52 | ).transpose(1, 2)
53 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
54 | batch_size, -1, attn.heads, head_dim
55 | ).transpose(1, 2)
56 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
57 | batch_size, -1, attn.heads, head_dim
58 | ).transpose(1, 2)
59 |
60 | if attn.norm_added_q is not None:
61 | encoder_hidden_states_query_proj = attn.norm_added_q(
62 | encoder_hidden_states_query_proj
63 | )
64 | if attn.norm_added_k is not None:
65 | encoder_hidden_states_key_proj = attn.norm_added_k(
66 | encoder_hidden_states_key_proj
67 | )
68 |
69 | # attention
70 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
71 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
72 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
73 |
74 | if image_rotary_emb is not None:
75 | from diffusers.models.embeddings import apply_rotary_emb
76 |
77 | query = apply_rotary_emb(query, image_rotary_emb)
78 | key = apply_rotary_emb(key, image_rotary_emb)
79 |
80 | if condition_latents is not None:
81 | cond_query = attn.to_q(condition_latents)
82 | cond_key = attn.to_k(condition_latents)
83 | cond_value = attn.to_v(condition_latents)
84 |
85 | cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
86 | 1, 2
87 | )
88 | cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
89 | cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
90 | 1, 2
91 | )
92 | if attn.norm_q is not None:
93 | cond_query = attn.norm_q(cond_query)
94 | if attn.norm_k is not None:
95 | cond_key = attn.norm_k(cond_key)
96 |
97 | if cond_rotary_emb is not None:
98 | cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
99 | cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
100 |
101 | if condition_latents is not None:
102 | query = torch.cat([query, cond_query], dim=2)
103 | key = torch.cat([key, cond_key], dim=2)
104 | value = torch.cat([value, cond_value], dim=2)
105 |
106 | if not model_config.get("union_cond_attn", True):
107 | # If we don't want to use the union condition attention, we need to mask the attention
108 | # between the hidden states and the condition latents
109 | attention_mask = torch.ones(
110 | query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
111 | )
112 | condition_n = cond_query.shape[2]
113 | attention_mask[-condition_n:, :-condition_n] = False
114 | attention_mask[:-condition_n, -condition_n:] = False
115 | if hasattr(attn, "c_factor"):
116 | attention_mask = torch.zeros(
117 | query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
118 | )
119 | condition_n = cond_query.shape[2]
120 | bias = torch.log(attn.c_factor[0])
121 | attention_mask[-condition_n:, :-condition_n] = bias
122 | attention_mask[:-condition_n, -condition_n:] = bias
123 | hidden_states = F.scaled_dot_product_attention(
124 | query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
125 | )
126 | hidden_states = hidden_states.transpose(1, 2).reshape(
127 | batch_size, -1, attn.heads * head_dim
128 | )
129 | hidden_states = hidden_states.to(query.dtype)
130 |
131 | if encoder_hidden_states is not None:
132 | if condition_latents is not None:
133 | encoder_hidden_states, hidden_states, condition_latents = (
134 | hidden_states[:, : encoder_hidden_states.shape[1]],
135 | hidden_states[
136 | :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
137 | ],
138 | hidden_states[:, -condition_latents.shape[1] :],
139 | )
140 | else:
141 | encoder_hidden_states, hidden_states = (
142 | hidden_states[:, : encoder_hidden_states.shape[1]],
143 | hidden_states[:, encoder_hidden_states.shape[1] :],
144 | )
145 |
146 | with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
147 | # linear proj
148 | hidden_states = attn.to_out[0](hidden_states)
149 | # dropout
150 | hidden_states = attn.to_out[1](hidden_states)
151 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
152 |
153 | if condition_latents is not None:
154 | condition_latents = attn.to_out[0](condition_latents)
155 | condition_latents = attn.to_out[1](condition_latents)
156 |
157 | return (
158 | (hidden_states, encoder_hidden_states, condition_latents)
159 | if condition_latents is not None
160 | else (hidden_states, encoder_hidden_states)
161 | )
162 | elif condition_latents is not None:
163 | # if there are condition_latents, we need to separate the hidden_states and the condition_latents
164 | hidden_states, condition_latents = (
165 | hidden_states[:, : -condition_latents.shape[1]],
166 | hidden_states[:, -condition_latents.shape[1] :],
167 | )
168 | return hidden_states, condition_latents
169 | else:
170 | return hidden_states
171 |
172 |
173 | def block_forward(
174 | self,
175 | hidden_states: torch.FloatTensor,
176 | encoder_hidden_states: torch.FloatTensor,
177 | condition_latents: torch.FloatTensor,
178 | temb: torch.FloatTensor,
179 | cond_temb: torch.FloatTensor,
180 | cond_rotary_emb=None,
181 | image_rotary_emb=None,
182 | model_config: Optional[Dict[str, Any]] = {},
183 | ):
184 | use_cond = condition_latents is not None
185 | with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
186 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
187 | hidden_states, emb=temb
188 | )
189 |
190 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
191 | self.norm1_context(encoder_hidden_states, emb=temb)
192 | )
193 |
194 | if use_cond:
195 | (
196 | norm_condition_latents,
197 | cond_gate_msa,
198 | cond_shift_mlp,
199 | cond_scale_mlp,
200 | cond_gate_mlp,
201 | ) = self.norm1(condition_latents, emb=cond_temb)
202 |
203 | # Attention.
204 | result = attn_forward(
205 | self.attn,
206 | model_config=model_config,
207 | hidden_states=norm_hidden_states,
208 | encoder_hidden_states=norm_encoder_hidden_states,
209 | condition_latents=norm_condition_latents if use_cond else None,
210 | image_rotary_emb=image_rotary_emb,
211 | cond_rotary_emb=cond_rotary_emb if use_cond else None,
212 | )
213 | attn_output, context_attn_output = result[:2]
214 | cond_attn_output = result[2] if use_cond else None
215 |
216 | # Process attention outputs for the `hidden_states`.
217 | # 1. hidden_states
218 | attn_output = gate_msa.unsqueeze(1) * attn_output
219 | hidden_states = hidden_states + attn_output
220 | # 2. encoder_hidden_states
221 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
222 | encoder_hidden_states = encoder_hidden_states + context_attn_output
223 | # 3. condition_latents
224 | if use_cond:
225 | cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
226 | condition_latents = condition_latents + cond_attn_output
227 | if model_config.get("add_cond_attn", False):
228 | hidden_states += cond_attn_output
229 |
230 | # LayerNorm + MLP.
231 | # 1. hidden_states
232 | norm_hidden_states = self.norm2(hidden_states)
233 | norm_hidden_states = (
234 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
235 | )
236 | # 2. encoder_hidden_states
237 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
238 | norm_encoder_hidden_states = (
239 | norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
240 | )
241 | # 3. condition_latents
242 | if use_cond:
243 | norm_condition_latents = self.norm2(condition_latents)
244 | norm_condition_latents = (
245 | norm_condition_latents * (1 + cond_scale_mlp[:, None])
246 | + cond_shift_mlp[:, None]
247 | )
248 |
249 | # Feed-forward.
250 | with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
251 | # 1. hidden_states
252 | ff_output = self.ff(norm_hidden_states)
253 | ff_output = gate_mlp.unsqueeze(1) * ff_output
254 | # 2. encoder_hidden_states
255 | context_ff_output = self.ff_context(norm_encoder_hidden_states)
256 | context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
257 | # 3. condition_latents
258 | if use_cond:
259 | cond_ff_output = self.ff(norm_condition_latents)
260 | cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
261 |
262 | # Process feed-forward outputs.
263 | hidden_states = hidden_states + ff_output
264 | encoder_hidden_states = encoder_hidden_states + context_ff_output
265 | if use_cond:
266 | condition_latents = condition_latents + cond_ff_output
267 |
268 | # Clip to avoid overflow.
269 | if encoder_hidden_states.dtype == torch.float16:
270 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
271 |
272 | return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
273 |
274 |
275 | def single_block_forward(
276 | self,
277 | hidden_states: torch.FloatTensor,
278 | temb: torch.FloatTensor,
279 | image_rotary_emb=None,
280 | condition_latents: torch.FloatTensor = None,
281 | cond_temb: torch.FloatTensor = None,
282 | cond_rotary_emb=None,
283 | model_config: Optional[Dict[str, Any]] = {},
284 | ):
285 |
286 | using_cond = condition_latents is not None
287 | residual = hidden_states
288 | with enable_lora(
289 | (
290 | self.norm.linear,
291 | self.proj_mlp,
292 | ),
293 | model_config.get("latent_lora", False),
294 | ):
295 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
296 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
297 | if using_cond:
298 | residual_cond = condition_latents
299 | norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
300 | mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
301 |
302 | attn_output = attn_forward(
303 | self.attn,
304 | model_config=model_config,
305 | hidden_states=norm_hidden_states,
306 | image_rotary_emb=image_rotary_emb,
307 | **(
308 | {
309 | "condition_latents": norm_condition_latents,
310 | "cond_rotary_emb": cond_rotary_emb if using_cond else None,
311 | }
312 | if using_cond
313 | else {}
314 | ),
315 | )
316 | if using_cond:
317 | attn_output, cond_attn_output = attn_output
318 |
319 | with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
320 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
321 | gate = gate.unsqueeze(1)
322 | hidden_states = gate * self.proj_out(hidden_states)
323 | hidden_states = residual + hidden_states
324 | if using_cond:
325 | condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
326 | cond_gate = cond_gate.unsqueeze(1)
327 | condition_latents = cond_gate * self.proj_out(condition_latents)
328 | condition_latents = residual_cond + condition_latents
329 |
330 | if hidden_states.dtype == torch.float16:
331 | hidden_states = hidden_states.clip(-65504, 65504)
332 |
333 | return hidden_states if not using_cond else (hidden_states, condition_latents)
334 |
--------------------------------------------------------------------------------
/train/src/flux/condition.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional, Union, List, Tuple
3 | from diffusers.pipelines import FluxPipeline
4 | from PIL import Image, ImageFilter
5 | import numpy as np
6 | import cv2
7 |
8 | from .pipeline_tools import encode_images
9 |
10 | condition_dict = {
11 | "depth": 0,
12 | "canny": 1,
13 | "subject": 4,
14 | "coloring": 6,
15 | "deblurring": 7,
16 | "depth_pred": 8,
17 | "fill": 9,
18 | "sr": 10,
19 | }
20 |
21 |
22 | class Condition(object):
23 | def __init__(
24 | self,
25 | condition_type: str,
26 | raw_img: Union[Image.Image, torch.Tensor] = None,
27 | condition: Union[Image.Image, torch.Tensor] = None,
28 | mask=None,
29 | position_delta=None,
30 | ) -> None:
31 | self.condition_type = condition_type
32 | assert raw_img is not None or condition is not None
33 | if raw_img is not None:
34 | self.condition = self.get_condition(condition_type, raw_img)
35 | else:
36 | self.condition = condition
37 | self.position_delta = position_delta
38 | # TODO: Add mask support
39 | assert mask is None, "Mask not supported yet"
40 |
41 | def get_condition(
42 | self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
43 | ) -> Union[Image.Image, torch.Tensor]:
44 | """
45 | Returns the condition image.
46 | """
47 | if condition_type == "depth":
48 | from transformers import pipeline
49 |
50 | depth_pipe = pipeline(
51 | task="depth-estimation",
52 | model="LiheYoung/depth-anything-small-hf",
53 | device="cuda",
54 | )
55 | source_image = raw_img.convert("RGB")
56 | condition_img = depth_pipe(source_image)["depth"].convert("RGB")
57 | return condition_img
58 | elif condition_type == "canny":
59 | img = np.array(raw_img)
60 | edges = cv2.Canny(img, 100, 200)
61 | edges = Image.fromarray(edges).convert("RGB")
62 | return edges
63 | elif condition_type == "subject":
64 | return raw_img
65 | elif condition_type == "coloring":
66 | return raw_img.convert("L").convert("RGB")
67 | elif condition_type == "deblurring":
68 | condition_image = (
69 | raw_img.convert("RGB")
70 | .filter(ImageFilter.GaussianBlur(10))
71 | .convert("RGB")
72 | )
73 | return condition_image
74 | elif condition_type == "fill":
75 | return raw_img.convert("RGB")
76 | return self.condition
77 |
78 | @property
79 | def type_id(self) -> int:
80 | """
81 | Returns the type id of the condition.
82 | """
83 | return condition_dict[self.condition_type]
84 |
85 | @classmethod
86 | def get_type_id(cls, condition_type: str) -> int:
87 | """
88 | Returns the type id of the condition.
89 | """
90 | return condition_dict[condition_type]
91 |
92 | def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]:
93 | """
94 | Encodes the condition into tokens, ids and type_id.
95 | """
96 | if self.condition_type in [
97 | "depth",
98 | "canny",
99 | "subject",
100 | "coloring",
101 | "deblurring",
102 | "depth_pred",
103 | "fill",
104 | "sr",
105 | ]:
106 | tokens, ids = encode_images(pipe, self.condition)
107 | else:
108 | raise NotImplementedError(
109 | f"Condition type {self.condition_type} not implemented"
110 | )
111 | if self.position_delta is None and self.condition_type == "subject":
112 | self.position_delta = [0, -self.condition.size[0] // 16]
113 | if self.position_delta is not None:
114 | ids[:, 1] += self.position_delta[0]
115 | ids[:, 2] += self.position_delta[1]
116 | type_id = torch.ones_like(ids[:, :1]) * self.type_id
117 | return tokens, ids, type_id
118 |
--------------------------------------------------------------------------------
/train/src/flux/generate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import yaml, os
3 | from diffusers.pipelines import FluxPipeline
4 | from typing import List, Union, Optional, Dict, Any, Callable
5 | from .transformer import tranformer_forward
6 | from .condition import Condition
7 |
8 | from diffusers.pipelines.flux.pipeline_flux import (
9 | FluxPipelineOutput,
10 | calculate_shift,
11 | retrieve_timesteps,
12 | np,
13 | )
14 |
15 |
16 | def get_config(config_path: str = None):
17 | config_path = config_path or os.environ.get("XFL_CONFIG")
18 | if not config_path:
19 | return {}
20 | with open(config_path, "r") as f:
21 | config = yaml.safe_load(f)
22 | return config
23 |
24 |
25 | def prepare_params(
26 | prompt: Union[str, List[str]] = None,
27 | prompt_2: Optional[Union[str, List[str]]] = None,
28 | height: Optional[int] = 512,
29 | width: Optional[int] = 512,
30 | num_inference_steps: int = 28,
31 | timesteps: List[int] = None,
32 | guidance_scale: float = 3.5,
33 | num_images_per_prompt: Optional[int] = 1,
34 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
35 | latents: Optional[torch.FloatTensor] = None,
36 | prompt_embeds: Optional[torch.FloatTensor] = None,
37 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
38 | output_type: Optional[str] = "pil",
39 | return_dict: bool = True,
40 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
41 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
42 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
43 | max_sequence_length: int = 512,
44 | **kwargs: dict,
45 | ):
46 | return (
47 | prompt,
48 | prompt_2,
49 | height,
50 | width,
51 | num_inference_steps,
52 | timesteps,
53 | guidance_scale,
54 | num_images_per_prompt,
55 | generator,
56 | latents,
57 | prompt_embeds,
58 | pooled_prompt_embeds,
59 | output_type,
60 | return_dict,
61 | joint_attention_kwargs,
62 | callback_on_step_end,
63 | callback_on_step_end_tensor_inputs,
64 | max_sequence_length,
65 | )
66 |
67 |
68 | def seed_everything(seed: int = 42):
69 | torch.backends.cudnn.deterministic = True
70 | torch.manual_seed(seed)
71 | np.random.seed(seed)
72 |
73 |
74 | @torch.no_grad()
75 | def generate(
76 | pipeline: FluxPipeline,
77 | conditions: List[Condition] = None,
78 | config_path: str = None,
79 | model_config: Optional[Dict[str, Any]] = {},
80 | condition_scale: float = 1.0,
81 | default_lora: bool = False,
82 | **params: dict,
83 | ):
84 | model_config = model_config or get_config(config_path).get("model", {})
85 | if condition_scale != 1:
86 | for name, module in pipeline.transformer.named_modules():
87 | if not name.endswith(".attn"):
88 | continue
89 | module.c_factor = torch.ones(1, 1) * condition_scale
90 |
91 | self = pipeline
92 | (
93 | prompt,
94 | prompt_2,
95 | height,
96 | width,
97 | num_inference_steps,
98 | timesteps,
99 | guidance_scale,
100 | num_images_per_prompt,
101 | generator,
102 | latents,
103 | prompt_embeds,
104 | pooled_prompt_embeds,
105 | output_type,
106 | return_dict,
107 | joint_attention_kwargs,
108 | callback_on_step_end,
109 | callback_on_step_end_tensor_inputs,
110 | max_sequence_length,
111 | ) = prepare_params(**params)
112 |
113 | height = height or self.default_sample_size * self.vae_scale_factor
114 | width = width or self.default_sample_size * self.vae_scale_factor
115 |
116 | # 1. Check inputs. Raise error if not correct
117 | self.check_inputs(
118 | prompt,
119 | prompt_2,
120 | height,
121 | width,
122 | prompt_embeds=prompt_embeds,
123 | pooled_prompt_embeds=pooled_prompt_embeds,
124 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
125 | max_sequence_length=max_sequence_length,
126 | )
127 |
128 | self._guidance_scale = guidance_scale
129 | self._joint_attention_kwargs = joint_attention_kwargs
130 | self._interrupt = False
131 |
132 | # 2. Define call parameters
133 | if prompt is not None and isinstance(prompt, str):
134 | batch_size = 1
135 | elif prompt is not None and isinstance(prompt, list):
136 | batch_size = len(prompt)
137 | else:
138 | batch_size = prompt_embeds.shape[0]
139 |
140 | device = self._execution_device
141 |
142 | lora_scale = (
143 | self.joint_attention_kwargs.get("scale", None)
144 | if self.joint_attention_kwargs is not None
145 | else None
146 | )
147 | (
148 | prompt_embeds,
149 | pooled_prompt_embeds,
150 | text_ids,
151 | ) = self.encode_prompt(
152 | prompt=prompt,
153 | prompt_2=prompt_2,
154 | prompt_embeds=prompt_embeds,
155 | pooled_prompt_embeds=pooled_prompt_embeds,
156 | device=device,
157 | num_images_per_prompt=num_images_per_prompt,
158 | max_sequence_length=max_sequence_length,
159 | lora_scale=lora_scale,
160 | )
161 |
162 | # 4. Prepare latent variables
163 | num_channels_latents = self.transformer.config.in_channels // 4
164 | latents, latent_image_ids = self.prepare_latents(
165 | batch_size * num_images_per_prompt,
166 | num_channels_latents,
167 | height,
168 | width,
169 | prompt_embeds.dtype,
170 | device,
171 | generator,
172 | latents,
173 | )
174 |
175 | # 4.1. Prepare conditions
176 | condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
177 | use_condition = conditions is not None or []
178 | if use_condition:
179 | assert len(conditions) <= 1, "Only one condition is supported for now."
180 | if not default_lora:
181 | pipeline.set_adapters(conditions[0].condition_type)
182 | for condition in conditions:
183 | tokens, ids, type_id = condition.encode(self)
184 | condition_latents.append(tokens) # [batch_size, token_n, token_dim]
185 | condition_ids.append(ids) # [token_n, id_dim(3)]
186 | condition_type_ids.append(type_id) # [token_n, 1]
187 | condition_latents = torch.cat(condition_latents, dim=1)
188 | condition_ids = torch.cat(condition_ids, dim=0)
189 | condition_type_ids = torch.cat(condition_type_ids, dim=0)
190 |
191 | # 5. Prepare timesteps
192 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
193 | image_seq_len = latents.shape[1]
194 | mu = calculate_shift(
195 | image_seq_len,
196 | self.scheduler.config.base_image_seq_len,
197 | self.scheduler.config.max_image_seq_len,
198 | self.scheduler.config.base_shift,
199 | self.scheduler.config.max_shift,
200 | )
201 | timesteps, num_inference_steps = retrieve_timesteps(
202 | self.scheduler,
203 | num_inference_steps,
204 | device,
205 | timesteps,
206 | sigmas,
207 | mu=mu,
208 | )
209 | num_warmup_steps = max(
210 | len(timesteps) - num_inference_steps * self.scheduler.order, 0
211 | )
212 | self._num_timesteps = len(timesteps)
213 |
214 | # 6. Denoising loop
215 | with self.progress_bar(total=num_inference_steps) as progress_bar:
216 | for i, t in enumerate(timesteps):
217 | if self.interrupt:
218 | continue
219 |
220 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
221 | timestep = t.expand(latents.shape[0]).to(latents.dtype)
222 |
223 | # handle guidance
224 | if self.transformer.config.guidance_embeds:
225 | guidance = torch.tensor([guidance_scale], device=device)
226 | guidance = guidance.expand(latents.shape[0])
227 | else:
228 | guidance = None
229 | noise_pred = tranformer_forward(
230 | self.transformer,
231 | model_config=model_config,
232 | # Inputs of the condition (new feature)
233 | condition_latents=condition_latents if use_condition else None,
234 | condition_ids=condition_ids if use_condition else None,
235 | condition_type_ids=condition_type_ids if use_condition else None,
236 | # Inputs to the original transformer
237 | hidden_states=latents,
238 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
239 | timestep=timestep / 1000,
240 | guidance=guidance,
241 | pooled_projections=pooled_prompt_embeds,
242 | encoder_hidden_states=prompt_embeds,
243 | txt_ids=text_ids,
244 | img_ids=latent_image_ids,
245 | joint_attention_kwargs=self.joint_attention_kwargs,
246 | return_dict=False,
247 | )[0]
248 |
249 | # compute the previous noisy sample x_t -> x_t-1
250 | latents_dtype = latents.dtype
251 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
252 |
253 | if latents.dtype != latents_dtype:
254 | if torch.backends.mps.is_available():
255 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
256 | latents = latents.to(latents_dtype)
257 |
258 | if callback_on_step_end is not None:
259 | callback_kwargs = {}
260 | for k in callback_on_step_end_tensor_inputs:
261 | callback_kwargs[k] = locals()[k]
262 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
263 |
264 | latents = callback_outputs.pop("latents", latents)
265 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
266 |
267 | # call the callback, if provided
268 | if i == len(timesteps) - 1 or (
269 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
270 | ):
271 | progress_bar.update()
272 |
273 | if output_type == "latent":
274 | image = latents
275 |
276 | else:
277 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
278 | latents = (
279 | latents / self.vae.config.scaling_factor
280 | ) + self.vae.config.shift_factor
281 | image = self.vae.decode(latents, return_dict=False)[0]
282 | image = self.image_processor.postprocess(image, output_type=output_type)
283 |
284 | # Offload all models
285 | self.maybe_free_model_hooks()
286 |
287 | if condition_scale != 1:
288 | for name, module in pipeline.transformer.named_modules():
289 | if not name.endswith(".attn"):
290 | continue
291 | del module.c_factor
292 |
293 | if not return_dict:
294 | return (image,)
295 |
296 | return FluxPipelineOutput(images=image)
297 |
--------------------------------------------------------------------------------
/train/src/flux/lora_controller.py:
--------------------------------------------------------------------------------
1 | from peft.tuners.tuners_utils import BaseTunerLayer
2 | from typing import List, Any, Optional, Type
3 |
4 |
5 | class enable_lora:
6 | def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
7 | self.activated: bool = activated
8 | if activated:
9 | return
10 | self.lora_modules: List[BaseTunerLayer] = [
11 | each for each in lora_modules if isinstance(each, BaseTunerLayer)
12 | ]
13 | self.scales = [
14 | {
15 | active_adapter: lora_module.scaling[active_adapter]
16 | for active_adapter in lora_module.active_adapters
17 | }
18 | for lora_module in self.lora_modules
19 | ]
20 |
21 | def __enter__(self) -> None:
22 | if self.activated:
23 | return
24 |
25 | for lora_module in self.lora_modules:
26 | if not isinstance(lora_module, BaseTunerLayer):
27 | continue
28 | lora_module.scale_layer(0)
29 |
30 | def __exit__(
31 | self,
32 | exc_type: Optional[Type[BaseException]],
33 | exc_val: Optional[BaseException],
34 | exc_tb: Optional[Any],
35 | ) -> None:
36 | if self.activated:
37 | return
38 | for i, lora_module in enumerate(self.lora_modules):
39 | if not isinstance(lora_module, BaseTunerLayer):
40 | continue
41 | for active_adapter in lora_module.active_adapters:
42 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
43 |
44 |
45 | class set_lora_scale:
46 | def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
47 | self.lora_modules: List[BaseTunerLayer] = [
48 | each for each in lora_modules if isinstance(each, BaseTunerLayer)
49 | ]
50 | self.scales = [
51 | {
52 | active_adapter: lora_module.scaling[active_adapter]
53 | for active_adapter in lora_module.active_adapters
54 | }
55 | for lora_module in self.lora_modules
56 | ]
57 | self.scale = scale
58 |
59 | def __enter__(self) -> None:
60 | for lora_module in self.lora_modules:
61 | if not isinstance(lora_module, BaseTunerLayer):
62 | continue
63 | lora_module.scale_layer(self.scale)
64 |
65 | def __exit__(
66 | self,
67 | exc_type: Optional[Type[BaseException]],
68 | exc_val: Optional[BaseException],
69 | exc_tb: Optional[Any],
70 | ) -> None:
71 | for i, lora_module in enumerate(self.lora_modules):
72 | if not isinstance(lora_module, BaseTunerLayer):
73 | continue
74 | for active_adapter in lora_module.active_adapters:
75 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
76 |
--------------------------------------------------------------------------------
/train/src/flux/pipeline_tools.py:
--------------------------------------------------------------------------------
1 | from diffusers.pipelines import FluxPipeline, FluxFillPipeline
2 | from diffusers.utils import logging
3 | from diffusers.pipelines.flux.pipeline_flux import logger
4 | from torch import Tensor
5 | import torch
6 |
7 |
8 | def encode_images(pipeline: FluxPipeline, images: Tensor):
9 | images = pipeline.image_processor.preprocess(images)
10 | images = images.to(pipeline.device).to(pipeline.dtype)
11 | images = pipeline.vae.encode(images).latent_dist.sample()
12 | images = (
13 | images - pipeline.vae.config.shift_factor
14 | ) * pipeline.vae.config.scaling_factor
15 | images_tokens = pipeline._pack_latents(images, *images.shape)
16 | images_ids = pipeline._prepare_latent_image_ids(
17 | images.shape[0],
18 | images.shape[2],
19 | images.shape[3],
20 | pipeline.device,
21 | pipeline.dtype,
22 | )
23 | if images_tokens.shape[1] != images_ids.shape[0]:
24 | images_ids = pipeline._prepare_latent_image_ids(
25 | images.shape[0],
26 | images.shape[2] // 2,
27 | images.shape[3] // 2,
28 | pipeline.device,
29 | pipeline.dtype,
30 | )
31 | return images_tokens, images_ids
32 |
33 |
34 |
35 | def encode_images_fill(pipeline: FluxFillPipeline, image: Tensor, mask_image: Tensor, dtype: torch.dtype, device: str):
36 | images_tokens, images_ids = encode_images(pipeline, image.clone().detach())
37 | height, width = image.shape[-2:]
38 | # print(f"height: {height}, width: {width}")
39 | image = pipeline.image_processor.preprocess(image, height=height, width=width)
40 | mask_image = pipeline.mask_processor.preprocess(mask_image, height=height, width=width)
41 |
42 | masked_image = image * (1 - mask_image)
43 | masked_image = masked_image.to(device=device, dtype=dtype)
44 |
45 | num_channels_latents = pipeline.vae.config.latent_channels
46 | height, width = image.shape[-2:]
47 | device = pipeline._execution_device
48 | mask, masked_image_latents = pipeline.prepare_mask_latents(
49 | mask_image,
50 | masked_image,
51 | image.shape[0],
52 | num_channels_latents,
53 | 1,
54 | height,
55 | width,
56 | dtype,
57 | device,
58 | None,
59 | )
60 | masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
61 | return images_tokens, masked_image_latents, images_ids
62 |
63 |
64 | def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
65 | # Turn off warnings (CLIP overflow)
66 | logger.setLevel(logging.ERROR)
67 | (
68 | prompt_embeds,
69 | pooled_prompt_embeds,
70 | text_ids,
71 | ) = pipeline.encode_prompt(
72 | prompt=prompts,
73 | prompt_2=None,
74 | prompt_embeds=None,
75 | pooled_prompt_embeds=None,
76 | device=pipeline.device,
77 | num_images_per_prompt=1,
78 | max_sequence_length=max_sequence_length,
79 | lora_scale=None,
80 | )
81 | # Turn on warnings
82 | logger.setLevel(logging.WARNING)
83 | return prompt_embeds, pooled_prompt_embeds, text_ids
84 |
--------------------------------------------------------------------------------
/train/src/flux/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers.pipelines import FluxPipeline
3 | from typing import List, Union, Optional, Dict, Any, Callable
4 | from .block import block_forward, single_block_forward
5 | from .lora_controller import enable_lora
6 | from diffusers.models.transformers.transformer_flux import (
7 | FluxTransformer2DModel,
8 | Transformer2DModelOutput,
9 | USE_PEFT_BACKEND,
10 | is_torch_version,
11 | scale_lora_layers,
12 | unscale_lora_layers,
13 | logger,
14 | )
15 | import numpy as np
16 |
17 |
18 | def prepare_params(
19 | hidden_states: torch.Tensor,
20 | encoder_hidden_states: torch.Tensor = None,
21 | pooled_projections: torch.Tensor = None,
22 | timestep: torch.LongTensor = None,
23 | img_ids: torch.Tensor = None,
24 | txt_ids: torch.Tensor = None,
25 | guidance: torch.Tensor = None,
26 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
27 | controlnet_block_samples=None,
28 | controlnet_single_block_samples=None,
29 | return_dict: bool = True,
30 | **kwargs: dict,
31 | ):
32 | return (
33 | hidden_states,
34 | encoder_hidden_states,
35 | pooled_projections,
36 | timestep,
37 | img_ids,
38 | txt_ids,
39 | guidance,
40 | joint_attention_kwargs,
41 | controlnet_block_samples,
42 | controlnet_single_block_samples,
43 | return_dict,
44 | )
45 |
46 |
47 | def tranformer_forward(
48 | transformer: FluxTransformer2DModel,
49 | condition_latents: torch.Tensor,
50 | condition_ids: torch.Tensor,
51 | condition_type_ids: torch.Tensor,
52 | model_config: Optional[Dict[str, Any]] = {},
53 | c_t=0,
54 | **params: dict,
55 | ):
56 | self = transformer
57 | use_condition = condition_latents is not None
58 |
59 | (
60 | hidden_states,
61 | encoder_hidden_states,
62 | pooled_projections,
63 | timestep,
64 | img_ids,
65 | txt_ids,
66 | guidance,
67 | joint_attention_kwargs,
68 | controlnet_block_samples,
69 | controlnet_single_block_samples,
70 | return_dict,
71 | ) = prepare_params(**params)
72 |
73 | if joint_attention_kwargs is not None:
74 | joint_attention_kwargs = joint_attention_kwargs.copy()
75 | lora_scale = joint_attention_kwargs.pop("scale", 1.0)
76 | else:
77 | lora_scale = 1.0
78 |
79 | if USE_PEFT_BACKEND:
80 | # weight the lora layers by setting `lora_scale` for each PEFT layer
81 | scale_lora_layers(self, lora_scale)
82 | else:
83 | if (
84 | joint_attention_kwargs is not None
85 | and joint_attention_kwargs.get("scale", None) is not None
86 | ):
87 | logger.warning(
88 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
89 | )
90 |
91 | with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
92 | hidden_states = self.x_embedder(hidden_states)
93 | condition_latents = self.x_embedder(condition_latents) if use_condition else None
94 |
95 | timestep = timestep.to(hidden_states.dtype) * 1000
96 |
97 | if guidance is not None:
98 | guidance = guidance.to(hidden_states.dtype) * 1000
99 | else:
100 | guidance = None
101 |
102 | temb = (
103 | self.time_text_embed(timestep, pooled_projections)
104 | if guidance is None
105 | else self.time_text_embed(timestep, guidance, pooled_projections)
106 | )
107 |
108 | cond_temb = (
109 | self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
110 | if guidance is None
111 | else self.time_text_embed(
112 | torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections
113 | )
114 | )
115 | encoder_hidden_states = self.context_embedder(encoder_hidden_states)
116 |
117 | if txt_ids.ndim == 3:
118 | logger.warning(
119 | "Passing `txt_ids` 3d torch.Tensor is deprecated."
120 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
121 | )
122 | txt_ids = txt_ids[0]
123 | if img_ids.ndim == 3:
124 | logger.warning(
125 | "Passing `img_ids` 3d torch.Tensor is deprecated."
126 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
127 | )
128 | img_ids = img_ids[0]
129 |
130 | ids = torch.cat((txt_ids, img_ids), dim=0)
131 | image_rotary_emb = self.pos_embed(ids)
132 | if use_condition:
133 | # condition_ids[:, :1] = condition_type_ids
134 | cond_rotary_emb = self.pos_embed(condition_ids)
135 |
136 | # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
137 |
138 | for index_block, block in enumerate(self.transformer_blocks):
139 | if self.training and self.gradient_checkpointing:
140 | ckpt_kwargs: Dict[str, Any] = (
141 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
142 | )
143 | encoder_hidden_states, hidden_states, condition_latents = (
144 | torch.utils.checkpoint.checkpoint(
145 | block_forward,
146 | self=block,
147 | model_config=model_config,
148 | hidden_states=hidden_states,
149 | encoder_hidden_states=encoder_hidden_states,
150 | condition_latents=condition_latents if use_condition else None,
151 | temb=temb,
152 | cond_temb=cond_temb if use_condition else None,
153 | cond_rotary_emb=cond_rotary_emb if use_condition else None,
154 | image_rotary_emb=image_rotary_emb,
155 | **ckpt_kwargs,
156 | )
157 | )
158 |
159 | else:
160 | encoder_hidden_states, hidden_states, condition_latents = block_forward(
161 | block,
162 | model_config=model_config,
163 | hidden_states=hidden_states,
164 | encoder_hidden_states=encoder_hidden_states,
165 | condition_latents=condition_latents if use_condition else None,
166 | temb=temb,
167 | cond_temb=cond_temb if use_condition else None,
168 | cond_rotary_emb=cond_rotary_emb if use_condition else None,
169 | image_rotary_emb=image_rotary_emb,
170 | )
171 |
172 | # controlnet residual
173 | if controlnet_block_samples is not None:
174 | interval_control = len(self.transformer_blocks) / len(
175 | controlnet_block_samples
176 | )
177 | interval_control = int(np.ceil(interval_control))
178 | hidden_states = (
179 | hidden_states
180 | + controlnet_block_samples[index_block // interval_control]
181 | )
182 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
183 |
184 | for index_block, block in enumerate(self.single_transformer_blocks):
185 | if self.training and self.gradient_checkpointing:
186 | ckpt_kwargs: Dict[str, Any] = (
187 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
188 | )
189 | result = torch.utils.checkpoint.checkpoint(
190 | single_block_forward,
191 | self=block,
192 | model_config=model_config,
193 | hidden_states=hidden_states,
194 | temb=temb,
195 | image_rotary_emb=image_rotary_emb,
196 | **(
197 | {
198 | "condition_latents": condition_latents,
199 | "cond_temb": cond_temb,
200 | "cond_rotary_emb": cond_rotary_emb,
201 | }
202 | if use_condition
203 | else {}
204 | ),
205 | **ckpt_kwargs,
206 | )
207 |
208 | else:
209 | result = single_block_forward(
210 | block,
211 | model_config=model_config,
212 | hidden_states=hidden_states,
213 | temb=temb,
214 | image_rotary_emb=image_rotary_emb,
215 | **(
216 | {
217 | "condition_latents": condition_latents,
218 | "cond_temb": cond_temb,
219 | "cond_rotary_emb": cond_rotary_emb,
220 | }
221 | if use_condition
222 | else {}
223 | ),
224 | )
225 | if use_condition:
226 | hidden_states, condition_latents = result
227 | else:
228 | hidden_states = result
229 |
230 | # controlnet residual
231 | if controlnet_single_block_samples is not None:
232 | interval_control = len(self.single_transformer_blocks) / len(
233 | controlnet_single_block_samples
234 | )
235 | interval_control = int(np.ceil(interval_control))
236 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
237 | hidden_states[:, encoder_hidden_states.shape[1] :, ...]
238 | + controlnet_single_block_samples[index_block // interval_control]
239 | )
240 |
241 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
242 |
243 | hidden_states = self.norm_out(hidden_states, temb)
244 | output = self.proj_out(hidden_states)
245 |
246 | if USE_PEFT_BACKEND:
247 | # remove `lora_scale` from each PEFT layer
248 | unscale_lora_layers(self, lora_scale)
249 |
250 | if not return_dict:
251 | return (output,)
252 | return Transformer2DModelOutput(sample=output)
253 |
--------------------------------------------------------------------------------
/train/src/train/callbacks.py:
--------------------------------------------------------------------------------
1 | import lightning as L
2 | from PIL import Image, ImageFilter, ImageDraw
3 | import numpy as np
4 | from transformers import pipeline
5 | # import cv2
6 | import torch
7 | import os
8 | from datetime import datetime
9 |
10 | try:
11 | import wandb
12 | except ImportError:
13 | wandb = None
14 |
15 | from ..flux.condition import Condition
16 | from ..flux.generate import generate
17 |
18 |
19 | class TrainingCallback(L.Callback):
20 | def __init__(self, run_name, training_config: dict = {}):
21 | self.run_name, self.training_config = run_name, training_config
22 |
23 | self.print_every_n_steps = training_config.get("print_every_n_steps", 10)
24 | self.save_interval = training_config.get("save_interval", 1000)
25 | self.sample_interval = training_config.get("sample_interval", 1000)
26 | self.save_path = training_config.get("save_path", "./output")
27 |
28 | self.wandb_config = training_config.get("wandb", None)
29 | self.use_wandb = (
30 | wandb is not None and os.environ.get("WANDB_API_KEY") is not None
31 | )
32 |
33 | self.total_steps = 0
34 |
35 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
36 | gradient_size = 0
37 | max_gradient_size = 0
38 | count = 0
39 | for _, param in pl_module.named_parameters():
40 | if param.grad is not None:
41 | gradient_size += param.grad.norm(2).item()
42 | max_gradient_size = max(max_gradient_size, param.grad.norm(2).item())
43 | count += 1
44 | if count > 0:
45 | gradient_size /= count
46 |
47 | self.total_steps += 1
48 |
49 | # Print training progress every n steps
50 | if self.use_wandb:
51 | report_dict = {
52 | "steps": batch_idx,
53 | "steps": self.total_steps,
54 | "epoch": trainer.current_epoch,
55 | "gradient_size": gradient_size,
56 | }
57 | loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches
58 | report_dict["loss"] = loss_value
59 | report_dict["t"] = pl_module.last_t
60 | wandb.log(report_dict)
61 |
62 | if self.total_steps % self.print_every_n_steps == 0:
63 | print(
64 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}"
65 | )
66 |
67 | # Save LoRA weights at specified intervals
68 | if self.total_steps % self.save_interval == 0:
69 | print(
70 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights"
71 | )
72 | pl_module.save_lora(
73 | f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}"
74 | )
75 |
76 | # Generate and save a sample image at specified intervals
77 | if self.total_steps % self.sample_interval == 0:
78 | print(
79 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample"
80 | )
81 | self.generate_a_sample(
82 | trainer,
83 | pl_module,
84 | f"{self.save_path}/{self.run_name}",
85 | f"lora_{self.total_steps}",
86 | batch["condition_type"][
87 | 0
88 | ], # Use the condition type from the current batch
89 | )
90 |
91 | @torch.no_grad()
92 | def generate_a_sample(
93 | self,
94 | trainer,
95 | pl_module,
96 | save_path,
97 | file_name,
98 | condition_type,
99 | ):
100 |
101 | file_name = [
102 | "assets/coffee.png",
103 | "assets/coffee.png",
104 | "assets/coffee.png",
105 | "assets/coffee.png",
106 | "assets/clock.jpg",
107 | "assets/book.jpg",
108 | "assets/monalisa.jpg",
109 | "assets/oranges.jpg",
110 | "assets/penguin.jpg",
111 | "assets/vase.jpg",
112 | "assets/room_corner.jpg",
113 | ]
114 |
115 | test_instruction = [
116 | "Make the image look like it's from an ancient Egyptian mural.",
117 | 'get rid of the coffee bean.',
118 | 'remove the cup.',
119 | "Change it to look like it's in the style of an impasto painting.",
120 | "Make this photo look like a comic book",
121 | "Give this the look of a traditional Japanese woodblock print.",
122 | 'delete the woman',
123 | "Change the image into a watercolor painting.",
124 | "Make it black and white.",
125 | "Make it pop art.",
126 | 'the sofa is leather, and the wall is black',
127 | ]
128 |
129 | pl_module.flux_fill_pipe.transformer.eval()
130 | for i, name in enumerate(file_name):
131 | test_image = Image.open(name)
132 | combined_image = Image.new('RGB', (test_image.size[0] * 2, test_image.size[1]))
133 | combined_image.paste(test_image, (0, 0))
134 | combined_image.paste(test_image, (test_image.size[0], 0))
135 |
136 | mask = Image.new('L', combined_image.size, 0)
137 | draw = ImageDraw.Draw(mask)
138 | draw.rectangle([test_image.size[0], 0, test_image.size[0] * 2, test_image.size[1]], fill=255)
139 | if condition_type == 'edit_n':
140 | prompt_ = "A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left. \n " + test_instruction[i]
141 | else:
142 | prompt_ = "A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left but " + test_instruction[i]
143 |
144 | image = pl_module.flux_fill_pipe(
145 | prompt=prompt_,
146 | image=combined_image,
147 | height=512,
148 | width=1024,
149 | mask_image=mask,
150 | guidance_scale=50,
151 | num_inference_steps=50,
152 | max_sequence_length=512,
153 | generator=torch.Generator("cpu").manual_seed(666)
154 | ).images[0]
155 | image.save(os.path.join(save_path, f'flux-fill-test-{self.total_steps}-{i}-{condition_type}.jpg'))
156 |
157 | pl_module.flux_fill_pipe.transformer.train()
158 |
--------------------------------------------------------------------------------
/train/src/train/data.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageFilter, ImageDraw
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | import torchvision.transforms as T
5 | import random
6 | from io import BytesIO
7 | import glob
8 | from tqdm import tqdm
9 |
10 | class EditDataset_with_Omini(Dataset):
11 | def __init__(
12 | self,
13 | magic_dataset,
14 | omni_dataset,
15 | condition_size: int = 512,
16 | target_size: int = 512,
17 | drop_text_prob: float = 0.1,
18 | return_pil_image: bool = False,
19 | crop_the_noise: bool = True,
20 | ):
21 | self.dataset = [magic_dataset['train'], magic_dataset['dev'], omni_dataset]
22 |
23 | from collections import Counter
24 | tasks = omni_dataset['task']
25 | task_counts = Counter(tasks)
26 | print("\n task type statistic:")
27 | for task, count in task_counts.items():
28 | print(f"{task}: {count} data ({count/len(tasks)*100:.2f}%)")
29 |
30 | self.condition_size = condition_size
31 | self.target_size = target_size
32 | self.drop_text_prob = drop_text_prob
33 | self.return_pil_image = return_pil_image
34 | self.crop_the_noise = crop_the_noise
35 | self.to_tensor = T.ToTensor()
36 |
37 | def __len__(self):
38 | return len(self.dataset[0]) + len(self.dataset[1]) + len(self.dataset[2])
39 |
40 |
41 | def __getitem__(self, idx):
42 | split = 0 if idx < len(self.dataset[0]) else (1 if idx < len(self.dataset[0]) + len(self.dataset[1]) else 2)
43 |
44 | if idx >= len(self.dataset[0]) + len(self.dataset[1]):
45 | idx -= len(self.dataset[0]) + len(self.dataset[1])
46 | elif idx >= len(self.dataset[0]):
47 | idx -= len(self.dataset[0])
48 |
49 | image = self.dataset[split][idx]["source_img" if split != 2 else "src_img"]
50 | instruction = 'A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left but ' + (self.dataset[split][idx]["instruction"] if split != 2 else random.choice(self.dataset[split][idx]["edited_prompt_list"]))
51 | edited_image = self.dataset[split][idx]["target_img" if split != 2 else "edited_img"]
52 |
53 | if self.crop_the_noise and split <= 1:
54 | image = image.crop((0, 0, image.width, image.height - image.height // 32))
55 | edited_image = edited_image.crop((0, 0, edited_image.width, edited_image.height - edited_image.height // 32))
56 |
57 | image = image.resize((self.condition_size, self.condition_size)).convert("RGB")
58 | edited_image = edited_image.resize((self.target_size, self.target_size)).convert("RGB")
59 |
60 | combined_image = Image.new('RGB', (self.condition_size * 2, self.condition_size))
61 | combined_image.paste(image, (0, 0))
62 | combined_image.paste(edited_image, (self.condition_size, 0))
63 |
64 |
65 |
66 | mask = Image.new('L', (self.condition_size * 2, self.condition_size), 0)
67 | draw = ImageDraw.Draw(mask)
68 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255)
69 |
70 | mask_combined_image = combined_image.copy()
71 | draw = ImageDraw.Draw(mask_combined_image)
72 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255)
73 |
74 | if random.random() < self.drop_text_prob:
75 | instruction = " "
76 |
77 | return {
78 | "image": self.to_tensor(combined_image),
79 | "condition": self.to_tensor(mask),
80 | "condition_type": "edit",
81 | "description": instruction,
82 | "position_delta": np.array([0, 0]),
83 | **({"pil_image": [edited_image, combined_image]} if self.return_pil_image else {}),
84 | }
85 |
86 | class OminiDataset(Dataset):
87 | def __init__(
88 | self,
89 | base_dataset,
90 | condition_size: int = 512,
91 | target_size: int = 512,
92 | drop_text_prob: float = 0.1,
93 | return_pil_image: bool = False,
94 | specific_task: list = None,
95 | ):
96 | self.base_dataset = base_dataset['train']
97 | if specific_task is not None:
98 | self.specific_task = specific_task
99 | task_indices = [i for i, task in enumerate(self.base_dataset['task']) if task in self.specific_task]
100 | task_set = set([task for task in self.base_dataset['task']])
101 | ori_len = len(self.base_dataset)
102 | self.base_dataset = self.base_dataset.select(task_indices)
103 | print(specific_task, len(self.base_dataset), ori_len)
104 | print(task_set)
105 |
106 | self.condition_size = condition_size
107 | self.target_size = target_size
108 | self.drop_text_prob = drop_text_prob
109 | self.return_pil_image = return_pil_image
110 | self.to_tensor = T.ToTensor()
111 |
112 | def __len__(self):
113 | return len(self.base_dataset)
114 |
115 | def __getitem__(self, idx):
116 | image = self.base_dataset[idx]["src_img"]
117 | instruction = 'A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left but ' + random.choice(self.base_dataset[idx]["edited_prompt_list"])
118 |
119 | edited_image = self.base_dataset[idx]["edited_img"]
120 |
121 | image = image.resize((self.condition_size, self.condition_size)).convert("RGB")
122 | edited_image = edited_image.resize((self.target_size, self.target_size)).convert("RGB")
123 |
124 | combined_image = Image.new('RGB', (self.condition_size * 2, self.condition_size))
125 | combined_image.paste(image, (0, 0))
126 | combined_image.paste(edited_image, (self.condition_size, 0))
127 |
128 | mask = Image.new('L', (self.condition_size * 2, self.condition_size), 0)
129 | draw = ImageDraw.Draw(mask)
130 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255)
131 |
132 | mask_combined_image = combined_image.copy()
133 | draw = ImageDraw.Draw(mask_combined_image)
134 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255)
135 |
136 | if random.random() < self.drop_text_prob:
137 | instruction = ""
138 |
139 | return {
140 | "image": self.to_tensor(combined_image),
141 | "condition": self.to_tensor(mask),
142 | "condition_type": "edit",
143 | "description": instruction,
144 | "position_delta": np.array([0, 0]),
145 | **({"pil_image": [edited_image, combined_image]} if self.return_pil_image else {}),
146 | }
147 |
148 |
149 | class EditDataset_mask(Dataset):
150 | def __init__(
151 | self,
152 | base_dataset,
153 | condition_size: int = 512,
154 | target_size: int = 512,
155 | drop_text_prob: float = 0.1,
156 | return_pil_image: bool = False,
157 | crop_the_noise: bool = True,
158 | ):
159 | print('THIS IS MAGICBRUSH!')
160 | self.base_dataset = base_dataset
161 | self.condition_size = condition_size
162 | self.target_size = target_size
163 | self.drop_text_prob = drop_text_prob
164 | self.return_pil_image = return_pil_image
165 | self.crop_the_noise = crop_the_noise
166 | self.to_tensor = T.ToTensor()
167 |
168 | def __len__(self):
169 | return len(self.base_dataset['train']) + len(self.base_dataset['dev'])
170 |
171 | def rgba_to_01_mask(image_rgba: Image.Image, reverse: bool = False, return_type: str = "numpy"):
172 | """
173 | Convert an RGBA image to a binary mask with values in the range [0, 1], where 0 represents transparent areas
174 | and 1 represents non-transparent areas. The resulting mask has a shape of (1, H, W).
175 |
176 | :param image_rgba: An RGBA image in PIL format.
177 | :param reverse: If True, reverse the mask, making transparent areas 1 and non-transparent areas 0.
178 | :param return_type: Specifies the return type. "numpy" returns a NumPy array, "PIL" returns a PIL Image.
179 |
180 | :return: The binary mask as a NumPy array or a PIL Image in RGB format.
181 | """
182 | alpha_channel = np.array(image_rgba)[:, :, 3]
183 | image_bw = (alpha_channel != 255).astype(np.uint8)
184 | if reverse:
185 | image_bw = 1 - image_bw
186 | mask = image_bw
187 | if return_type == "numpy":
188 | return mask
189 | else: # return PIL image
190 | mask = Image.fromarray(np.uint8(mask * 255) , 'L').convert('RGB')
191 | return mask
192 |
193 | def __getitem__(self, idx):
194 | split = 'train' if idx < len(self.base_dataset['train']) else 'dev'
195 | idx = idx - len(self.base_dataset['train']) if split == 'dev' else idx
196 | image = self.base_dataset[split][idx]["source_img"]
197 | instruction = 'A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left. \n ' + self.base_dataset[split][idx]["instruction"]
198 | edited_image = self.base_dataset[split][idx]["target_img"]
199 |
200 | if self.crop_the_noise:
201 | image = image.crop((0, 0, image.width, image.height - image.height // 32))
202 | edited_image = edited_image.crop((0, 0, edited_image.width, edited_image.height - edited_image.height // 32))
203 |
204 | image = image.resize((self.condition_size, self.condition_size)).convert("RGB")
205 | edited_image = edited_image.resize((self.target_size, self.target_size)).convert("RGB")
206 |
207 | combined_image = Image.new('RGB', (self.condition_size * 2, self.condition_size))
208 | combined_image.paste(image, (0, 0))
209 | combined_image.paste(edited_image, (self.condition_size, 0))
210 |
211 | mask = Image.new('L', (self.condition_size * 2, self.condition_size), 0)
212 | draw = ImageDraw.Draw(mask)
213 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255)
214 |
215 | mask_combined_image = combined_image.copy()
216 | draw = ImageDraw.Draw(mask_combined_image)
217 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255)
218 |
219 | if random.random() < self.drop_text_prob:
220 | instruction = " \n "
221 | return {
222 | "image": self.to_tensor(combined_image),
223 | "condition": self.to_tensor(mask),
224 | "condition_type": "edit_n",
225 | "description": instruction,
226 | "position_delta": np.array([0, 0]),
227 | **({"pil_image": [edited_image, combined_image]} if self.return_pil_image else {}),
228 | }
229 |
230 | class EditDataset(Dataset):
231 | def __init__(
232 | self,
233 | base_dataset,
234 | condition_size: int = 512,
235 | target_size: int = 512,
236 | drop_text_prob: float = 0.1,
237 | return_pil_image: bool = False,
238 | crop_the_noise: bool = True,
239 | ):
240 | print('THIS IS MAGICBRUSH!')
241 | self.base_dataset = base_dataset
242 | self.condition_size = condition_size
243 | self.target_size = target_size
244 | self.drop_text_prob = drop_text_prob
245 | self.return_pil_image = return_pil_image
246 | self.crop_the_noise = crop_the_noise
247 | self.to_tensor = T.ToTensor()
248 |
249 | def __len__(self):
250 | return len(self.base_dataset['train']) + len(self.base_dataset['dev'])
251 |
252 | def rgba_to_01_mask(image_rgba: Image.Image, reverse: bool = False, return_type: str = "numpy"):
253 | """
254 | Convert an RGBA image to a binary mask with values in the range [0, 1], where 0 represents transparent areas
255 | and 1 represents non-transparent areas. The resulting mask has a shape of (1, H, W).
256 |
257 | :param image_rgba: An RGBA image in PIL format.
258 | :param reverse: If True, reverse the mask, making transparent areas 1 and non-transparent areas 0.
259 | :param return_type: Specifies the return type. "numpy" returns a NumPy array, "PIL" returns a PIL Image.
260 |
261 | :return: The binary mask as a NumPy array or a PIL Image in RGB format.
262 | """
263 | alpha_channel = np.array(image_rgba)[:, :, 3]
264 | image_bw = (alpha_channel != 255).astype(np.uint8)
265 | if reverse:
266 | image_bw = 1 - image_bw
267 | mask = image_bw
268 | if return_type == "numpy":
269 | return mask
270 | else: # return PIL image
271 | mask = Image.fromarray(np.uint8(mask * 255) , 'L').convert('RGB')
272 | return mask
273 |
274 | def __getitem__(self, idx):
275 | split = 'train' if idx < len(self.base_dataset['train']) else 'dev'
276 | idx = idx - len(self.base_dataset['train']) if split == 'dev' else idx
277 | image = self.base_dataset[split][idx]["source_img"]
278 | instruction = 'A diptych with two side-by-side images of the same scene. On the right, the scene is exactly the same as on the left but ' + self.base_dataset[split][idx]["instruction"]
279 | edited_image = self.base_dataset[split][idx]["target_img"]
280 |
281 |
282 | if self.crop_the_noise:
283 | image = image.crop((0, 0, image.width, image.height - image.height // 32))
284 | edited_image = edited_image.crop((0, 0, edited_image.width, edited_image.height - edited_image.height // 32))
285 |
286 | image = image.resize((self.condition_size, self.condition_size)).convert("RGB")
287 | edited_image = edited_image.resize((self.target_size, self.target_size)).convert("RGB")
288 |
289 | combined_image = Image.new('RGB', (self.condition_size * 2, self.condition_size))
290 | combined_image.paste(image, (0, 0))
291 | combined_image.paste(edited_image, (self.condition_size, 0))
292 |
293 | mask = Image.new('L', (self.condition_size * 2, self.condition_size), 0)
294 | draw = ImageDraw.Draw(mask)
295 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255)
296 |
297 | mask_combined_image = combined_image.copy()
298 | draw = ImageDraw.Draw(mask_combined_image)
299 | draw.rectangle([self.condition_size, 0, self.condition_size * 2, self.condition_size], fill=255)
300 |
301 | if random.random() < self.drop_text_prob:
302 | instruction = " "
303 |
304 | return {
305 | "image": self.to_tensor(combined_image),
306 | "condition": self.to_tensor(mask),
307 | "condition_type": "edit",
308 | "description": instruction,
309 | "position_delta": np.array([0, 0]),
310 | **({"pil_image": [edited_image, combined_image]} if self.return_pil_image else {}),
311 | }
312 |
--------------------------------------------------------------------------------
/train/src/train/model.py:
--------------------------------------------------------------------------------
1 | import lightning as L
2 | from diffusers.pipelines import FluxPipeline, FluxFillPipeline
3 | import torch
4 | from peft import LoraConfig, get_peft_model_state_dict
5 | import os
6 | import prodigyopt
7 |
8 | from ..flux.transformer import tranformer_forward
9 | from ..flux.condition import Condition
10 | from ..flux.pipeline_tools import encode_images, encode_images_fill, prepare_text_input
11 |
12 |
13 | class OminiModel(L.LightningModule):
14 | def __init__(
15 | self,
16 | flux_fill_id: str,
17 | lora_path: str = None,
18 | lora_config: dict = None,
19 | device: str = "cuda",
20 | dtype: torch.dtype = torch.bfloat16,
21 | model_config: dict = {},
22 | optimizer_config: dict = None,
23 | gradient_checkpointing: bool = False,
24 | use_offset_noise: bool = False,
25 | ):
26 | # Initialize the LightningModule
27 | super().__init__()
28 | self.model_config = model_config
29 |
30 | self.optimizer_config = optimizer_config
31 |
32 | # Load the Flux pipeline
33 | self.flux_fill_pipe = FluxFillPipeline.from_pretrained(flux_fill_id).to(dtype=dtype).to(device)
34 |
35 | self.transformer = self.flux_fill_pipe.transformer
36 | self.text_encoder = self.flux_fill_pipe.text_encoder
37 | self.text_encoder_2 = self.flux_fill_pipe.text_encoder_2
38 | self.transformer.gradient_checkpointing = gradient_checkpointing
39 | self.transformer.train()
40 | # Freeze the Flux pipeline
41 | self.text_encoder.requires_grad_(False)
42 | self.text_encoder_2.requires_grad_(False)
43 | self.flux_fill_pipe.vae.requires_grad_(False).eval()
44 | self.use_offset_noise = use_offset_noise
45 |
46 | if use_offset_noise:
47 | print('[debug] use OFFSET NOISE.')
48 |
49 | self.lora_layers = self.init_lora(lora_path, lora_config)
50 |
51 | self.to(device).to(dtype)
52 |
53 | def init_lora(self, lora_path: str, lora_config: dict):
54 | assert lora_path or lora_config
55 | if lora_path:
56 | # TODO: Implement this
57 | raise NotImplementedError
58 | else:
59 | self.transformer.add_adapter(LoraConfig(**lora_config))
60 | # TODO: Check if this is correct (p.requires_grad)
61 | lora_layers = filter(
62 | lambda p: p.requires_grad, self.transformer.parameters()
63 | )
64 | return list(lora_layers)
65 |
66 | def save_lora(self, path: str):
67 | FluxFillPipeline.save_lora_weights(
68 | save_directory=path,
69 | transformer_lora_layers=get_peft_model_state_dict(self.transformer),
70 | safe_serialization=True,
71 | )
72 | if self.model_config['use_sep']:
73 | torch.save(self.text_encoder_2.shared, os.path.join(path, "t5_embedding.pth"))
74 | torch.save(self.text_encoder.text_model.embeddings.token_embedding, os.path.join(path, "clip_embedding.pth"))
75 |
76 | def configure_optimizers(self):
77 | # Freeze the transformer
78 | self.transformer.requires_grad_(False)
79 | opt_config = self.optimizer_config
80 |
81 | # Set the trainable parameters
82 | self.trainable_params = self.lora_layers
83 |
84 | # Unfreeze trainable parameters
85 | for p in self.trainable_params:
86 | p.requires_grad_(True)
87 |
88 | # Initialize the optimizer
89 | if opt_config["type"] == "AdamW":
90 | optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
91 | elif opt_config["type"] == "Prodigy":
92 | optimizer = prodigyopt.Prodigy(
93 | self.trainable_params,
94 | **opt_config["params"],
95 | )
96 | elif opt_config["type"] == "SGD":
97 | optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
98 | else:
99 | raise NotImplementedError
100 |
101 | return optimizer
102 |
103 | def training_step(self, batch, batch_idx):
104 | step_loss = self.step(batch)
105 | self.log_loss = (
106 | step_loss.item()
107 | if not hasattr(self, "log_loss")
108 | else self.log_loss * 0.95 + step_loss.item() * 0.05
109 | )
110 | return step_loss
111 |
112 | def step(self, batch):
113 | imgs = batch["image"]
114 | mask_imgs = batch["condition"]
115 | condition_types = batch["condition_type"]
116 | prompts = batch["description"]
117 | position_delta = batch["position_delta"][0]
118 |
119 | with torch.no_grad():
120 | prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input(
121 | self.flux_fill_pipe, prompts
122 | )
123 |
124 | x_0, x_cond, img_ids = encode_images_fill(self.flux_fill_pipe, imgs, mask_imgs, prompt_embeds.dtype, prompt_embeds.device)
125 |
126 | # Prepare t and x_t
127 | t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
128 | x_1 = torch.randn_like(x_0).to(self.device)
129 |
130 | if self.use_offset_noise:
131 | x_1 = x_1 + 0.1 * torch.randn(x_1.shape[0], 1, x_1.shape[2]).to(self.device).to(self.dtype)
132 |
133 | t_ = t.unsqueeze(1).unsqueeze(1)
134 | x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
135 |
136 | # Prepare guidance
137 | guidance = (
138 | torch.ones_like(t).to(self.device)
139 | if self.transformer.config.guidance_embeds
140 | else None
141 | )
142 |
143 | # Forward pass
144 | transformer_out = self.transformer(
145 | hidden_states=torch.cat((x_t, x_cond), dim=2),
146 | timestep=t,
147 | guidance=guidance,
148 | pooled_projections=pooled_prompt_embeds,
149 | encoder_hidden_states=prompt_embeds,
150 | txt_ids=text_ids,
151 | img_ids=img_ids,
152 | joint_attention_kwargs=None,
153 | return_dict=False,
154 | )
155 | pred = transformer_out[0]
156 |
157 | # Compute loss
158 | loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean")
159 | self.last_t = t.mean().item()
160 | return loss
161 |
--------------------------------------------------------------------------------
/train/src/train/train.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | import torch
3 | import lightning as L
4 | import yaml
5 | import os
6 | import random
7 | import time
8 | import numpy as np
9 | from datasets import load_dataset
10 |
11 | from .data import (
12 | EditDataset,
13 | OminiDataset,
14 | EditDataset_with_Omini
15 | )
16 | from .model import OminiModel
17 | from .callbacks import TrainingCallback
18 |
19 |
20 | def get_rank():
21 | try:
22 | rank = int(os.environ.get("LOCAL_RANK"))
23 | except:
24 | rank = 0
25 | return rank
26 |
27 |
28 | def get_config():
29 | config_path = os.environ.get("XFL_CONFIG")
30 | assert config_path is not None, "Please set the XFL_CONFIG environment variable"
31 | with open(config_path, "r") as f:
32 | config = yaml.safe_load(f)
33 | return config
34 |
35 |
36 | def init_wandb(wandb_config, run_name):
37 | import wandb
38 |
39 | try:
40 | assert os.environ.get("WANDB_API_KEY") is not None
41 | wandb.init(
42 | project=wandb_config["project"],
43 | name=run_name,
44 | config={},
45 | )
46 | except Exception as e:
47 | print("Failed to initialize WanDB:", e)
48 |
49 |
50 | def main():
51 | # Initialize
52 | is_main_process, rank = get_rank() == 0, get_rank()
53 | torch.cuda.set_device(rank)
54 | config = get_config()
55 | training_config = config["train"]
56 | run_name = time.strftime("%Y%m%d-%H%M%S")
57 |
58 | seed = 666
59 | np.random.seed(seed)
60 | torch.manual_seed(seed)
61 | torch.cuda.manual_seed_all(seed)
62 | random.seed(seed)
63 |
64 | # Initialize WanDB
65 | wandb_config = training_config.get("wandb", None)
66 | if wandb_config is not None and is_main_process:
67 | init_wandb(wandb_config, run_name)
68 |
69 | print("Rank:", rank)
70 | if is_main_process:
71 | print("Config:", config)
72 |
73 | if 'use_offset_noise' not in config.keys():
74 | config['use_offset_noise'] = False
75 |
76 | # Initialize dataset and dataloader
77 |
78 | if training_config["dataset"]["type"] == "edit":
79 | dataset = load_dataset('osunlp/MagicBrush')
80 | dataset = EditDataset(
81 | dataset,
82 | condition_size=training_config["dataset"]["condition_size"],
83 | target_size=training_config["dataset"]["target_size"],
84 | drop_text_prob=training_config["dataset"]["drop_text_prob"],
85 | )
86 | elif training_config["dataset"]["type"] == "omini":
87 | dataset = load_dataset(training_config["dataset"]["path"])
88 | dataset = OminiDataset(
89 | dataset,
90 | condition_size=training_config["dataset"]["condition_size"],
91 | target_size=training_config["dataset"]["target_size"],
92 | drop_text_prob=training_config["dataset"]["drop_text_prob"],
93 | )
94 |
95 | elif training_config["dataset"]["type"] == "edit_with_omini":
96 | omni = load_dataset("parquet", data_files=os.path.abspath(training_config["dataset"]["path"]), split="train")
97 | magic = load_dataset('osunlp/MagicBrush')
98 | dataset = EditDataset_with_Omini(
99 | magic,
100 | omni,
101 | condition_size=training_config["dataset"]["condition_size"],
102 | target_size=training_config["dataset"]["target_size"],
103 | drop_text_prob=training_config["dataset"]["drop_text_prob"],
104 | )
105 |
106 |
107 | print("Dataset length:", len(dataset))
108 | train_loader = DataLoader(
109 | dataset,
110 | batch_size=training_config["batch_size"],
111 | shuffle=True,
112 | num_workers=training_config["dataloader_workers"],
113 | )
114 |
115 | # Initialize model
116 | trainable_model = OminiModel(
117 | flux_fill_id=config["flux_path"],
118 | lora_config=training_config["lora_config"],
119 | device=f"cuda",
120 | dtype=getattr(torch, config["dtype"]),
121 | optimizer_config=training_config["optimizer"],
122 | model_config=config.get("model", {}),
123 | gradient_checkpointing=training_config.get("gradient_checkpointing", False),
124 | use_offset_noise=config["use_offset_noise"],
125 | )
126 |
127 | # Callbacks for logging and saving checkpoints
128 | training_callbacks = (
129 | [TrainingCallback(run_name, training_config=training_config)]
130 | if is_main_process
131 | else []
132 | )
133 |
134 | # Initialize trainer
135 | trainer = L.Trainer(
136 | accumulate_grad_batches=training_config["accumulate_grad_batches"],
137 | callbacks=training_callbacks,
138 | enable_checkpointing=False,
139 | enable_progress_bar=False,
140 | logger=False,
141 | max_steps=training_config.get("max_steps", -1),
142 | max_epochs=training_config.get("max_epochs", -1),
143 | gradient_clip_val=training_config.get("gradient_clip_val", 0.5),
144 | )
145 |
146 | setattr(trainer, "training_config", training_config)
147 |
148 | # Save config
149 | save_path = training_config.get("save_path", "./output")
150 | if is_main_process:
151 | os.makedirs(f"{save_path}/{run_name}")
152 | with open(f"{save_path}/{run_name}/config.yaml", "w") as f:
153 | yaml.dump(config, f)
154 |
155 | # Start training
156 | trainer.fit(trainable_model, train_loader)
157 |
158 |
159 | if __name__ == "__main__":
160 | main()
161 |
--------------------------------------------------------------------------------
/train/train/config/normal_lora.yaml:
--------------------------------------------------------------------------------
1 | flux_path: "black-forest-labs/flux.1-fill-dev"
2 | dtype: "bfloat16"
3 |
4 | model:
5 | union_cond_attn: true
6 | add_cond_attn: false
7 | latent_lora: false
8 | use_sep: false
9 |
10 | train:
11 | batch_size: 2
12 | accumulate_grad_batches: 1
13 | dataloader_workers: 5
14 | save_interval: 1000
15 | sample_interval: 1000
16 | max_steps: -1
17 | gradient_checkpointing: true
18 | save_path: "runs"
19 |
20 | condition_type: "edit"
21 | dataset:
22 | type: "edit_with_omini"
23 | path: "parquet/*.parquet"
24 | condition_size: 512
25 | target_size: 512
26 | image_size: 512
27 | padding: 8
28 | drop_text_prob: 0.1
29 | drop_image_prob: 0.1
30 | # specific_task: ["removal", "style", "attribute_modification", "env", "swap"]
31 |
32 | wandb:
33 | project: "ICEdit"
34 |
35 | lora_config:
36 | r: 32
37 | lora_alpha: 32
38 | init_lora_weights: "gaussian"
39 | target_modules: "(.*x_embedder|.*(?